1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ |
18 | |
19 | #include <map> |
20 | |
21 | #include "tensorflow/compiler/xla/service/hlo_module.h" |
22 | #include "tensorflow/compiler/xla/service/session.pb.h" |
23 | #include "tensorflow/compiler/xla/service/user_computation.h" |
24 | #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" |
25 | #include "tensorflow/compiler/xla/status.h" |
26 | #include "tensorflow/compiler/xla/statusor.h" |
27 | #include "tensorflow/compiler/xla/types.h" |
28 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
29 | #include "tensorflow/core/lib/gtl/array_slice.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/thread_annotations.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | |
35 | namespace xla { |
36 | |
37 | // Tracks channels between computations in the XLA service. Channels |
38 | // are associated with a unique handle and can be resolved from the handle for |
39 | // later use. |
40 | // |
41 | // TODO(b/34027823): Destruct channels when all the associated computations that |
42 | // communicate via each channel are destructed. |
43 | class ChannelTracker { |
44 | public: |
45 | ChannelTracker(); |
46 | |
47 | // A struct that keeps the current status of each channel. has_sender and |
48 | // receiver_count fields are initialized with false and 0 respectively when |
49 | // the struct is created and are updated by RegisterSend() and RegisterRecev() |
50 | // as Send or Recv instructions using the channel are requested. |
51 | struct Channel { |
52 | bool has_sender; |
53 | int64 receiver_count; |
54 | }; |
55 | |
56 | // Creates a new Channel object and returns the corresponding |
57 | // ChannelHandle for it. |
58 | ChannelHandle NewChannel(); |
59 | |
60 | // Informs that the given channel handle is used for a Send operation. |
61 | // Returns an error status if the handle is already used by another Send. |
62 | Status RegisterSend(const ChannelHandle& handle); |
63 | |
64 | // Informs that the given channel handle is used for a Recv operation. |
65 | // Returns an error status if the handle is already used by another Recv. |
66 | Status RegisterRecv(const ChannelHandle& handle); |
67 | |
68 | private: |
69 | // Bumps the next_channel_ number and returns the allocated number |
70 | // wrapped in a ChannelHandle. |
71 | ChannelHandle AllocateHandle() EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); |
72 | |
73 | Status RegisterSendInternal(const ChannelHandle& handle) |
74 | EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); |
75 | |
76 | Status RegisterRecvInternal(const ChannelHandle& handle) |
77 | EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); |
78 | |
79 | // Guards the channel mapping. |
80 | tensorflow::mutex channel_mutex_; |
81 | |
82 | // The next sequence number to assign to a channel. |
83 | int64 next_channel_ GUARDED_BY(channel_mutex_); |
84 | |
85 | // Mapping from ChannelHandle value to the corresponding registered |
86 | // Channel object. |
87 | std::map<int64, Channel> opaque_to_channel_ GUARDED_BY(channel_mutex_); |
88 | |
89 | TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); |
90 | }; |
91 | |
92 | } // namespace xla |
93 | |
94 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_ |
95 |