23 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_ 24 #define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_ 32 #include "tensorflow/compiler/xla/client/local_client.h" 34 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 35 #include "tensorflow/compiler/xla/service/local_service.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/core/platform/macros.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 49 class LocalClientOptions {
51 LocalClientOptions(perftools::gputools::Platform* platform =
nullptr,
52 int number_of_replicas = 1,
53 int intra_op_parallelism_threads = -1);
56 LocalClientOptions& set_platform(perftools::gputools::Platform* platform);
57 perftools::gputools::Platform* platform()
const;
61 LocalClientOptions& set_number_of_replicas(
int number_of_replicas);
62 int number_of_replicas()
const;
65 LocalClientOptions& set_intra_op_parallelism_threads(
int num_threads);
66 int intra_op_parallelism_threads()
const;
69 perftools::gputools::Platform* platform_;
70 int number_of_replicas_;
71 int intra_op_parallelism_threads_;
84 static StatusOr<LocalClient*> GetOrCreateLocalClient(
85 perftools::gputools::Platform* platform =
nullptr);
86 static StatusOr<LocalClient*> GetOrCreateLocalClient(
87 const LocalClientOptions& options);
91 static LocalClient* LocalClientOrDie();
95 static LocalService* GetXlaService(perftools::gputools::Platform* platform);
102 perftools::gputools::Platform* platform =
nullptr);
107 static void DestroyLocalInstances();
116 struct LocalInstance {
118 std::unique_ptr<LocalService> service;
120 std::unique_ptr<LocalClient> client;
123 struct CompileOnlyInstance {
125 std::unique_ptr<CompileOnlyService> service;
127 std::unique_ptr<CompileOnlyClient> client;
130 tensorflow::mutex service_mutex_;
131 std::unordered_map<perftools::gputools::Platform::Id,
132 std::unique_ptr<LocalInstance>>
133 local_instances_ GUARDED_BY(service_mutex_);
135 std::unordered_map<perftools::gputools::Platform::Id,
136 std::unique_ptr<CompileOnlyInstance>>
137 compile_only_instances_ GUARDED_BY(service_mutex_);
144 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
namespace for xla
Definition: client_library.cc:26
Definition: client_library.h:77
static StatusOr< CompileOnlyClient * > GetOrCreateCompileOnlyClient(perftools::gputools::Platform *platform=nullptr)
Definition: client_library.cc:136