tf_1.8_xla_doc
hlo_scheduling.h
Go to the documentation of this file.
1 
3 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
20 
21 #include <vector>
22 
26 #include "tensorflow/compiler/xla/service/logical_buffer.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 
31 namespace xla {
32 
33 // Returns the minimum memory required to compute the given module sequence,
34 // assuming no fragmentation.
35 StatusOr<int64> MinimumMemoryForSequence(
36  const SequentialHloOrdering::HloModuleSequence& module_sequence,
37  const LogicalBuffer::SizeFunction& size_function);
38 
39 // A memory scheduler computes an execution sequence for the HLO instructions in
40 // 'computation' that minimizes peak memory, given a points-to analysis result
41 // that describes buffer aliasing, together with a target-specific size function
42 // that maps a tensor's logical size to its padded size.
43 typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
44  const HloComputation&, const TuplePointsToAnalysis&,
45  const LogicalBuffer::SizeFunction&)>
46  MemorySchedulerAlgorithm;
47 
48 // List scheduler
49 StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
50  const HloComputation& computation,
51  const TuplePointsToAnalysis& points_to_analysis,
52  const LogicalBuffer::SizeFunction& size_function);
53 
54 // DFS-order scheduler
55 StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
56  const HloComputation& computation,
57  const TuplePointsToAnalysis& points_to_analysis,
58  const LogicalBuffer::SizeFunction& size_function);
59 
60 // The default scheduling algorithm. Runs both the list scheduler
61 // and the DFS scheduler, and chooses whichever returns a lower min-memory,
62 // not accounting for fragmentation.
63 StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
64  const HloComputation& computation,
65  const TuplePointsToAnalysis& points_to_analysis,
66  const LogicalBuffer::SizeFunction& size_function);
67 
75 StatusOr<SequentialHloOrdering::HloModuleSequence>
76 CreateMemoryMinimizingSequence(const HloModule& module,
77  const LogicalBuffer::SizeFunction& size_function,
78  const MemorySchedulerAlgorithm& algorithm = {});
79 
80 // Overload of above that computes the sequence for a single computation.
81 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
82  const HloComputation& computation,
83  const LogicalBuffer::SizeFunction& size_function,
84  const MemorySchedulerAlgorithm& algorithm = {});
85 
86 } // namespace xla
87 
88 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
StatusOr< SequentialHloOrdering::HloModuleSequence > CreateMemoryMinimizingSequence(const HloModule &module, const LogicalBuffer::SizeFunction &size_function, const MemorySchedulerAlgorithm &algorithm)
Definition: hlo_scheduling.cc:489
StatusOr< std::vector< const HloInstruction * > > DefaultMemoryScheduler(const HloComputation &computation, const TuplePointsToAnalysis &points_to_analysis, const LogicalBuffer::SizeFunction &size_function)
Definition: hlo_scheduling.cc:442
namespace for xla
Definition: client_library.cc:26
tensorflow::gtl::FlatMap< const HloComputation *, std::vector< const HloInstruction * > > HloModuleSequence
Definition: hlo_ordering.h:208