18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 28 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 31 #include "tensorflow/compiler/xla/service/logical_buffer.h" 32 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h" 33 #include "tensorflow/compiler/xla/shape_tree.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/lib/gtl/array_slice.h" 39 #include "tensorflow/core/lib/gtl/compactptrset.h" 40 #include "tensorflow/core/lib/gtl/flatmap.h" 41 #include "tensorflow/core/lib/gtl/flatset.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/types.h" 58 explicit PointsToSet(
const Shape* shape) : tree_(shape) {}
62 bool IsAmbiguous()
const;
66 bool IsDistinct()
const;
74 using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
75 BufferSet CreateFlattenedSet()
const;
79 bool ContainsBufferAtIndex(
const LogicalBuffer& buffer,
80 const ShapeIndex& index)
const;
83 bool ContainsBuffer(
const LogicalBuffer& buffer)
const;
87 void AddPointedToBuffer(
const LogicalBuffer& buffer,
const ShapeIndex& index);
108 using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
109 const SourceSet& tuple_sources(
const ShapeIndex& index)
const;
112 void add_tuple_source(
const ShapeIndex& index, HloInstruction* tuple);
114 using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
117 const BufferList& element(
const ShapeIndex& index)
const {
118 return tree_.element(index).buffers;
120 BufferList* mutable_element(
const ShapeIndex& index) {
121 return &tree_.mutable_element(index)->buffers;
125 template <
typename Fn>
126 void ForEachElement(
const Fn& fn)
const {
127 tree_.ForEachElement([&fn](
const ShapeIndex& index,
const Elem& elem) {
128 fn(index, elem.buffers);
131 template <
typename Fn>
132 void ForEachMutableElement(
const Fn& fn) {
133 tree_.ForEachMutableElement([&fn](
const ShapeIndex& index, Elem* elem) {
134 fn(index, &elem->buffers);
137 template <
typename Fn>
138 Status ForEachElementWithStatus(
const Fn& fn)
const {
139 return tree_.ForEachElementWithStatus(
140 [&fn](
const ShapeIndex& index,
const Elem& elem) {
141 return fn(index, elem.buffers);
148 SourceSet tuple_sources;
150 ShapeTree<Elem> tree_;
154 TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet);
162 BufferAlias(HloInstruction* instruction,
const ShapeIndex& index)
163 : instruction_(instruction), index_(index) {}
166 HloInstruction* instruction()
const {
return instruction_; }
167 const ShapeIndex& index()
const {
return index_; }
169 bool operator==(
const BufferAlias& other)
const {
170 return instruction_ == other.instruction_ && index_ == other.index_;
172 bool operator!=(
const BufferAlias& other)
const {
return !(*
this == other); }
174 string ToString()
const;
177 HloInstruction* instruction_;
181 std::ostream& operator<<(std::ostream& out,
const BufferAlias& buffer_alias);
191 static StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(
196 const PointsToSet& GetPointsToSet(
200 const LogicalBuffer& GetBuffer(LogicalBuffer::Id
id)
const;
204 StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
205 const HloInstruction* instruction,
const ShapeIndex& index)
const;
211 using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>;
212 const BufferAliasVector& GetBufferAliases(
const LogicalBuffer& buffer)
const;
215 LogicalBuffer::Id num_logical_buffers()
const {
216 return logical_buffer_analysis_->num_logical_buffers();
226 LogicalBuffer& logical_buffer(LogicalBuffer::Id
id)
const {
227 return logical_buffer_analysis_->GetBuffer(
id);
234 using BufferDefinitionVector =
235 tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
236 const BufferDefinitionVector& GetBuffersDefinedByInstruction(
240 bool InstructionDefinesBufferAtIndex(
const HloInstruction* instruction,
241 const ShapeIndex& index)
const;
250 Status VerifyBuffer(
const LogicalBuffer& buffer)
const;
254 Status HandleGetTupleElement(
HloInstruction* get_tuple_element)
override;
262 string ToString()
const;
267 std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)
269 logical_buffer_analysis_(std::move(logical_buffer_analysis)) {}
277 Status PopulateDefinedBuffersAndAliases(
const decltype(
278 std::declval<HloComputation>().instructions())& instructions);
282 PointsToSet& CreateEmptyPointsToSet(
const HloInstruction* instruction);
286 PointsToSet& CreateCopiedPointsToSet(
const HloInstruction* instruction,
290 Status GatherBuffersDefinedByInstruction(
const HloInstruction* instruction,
291 BufferDefinitionVector* buffers);
295 string* output)
const;
298 struct PerInstruction {
299 std::unique_ptr<PointsToSet> points_to_set;
302 BufferDefinitionVector instruction_defined_buffers;
306 int id = inst->unique_id();
308 DCHECK_LT(
id, per_instruction_.size());
309 return &per_instruction_[id];
312 int id = inst->unique_id();
314 DCHECK_LT(
id, per_instruction_.size());
315 return &per_instruction_[id];
322 const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;
325 std::vector<PerInstruction> per_instruction_;
329 std::vector<BufferAliasVector> logical_buffer_aliases_;
336 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ static StatusOr< std::unique_ptr< TuplePointsToAnalysis > > Run(const HloModule *module)
Definition: tuple_points_to_analysis.cc:149
Definition: tuple_points_to_analysis.h:188
Definition: hlo_instruction.h:165
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52