Move shared validation for StridedSliceOp to a separate header, in preparation
for calling it from the shape fn. Change: 132732667
This commit is contained in:
parent
af02c57ade
commit
f51e1964b5
@ -287,6 +287,7 @@ tf_cuda_library(
|
||||
"util/sparse/sparse_tensor.h",
|
||||
"util/stat_summarizer.h",
|
||||
"util/stream_executor_util.h",
|
||||
"util/strided_slice_op.h",
|
||||
"util/tensor_format.h",
|
||||
"util/tensor_slice_reader.h",
|
||||
"util/tensor_slice_reader_cache.h",
|
||||
|
@ -35,286 +35,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Constants
|
||||
constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
|
||||
|
||||
// Sparse slicing specification
|
||||
// if one does foo[3:5, ..., -3], this will have 3 length tensors
|
||||
struct StridedSliceSparseSpec {
|
||||
int64 dims;
|
||||
int32 num_add_axis_after_ellipsis;
|
||||
const Tensor& begin_tensor;
|
||||
const Tensor& end_tensor;
|
||||
const Tensor& strides_tensor;
|
||||
const int32 begin_mask, end_mask;
|
||||
int32 ellipsis_mask;
|
||||
const int32 new_axis_mask, shrink_axis_mask;
|
||||
};
|
||||
|
||||
// Dense slicing specification
|
||||
// all ellipses and newaxis' are expanded out. So if
|
||||
// foo[3:5, ..., -3] where foo is 10 dimensional,
|
||||
// each inlinedVector will have 10 entries whereas the
|
||||
// sparse had 3 length tensors.
|
||||
struct StridedSliceDenseSpec {
|
||||
const int64 dims;
|
||||
int32 begin_mask;
|
||||
int32 end_mask;
|
||||
gtl::InlinedVector<int64, 4>& begin;
|
||||
gtl::InlinedVector<int64, 4>& end;
|
||||
gtl::InlinedVector<int64, 4>& strides;
|
||||
// This vector helps construct the final shape of the slice.
|
||||
// The final tensor is reduced in rank whenever a single index e.g. foo[3]
|
||||
// is called for. The final tensor increases in rank with tf.newaxis
|
||||
// entries. If an index in this array is positive, the size of the dimension
|
||||
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
|
||||
// it will be 1. A shrunk dimension is skipped.
|
||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
|
||||
// The dense indexed shrink mask is which processing dimensions
|
||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||
// dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
|
||||
int32 shrink_axis_mask;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <class T>
|
||||
static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
|
||||
StridedSliceDenseSpec* dense) {
|
||||
// Build expanded begin, end, strides, begin_mask, end_mask
|
||||
// to remove any ellipsis
|
||||
dense->begin.resize(dense->dims);
|
||||
dense->end.resize(dense->dims);
|
||||
dense->strides.resize(dense->dims);
|
||||
// What indices to get the final shape from.
|
||||
dense->begin_mask = 0;
|
||||
dense->end_mask = 0;
|
||||
dense->shrink_axis_mask = 0;
|
||||
{
|
||||
int full_index = 0;
|
||||
|
||||
const auto& begin_flat = sparse.begin_tensor.flat<T>();
|
||||
const auto& end_flat = sparse.end_tensor.flat<T>();
|
||||
const auto& strides_flat = sparse.strides_tensor.flat<T>();
|
||||
|
||||
for (int i = 0; i < sparse.dims; i++) {
|
||||
if ((1 << i) & sparse.ellipsis_mask) {
|
||||
// Expand the ellipsis into the appropriate indices
|
||||
// NOTE: this only works because we guaranteed one ellipsis
|
||||
int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
|
||||
sparse.num_add_axis_after_ellipsis,
|
||||
dense->dims);
|
||||
for (; full_index < next_index; full_index++) {
|
||||
// new_axis' aren't real axis so you have to skip
|
||||
dense->begin[full_index] = dense->end[full_index] = 0;
|
||||
dense->strides[full_index] = 1;
|
||||
dense->begin_mask |= (1 << full_index);
|
||||
dense->end_mask |= (1 << full_index);
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
}
|
||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||
} else {
|
||||
// Gather slicing spec into appropriate index
|
||||
dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
|
||||
dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
|
||||
dense->strides[full_index] =
|
||||
internal::SubtleMustCopy<T>(strides_flat(i));
|
||||
if (sparse.begin_mask & (1 << i)) {
|
||||
dense->begin_mask |= (1 << full_index);
|
||||
}
|
||||
if (sparse.end_mask & (1 << i)) {
|
||||
dense->end_mask |= (1 << full_index);
|
||||
}
|
||||
// If shrink, record where to get the dimensionality from (i.e.
|
||||
// new_axis creates a fake 1 size dimension. Also remember shrink
|
||||
// axis (now in dense form) so we can ignore dense->end below.
|
||||
if (sparse.shrink_axis_mask & (1 << i)) {
|
||||
dense->final_shape_gather_indices.push_back(kShrinkAxis);
|
||||
dense->shrink_axis_mask |= (1 << full_index);
|
||||
} else {
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
}
|
||||
full_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shared code that is not dependent on the type of T. We do this to reduce
|
||||
// code size by not duplicating all this for all T (float, double, int32, etc.)
|
||||
static void SharedValidation(
|
||||
OpKernelContext* context, const TensorShape& input_shape,
|
||||
int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
|
||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
|
||||
const Tensor& begin_tensor = context->input(1);
|
||||
const Tensor& end_tensor = context->input(2);
|
||||
const Tensor& strides_tensor = context->input(3);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVector(begin_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(end_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(strides_tensor.shape()) &&
|
||||
strides_tensor.dims() == 1 &&
|
||||
strides_tensor.dims() == begin_tensor.dims() &&
|
||||
strides_tensor.dims() == end_tensor.dims() &&
|
||||
begin_tensor.dim_size(0) == end_tensor.dim_size(0) &&
|
||||
begin_tensor.dim_size(0) == strides_tensor.dim_size(0) &&
|
||||
begin_tensor.dim_size(0) < 32, // using 32 bit masks
|
||||
errors::InvalidArgument(
|
||||
"Expected begin, end, and strides to be 1D equal size tensors, ",
|
||||
"but got shapes ", begin_tensor.shape().DebugString(), ", ",
|
||||
end_tensor.shape().DebugString(), ", and ",
|
||||
strides_tensor.shape().DebugString(), " instead."));
|
||||
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2
|
||||
// i.e. there exists only no more than one ellipsis
|
||||
OP_REQUIRES(context,
|
||||
!ellipsis_mask || (ellipsis_mask & (ellipsis_mask - 1)) == 0,
|
||||
errors::InvalidArgument("Multiple ellipsis' in slice "
|
||||
"spec not allowed"));
|
||||
|
||||
// Step 1: Account for ellipsis and new axis
|
||||
//
|
||||
// Check for ellipses and count how many non-newaxis' there are after
|
||||
// TODO(aselle): Convert this to do a fast log2 followed by iteration
|
||||
// counting ones in next guys
|
||||
bool ellipsis_seen = false;
|
||||
|
||||
StridedSliceSparseSpec sparse_spec = {begin_tensor.NumElements(),
|
||||
0,
|
||||
begin_tensor,
|
||||
end_tensor,
|
||||
strides_tensor,
|
||||
begin_mask_spec,
|
||||
end_mask_spec,
|
||||
ellipsis_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask};
|
||||
|
||||
for (int32 i = 0; i < sparse_spec.dims; i++) {
|
||||
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
|
||||
sparse_spec.num_add_axis_after_ellipsis++;
|
||||
}
|
||||
if ((1 << i) & ellipsis_mask) {
|
||||
ellipsis_seen = true;
|
||||
}
|
||||
}
|
||||
// If no ellipsis insert one at the end
|
||||
if (!ellipsis_seen) {
|
||||
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
|
||||
sparse_spec.dims++; // this effects loop iteration below
|
||||
}
|
||||
|
||||
// Step 2: Make a sparse spec into a full index spec
|
||||
//
|
||||
// The sparse spec does not corresopnds to the number of dimensions
|
||||
// Make a dense spec that corresponds to thte number of dimensions
|
||||
//
|
||||
// For example suppose foo[...,3:] on foo.shape=(2,2,3) then
|
||||
// we need to produce the missing begin_mask for the the first two
|
||||
// dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
|
||||
// we achieve begin_mask=6, end_mask=7
|
||||
StridedSliceDenseSpec dense_spec = {
|
||||
input_shape.dims(), 0, 0, *begin, *end, *strides};
|
||||
|
||||
if (begin_tensor.dtype() == DT_INT32) {
|
||||
BuildDenseSpec<int32>(sparse_spec, &dense_spec);
|
||||
} else if (begin_tensor.dtype() == DT_INT64) {
|
||||
BuildDenseSpec<int64>(sparse_spec, &dense_spec);
|
||||
} else {
|
||||
LOG(FATAL) << "begin must be either int32 or int64";
|
||||
}
|
||||
|
||||
// Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
|
||||
// and bounds check!
|
||||
*is_identity = true;
|
||||
*slice_dim0 = true;
|
||||
*is_simple_slice = true;
|
||||
for (int i = 0; i < dense_spec.dims; ++i) {
|
||||
int64& begin_i = (*begin)[i];
|
||||
int64& end_i = (*end)[i];
|
||||
int64& stride_i = (*strides)[i];
|
||||
int64 dim_i = input_shape.dim_size(i);
|
||||
OP_REQUIRES(context, stride_i != 0,
|
||||
errors::InvalidArgument("strides[", i, "] must be non-zero"));
|
||||
|
||||
int64 masks[] = {dense_spec.begin_mask & (1 << i),
|
||||
dense_spec.end_mask & (1 << i)};
|
||||
int64 valid_range[] = {stride_i > 0 ? 0 : -1,
|
||||
stride_i > 0 ? dim_i : dim_i - 1};
|
||||
|
||||
auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) {
|
||||
if (masks[c]) {
|
||||
return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
|
||||
} else {
|
||||
int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
|
||||
return x_fwd < valid_range[0]
|
||||
? valid_range[0]
|
||||
: x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
|
||||
}
|
||||
};
|
||||
if (dense_spec.shrink_axis_mask & (1 << i)) {
|
||||
// If we are shrinking, the end index is now possibly incorrect. In
|
||||
// particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
|
||||
// and canonical puts these to n-1 and 0, which implies a degenerate
|
||||
// interval. Fortunately, it is now safe to re-create end as begin+1.
|
||||
int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
|
||||
begin_i = x_fwd;
|
||||
end_i = begin_i + 1;
|
||||
OP_REQUIRES(context, stride_i > 0,
|
||||
errors::InvalidArgument("only stride 1 allowed on"
|
||||
" non-range indexing."));
|
||||
OP_REQUIRES(
|
||||
context, x_fwd >= 0 && x_fwd < dim_i,
|
||||
errors::InvalidArgument("slice index ", begin_i, " of dimension ", i,
|
||||
" out of bounds."));
|
||||
} else {
|
||||
begin_i = canonical(begin_i, 0);
|
||||
end_i = canonical(end_i, 1);
|
||||
}
|
||||
// Update optimization values
|
||||
(*is_simple_slice) &= stride_i == 1;
|
||||
bool take_all_in_dimension =
|
||||
stride_i == 1 && begin_i == 0 && end_i == input_shape.dim_size(i);
|
||||
(*is_identity) &= take_all_in_dimension;
|
||||
(*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
|
||||
|
||||
// Compute the processing shape (the intermediate Eigen will produce)
|
||||
int64 interval_length = end_i - begin_i;
|
||||
int64 size_i;
|
||||
// Hold zero if the interval is degenerate, otherwise account for remainder
|
||||
if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0)))
|
||||
size_i = 0;
|
||||
else
|
||||
size_i = interval_length / stride_i +
|
||||
(interval_length % stride_i != 0 ? 1 : 0);
|
||||
processing_shape->AddDim(size_i);
|
||||
}
|
||||
|
||||
// Step 4: Compute the final shape
|
||||
//
|
||||
// new_axis will increase dimension by 1 (with a one-size dimension)
|
||||
// slices like foo[3,...] will reduce dimension by 1.
|
||||
// This cannot be done earlier, because it depends on Step 3.
|
||||
for (auto gather_index : dense_spec.final_shape_gather_indices) {
|
||||
if (gather_index >= 0)
|
||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||
else if (gather_index == kNewAxis)
|
||||
final_shape->AddDim(1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class StridedSliceOp : public OpKernel {
|
||||
public:
|
||||
@ -336,11 +60,13 @@ class StridedSliceOp : public OpKernel {
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
|
||||
SharedValidation(context, context->input(0).shape(), begin_mask, end_mask,
|
||||
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
&processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides);
|
||||
if (!context->status().ok()) return;
|
||||
OP_REQUIRES_OK(context,
|
||||
ValidateStridedSliceOp(
|
||||
context->input(1), context->input(2), context->input(3),
|
||||
context->input(0).shape(), begin_mask, end_mask,
|
||||
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
&processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
|
||||
|
||||
const Tensor& input = context->input(0);
|
||||
|
||||
@ -460,10 +186,13 @@ class StridedSliceGradOp : public OpKernel {
|
||||
LOG(FATAL) << "shape must have type int32 or int64.";
|
||||
}
|
||||
|
||||
SharedValidation(context, input_shape, begin_mask, end_mask, ellipsis_mask,
|
||||
new_axis_mask, shrink_axis_mask, &processing_shape,
|
||||
&final_shape, &is_identity, &is_simple_slice, &slice_dim0,
|
||||
&begin, &end, &strides);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
ValidateStridedSliceOp(
|
||||
context->input(1), context->input(2), context->input(3),
|
||||
input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
|
||||
|
||||
// Check to make sure dy is consistent with the original slice
|
||||
TensorShape dy_shape = context->input(4).shape();
|
||||
@ -527,11 +256,13 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
context->forward_ref_input_to_ref_output(0, 0);
|
||||
Tensor old_lhs = context->mutable_input(0, true);
|
||||
|
||||
SharedValidation(context, old_lhs.shape(), begin_mask, end_mask,
|
||||
ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
&processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides);
|
||||
if (!context->status().ok()) return;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
ValidateStridedSliceOp(
|
||||
context->input(1), context->input(2), context->input(3),
|
||||
old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
|
||||
|
||||
if (processing_shape.num_elements()) {
|
||||
const Tensor& input = context->input(4);
|
||||
|
296
tensorflow/core/util/strided_slice_op.cc
Normal file
296
tensorflow/core/util/strided_slice_op.cc
Normal file
@ -0,0 +1,296 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Constants
|
||||
constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
|
||||
|
||||
// Sparse slicing specification
|
||||
// if one does foo[3:5, ..., -3], this will have 3 length tensors
|
||||
struct StridedSliceSparseSpec {
|
||||
int64 dims;
|
||||
int32 num_add_axis_after_ellipsis;
|
||||
const Tensor& begin_tensor;
|
||||
const Tensor& end_tensor;
|
||||
const Tensor& strides_tensor;
|
||||
const int32 begin_mask, end_mask;
|
||||
int32 ellipsis_mask;
|
||||
const int32 new_axis_mask, shrink_axis_mask;
|
||||
};
|
||||
|
||||
// Dense slicing specification
|
||||
// all ellipses and newaxis' are expanded out. So if
|
||||
// foo[3:5, ..., -3] where foo is 10 dimensional,
|
||||
// each inlinedVector will have 10 entries whereas the
|
||||
// sparse had 3 length tensors.
|
||||
struct StridedSliceDenseSpec {
|
||||
const int64 dims;
|
||||
int32 begin_mask;
|
||||
int32 end_mask;
|
||||
gtl::InlinedVector<int64, 4>& begin;
|
||||
gtl::InlinedVector<int64, 4>& end;
|
||||
gtl::InlinedVector<int64, 4>& strides;
|
||||
// This vector helps construct the final shape of the slice.
|
||||
// The final tensor is reduced in rank whenever a single index e.g. foo[3]
|
||||
// is called for. The final tensor increases in rank with tf.newaxis
|
||||
// entries. If an index in this array is positive, the size of the dimension
|
||||
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
|
||||
// it will be 1. A shrunk dimension is skipped.
|
||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
|
||||
// The dense indexed shrink mask is which processing dimensions
|
||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||
// dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
|
||||
int32 shrink_axis_mask;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class T>
|
||||
static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
|
||||
StridedSliceDenseSpec* dense) {
|
||||
// Build expanded begin, end, strides, begin_mask, end_mask
|
||||
// to remove any ellipsis
|
||||
dense->begin.resize(dense->dims);
|
||||
dense->end.resize(dense->dims);
|
||||
dense->strides.resize(dense->dims);
|
||||
// What indices to get the final shape from.
|
||||
dense->begin_mask = 0;
|
||||
dense->end_mask = 0;
|
||||
dense->shrink_axis_mask = 0;
|
||||
{
|
||||
int full_index = 0;
|
||||
|
||||
const auto& begin_flat = sparse.begin_tensor.flat<T>();
|
||||
const auto& end_flat = sparse.end_tensor.flat<T>();
|
||||
const auto& strides_flat = sparse.strides_tensor.flat<T>();
|
||||
|
||||
for (int i = 0; i < sparse.dims; i++) {
|
||||
if ((1 << i) & sparse.ellipsis_mask) {
|
||||
// Expand the ellipsis into the appropriate indices
|
||||
// NOTE: this only works because we guaranteed one ellipsis
|
||||
int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
|
||||
sparse.num_add_axis_after_ellipsis,
|
||||
dense->dims);
|
||||
for (; full_index < next_index; full_index++) {
|
||||
// new_axis' aren't real axis so you have to skip
|
||||
dense->begin[full_index] = dense->end[full_index] = 0;
|
||||
dense->strides[full_index] = 1;
|
||||
dense->begin_mask |= (1 << full_index);
|
||||
dense->end_mask |= (1 << full_index);
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
}
|
||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||
} else {
|
||||
// Gather slicing spec into appropriate index
|
||||
dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
|
||||
dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
|
||||
dense->strides[full_index] =
|
||||
internal::SubtleMustCopy<T>(strides_flat(i));
|
||||
if (sparse.begin_mask & (1 << i)) {
|
||||
dense->begin_mask |= (1 << full_index);
|
||||
}
|
||||
if (sparse.end_mask & (1 << i)) {
|
||||
dense->end_mask |= (1 << full_index);
|
||||
}
|
||||
// If shrink, record where to get the dimensionality from (i.e.
|
||||
// new_axis creates a fake 1 size dimension. Also remember shrink
|
||||
// axis (now in dense form) so we can ignore dense->end below.
|
||||
if (sparse.shrink_axis_mask & (1 << i)) {
|
||||
dense->final_shape_gather_indices.push_back(kShrinkAxis);
|
||||
dense->shrink_axis_mask |= (1 << full_index);
|
||||
} else {
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
}
|
||||
full_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status ValidateStridedSliceOp(
|
||||
const Tensor& begin_tensor, const Tensor& end_tensor,
|
||||
const Tensor& strides_tensor, const TensorShape& input_shape,
|
||||
int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
|
||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
|
||||
if (!(TensorShapeUtils::IsVector(begin_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(end_tensor.shape()) &&
|
||||
TensorShapeUtils::IsVector(strides_tensor.shape()) &&
|
||||
strides_tensor.dims() == 1 &&
|
||||
strides_tensor.dims() == begin_tensor.dims() &&
|
||||
strides_tensor.dims() == end_tensor.dims() &&
|
||||
begin_tensor.dim_size(0) == end_tensor.dim_size(0) &&
|
||||
begin_tensor.dim_size(0) == strides_tensor.dim_size(0) &&
|
||||
begin_tensor.dim_size(0) < 32 /* using 32 bit masks */)) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected begin, end, and strides to be 1D equal size tensors, ",
|
||||
"but got shapes ", begin_tensor.shape().DebugString(), ", ",
|
||||
end_tensor.shape().DebugString(), ", and ",
|
||||
strides_tensor.shape().DebugString(), " instead.");
|
||||
}
|
||||
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2
|
||||
// i.e. there exists only no more than one ellipsis
|
||||
if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
|
||||
return errors::InvalidArgument(
|
||||
"Multiple ellipsis' in slice spec not allowed");
|
||||
}
|
||||
|
||||
// Step 1: Account for ellipsis and new axis
|
||||
//
|
||||
// Check for ellipses and count how many non-newaxis' there are after
|
||||
// TODO(aselle): Convert this to do a fast log2 followed by iteration
|
||||
// counting ones in next guys
|
||||
bool ellipsis_seen = false;
|
||||
|
||||
StridedSliceSparseSpec sparse_spec = {begin_tensor.NumElements(),
|
||||
0,
|
||||
begin_tensor,
|
||||
end_tensor,
|
||||
strides_tensor,
|
||||
begin_mask_spec,
|
||||
end_mask_spec,
|
||||
ellipsis_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask};
|
||||
|
||||
for (int32 i = 0; i < sparse_spec.dims; i++) {
|
||||
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
|
||||
sparse_spec.num_add_axis_after_ellipsis++;
|
||||
}
|
||||
if ((1 << i) & ellipsis_mask) {
|
||||
ellipsis_seen = true;
|
||||
}
|
||||
}
|
||||
// If no ellipsis insert one at the end
|
||||
if (!ellipsis_seen) {
|
||||
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
|
||||
sparse_spec.dims++; // this effects loop iteration below
|
||||
}
|
||||
|
||||
// Step 2: Make a sparse spec into a full index spec
|
||||
//
|
||||
// The sparse spec does not corresopnds to the number of dimensions
|
||||
// Make a dense spec that corresponds to thte number of dimensions
|
||||
//
|
||||
// For example suppose foo[...,3:] on foo.shape=(2,2,3) then
|
||||
// we need to produce the missing begin_mask for the the first two
|
||||
// dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
|
||||
// we achieve begin_mask=6, end_mask=7
|
||||
StridedSliceDenseSpec dense_spec = {
|
||||
input_shape.dims(), 0, 0, *begin, *end, *strides};
|
||||
|
||||
if (begin_tensor.dtype() == DT_INT32) {
|
||||
BuildDenseSpec<int32>(sparse_spec, &dense_spec);
|
||||
} else if (begin_tensor.dtype() == DT_INT64) {
|
||||
BuildDenseSpec<int64>(sparse_spec, &dense_spec);
|
||||
} else {
|
||||
LOG(FATAL) << "begin must be either int32 or int64";
|
||||
}
|
||||
|
||||
// Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
|
||||
// and bounds check!
|
||||
*is_identity = true;
|
||||
*slice_dim0 = true;
|
||||
*is_simple_slice = true;
|
||||
for (int i = 0; i < dense_spec.dims; ++i) {
|
||||
int64& begin_i = (*begin)[i];
|
||||
int64& end_i = (*end)[i];
|
||||
int64& stride_i = (*strides)[i];
|
||||
int64 dim_i = input_shape.dim_size(i);
|
||||
if (stride_i == 0) {
|
||||
return errors::InvalidArgument("strides[", i, "] must be non-zero");
|
||||
}
|
||||
|
||||
int64 masks[] = {dense_spec.begin_mask & (1 << i),
|
||||
dense_spec.end_mask & (1 << i)};
|
||||
int64 valid_range[] = {stride_i > 0 ? 0 : -1,
|
||||
stride_i > 0 ? dim_i : dim_i - 1};
|
||||
|
||||
auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) {
|
||||
if (masks[c]) {
|
||||
return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
|
||||
} else {
|
||||
int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
|
||||
return x_fwd < valid_range[0]
|
||||
? valid_range[0]
|
||||
: x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
|
||||
}
|
||||
};
|
||||
if (dense_spec.shrink_axis_mask & (1 << i)) {
|
||||
// If we are shrinking, the end index is now possibly incorrect. In
|
||||
// particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
|
||||
// and canonical puts these to n-1 and 0, which implies a degenerate
|
||||
// interval. Fortunately, it is now safe to re-create end as begin+1.
|
||||
int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
|
||||
begin_i = x_fwd;
|
||||
end_i = begin_i + 1;
|
||||
if (stride_i <= 0) {
|
||||
return errors::InvalidArgument(
|
||||
"only stride 1 allowed on non-range indexing.");
|
||||
}
|
||||
if (x_fwd < 0 || x_fwd >= dim_i) {
|
||||
return errors::InvalidArgument("slice index ", begin_i,
|
||||
" of dimension ", i, " out of bounds.");
|
||||
}
|
||||
} else {
|
||||
begin_i = canonical(begin_i, 0);
|
||||
end_i = canonical(end_i, 1);
|
||||
}
|
||||
// Update optimization values
|
||||
(*is_simple_slice) &= stride_i == 1;
|
||||
bool take_all_in_dimension =
|
||||
stride_i == 1 && begin_i == 0 && end_i == input_shape.dim_size(i);
|
||||
(*is_identity) &= take_all_in_dimension;
|
||||
(*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
|
||||
|
||||
// Compute the processing shape (the intermediate Eigen will produce)
|
||||
int64 interval_length = end_i - begin_i;
|
||||
int64 size_i;
|
||||
// Hold zero if the interval is degenerate, otherwise account for remainder
|
||||
if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0)))
|
||||
size_i = 0;
|
||||
else
|
||||
size_i = interval_length / stride_i +
|
||||
(interval_length % stride_i != 0 ? 1 : 0);
|
||||
processing_shape->AddDim(size_i);
|
||||
}
|
||||
|
||||
// Step 4: Compute the final shape
|
||||
//
|
||||
// new_axis will increase dimension by 1 (with a one-size dimension)
|
||||
// slices like foo[3,...] will reduce dimension by 1.
|
||||
// This cannot be done earlier, because it depends on Step 3.
|
||||
for (auto gather_index : dense_spec.final_shape_gather_indices) {
|
||||
if (gather_index >= 0)
|
||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||
else if (gather_index == kNewAxis)
|
||||
final_shape->AddDim(1);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
42
tensorflow/core/util/strided_slice_op.h
Normal file
42
tensorflow/core/util/strided_slice_op.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Runs validation on the strided slice op parameters.
|
||||
//
|
||||
// Is a separate translation unit from the kernel so that:
|
||||
// 1. The op's shape function can use it.
|
||||
// 2. The code size is reduced vs templating this on the kernel's type.
|
||||
Status ValidateStridedSliceOp(
|
||||
const Tensor& begin_tensor, const Tensor& end_tensor,
|
||||
const Tensor& strides_tensor, const TensorShape& input_shape,
|
||||
int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
|
||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
|
Loading…
Reference in New Issue
Block a user