1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/strided_slice_op.h" |
17 | |
18 | #include <array> |
19 | #include "tensorflow/core/kernels/bounds_check.h" |
20 | #include "tensorflow/core/lib/core/status.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace { |
24 | |
25 | /// Constants |
26 | constexpr int32 kShrinkAxis = -1, kNewAxis = -2; |
27 | |
28 | // Sparse slicing specification |
29 | // if one does foo[3:5, ..., -3], this will have 3 length tensors |
30 | struct StridedSliceSparseSpec { |
31 | int64 dims; |
32 | int32 num_add_axis_after_ellipsis; |
33 | const Tensor* begin_tensor; |
34 | const Tensor* end_tensor; |
35 | const Tensor& strides_tensor; |
36 | const int32 begin_mask, end_mask; |
37 | int32 ellipsis_mask; |
38 | const int32 new_axis_mask, shrink_axis_mask; |
39 | }; |
40 | |
41 | // Dense slicing specification |
42 | // all ellipses and newaxis' are expanded out. So if |
43 | // foo[3:5, ..., -3] where foo is 10 dimensional, |
44 | // each inlinedVector will have 10 entries whereas the |
45 | // sparse had 3 length tensors. |
46 | struct StridedSliceDenseSpec { |
47 | const int64 dims; |
48 | int32 begin_mask; |
49 | int32 end_mask; |
50 | bool begin_valid; |
51 | bool end_valid; |
52 | gtl::InlinedVector<int64, 4>& begin; |
53 | gtl::InlinedVector<int64, 4>& end; |
54 | gtl::InlinedVector<int64, 4>& strides; |
55 | // This vector helps construct the final shape of the slice. |
56 | // The final tensor is reduced in rank whenever a single index e.g. foo[3] |
57 | // is called for. The final tensor increases in rank with tf.newaxis |
58 | // entries. If an index in this array is positive, the size of the dimension |
59 | // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis, |
60 | // it will be 1. A shrunk dimension is skipped. |
61 | gtl::InlinedVector<int32, 4> final_shape_gather_indices; |
62 | // The dense indexed shrink mask is which processing dimensions |
63 | // should be shrunk. For example, if foo.shape = (10,10,10,10) |
64 | // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and |
65 | // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10). |
66 | int32 shrink_axis_mask; |
67 | }; |
68 | |
69 | } // namespace |
70 | |
71 | template <class T> |
72 | static Status TF_MUST_USE_RESULT BuildDenseSpec( |
73 | const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) { |
74 | // Build expanded begin, end, strides, begin_mask, end_mask |
75 | // to remove any ellipsis |
76 | dense->begin.resize(dense->dims); |
77 | dense->end.resize(dense->dims); |
78 | dense->strides.resize(dense->dims); |
79 | // What indices to get the final shape from. |
80 | dense->begin_mask = 0; |
81 | dense->end_mask = 0; |
82 | dense->shrink_axis_mask = 0; |
83 | { |
84 | int full_index = 0; |
85 | |
86 | const auto& strides_flat = sparse.strides_tensor.flat<T>(); |
87 | dense->begin_valid = sparse.begin_tensor != nullptr; |
88 | dense->end_valid = sparse.end_tensor != nullptr; |
89 | |
90 | for (int i = 0; i < sparse.dims; i++) { |
91 | if ((1 << i) & sparse.ellipsis_mask) { |
92 | // Expand the ellipsis into the appropriate indices |
93 | // NOTE: this only works because we guaranteed one ellipsis |
94 | int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 + |
95 | sparse.num_add_axis_after_ellipsis, |
96 | dense->dims); |
97 | for (; full_index < next_index; full_index++) { |
98 | // new_axis' aren't real axis so you have to skip |
99 | dense->begin[full_index] = dense->end[full_index] = 0; |
100 | dense->strides[full_index] = 1; |
101 | dense->begin_mask |= (1 << full_index); |
102 | dense->end_mask |= (1 << full_index); |
103 | dense->final_shape_gather_indices.push_back(full_index); |
104 | } |
105 | } else if ((1 << i) & sparse.new_axis_mask) { |
106 | dense->final_shape_gather_indices.push_back(kNewAxis); |
107 | } else { |
108 | if (full_index == dense->begin.size()) { |
109 | return errors::InvalidArgument("Index out of range using input dim " , |
110 | full_index, "; input has only " , |
111 | dense->dims, " dims" ); |
112 | } |
113 | |
114 | // Gather slicing spec into appropriate index |
115 | if (sparse.begin_tensor != nullptr) { |
116 | const auto& begin_flat = sparse.begin_tensor->flat<T>(); |
117 | dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i)); |
118 | } |
119 | if (sparse.end_tensor != nullptr) { |
120 | const auto& end_flat = sparse.end_tensor->flat<T>(); |
121 | dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i)); |
122 | } |
123 | dense->strides[full_index] = |
124 | internal::SubtleMustCopy<T>(strides_flat(i)); |
125 | if (sparse.begin_mask & (1 << i)) { |
126 | dense->begin_mask |= (1 << full_index); |
127 | } |
128 | if (sparse.end_mask & (1 << i)) { |
129 | dense->end_mask |= (1 << full_index); |
130 | } |
131 | // If shrink, record where to get the dimensionality from (i.e. |
132 | // new_axis creates a fake 1 size dimension. Also remember shrink |
133 | // axis (now in dense form) so we can ignore dense->end below. |
134 | if (sparse.shrink_axis_mask & (1 << i)) { |
135 | dense->final_shape_gather_indices.push_back(kShrinkAxis); |
136 | dense->shrink_axis_mask |= (1 << full_index); |
137 | } else { |
138 | dense->final_shape_gather_indices.push_back(full_index); |
139 | } |
140 | full_index++; |
141 | } |
142 | } |
143 | } |
144 | return Status::OK(); |
145 | } |
146 | |
147 | Status ValidateStridedSliceOp( |
148 | const Tensor* begin_tensor, const Tensor* end_tensor, |
149 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
150 | int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, |
151 | int32 new_axis_mask, int32 shrink_axis_mask, |
152 | PartialTensorShape* processing_shape, PartialTensorShape* final_shape, |
153 | bool* is_identity, bool* is_simple_slice, bool* slice_dim0, |
154 | gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, |
155 | gtl::InlinedVector<int64, 4>* strides) { |
156 | const bool begin_is_wrong = |
157 | begin_tensor != nullptr && |
158 | !(TensorShapeUtils::IsVector(begin_tensor->shape()) && |
159 | begin_tensor->NumElements() == strides_tensor.NumElements() && |
160 | begin_tensor->NumElements() < 32 /* using 32 bit masks */); |
161 | const bool end_is_wrong = |
162 | end_tensor != nullptr && |
163 | !(TensorShapeUtils::IsVector(end_tensor->shape()) && |
164 | end_tensor->NumElements() == strides_tensor.NumElements()); |
165 | if (begin_is_wrong || end_is_wrong || |
166 | !TensorShapeUtils::IsVector(strides_tensor.shape())) { |
167 | if (begin_tensor != nullptr && end_tensor != nullptr) { |
168 | return errors::InvalidArgument( |
169 | "Expected begin, end, and strides to be 1D equal size tensors, " , |
170 | "but got shapes " , begin_tensor->shape().DebugString(), ", " , |
171 | end_tensor->shape().DebugString(), ", and " , |
172 | strides_tensor.shape().DebugString(), " instead." ); |
173 | } else { |
174 | return errors::InvalidArgument( |
175 | "Expected begin, end, and strides to be 1D equal size tensors, " , |
176 | "but got shape " , strides_tensor.shape().DebugString(), |
177 | " for strides." ); |
178 | } |
179 | } |
180 | // Use bit compares to ensure ellipsis_mask is 0 or a power of 2 |
181 | // i.e. there exists only no more than one ellipsis |
182 | if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) { |
183 | return errors::InvalidArgument( |
184 | "Multiple ellipses in slice spec not allowed" ); |
185 | } |
186 | |
187 | // Step 1: Account for ellipsis and new axis |
188 | // |
189 | // Check for ellipses and count how many non-newaxis' there are after |
190 | // TODO(aselle): Convert this to do a fast log2 followed by iteration |
191 | // counting ones in next guys |
192 | bool ellipsis_seen = false; |
193 | |
194 | StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(), |
195 | 0, |
196 | begin_tensor, |
197 | end_tensor, |
198 | strides_tensor, |
199 | begin_mask_spec, |
200 | end_mask_spec, |
201 | ellipsis_mask, |
202 | new_axis_mask, |
203 | shrink_axis_mask}; |
204 | |
205 | for (int32 i = 0; i < sparse_spec.dims; i++) { |
206 | if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { |
207 | sparse_spec.num_add_axis_after_ellipsis++; |
208 | } |
209 | if ((1 << i) & ellipsis_mask) { |
210 | ellipsis_seen = true; |
211 | } |
212 | } |
213 | // If no ellipsis insert one at the end |
214 | if (!ellipsis_seen) { |
215 | sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); |
216 | sparse_spec.dims++; // this effects loop iteration below |
217 | } |
218 | |
219 | // Step 2: Make a sparse spec into a full index spec |
220 | // |
221 | // The sparse spec does not correspond to the number of dimensions |
222 | // Make a dense spec that corresponds to the number of dimensions |
223 | // |
224 | // For example suppose foo[...,3:] on foo.shape=(2,2,3) then |
225 | // we need to produce the missing begin_mask for the first two |
226 | // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 |
227 | // we achieve begin_mask=6, end_mask=7 |
228 | StridedSliceDenseSpec dense_spec = {input_shape.dims(), |
229 | 0 /* begin_mask */, |
230 | 0 /* end_mask */, |
231 | false /* begin_valid */, |
232 | false /* end_valid */, |
233 | *begin, |
234 | *end, |
235 | *strides}; |
236 | |
237 | if (strides_tensor.dtype() == DT_INT32) { |
238 | TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec)); |
239 | } else if (strides_tensor.dtype() == DT_INT64) { |
240 | TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec)); |
241 | } else { |
242 | LOG(FATAL) << "begin must be either int32 or int64" ; |
243 | } |
244 | |
245 | // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit |
246 | // and bounds check! |
247 | *is_identity = true; |
248 | *slice_dim0 = true; |
249 | *is_simple_slice = true; |
250 | processing_shape->Clear(); |
251 | for (int i = 0; i < input_shape.dims(); ++i) { |
252 | int64& begin_i = (*begin)[i]; |
253 | int64& end_i = (*end)[i]; |
254 | int64& stride_i = (*strides)[i]; |
255 | int64 dim_i = input_shape.dim_size(i); |
256 | if (stride_i == 0) { |
257 | return errors::InvalidArgument("strides[" , i, "] must be non-zero" ); |
258 | } |
259 | bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); |
260 | if (dim_i == -1) { |
261 | processing_shape->AddDim(shrink_i ? 1 : -1); |
262 | continue; |
263 | } |
264 | |
265 | const std::array<int64, 2> masks = { |
266 | {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}}; |
267 | const std::array<int64, 2> valid_range = { |
268 | {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; |
269 | |
270 | auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) { |
271 | if (masks[c]) { |
272 | return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; |
273 | } else { |
274 | int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive |
275 | return x_fwd < valid_range[0] |
276 | ? valid_range[0] |
277 | : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; |
278 | } |
279 | }; |
280 | if (shrink_i && stride_i <= 0) { |
281 | return errors::InvalidArgument( |
282 | "only stride 1 allowed on non-range indexing." ); |
283 | } |
284 | (*is_simple_slice) &= stride_i == 1; |
285 | |
286 | const bool begin_and_end_masked = |
287 | (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i)); |
288 | if (dense_spec.begin_valid && dense_spec.end_valid) { |
289 | if (shrink_i) { |
290 | // If we are shrinking, the end index is now possibly incorrect. In |
291 | // particular foo[-1] produces sparse_begin = -1, sparse_end = 0. |
292 | // and canonical puts these to n-1 and 0, which implies a degenerate |
293 | // interval. Fortunately, it is now safe to re-create end as begin+1. |
294 | int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; |
295 | begin_i = x_fwd; |
296 | end_i = begin_i + 1; |
297 | if (x_fwd < 0 || x_fwd >= dim_i) { |
298 | return errors::InvalidArgument( |
299 | "slice index " , begin_i, " of dimension " , i, " out of bounds." ); |
300 | } |
301 | } else { |
302 | begin_i = canonical(begin_i, 0); |
303 | end_i = canonical(end_i, 1); |
304 | } |
305 | // Update optimization values |
306 | bool take_all_in_dimension = |
307 | stride_i == 1 && begin_i == 0 && end_i == dim_i; |
308 | (*is_identity) &= take_all_in_dimension; |
309 | (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension; |
310 | } else { |
311 | (*is_identity) &= stride_i == 1 && begin_and_end_masked; |
312 | (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked; |
313 | } |
314 | // Compute the processing shape (the intermediate Eigen will produce) |
315 | int64 interval_length; |
316 | bool known_interval = false; |
317 | if (dense_spec.begin_valid && dense_spec.end_valid) { |
318 | interval_length = end_i - begin_i; |
319 | known_interval = true; |
320 | } else if (shrink_i) { |
321 | // The dimension is still known as 1 for the processing_shape, but will be |
322 | // discarded for the final shape. |
323 | interval_length = 1; |
324 | known_interval = true; |
325 | } else if (begin_and_end_masked) { |
326 | // Even if we don't have values for begin or end, we do know that this |
327 | // dimension covers the whole interval. If we have shape information for |
328 | // this dimension, that tells us the interval length. |
329 | if (dim_i > 0) { |
330 | if (stride_i < 0) { |
331 | interval_length = -dim_i; |
332 | } else { |
333 | interval_length = dim_i; |
334 | } |
335 | known_interval = true; |
336 | } |
337 | } |
338 | if (known_interval) { |
339 | int64 size_i; |
340 | // Hold zero if the interval is degenerate, otherwise account for |
341 | // remainder |
342 | if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { |
343 | size_i = 0; |
344 | } else { |
345 | size_i = interval_length / stride_i + |
346 | (interval_length % stride_i != 0 ? 1 : 0); |
347 | } |
348 | processing_shape->AddDim(size_i); |
349 | } else { |
350 | processing_shape->AddDim(-1); |
351 | } |
352 | } |
353 | |
354 | // Step 4: Compute the final shape |
355 | // |
356 | // new_axis will increase dimension by 1 (with a one-size dimension) |
357 | // slices like foo[3,...] will reduce dimension by 1. |
358 | // This cannot be done earlier, because it depends on Step 3. |
359 | final_shape->Clear(); |
360 | for (auto gather_index : dense_spec.final_shape_gather_indices) { |
361 | if (gather_index >= 0) { |
362 | final_shape->AddDim(processing_shape->dim_size(gather_index)); |
363 | } else if (gather_index == kNewAxis) { |
364 | final_shape->AddDim(1); |
365 | } |
366 | } |
367 | return Status::OK(); |
368 | } |
369 | |
370 | Status ValidateStridedSliceOp( |
371 | const Tensor* begin_tensor, const Tensor* end_tensor, |
372 | const Tensor& strides_tensor, const PartialTensorShape& input_shape, |
373 | int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, |
374 | int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape, |
375 | TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, |
376 | bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, |
377 | gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) { |
378 | // Validate with PartialTensorShape output |
379 | PartialTensorShape partial_processing_shape, partial_final_shape; |
380 | TF_RETURN_IF_ERROR(ValidateStridedSliceOp( |
381 | begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, |
382 | end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, |
383 | &partial_processing_shape, &partial_final_shape, is_identity, |
384 | is_simple_slice, slice_dim0, begin, end, strides)); |
385 | |
386 | // Verify that the output shapes are fully known |
387 | if (!partial_processing_shape.AsTensorShape(processing_shape) || |
388 | !partial_final_shape.AsTensorShape(final_shape)) { |
389 | return errors::Internal("ValidateStridedSliceOp returned partial shapes " , |
390 | partial_processing_shape.DebugString(), " and " , |
391 | partial_final_shape.DebugString()); |
392 | } |
393 | return Status::OK(); |
394 | } |
395 | |
396 | } // namespace tensorflow |
397 | |