tf_1.8_xla_doc
client_library.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 // The "client library" instantiates a local (in-process) XLA service for
19 // use by this process, and connects to it with a singleton XLA local
20 // client. ClientLibrary::GetOrCreateLocalClient will spawn a local service,
21 // and return a client that's connected to it and ready to run XLA
22 // computations.
23 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
24 #define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
25 
26 #include <functional>
27 #include <memory>
28 #include <string>
29 #include <vector>
30 
32 #include "tensorflow/compiler/xla/client/local_client.h"
34 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
35 #include "tensorflow/compiler/xla/service/local_service.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 
46 namespace xla {
47 
48 // Options to configure the local client when it is created.
49 class LocalClientOptions {
50  public:
51  LocalClientOptions(perftools::gputools::Platform* platform = nullptr,
52  int number_of_replicas = 1,
53  int intra_op_parallelism_threads = -1);
54 
55  // Set the platform backing the service, or nullptr for the default platform.
56  LocalClientOptions& set_platform(perftools::gputools::Platform* platform);
57  perftools::gputools::Platform* platform() const;
58 
59  // Set the number of replicas to use when compiling replicated
60  // programs.
61  LocalClientOptions& set_number_of_replicas(int number_of_replicas);
62  int number_of_replicas() const;
63 
64  // Sets the thread pool size for parallel execution of an individual operator.
65  LocalClientOptions& set_intra_op_parallelism_threads(int num_threads);
66  int intra_op_parallelism_threads() const;
67 
68  private:
69  perftools::gputools::Platform* platform_;
70  int number_of_replicas_;
71  int intra_op_parallelism_threads_;
72 };
73 
78  public:
79  // Singleton constructor-or-accessor -- returns a client for the application
80  // to issue XLA commands on. Arguments:
81  //
82  // platform : The platform the underlying XLA service should target. If
83  // null then default platform is used.
84  static StatusOr<LocalClient*> GetOrCreateLocalClient(
85  perftools::gputools::Platform* platform = nullptr);
86  static StatusOr<LocalClient*> GetOrCreateLocalClient(
87  const LocalClientOptions& options);
88 
89  // Convenience "or-die" wrapper around the above which returns the existing
90  // client library or creates one with default platform and allocator.
91  static LocalClient* LocalClientOrDie();
92 
93  // Returns the service from the service thread. Only used in unit tests to
94  // access user computations from client.
95  static LocalService* GetXlaService(perftools::gputools::Platform* platform);
96 
97  // Singleton constructor-or-accessor for compile-only clients. Arguments:
98  //
99  // platform : The platform the underlying XLA service should target. If
100  // null then default platform is used.
101  static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
102  perftools::gputools::Platform* platform = nullptr);
103 
104  // Clears the local instance and compile only instance caches. The client
105  // pointers returned by the previous GetOrCreateLocalClient() or
106  // GetOrCreateCompileOnlyClient() invocations are not valid anymore.
107  static void DestroyLocalInstances();
108 
109  private:
110  // Returns the singleton instance of ClientLibrary.
111  static ClientLibrary& Singleton();
112 
113  ClientLibrary();
114  ~ClientLibrary();
115 
116  struct LocalInstance {
117  // Service that is wrapped by the singleton client object.
118  std::unique_ptr<LocalService> service;
119  // Singleton client object.
120  std::unique_ptr<LocalClient> client;
121  };
122 
123  struct CompileOnlyInstance {
124  // Service that is wrapped by the singleton client object.
125  std::unique_ptr<CompileOnlyService> service;
126  // Singleton client object.
127  std::unique_ptr<CompileOnlyClient> client;
128  };
129 
130  tensorflow::mutex service_mutex_; // Guards the singleton creation state.
131  std::unordered_map<perftools::gputools::Platform::Id,
132  std::unique_ptr<LocalInstance>>
133  local_instances_ GUARDED_BY(service_mutex_);
134 
135  std::unordered_map<perftools::gputools::Platform::Id,
136  std::unique_ptr<CompileOnlyInstance>>
137  compile_only_instances_ GUARDED_BY(service_mutex_);
138 
139  TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
140 };
141 
142 } // namespace xla
143 
144 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
namespace for xla
Definition: client_library.cc:26
Definition: client_library.h:77
static StatusOr< CompileOnlyClient * > GetOrCreateCompileOnlyClient(perftools::gputools::Platform *platform=nullptr)
Definition: client_library.cc:136