Support more cases in strided slice ops.

- When ranges are dynamic, we can lower strided slice ops into xla dynamic slices.
- Optimize xla buidler dynamism inference to not materialize parameter tuple.
- Export some helper data structures from strided_slice_op.

PiperOrigin-RevId: 356839276
Change-Id: I3b5d9fcb1289496029afbf97e49df49df26ce699
This commit is contained in:
Yunxing Dai 2021-02-10 15:23:45 -08:00 committed by TensorFlower Gardener
parent dcaa3c9ed9
commit 41b3e84fa9
4 changed files with 242 additions and 126 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
@ -52,6 +53,145 @@ class StridedSliceOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
}
void EmitDynamicSlice(XlaOpKernelContext* ctx,
const absl::InlinedVector<int64, 4>& strides,
TensorShape processing_shape, TensorShape final_shape,
PartialTensorShape partial_processing_shape,
PartialTensorShape partial_final_shape,
const StridedSliceShapeSpec& shape_spec,
const std::vector<bool>& begins_are_dynamic,
const std::vector<bool>& ends_are_dynamic) {
const TensorShape input_shape = ctx->InputShape(0);
xla::XlaOp slice = ctx->Input(0);
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
OP_REQUIRES(ctx, strides[i] == 1,
errors::InvalidArgument(
"Strides have to be one when inputs are not constant."));
}
// Infer static output shape, reconsile unknown dimension with input dim
// size.
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
if (partial_final_shape.dim_size(i) == -1) {
// Use input shape shape_spec.
partial_final_shape.set_dim(
i,
input_shape.dim_size(shape_spec.output_to_processing_mapping[i]));
}
}
OP_REQUIRES(
ctx, partial_final_shape.AsTensorShape(&final_shape),
InvalidArgument("XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
for (int64 i = 0; i < partial_processing_shape.dims(); ++i) {
if (partial_processing_shape.dim_size(i) == -1) {
// Use input shape shape_spec.
partial_processing_shape.set_dim(i, input_shape.dim_size(i));
}
}
OP_REQUIRES(
ctx, partial_processing_shape.AsTensorShape(&processing_shape),
InvalidArgument("XLA can't deduce compile time constant processing "
"shape for strided slice: ",
partial_processing_shape.DebugString(),
", output shape must be a compile-time constant"));
// When inputs are not compile time constants, shape inference can only
// inference size 1 slice.
std::vector<int64> slice_sizes(input_shape.dims(), 1);
// If there is dynamic begin/end (and if the dimension is not shrunk), we
// need to use dynamic shape infrastructure -- we slice the output with
// full size, then call SetDimensionSize on the output. However, if we
// slice with the full size at a non-zero dimension we may get OOB access.
// To avoid that, we first pad the input to 2x before calling slice.
xla::PaddingConfig padding_config;
bool need_padding = false;
std::vector<bool> result_dims_are_dynamic;
for (int64 i = 0; i < input_shape.dims(); ++i) {
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask;
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(0);
dims->set_interior_padding(0);
if ((begins_are_dynamic[sparse_index] ||
ends_are_dynamic[sparse_index]) &&
!shrink_axis_set) {
// Need to slice this dimension so pad first.
dims->set_edge_padding_high(input_shape.dim_size(i));
need_padding = true;
result_dims_are_dynamic.push_back(true);
} else {
dims->set_edge_padding_high(0);
result_dims_are_dynamic.push_back(false);
}
}
if (need_padding) {
// Pad input to 2x to avoid OOB access.
slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
padding_config);
}
std::vector<xla::XlaOp> start_indices;
std::vector<xla::XlaOp> slice_sizes_dynamic;
xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie();
for (int64 i = 0; i < input_shape.dims(); ++i) {
bool begin_mask = (1 << i) & shape_spec.begin_dense_mask;
bool end_mask = (1 << i) & shape_spec.end_dense_mask;
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
xla::XlaOp begin_index, end_index;
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i);
xla::XlaOp dim_size;
if (xla_input_is_dynamic) {
dim_size = xla::GetDimensionSize(ctx->Input(0), i);
OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32,
errors::InvalidArgument("'begin shape has to be int32 when "
"indices to slice op are dynamic"));
} else {
dim_size =
xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
input_xla_shape.dimensions(i));
}
if (begin_mask) {
begin_index = zero;
} else {
begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
{sparse_index + 1}, {1});
begin_index = xla::Reshape(begin_index, {});
auto index_negative = xla::Lt(begin_index, zero);
auto wrapped_index = xla::Add(dim_size, begin_index);
// Wrap negative indices around.
begin_index = xla::Select(index_negative, wrapped_index, begin_index);
}
start_indices.push_back(begin_index);
if (end_mask) {
end_index = dim_size;
} else {
end_index = xla::Slice(ctx->Input("end"), {sparse_index},
{sparse_index + 1}, {1});
end_index = xla::Reshape(end_index, {});
auto index_negative = xla::Lt(end_index, zero);
auto wrapped_index = xla::Add(dim_size, end_index);
end_index = xla::Select(index_negative, wrapped_index, end_index);
}
slice_sizes_dynamic.push_back(
xla::Max(xla::Sub(end_index, begin_index), zero));
}
slice =
xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
for (int64 i = 0; i < input_shape.dims(); ++i) {
if (result_dims_are_dynamic[i]) {
slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
}
}
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape begin_shape = ctx->InputShape("begin");
@ -80,31 +220,33 @@ class StridedSliceOp : public XlaOpKernel {
}
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
TensorShape final_shape;
PartialTensorShape dummy_processing_shape, partial_final_shape;
TensorShape processing_shape, final_shape;
PartialTensorShape partial_processing_shape, partial_final_shape;
bool dummy = false;
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
absl::InlinedVector<int64, 4> output_to_processing_mapping;
StridedSliceShapeSpec shape_spec;
OP_REQUIRES_OK(
ctx,
ValidateStridedSliceOp(
begin_is_constant ? &begin_tensor : nullptr,
end_is_constant ? &end_tensor : nullptr, strides_tensor,
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides,
&output_to_sparse_mapping, &output_to_processing_mapping));
OP_REQUIRES(
ctx, partial_final_shape.AsTensorShape(&final_shape),
InvalidArgument("XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
shrink_axis_mask_, &partial_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec));
xla::XlaOp slice = ctx->Input(0);
std::vector<bool> begins_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
std::vector<bool> ends_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
if (begin_is_constant && end_is_constant) {
OP_REQUIRES(
ctx, partial_final_shape.AsTensorShape(&final_shape),
InvalidArgument("XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
absl::InlinedVector<int64, 4> dimensions_to_reverse;
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
for (int i = 0; i < begin.size(); ++i) {
@ -129,12 +271,7 @@ class StridedSliceOp : public XlaOpKernel {
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
OP_REQUIRES_OK(ctx, operand_shape_or.status());
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
std::vector<bool> begins_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
std::vector<bool> ends_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
bool begins_are_static = absl::c_all_of(
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
OP_REQUIRES(ctx, begins_are_static,
@ -150,13 +287,13 @@ class StridedSliceOp : public XlaOpKernel {
}
for (int64 i = 0; i < final_shape.dims(); ++i) {
int64 input_index = output_to_processing_mapping[i];
int64 input_index = shape_spec.output_to_processing_mapping[i];
if (input_index == -1) {
continue;
}
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
int64 sparse_index = output_to_sparse_mapping[i];
int64 sparse_index = shape_spec.output_to_sparse_mapping[i];
bool end_is_dynamic =
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
bool backward_slice = sparse_index == -1
@ -208,62 +345,9 @@ class StridedSliceOp : public XlaOpKernel {
ctx->SetOutput(0, slice);
return;
} else {
// When output shape is fully defined, it must be a size one slice:
//
// 1. The number of output elements has to be equal to the number of input
// elements that are sliced.
// 2. The stride of the slice dimensions must be exact one.
int64 output_elements = final_shape.num_elements();
int64 input_elements_sliced = 1;
int64 slicing_dim_size = begin_shape.dim_size(0);
// We only support slicing major dimensions, so minor dimensions after
// slicing dimension are all sliced with their full sizes.
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
input_elements_sliced *= input_shape.dim_size(d);
}
OP_REQUIRES(ctx, output_elements == input_elements_sliced,
errors::InvalidArgument(
"Dynamic indices of strided_slice_op have to be leading "
"dimensions in the indices list."));
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
OP_REQUIRES(
ctx, strides[i] == 1,
errors::InvalidArgument(
"Strides have to be one when inputs are not constant."));
}
// When inputs are not compile time constants, shape inference can only
// inference size 1 slice.
std::vector<int64> slice_sizes(slicing_dim_size, 1);
std::vector<xla::XlaOp> start_indices;
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
for (int64 d = 0; d < slicing_dim_size; ++d) {
auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
// Convert index to scalar.
index = xla::Reshape(index, {});
// Negative index: wrap it around with dimension size.
auto index_negative = xla::Lt(index, zero);
auto dim_size = xla::ConvertElementType(
xla::ConstantR0<int32>(ctx->builder(), input_shape.dim_size(d)),
ctx->InputXlaType("begin"));
auto wrapped_index = xla::Add(dim_size, index);
index = xla::Select(index_negative, wrapped_index, index);
start_indices.push_back(index);
}
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
// For non-slice dims, naturally we get the full slice starting from 0.
slice_sizes.push_back(input_shape.dim_size(d));
start_indices.push_back(zero);
}
std::vector<int64> output_shape_dim_sizes;
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
EmitDynamicSlice(ctx, strides, processing_shape, final_shape,
partial_processing_shape, partial_final_shape,
shape_spec, begins_are_dynamic, ends_are_dynamic);
}
}
@ -308,10 +392,7 @@ class StridedSliceGradOp : public XlaOpKernel {
absl::InlinedVector<int64, 4> begin;
absl::InlinedVector<int64, 4> end;
absl::InlinedVector<int64, 4> strides;
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
absl::InlinedVector<int64, 4> output_to_processing_mapping;
StridedSliceShapeSpec shape_spec;
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
OP_REQUIRES_OK(
@ -319,8 +400,7 @@ class StridedSliceGradOp : public XlaOpKernel {
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
&begin, &end, &strides, &output_to_sparse_mapping,
&output_to_processing_mapping));
&begin, &end, &strides, &shape_spec));
for (int64 i = 0; i < processing_shape.dims(); ++i) {
OP_REQUIRES(
ctx, strides[i] == 1,
@ -341,20 +421,20 @@ class StridedSliceGradOp : public XlaOpKernel {
// Use grad shape, which is known, to update unknown processing shape.
// Grad shape is the output of the ValidateStridedSliceOp function in
// forward pass, thus we use output_to_processing_mapping.
if (output_to_processing_mapping[i] != -1) {
processing_shape.set_dim(output_to_processing_mapping[i],
if (shape_spec.output_to_processing_mapping[i] != -1) {
processing_shape.set_dim(shape_spec.output_to_processing_mapping[i],
grad_shape.dimensions(i));
}
// Similarly, use output_to_sparse_mapping to find out corresponding
// begin dim of the output, as indices for dynamic update slice.
int64 begin_dim = output_to_sparse_mapping[i];
int64 begin_dim = shape_spec.output_to_sparse_mapping[i];
if (begin_dim != -1) {
auto begin_index =
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
auto begin_index_scalar = xla::Reshape(
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
begins[output_to_sparse_mapping[i]] = begin_index_scalar;
begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar;
}
}
VLOG(1) << "processing_shape" << processing_shape.DebugString();

View File

@ -3546,13 +3546,25 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
}
case HloOpcode::kTuple:
case HloOpcode::kTranspose:
case HloOpcode::kGetTupleElement:
case HloOpcode::kSlice:
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad:
break;
case HloOpcode::kGetTupleElement: {
// Rewrite parameter followed by gte into constants to avoid
// rematerializing the tuple parameter (could be very large).
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
SetInstructionAsConstant(new_instr, id, new_shape, true);
}
break;
}
case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0);
int64 operand_handle = instr_proto->operand_ids(0);
@ -3646,6 +3658,18 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
should_visit_operand = false;
}
if (opcode == HloOpcode::kGetTupleElement) {
int64 operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
// Don't rematerialize the whole parameter if it's followed by a GTE.
should_visit_operand = false;
}
}
if (opcode == HloOpcode::kSelect) {
TF_ASSIGN_OR_RETURN(bool constant_predicate,
operand_is_constant(instr_proto, 0));

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/core/util/strided_slice_op.h"
#include <array>
#include <iterator>
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/lib/core/status.h"
@ -64,6 +66,7 @@ struct StridedSliceDenseSpec {
// index. A -1 in this vector means there the index is not from the sparse
// input.
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
// The dense indexed shrink mask is which processing dimensions
// should be shrunk. For example, if foo.shape = (10,10,10,10)
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@ -81,6 +84,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
dense->begin.resize(dense->dims);
dense->end.resize(dense->dims);
dense->strides.resize(dense->dims);
dense->input_shape_gather_indices_sparse.resize(dense->dims);
// What indices to get the final shape from.
dense->begin_mask = 0;
dense->end_mask = 0;
@ -114,6 +118,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
dense->end_mask |= (1 << full_index);
dense->final_shape_gather_indices.push_back(full_index);
dense->final_shape_gather_indices_sparse.push_back(-1);
dense->input_shape_gather_indices_sparse[full_index] = i;
}
} else if ((1 << i) & sparse.new_axis_mask) {
dense->final_shape_gather_indices.push_back(kNewAxis);
@ -153,6 +158,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
// from.
dense->final_shape_gather_indices_sparse.push_back(i);
}
dense->input_shape_gather_indices_sparse[full_index] = i;
full_index++;
}
}
@ -168,9 +174,7 @@ Status ValidateStridedSliceOp(
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
const bool begin_is_wrong =
begin_tensor != nullptr &&
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@ -375,13 +379,18 @@ Status ValidateStridedSliceOp(
// slices like foo[3,...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
final_shape->Clear();
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->clear();
if (shape_spec != nullptr) {
shape_spec->output_to_sparse_mapping.clear();
shape_spec->output_to_processing_mapping.clear();
shape_spec->processing_to_sparse_mapping.assign(
dense_spec.input_shape_gather_indices_sparse.begin(),
dense_spec.input_shape_gather_indices_sparse.end());
shape_spec->begin_dense_mask = dense_spec.begin_mask;
shape_spec->end_dense_mask = dense_spec.end_mask;
shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->clear();
}
for (int64 dense_dim = 0;
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
@ -389,22 +398,19 @@ Status ValidateStridedSliceOp(
dense_spec.final_shape_gather_indices_sparse[dense_dim];
if (gather_index >= 0) {
final_shape->AddDim(processing_shape->dim_size(gather_index));
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->push_back(sparse_index);
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->push_back(gather_index);
if (shape_spec != nullptr) {
shape_spec->output_to_sparse_mapping.push_back(sparse_index);
shape_spec->output_to_processing_mapping.push_back(gather_index);
}
} else if (gather_index == kNewAxis) {
final_shape->AddDim(1);
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->push_back(-1);
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->push_back(-1);
if (shape_spec != nullptr) {
shape_spec->output_to_sparse_mapping.push_back(-1);
shape_spec->output_to_processing_mapping.push_back(-1);
}
}
}
return Status::OK();
}
@ -416,16 +422,14 @@ Status ValidateStridedSliceOp(
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
StridedSliceShapeSpec* shape_spec) {
// Validate with PartialTensorShape output
PartialTensorShape partial_processing_shape, partial_final_shape;
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
&partial_processing_shape, &partial_final_shape, is_identity,
is_simple_slice, slice_dim0, begin, end, strides,
output_to_sparse_mapping, output_to_processing_mapping));
is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
// Verify that the output shapes are fully known
if (!partial_processing_shape.AsTensorShape(processing_shape) ||

View File

@ -23,6 +23,26 @@ limitations under the License.
namespace tensorflow {
struct StridedSliceShapeSpec {
// Begin mask canonlized in dense form.
int32 begin_dense_mask;
// End mask canonlized in dense form.
int32 end_dense_mask;
// Shrink axis mask canonlized in dense form.
int32 shrink_axis_dense_mask;
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
// index in the begin_tensor. If
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
// in sparse_mapping.
gtl::InlinedVector<int64, 4> output_to_sparse_mapping;
// output_to_processing_mapping is similar to output_to_sparse_mapping, but
// for processing shape.
gtl::InlinedVector<int64, 4> output_to_processing_mapping;
// processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
// dim index in the begin_tensor.
gtl::InlinedVector<int64, 4> processing_to_sparse_mapping;
};
// Runs validation on the strided slice op parameters.
//
// Is a separate translation unit from the kernel so that:
@ -41,16 +61,6 @@ namespace tensorflow {
// (-1). Any validation that can be done without complete information is
// performed.
//
// This function changes the orders of dimensions, output_to_sparse_mapping and
// output_to_processing_mapping are used to track the order change.
//
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
// index in the begin_tensor. If
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
// sparse_mapping.
//
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
// processing_shape.
Status ValidateStridedSliceOp(
const Tensor* begin_tensor, const Tensor* end_tensor,
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@ -60,8 +70,7 @@ Status ValidateStridedSliceOp(
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
StridedSliceShapeSpec* shape_spec = nullptr);
// Same as above, but the outputs are TensorShape, not PartialTensorShape
Status ValidateStridedSliceOp(
@ -72,8 +81,7 @@ Status ValidateStridedSliceOp(
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
StridedSliceShapeSpec* shape_spec = nullptr);
} // namespace tensorflow