1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ |
17 | #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ |
18 | |
19 | #include <cassert> |
20 | #include <string> |
21 | |
22 | #include "tensorflow/compiler/xla/executable_run_options.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | |
25 | // Forward-declare, rather than include, to reduce code size for users that |
26 | // never use this functionality. |
27 | namespace xla { |
28 | class ProgramShape; |
29 | class HloProfilePrinterData; |
30 | } |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // Represents a function compiled by XLA, produced via either JIT or AOT. |
35 | // |
36 | // The Run method invokes the actual computation, with inputs read from arg |
37 | // buffers, and outputs written to result buffers. Each Run call may also use a |
38 | // set of temporary buffers for the computation. |
39 | // |
40 | // By default each instance of this class manages its own arg, result and temp |
41 | // buffers. The AllocMode constructor parameter may be used to modify the buffer |
42 | // allocation strategy. |
43 | // |
44 | // Under the default allocation strategy, this class is thread-compatible: |
45 | // o Calls to non-const methods require exclusive access to the object. |
46 | // o Concurrent calls to const methods are OK, if those calls are made while it |
47 | // is guaranteed that no thread may call a non-const method. |
48 | class XlaCompiledCpuFunction { |
49 | public: |
50 | // Type of the raw function, produced by either JIT or AOT. |
51 | using RawFunction = void (*)(void* result, |
52 | const xla::ExecutableRunOptions* run_options, |
53 | const void** args, void** temps, |
54 | int64* profile_counters); |
55 | |
56 | // StaticData represents the state necessary to run an XLA-compiled |
57 | // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for |
58 | // AOT this is backed by data compiled into the object file. |
59 | struct StaticData { |
60 | // The raw function to call. |
61 | RawFunction raw_function; |
62 | |
63 | // Cardinality and sizes of arg and temp buffers. |
64 | const intptr_t* arg_sizes = nullptr; |
65 | size_t num_args = 0; |
66 | const intptr_t* temp_sizes = nullptr; |
67 | size_t num_temps = 0; |
68 | |
69 | // The 0-based index of the result tuple, in the temp buffers. |
70 | size_t result_index = 0; |
71 | |
72 | // [Optional] Arrays of arg and result names. These are arrays of C-style |
73 | // strings, where the array is terminated by nullptr. |
74 | const char** arg_names = nullptr; |
75 | const char** result_names = nullptr; |
76 | |
77 | // [Optional] Arg and result shapes. |
78 | const xla::ProgramShape* program_shape = nullptr; |
79 | |
80 | // [Optional] Profile printer data. Null if profiling is disabled. |
81 | const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; |
82 | |
83 | // [Optional] The number of profile counters expected in the profile counter |
84 | // buffer by the generated code and hlo_profile_printer. 0 if profiling is |
85 | // disabled. This information is already present in |
86 | // hlo_profile_printer_data but xla::HloProfilePrinterData is forward |
87 | // declared so we don't have access to that information here. |
88 | int64 profile_counters_size = 0; |
89 | }; |
90 | |
91 | // AllocMode controls the buffer allocation mode. |
92 | enum class AllocMode { |
93 | // Allocate all buffers - args, results, profile and temps. |
94 | ARGS_RESULTS_PROFILES_AND_TEMPS, |
95 | |
96 | // Only allocate result, profile and temp buffers. |
97 | // Use set_arg_data to set argument buffers before Run is called. |
98 | RESULTS_PROFILES_AND_TEMPS_ONLY, |
99 | }; |
100 | |
101 | XlaCompiledCpuFunction( |
102 | const StaticData& static_data, |
103 | AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); |
104 | virtual ~XlaCompiledCpuFunction(); |
105 | |
106 | XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; |
107 | XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; |
108 | |
109 | // Sets the intra-op thread pool used to run individual ops concurrently. |
110 | void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { |
111 | run_options_.set_intra_op_thread_pool(pool); |
112 | } |
113 | |
114 | // Runs the computation, with inputs read from arg buffers, and outputs |
115 | // written to result buffers. Returns true on success and false on failure. |
116 | bool Run() { |
117 | raw_function_(temps_[result_index_], &run_options_, |
118 | const_cast<const void**>(args_), temps_, profile_counters_); |
119 | return true; |
120 | } |
121 | |
122 | // Returns the error message from the previous failed Run call. |
123 | // |
124 | // TODO(fschneider): For now this always returns an empty string because there |
125 | // is no support for error reporting in XLA. Remove this once all callers are |
126 | // updated. |
127 | string error_msg() const { return {}; } |
128 | |
129 | // ------------------------------ |
130 | // Arg methods for managing input buffers. Buffers are in row-major order. |
131 | |
132 | // Returns the underlying array of argument buffers, where args()[I] is the |
133 | // buffer for the positional argument at index I. |
134 | void** args() { return args_; } |
135 | const void* const* args() const { return args_; } |
136 | |
137 | // Returns the buffer for the positional argument at the given `index`. |
138 | void* arg_data(size_t index) { return args_[index]; } |
139 | const void* arg_data(size_t index) const { return args_[index]; } |
140 | |
141 | // Sets the buffer for the positional argument at the given `index` to `data`. |
142 | // Must be called before Run to have an effect. May be called under any |
143 | // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be |
144 | // called for each positional argument, in order to set the argument buffers. |
145 | // |
146 | // Allocated memory must be aligned to the size specified by |
147 | // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in |
148 | // tensorflow/compiler/aot/runtime.h to ensure correct alignment. |
149 | // |
150 | // Aliasing of argument and result buffers is not allowed, and results in |
151 | // undefined behavior. |
152 | void set_arg_data(size_t index, void* data) { args_[index] = data; } |
153 | |
154 | // ------------------------------ |
155 | // Result methods for managing output buffers. Buffers are in row-major order. |
156 | // Must only be called after a successful Run call. Unlike the arg methods, |
157 | // there is no set_resultN_data method. The result buffers are managed |
158 | // internally, and may change after each call to Run. |
159 | |
160 | // Returns the underlying array of result buffers, where results()[I] is the |
161 | // buffer for the positional result at index I. |
162 | void** results() { return static_cast<void**>(temps_[result_index_]); } |
163 | const void* const* results() const { |
164 | return static_cast<const void* const*>(temps_[result_index_]); |
165 | } |
166 | |
167 | // Profile counters for this XLA computation. |
168 | // |
169 | // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in |
170 | // this case) these counters are non-null and are automatically populated by |
171 | // `Run`. The counters can then be pretty-printed using |
172 | // `hlo_profile_printer()`. |
173 | // |
174 | // When Hlo profiling is disabled, this accessor returns null. |
175 | const int64* profile_counters() const { return profile_counters_; } |
176 | |
177 | // Returns the buffer for the positional result at the given `index`. |
178 | void* result_data(size_t index) { return results()[index]; } |
179 | const void* result_data(size_t index) const { return results()[index]; } |
180 | |
181 | // ------------------------------ |
182 | // Methods for extracting optional metadata. |
183 | |
184 | // Returns true iff data is available for the Lookup{Arg,Result}Index methods. |
185 | // E.g. the data might not be compiled into the binary for AOT. |
186 | bool HasNameIndices() const { |
187 | return arg_names_ != nullptr && result_names_ != nullptr; |
188 | } |
189 | |
190 | // Returns the 0-based index for the argument with the given `name`. |
191 | // Returns -1 if the name wasn't found, or data isn't available. |
192 | // |
193 | // The index remains constant for every instance of XlaCompiledCpuFunction |
194 | // generated from the same static data, and might not be cheap to determine. |
195 | // Recommended usage is to capture this in a variable for re-use. |
196 | int LookupArgIndex(const string& name) const; |
197 | |
198 | // Returns the 0-based index for the result with the given `name`. |
199 | // Returns -1 if the name wasn't found, or data isn't available. |
200 | // |
201 | // The index remains constant for every instance of XlaCompiledCpuFunction |
202 | // generated from the same static data, and might not be cheap to determine. |
203 | // Recommended usage is to capture this in a variable for re-use. |
204 | int LookupResultIndex(const string& name) const; |
205 | |
206 | // Returns the shape of the args and results. May return nullptr if the |
207 | // program shape isn't available. |
208 | const xla::ProgramShape* ProgramShape() const { return program_shape_; } |
209 | |
210 | bool hlo_profiling_enabled() const { |
211 | return hlo_profile_printer_data_ != nullptr; |
212 | } |
213 | const xla::HloProfilePrinterData& hlo_profile_printer_data() const { |
214 | assert(hlo_profiling_enabled()); |
215 | return *hlo_profile_printer_data_; |
216 | } |
217 | |
218 | private: |
219 | const RawFunction raw_function_; |
220 | const size_t result_index_; |
221 | |
222 | // Arrays of argument and temp buffers; entries in args_ may be overwritten by |
223 | // the user. |
224 | void** args_ = nullptr; |
225 | void** temps_ = nullptr; |
226 | |
227 | // Backing memory for individual arg and temp buffers. |
228 | void* alloc_args_ = nullptr; |
229 | void* alloc_temps_ = nullptr; |
230 | |
231 | // Backing memory for profiling counters. |
232 | int64* profile_counters_ = nullptr; |
233 | |
234 | // Options and context passed to the compiled function. |
235 | xla::ExecutableRunOptions run_options_; |
236 | |
237 | // Optional metadata. |
238 | const char** arg_names_ = nullptr; |
239 | const char** result_names_ = nullptr; |
240 | const xla::ProgramShape* program_shape_ = nullptr; |
241 | const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; |
242 | }; |
243 | |
244 | } // namespace tensorflow |
245 | |
246 | #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ |
247 | |