1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
17#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
18
19#include <cassert>
20#include <string>
21
22#include "tensorflow/compiler/xla/executable_run_options.h"
23#include "tensorflow/core/platform/types.h"
24
25// Forward-declare, rather than include, to reduce code size for users that
26// never use this functionality.
27namespace xla {
28class ProgramShape;
29class HloProfilePrinterData;
30}
31
32namespace tensorflow {
33
34// Represents a function compiled by XLA, produced via either JIT or AOT.
35//
36// The Run method invokes the actual computation, with inputs read from arg
37// buffers, and outputs written to result buffers. Each Run call may also use a
38// set of temporary buffers for the computation.
39//
40// By default each instance of this class manages its own arg, result and temp
41// buffers. The AllocMode constructor parameter may be used to modify the buffer
42// allocation strategy.
43//
44// Under the default allocation strategy, this class is thread-compatible:
45// o Calls to non-const methods require exclusive access to the object.
46// o Concurrent calls to const methods are OK, if those calls are made while it
47// is guaranteed that no thread may call a non-const method.
48class XlaCompiledCpuFunction {
49 public:
50 // Type of the raw function, produced by either JIT or AOT.
51 using RawFunction = void (*)(void* result,
52 const xla::ExecutableRunOptions* run_options,
53 const void** args, void** temps,
54 int64* profile_counters);
55
56 // StaticData represents the state necessary to run an XLA-compiled
57 // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
58 // AOT this is backed by data compiled into the object file.
59 struct StaticData {
60 // The raw function to call.
61 RawFunction raw_function;
62
63 // Cardinality and sizes of arg and temp buffers.
64 const intptr_t* arg_sizes = nullptr;
65 size_t num_args = 0;
66 const intptr_t* temp_sizes = nullptr;
67 size_t num_temps = 0;
68
69 // The 0-based index of the result tuple, in the temp buffers.
70 size_t result_index = 0;
71
72 // [Optional] Arrays of arg and result names. These are arrays of C-style
73 // strings, where the array is terminated by nullptr.
74 const char** arg_names = nullptr;
75 const char** result_names = nullptr;
76
77 // [Optional] Arg and result shapes.
78 const xla::ProgramShape* program_shape = nullptr;
79
80 // [Optional] Profile printer data. Null if profiling is disabled.
81 const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
82
83 // [Optional] The number of profile counters expected in the profile counter
84 // buffer by the generated code and hlo_profile_printer. 0 if profiling is
85 // disabled. This information is already present in
86 // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
87 // declared so we don't have access to that information here.
88 int64 profile_counters_size = 0;
89 };
90
91 // AllocMode controls the buffer allocation mode.
92 enum class AllocMode {
93 // Allocate all buffers - args, results, profile and temps.
94 ARGS_RESULTS_PROFILES_AND_TEMPS,
95
96 // Only allocate result, profile and temp buffers.
97 // Use set_arg_data to set argument buffers before Run is called.
98 RESULTS_PROFILES_AND_TEMPS_ONLY,
99 };
100
101 XlaCompiledCpuFunction(
102 const StaticData& static_data,
103 AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS);
104 virtual ~XlaCompiledCpuFunction();
105
106 XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
107 XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete;
108
109 // Sets the intra-op thread pool used to run individual ops concurrently.
110 void set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
111 run_options_.set_intra_op_thread_pool(pool);
112 }
113
114 // Runs the computation, with inputs read from arg buffers, and outputs
115 // written to result buffers. Returns true on success and false on failure.
116 bool Run() {
117 raw_function_(temps_[result_index_], &run_options_,
118 const_cast<const void**>(args_), temps_, profile_counters_);
119 return true;
120 }
121
122 // Returns the error message from the previous failed Run call.
123 //
124 // TODO(fschneider): For now this always returns an empty string because there
125 // is no support for error reporting in XLA. Remove this once all callers are
126 // updated.
127 string error_msg() const { return {}; }
128
129 // ------------------------------
130 // Arg methods for managing input buffers. Buffers are in row-major order.
131
132 // Returns the underlying array of argument buffers, where args()[I] is the
133 // buffer for the positional argument at index I.
134 void** args() { return args_; }
135 const void* const* args() const { return args_; }
136
137 // Returns the buffer for the positional argument at the given `index`.
138 void* arg_data(size_t index) { return args_[index]; }
139 const void* arg_data(size_t index) const { return args_[index]; }
140
141 // Sets the buffer for the positional argument at the given `index` to `data`.
142 // Must be called before Run to have an effect. May be called under any
143 // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
144 // called for each positional argument, in order to set the argument buffers.
145 //
146 // Allocated memory must be aligned to the size specified by
147 // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in
148 // tensorflow/compiler/aot/runtime.h to ensure correct alignment.
149 //
150 // Aliasing of argument and result buffers is not allowed, and results in
151 // undefined behavior.
152 void set_arg_data(size_t index, void* data) { args_[index] = data; }
153
154 // ------------------------------
155 // Result methods for managing output buffers. Buffers are in row-major order.
156 // Must only be called after a successful Run call. Unlike the arg methods,
157 // there is no set_resultN_data method. The result buffers are managed
158 // internally, and may change after each call to Run.
159
160 // Returns the underlying array of result buffers, where results()[I] is the
161 // buffer for the positional result at index I.
162 void** results() { return static_cast<void**>(temps_[result_index_]); }
163 const void* const* results() const {
164 return static_cast<const void* const*>(temps_[result_index_]);
165 }
166
167 // Profile counters for this XLA computation.
168 //
169 // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in
170 // this case) these counters are non-null and are automatically populated by
171 // `Run`. The counters can then be pretty-printed using
172 // `hlo_profile_printer()`.
173 //
174 // When Hlo profiling is disabled, this accessor returns null.
175 const int64* profile_counters() const { return profile_counters_; }
176
177 // Returns the buffer for the positional result at the given `index`.
178 void* result_data(size_t index) { return results()[index]; }
179 const void* result_data(size_t index) const { return results()[index]; }
180
181 // ------------------------------
182 // Methods for extracting optional metadata.
183
184 // Returns true iff data is available for the Lookup{Arg,Result}Index methods.
185 // E.g. the data might not be compiled into the binary for AOT.
186 bool HasNameIndices() const {
187 return arg_names_ != nullptr && result_names_ != nullptr;
188 }
189
190 // Returns the 0-based index for the argument with the given `name`.
191 // Returns -1 if the name wasn't found, or data isn't available.
192 //
193 // The index remains constant for every instance of XlaCompiledCpuFunction
194 // generated from the same static data, and might not be cheap to determine.
195 // Recommended usage is to capture this in a variable for re-use.
196 int LookupArgIndex(const string& name) const;
197
198 // Returns the 0-based index for the result with the given `name`.
199 // Returns -1 if the name wasn't found, or data isn't available.
200 //
201 // The index remains constant for every instance of XlaCompiledCpuFunction
202 // generated from the same static data, and might not be cheap to determine.
203 // Recommended usage is to capture this in a variable for re-use.
204 int LookupResultIndex(const string& name) const;
205
206 // Returns the shape of the args and results. May return nullptr if the
207 // program shape isn't available.
208 const xla::ProgramShape* ProgramShape() const { return program_shape_; }
209
210 bool hlo_profiling_enabled() const {
211 return hlo_profile_printer_data_ != nullptr;
212 }
213 const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
214 assert(hlo_profiling_enabled());
215 return *hlo_profile_printer_data_;
216 }
217
218 private:
219 const RawFunction raw_function_;
220 const size_t result_index_;
221
222 // Arrays of argument and temp buffers; entries in args_ may be overwritten by
223 // the user.
224 void** args_ = nullptr;
225 void** temps_ = nullptr;
226
227 // Backing memory for individual arg and temp buffers.
228 void* alloc_args_ = nullptr;
229 void* alloc_temps_ = nullptr;
230
231 // Backing memory for profiling counters.
232 int64* profile_counters_ = nullptr;
233
234 // Options and context passed to the compiled function.
235 xla::ExecutableRunOptions run_options_;
236
237 // Optional metadata.
238 const char** arg_names_ = nullptr;
239 const char** result_names_ = nullptr;
240 const xla::ProgramShape* program_shape_ = nullptr;
241 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
242};
243
244} // namespace tensorflow
245
246#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
247