tf_1.8_xla_doc
computation_tracker.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
15 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
16 #include <list>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
22 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
23 #include "tensorflow/compiler/xla/service/session.pb.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
37 namespace xla {
47  public:
49  // Creates a new UserComputation object and returns the corresponding
50  // ComputationHandle for it.
51  //
52  // Precondition: user_computation is not already present in the map.
53  ComputationHandle NewComputation(const string& computation_name);
54  // Restores session data for a computation that has been serialized, and
55  // allocates a new computation handle for it.
56  StatusOr<ComputationHandle> LoadSessionModule(
57  const SessionModule& session_module);
58  // Snapshots a computation (referenced by the provided handle) at its latest
59  // version, returning a module where it is the entry, and any referred-to
60  // computations are entrained as "embedded" (non-entry) computations.
61  StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
62  const ComputationHandle& computation);
63  // Resolves a ComputationHandle to a UserComputation that is present in the
64  // map.
65  StatusOr<UserComputation*> Resolve(
66  const ComputationHandle& computation) const;
67  // Builds an HLO module using the specified computation as the entry. The
68  // module will include the entry computation as well as all computations which
69  // are called directly or indirectly from the entry computation via operations
70  // like "map". config is the HLO module configuration to use for the
71  // constructed module.
72  // If include_unreachable_instructions is true, then instructions
73  // which are not reachable from the root are lowered into HloInstructions
74  // including unreachable parameters. This ensures the entry HloComputation has
75  // the same program shape (ProgramShape) as the entry UserComputation.
76  StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
77  const VersionedComputationHandle& entry_handle,
78  const HloModuleConfig& config,
79  bool include_unreachable_instructions = true) const;
80  string ToString() const;
81  private:
82  // Bumps the next_computation_ number and returns the allocated number wrapped
83  // in a ComputationHandle.
84  ComputationHandle AllocateHandle()
85  EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
86  // Loads a session computation into a UserComputation, registers it, and
87  // returns the computation handle of the registered computation. If old_to_new
88  // is provided, it is used for remapping references to computations present in
89  // session_computation.
90  //
91  // old_to_new will be updated with the mapping from session_computation's old
92  // handle to the returned handle value, and may not be null.
93  StatusOr<ComputationHandle> LoadSessionComputation(
94  const SessionComputation& session_computation,
95  std::map<int64, ComputationHandle>* old_to_new)
96  EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
97  // Internal implementation of Resolve method which requires, but does not
98  // acquire the mutex.
99  StatusOr<UserComputation*> ResolveInternal(
100  const ComputationHandle& computation) const
101  EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
102  // Builds a post order sort of a computation ("entry") and all of its embedded
103  // computations including all transitively embedded computations. An embedded
104  // computation (the callee) will always appear in the sort before the
105  // computation which calls the embedded computation (the caller). Necessarily,
106  // the entry computation is the last element in the sort. visited and
107  // post_order should be empty when calling. post_order contains the post order
108  // sort when the function return.
109  void ComputeComputationPostOrder(
110  const VersionedComputationHandle& versioned_handle,
111  std::set<VersionedComputationHandle>* visited,
112  std::list<VersionedComputationHandle>* post_order) const
113  EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
114  string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
115  // Guards the computation mapping. Marked mutable so that the Resolve method
116  // can remain const; Resolve does't really modify the tracker in any way, but
117  // it has to lock the mutex for safety.
118  mutable tensorflow::mutex computation_mutex_;
119  // The next sequence number to assign to a computation, guarded by the same
120  // mutex as the mapping as they'll be mutated at the same time.
121  int64 next_computation_ GUARDED_BY(computation_mutex_);
122  // Mapping from ComputationHandle value to the corresponding registered
123  // UserComputation object.
124  std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
125  GUARDED_BY(computation_mutex_);
126  TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
127 };
128 } // namespace xla
129 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
130 
ComputationHandle NewComputation(const string &computation_name)
Create a new UserComputation object and return the corresponding ComputationHandle for it...
Definition: computation_tracker.cc:32
StatusOr< std::unique_ptr< HloModule > > BuildHloModule(const VersionedComputationHandle &entry_handle, const HloModuleConfig &config, bool include_unreachable_instructions=true) const
Build a HLO module (which is basically a set of HLO instructions) using the specified computation as ...
Definition: computation_tracker.cc:169
Definition: versioned_computation_handle.h:37
StatusOr< std::unique_ptr< SessionModule > > SnapshotComputation(const ComputationHandle &computation)
Definition: computation_tracker.cc:84
StatusOr< UserComputation * > Resolve(const ComputationHandle &computation) const
Definition: computation_tracker.cc:110
Tracks computations for the XLA service. Registered with a xla::UserComputation instance and can be r...
Definition: computation_tracker.h:46
namespace for xla
Definition: client_library.cc:26