1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18
19#include <functional>
20#include <iosfwd>
21#include <memory>
22#include <string>
23#include <vector>
24
25#include "tensorflow/compiler/xla/service/buffer_liveness.h"
26#include "tensorflow/compiler/xla/service/heap_simulator.h"
27#include "tensorflow/compiler/xla/service/hlo.pb.h"
28#include "tensorflow/compiler/xla/service/hlo_computation.h"
29#include "tensorflow/compiler/xla/service/hlo_instruction.h"
30#include "tensorflow/compiler/xla/service/hlo_module.h"
31#include "tensorflow/compiler/xla/service/logical_buffer.h"
32#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
33#include "tensorflow/compiler/xla/statusor.h"
34#include "tensorflow/compiler/xla/types.h"
35#include "tensorflow/compiler/xla/xla_data.pb.h"
36#include "tensorflow/core/lib/gtl/array_slice.h"
37#include "tensorflow/core/lib/gtl/flatmap.h"
38#include "tensorflow/core/lib/gtl/flatset.h"
39#include "tensorflow/core/platform/logging.h"
40#include "tensorflow/core/platform/macros.h"
41#include "tensorflow/core/platform/types.h"
42
43namespace xla {
44
45// This class abstracts an allocation of contiguous memory which can hold the
46// values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
47// of the allocation, represented by a Slice. A single BufferAllocation may hold
48// LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
49// single BufferAllocation may also hold LogicalBuffers with overlapping
50// liveness, which must have disjoint Slices.
51//
52// The abstraction includes information required by the backends for allocation,
53// use, and deallocation of the buffer. This includes the LogicalBuffers which
54// are held in this allocation through the execution of the computation.
55class BufferAllocation {
56 public:
57 // Holds a unique identifier for each allocation. Values are assigned
58 // contiguously and can be used as array indexes.
59 using Index = int64;
60
61 BufferAllocation(Index index, int64 size, bool is_thread_local,
62 bool is_reusable, LogicalBuffer::Color color)
63 : index_(index),
64 size_(size),
65 is_thread_local_(is_thread_local),
66 is_reusable_(is_reusable),
67 color_(color) {}
68 ~BufferAllocation() {}
69
70 // Returns the index of this allocation.
71 Index index() const { return index_; }
72
73 // Whether this allocation is used in a parallel calling context such as
74 // inside of a map or reduce computation. Such allocations need to be thread
75 // local.
76 bool is_thread_local() const { return is_thread_local_; }
77
78 // Whether this allocation can be used by more than one logical buffer.
79 bool is_reusable() const { return is_reusable_; }
80
81 // Whether this allocation holds a LogicalBuffer from a parameter of the entry
82 // computation. These buffers have lifetimes which may be longer than the
83 // XLA computation.
84 bool is_entry_computation_parameter() const {
85 return is_entry_computation_parameter_;
86 }
87 // If this allocation holds a Buffer from a parameter of the entry
88 // computation, this methods returns the parameter number. CHECKs otherwise.
89 int64 parameter_number() const {
90 CHECK(is_entry_computation_parameter_);
91 return parameter_number_;
92 }
93
94 // If this allocation is for a parameter of the entry computation, this
95 // function returns which subshape of the parameter the allocation is for.
96 const ShapeIndex& param_shape_index() const {
97 CHECK(is_entry_computation_parameter_);
98 return param_shape_index_;
99 }
100
101 // Returns whether this allocation is assigned a LogicalBuffer which may
102 // be live out of the entry computation.
103 bool maybe_live_out() const { return maybe_live_out_; }
104
105 // Returns the size of the allocation. Necessarily this must be at least as
106 // large as any LogicalBuffer assigned to this allocation.
107 int64 size() const { return size_; }
108
109 // Returns the color of the allocation. Only logical buffers with a matching
110 // color can reside in this allocation.
111 LogicalBuffer::Color color() const { return color_; }
112
113 struct OffsetSize {
114 int64 offset = 0;
115 int64 size = 0;
116 };
117
118 // Access to the logical buffers assigned to this allocation, and their
119 // associated logical offsets and sizes.
120 const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>&
121 assigned_buffers() const {
122 return assigned_buffers_;
123 }
124
125 // A Slice represents a contiguous portion of a memory allocation. It is used
126 // to identify the memory range that a LogicalBuffer corresponds to.
127 class Slice {
128 public:
129 Slice() {}
130 Slice(const BufferAllocation* allocation, int64 offset, int64 size)
131 : allocation_(allocation), offset_(offset), size_(size) {}
132
133 const BufferAllocation* allocation() const { return allocation_; }
134 Index index() const { return allocation_->index(); }
135 int64 offset() const { return offset_; }
136 int64 size() const { return size_; }
137
138 bool operator==(const Slice& other) const {
139 return index() == other.index() && offset_ == other.offset_ &&
140 size_ == other.size_;
141 }
142 bool operator!=(const Slice& other) const { return !(*this == other); }
143 bool operator<(const Slice& other) const {
144 if (index() != other.index()) return index() < other.index();
145 if (offset_ != other.offset_) return offset_ < other.offset_;
146 return size_ < other.size_;
147 }
148
149 // Returns true iff this slice's memory range has a non-empty intersection
150 // with the other slice's memory range.
151 bool OverlapsWith(const Slice& other) const {
152 const int64 end = offset_ + size_;
153 const int64 other_end = other.offset_ + other.size_;
154 return index() == other.index() && offset_ < other_end &&
155 end > other.offset_;
156 }
157
158 struct Hasher {
159 size_t operator()(Slice s) const;
160 };
161
162 string ToString() const;
163
164 private:
165 const BufferAllocation* allocation_ = nullptr;
166 int64 offset_ = 0;
167 int64 size_ = 0;
168 };
169
170 // GetSlice returns the Slice of contiguous memory that holds the value
171 // described by the given 'buffer'.
172 // REQUIRES: 'buffer' must be assigned to this allocation.
173 Slice GetSlice(const LogicalBuffer& buffer) const;
174
175 string ToString() const;
176 BufferAllocationProto ToProto() const;
177
178 // Whether the buffer is a parameter to or live out of the entry computation.
179 bool IsInputOrOutput() const {
180 return is_entry_computation_parameter() || maybe_live_out();
181 }
182
183 // Whether the buffer is a temporary buffer allocated before
184 // Executable::ExecuteOnStream.
185 bool IsPreallocatedTempBuffer() const {
186 // Parameters do not need temporary buffers.
187 return !is_entry_computation_parameter() &&
188 // LogicalBuffers that maybe pointed to by the output should live out
189 // of the computation.
190 !maybe_live_out() &&
191 // Thread-local buffers are allocated using `alloca`s.
192 !is_thread_local();
193 }
194
195 // Add a heap trace which was used to assign slices to logical buffers in this
196 // allocation. A single BufferAllocation may include multiple heap traces
197 // in the case of the temporary block where there is a heap trace per
198 // computation.
199 void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
200 heap_traces_.push_back(heap_trace);
201 }
202
203 // Return the set of heap traces used to assign slices to logical buffers in
204 // this allocation.
205 const std::vector<HeapSimulatorTrace> HeapTraces() const {
206 return heap_traces_;
207 }
208
209 // Compute and return the LogicalBuffers which are live at the point of peak
210 // memory usage for the given allocation. The point of peak memory usage is
211 // the point at which the total size of all live logical buffers is
212 // maximal. If peak memory is reached at multiple points, the set of logical
213 // buffers live at the earliest maximal point is returned. The vector is
214 // stabily asserted by LogicalBuffer::Index.
215 //
216 // The return value is a pair of total size of the logical buffers at peak,
217 // and the buffers themselves.
218 std::pair<int64, std::vector<const LogicalBuffer*>>
219 ComputePeakMemoryLogicalBuffers() const;
220
221 // Get the number of bytes lost to fragmentation. This is equal to the
222 // difference between the size of the allocation and the size of the maximal
223 // live set.
224 int64 fragmentation_bytes() const { return fragmentation_bytes_; }
225
226 bool operator==(const BufferAllocation& other) const {
227 return index_ == other.index_;
228 }
229 bool operator!=(const BufferAllocation& other) const {
230 return !(*this == other);
231 }
232 bool operator<(const BufferAllocation& other) const {
233 return index() < other.index();
234 }
235
236 private:
237 // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
238 friend class BufferAssigner;
239 friend class BufferAssignment;
240
241 // Adds a LogicalBuffer to the set assigned to this buffer.
242 void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size);
243
244 void set_entry_computation_parameter(int64 parameter_number,
245 ShapeIndex param_shape_index) {
246 is_entry_computation_parameter_ = true;
247 parameter_number_ = parameter_number;
248 param_shape_index_ = std::move(param_shape_index);
249 }
250 void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
251 void set_index(Index index) { index_ = index; }
252 void set_size(int64 size) { size_ = size; }
253
254 // The index of the allocation in the BufferAssignment.
255 Index index_;
256
257 // Size of the allocation in bytes.
258 int64 size_;
259
260 // Whether this buffer needs to be thread-local.
261 bool is_thread_local_;
262
263 // Whether this buffer is usable by more than one logical buffer.
264 bool is_reusable_;
265
266 // Color of the allocation.
267 LogicalBuffer::Color color_;
268
269 // Whether this allocation holds an entry computation parameter. Entry
270 // computation parameters are special be cause they have lifetimes which may
271 // outlast the computation.
272 bool is_entry_computation_parameter_ = false;
273
274 // If this allocation holds an entry computation parameter, this field
275 // indicates the index (starting from 0) of the parameter.
276 int64 parameter_number_ = 0;
277
278 // If this buffer is for an entry computation parameter, which subshape of the
279 // parameter is it for?
280 ShapeIndex param_shape_index_;
281
282 // Whether the allocation contains a LogicalBuffer which may be live-out of
283 // the entry computation. Note that this flag is conservatively computed by
284 // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_`
285 // might not actually escape.
286 bool maybe_live_out_ = false;
287
288 // Mapping from the set of buffers assigned to this allocation to their
289 // logical offsets and sizes.
290 tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
291
292 int64 fragmentation_bytes_ = 0;
293 std::vector<HeapSimulatorTrace> heap_traces_;
294};
295
296// Add stream operators for nicer output of CHECK/RET_CHECK failures.
297std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
298std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
299
300// This class encapsulates an assignment of the LogicalBuffers in an XLA
301// module to a set of BufferAllocations.
302class BufferAssignment {
303 public:
304 // Returns the vector containing all buffer allocations in this assignment.
305 const std::vector<BufferAllocation>& Allocations() const {
306 return allocations_;
307 }
308
309 // Returns the total size allocation holding all temporary buffers.
310 int64 temp_allocation_total_size() const {
311 return temp_allocation_total_size_;
312 }
313
314 // Returns whether the given buffer has been assigned an allocation.
315 bool HasAllocation(const LogicalBuffer& buffer) const;
316
317 // Returns the allocation that a particular LogicalBuffer has been assigned
318 // to. CHECKs if buffer has not been assigned an allocation.
319 const BufferAllocation& GetAssignedAllocation(
320 const LogicalBuffer& buffer) const;
321
322 // Returns the allocation with the given index. CHECKs if no allocation exists
323 // with the given index.
324 const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
325
326 // Builds and returns a vector containing the slices which might contain the
327 // subvalue at the given index of given instruction.
328 std::set<BufferAllocation::Slice> GetAllSlices(
329 const HloInstruction* instruction, const ShapeIndex& index) const;
330
331 // Convenience function which returns whether the buffer of the
332 // instruction at the given index is assigned an allocation.
333 bool HasAllocationAt(const HloInstruction* instruction,
334 const ShapeIndex& index) const;
335
336 // Convenience function which returns whether the top-level buffer of the
337 // instruction (index == {}) is assigned an allocation.
338 bool HasTopLevelAllocation(const HloInstruction* instruction) const;
339
340 // Convenience function which returns the unique slice containing the buffer
341 // at the given index of the given instruction. If a slice is not assigned or
342 // the slice cannot be determined at compile time then an error is returned.
343 StatusOr<BufferAllocation::Slice> GetUniqueSlice(
344 const HloInstruction* instruction, const ShapeIndex& index) const;
345 // Like GetUniqueSlice but fixes the index to the top-level of the shape
346 // (index = {}).
347 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
348 const HloInstruction* instruction) const;
349 // Like GetUniqueTopLevelSlice but returns the slice for the output of the
350 // entry computation of the HLO module (ie, the result of the XLA
351 // computation).
352 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
353
354 // Returns the set LogicalBuffers which may be the source of the value at the
355 // given index and instruction.
356 const PointsToSet::BufferList& GetSourceBuffers(
357 const HloInstruction* instruction, const ShapeIndex& index) const {
358 return GetPointsToSet(instruction).element(index);
359 }
360
361 // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
362 // share the same BufferAllocation::Slice.
363 // Returns false otherwise.
364 // REQUIRES: BufferAssignment assigned allocations to both instructions.
365 bool SharesSliceAtIndex(const HloInstruction* hlo_a,
366 const ShapeIndex& shape_index_a,
367 const HloInstruction* hlo_b,
368 const ShapeIndex& shape_index_b) const;
369
370 // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
371 // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
372 bool SharesTopLevelSlice(const HloInstruction* hlo_a,
373 const HloInstruction* hlo_b) const {
374 return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
375 }
376
377 // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
378 // their top-level and each of their nested shape indices, and if hlo_a's
379 // buffers are all different from hlo_b's buffers.
380 bool HaveDisjointSlices(const HloInstruction* hlo_a,
381 const HloInstruction* hlo_b) const;
382
383 // Returns the underlying points-to analysis used for this assignment.
384 const TuplePointsToAnalysis& points_to_analysis() const {
385 return liveness_->points_to_analysis();
386 }
387
388 // Returns the BufferLiveness object used to construct this assignment.
389 const BufferLiveness& liveness() const { return *liveness_; }
390
391 string ToString() const;
392 BufferAssignmentProto ToProto() const;
393
394 // Statistics for the assignment. Values initialized to -1 are not always
395 // collected; fragmentation is only collected for instructions that have a
396 // sequential total ordering.
397 struct Stats {
398 int64 parameter_allocation_count = 0;
399 int64 parameter_allocation_bytes = 0;
400 int64 maybe_live_out_allocation_count = 0;
401 int64 maybe_live_out_allocation_bytes = 0;
402 int64 preallocated_temp_allocation_count = 0;
403 int64 preallocated_temp_allocation_bytes = 0;
404 int64 preallocated_temp_fragmentation_bytes = -1;
405 int64 total_allocation_count = 0;
406 int64 total_allocation_bytes = 0;
407 int64 total_fragmentation_bytes = -1;
408
409 string ToString() const;
410 };
411 const Stats& GetStats() const { return stats_; }
412
413 private:
414 // Only BufferAssigner can build or modify BufferAssignments.
415 friend class BufferAssigner;
416
417 explicit BufferAssignment(const HloModule* module,
418 std::unique_ptr<BufferLiveness> liveness,
419 LogicalBuffer::SizeFunction buffer_size,
420 LogicalBuffer::AlignmentFunction color_alignment)
421 : module_(module),
422 liveness_(std::move(liveness)),
423 buffer_size_(std::move(buffer_size)),
424 color_alignment_(std::move(color_alignment)) {}
425
426 // Creates and returns a new BufferAllocation, with no assigned
427 // LogicalBuffers. Ownership is maintained internally.
428 BufferAllocation* NewEmptyAllocation(int64 size, bool is_thread_local,
429 bool is_reusable,
430 LogicalBuffer::Color color);
431
432 // Helper that calls NewEmptyAllocation and AddAssignment in one call,
433 // creating an allocation containing a single LogicalBuffer.
434 BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size,
435 bool is_thread_local, bool is_reusable);
436
437 // Adds a LogicalBuffer to the set assigned to the given allocation.
438 void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer,
439 int64 offset, int64 size);
440
441 // Returns the HloModule used to construct this assignment.
442 const HloModule& module() const { return *module_; }
443
444 // Convenience function which returns the PointsToSet for the given
445 // instruction. Extracted from the liveness object.
446 const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const;
447
448 // Mutable accessors for allocations.
449 BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer);
450 BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
451
452 // Combines allocations of temporary buffers into one big BufferAllocation.
453 void CombineTempAllocations();
454
455 // Computes stats for the assignment, to be retrieved by GetStats.
456 Status ComputeSummaryStats();
457
458 // The vector of buffer allocations. Indexed by BufferAllocation::Index.
459 std::vector<BufferAllocation> allocations_;
460
461 // The total size of all temporary buffers.
462 int64 temp_allocation_total_size_ = 0;
463
464 // Maps Buffers to the index of the BufferAllocation which holds the buffer.
465 tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index>
466 allocation_index_for_buffer_;
467
468 const HloModule* module_;
469 const std::unique_ptr<BufferLiveness> liveness_;
470
471 // Function which returns the buffer size for a given logical buffer (shape).
472 LogicalBuffer::SizeFunction buffer_size_;
473
474 // Function which returns the alignment for a given logical buffer color.
475 LogicalBuffer::AlignmentFunction color_alignment_;
476
477 Stats stats_;
478
479 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
480};
481
482// A class which constructs a buffer assignment.
483class BufferAssigner {
484 public:
485 // Build and return a BufferAssignment for the given module. The given
486 // HloOrdering is used to determine buffer liveness. buffer_size and
487 // color_alignment are functions which returns the size and alignment of a
488 // LogicalBuffer. allow_input_output_aliasing specifies whether input buffer
489 // are allowed to be reused as outbut buffers by the client code.
490 static StatusOr<std::unique_ptr<BufferAssignment>> Run(
491 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
492 LogicalBuffer::SizeFunction buffer_size,
493 LogicalBuffer::AlignmentFunction color_alignment,
494 bool allow_input_output_aliasing = false,
495 BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer());
496
497 private:
498 BufferAssigner(bool allow_input_output_aliasing,
499 BufferLiveness::Colorer colorer)
500 : allow_input_output_aliasing_(allow_input_output_aliasing),
501 colorer_(colorer) {}
502 virtual ~BufferAssigner() = default;
503
504 // Create a buffer assignment.
505 StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
506 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
507 LogicalBuffer::SizeFunction buffer_size,
508 LogicalBuffer::AlignmentFunction color_alignment);
509
510 // Assigns buffers to the instructions in the given computation. "assignment"
511 // is modified to reflect the new buffer assignments. If is_thread_local is
512 // true, then all assigned buffers have the is_thread_local flag set to
513 // true.
514 Status AssignBuffersForComputation(
515 const HloComputation* computation, const DebugOptions& debug_options,
516 bool is_thread_local,
517 const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
518 const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
519 colocated_allocations,
520 tensorflow::gtl::FlatMap<const HloComputation*,
521 tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
522 buffers_to_assign_sequentially,
523 BufferAssignment* assignment);
524
525 // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
526 // the HLO instructions will be executed in the sequential order given by
527 // assignment->liveness().hlo_ordering().SequentialOrder. If
528 // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
529 // assuming all global computations are sequentially ordered.
530 Status AssignBuffersWithSequentialOrdering(
531 const tensorflow::gtl::FlatMap<
532 const HloComputation*,
533 tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
534 buffers_to_assign_sequentially,
535 bool run_whole_module_heap_simulation, BufferAssignment* assignment);
536
537 // Uses the results of the heap simulator to create a single allocation, with
538 // LogicalBuffers packed to specific offsets.
539 void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
540 BufferAssignment* assignment,
541 LogicalBuffer::Color color);
542
543 // Tries to assign the given instruction to the given buffer. Returns if the
544 // assignment was successful.
545 bool MaybeAssignBuffer(BufferAllocation* allocation,
546 const LogicalBuffer& buffer,
547 BufferAssignment* assignment);
548
549 // Colocated buffers are logical buffers from different computations which
550 // alias. Explicitly handling these colocated buffers is necessary because
551 // points-to analysis is computation level scope and does not recognize
552 // aliasing across computations (b/32491382).
553 using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
554
555 // Returns a vector of ColocatedBufferSet objects, where each
556 // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
557 // which should be colocated in the same buffer allocation.
558 void BuildColocatedBufferSets(
559 const HloModule* module, const BufferLiveness& buffer_liveness,
560 const LogicalBuffer::SizeFunction& buffer_size,
561 std::vector<ColocatedBufferSet>* colocated_buffer_sets);
562
563 // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
564 // same set to the same buffer allocation in 'assignment'.
565 void AssignColocatedBufferSets(
566 const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
567 BufferAssignment* assignment,
568 tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
569 tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
570
571 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
572 // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
573 void AddSetToColocatedBufferSets(
574 const std::vector<const LogicalBuffer*>& colocated_set,
575 std::vector<ColocatedBufferSet>* colocated_buffer_sets);
576
577 // Given a list of colocated buffer sets (each colocated buffer set represents
578 // the logical buffers that would be assigned to the same physical buffer),
579 // try to merge the sets if the buffers can be shared. Returns the merged set.
580 std::vector<ColocatedBufferSet> MergeColocatedBufferSets(
581 const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
582 const BufferLiveness& buffer_liveness,
583 const LogicalBuffer::SizeFunction& buffer_size);
584
585 // Split a set of buffers into several sets, each of which contains buffers
586 // colored with the same color.
587 tensorflow::gtl::FlatMap<LogicalBuffer::Color,
588 tensorflow::gtl::FlatSet<const LogicalBuffer*>,
589 LogicalBuffer::Color::Hasher>
590 SplitBuffersByColor(
591 const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
592
593 // If true, buffer assignments assumes that input parameter buffers and output
594 // buffers can be shared if their sizes match.
595 bool allow_input_output_aliasing_;
596
597 // Functor used to assign colors to newly allocated logical buffers.
598 BufferLiveness::Colorer colorer_;
599
600 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
601};
602
603} // namespace xla
604
605#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
606