From 41b3e84fa920ba51e7bf321471462d5e6076776f Mon Sep 17 00:00:00 2001 From: Yunxing Dai <yunxing@google.com> Date: Wed, 10 Feb 2021 15:23:45 -0800 Subject: [PATCH] Support more cases in strided slice ops. - When ranges are dynamic, we can lower strided slice ops into xla dynamic slices. - Optimize xla buidler dynamism inference to not materialize parameter tuple. - Export some helper data structures from strided_slice_op. PiperOrigin-RevId: 356839276 Change-Id: I3b5d9fcb1289496029afbf97e49df49df26ce699 --- .../tf2xla/kernels/strided_slice_op.cc | 258 ++++++++++++------ tensorflow/compiler/xla/client/xla_builder.cc | 26 +- tensorflow/core/util/strided_slice_op.cc | 48 ++-- tensorflow/core/util/strided_slice_op.h | 36 ++- 4 files changed, 242 insertions(+), 126 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 8bb22c6e7b0..d5e7577862b 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" @@ -52,6 +53,145 @@ class StridedSliceOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); } + void EmitDynamicSlice(XlaOpKernelContext* ctx, + const absl::InlinedVector<int64, 4>& strides, + TensorShape processing_shape, TensorShape final_shape, + PartialTensorShape partial_processing_shape, + PartialTensorShape partial_final_shape, + const StridedSliceShapeSpec& shape_spec, + const std::vector<bool>& begins_are_dynamic, + const std::vector<bool>& ends_are_dynamic) { + const TensorShape input_shape = ctx->InputShape(0); + xla::XlaOp slice = ctx->Input(0); + for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) { + OP_REQUIRES(ctx, strides[i] == 1, + errors::InvalidArgument( + "Strides have to be one when inputs are not constant.")); + } + // Infer static output shape, reconsile unknown dimension with input dim + // size. + for (int64 i = 0; i < partial_final_shape.dims(); ++i) { + if (partial_final_shape.dim_size(i) == -1) { + // Use input shape shape_spec. + partial_final_shape.set_dim( + i, + input_shape.dim_size(shape_spec.output_to_processing_mapping[i])); + } + } + + OP_REQUIRES( + ctx, partial_final_shape.AsTensorShape(&final_shape), + InvalidArgument("XLA can't deduce compile time constant output " + "shape for strided slice: ", + partial_final_shape.DebugString(), + ", output shape must be a compile-time constant")); + for (int64 i = 0; i < partial_processing_shape.dims(); ++i) { + if (partial_processing_shape.dim_size(i) == -1) { + // Use input shape shape_spec. + partial_processing_shape.set_dim(i, input_shape.dim_size(i)); + } + } + OP_REQUIRES( + ctx, partial_processing_shape.AsTensorShape(&processing_shape), + InvalidArgument("XLA can't deduce compile time constant processing " + "shape for strided slice: ", + partial_processing_shape.DebugString(), + ", output shape must be a compile-time constant")); + // When inputs are not compile time constants, shape inference can only + // inference size 1 slice. + std::vector<int64> slice_sizes(input_shape.dims(), 1); + // If there is dynamic begin/end (and if the dimension is not shrunk), we + // need to use dynamic shape infrastructure -- we slice the output with + // full size, then call SetDimensionSize on the output. However, if we + // slice with the full size at a non-zero dimension we may get OOB access. + // To avoid that, we first pad the input to 2x before calling slice. + xla::PaddingConfig padding_config; + bool need_padding = false; + std::vector<bool> result_dims_are_dynamic; + for (int64 i = 0; i < input_shape.dims(); ++i) { + int64 sparse_index = shape_spec.processing_to_sparse_mapping[i]; + bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask; + auto* dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(0); + + dims->set_interior_padding(0); + if ((begins_are_dynamic[sparse_index] || + ends_are_dynamic[sparse_index]) && + !shrink_axis_set) { + // Need to slice this dimension so pad first. + dims->set_edge_padding_high(input_shape.dim_size(i)); + need_padding = true; + result_dims_are_dynamic.push_back(true); + } else { + dims->set_edge_padding_high(0); + result_dims_are_dynamic.push_back(false); + } + } + + if (need_padding) { + // Pad input to 2x to avoid OOB access. + slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)), + padding_config); + } + std::vector<xla::XlaOp> start_indices; + std::vector<xla::XlaOp> slice_sizes_dynamic; + xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie(); + for (int64 i = 0; i < input_shape.dims(); ++i) { + bool begin_mask = (1 << i) & shape_spec.begin_dense_mask; + bool end_mask = (1 << i) & shape_spec.end_dense_mask; + auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); + xla::XlaOp begin_index, end_index; + int64 sparse_index = shape_spec.processing_to_sparse_mapping[i]; + bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i); + xla::XlaOp dim_size; + if (xla_input_is_dynamic) { + dim_size = xla::GetDimensionSize(ctx->Input(0), i); + OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32, + errors::InvalidArgument("'begin shape has to be int32 when " + "indices to slice op are dynamic")); + } else { + dim_size = + xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"), + input_xla_shape.dimensions(i)); + } + if (begin_mask) { + begin_index = zero; + } else { + begin_index = xla::Slice(ctx->Input("begin"), {sparse_index}, + {sparse_index + 1}, {1}); + begin_index = xla::Reshape(begin_index, {}); + auto index_negative = xla::Lt(begin_index, zero); + auto wrapped_index = xla::Add(dim_size, begin_index); + // Wrap negative indices around. + begin_index = xla::Select(index_negative, wrapped_index, begin_index); + } + start_indices.push_back(begin_index); + if (end_mask) { + end_index = dim_size; + } else { + end_index = xla::Slice(ctx->Input("end"), {sparse_index}, + {sparse_index + 1}, {1}); + end_index = xla::Reshape(end_index, {}); + auto index_negative = xla::Lt(end_index, zero); + auto wrapped_index = xla::Add(dim_size, end_index); + end_index = xla::Select(index_negative, wrapped_index, end_index); + } + slice_sizes_dynamic.push_back( + xla::Max(xla::Sub(end_index, begin_index), zero)); + } + + slice = + xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes()); + + for (int64 i = 0; i < input_shape.dims(); ++i) { + if (result_dims_are_dynamic[i]) { + slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i); + } + } + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); + } + void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); const TensorShape begin_shape = ctx->InputShape("begin"); @@ -80,31 +220,33 @@ class StridedSliceOp : public XlaOpKernel { } OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - - TensorShape final_shape; - PartialTensorShape dummy_processing_shape, partial_final_shape; + TensorShape processing_shape, final_shape; + PartialTensorShape partial_processing_shape, partial_final_shape; bool dummy = false; - absl::InlinedVector<int64, 4> output_to_sparse_mapping; - absl::InlinedVector<int64, 4> output_to_processing_mapping; + StridedSliceShapeSpec shape_spec; OP_REQUIRES_OK( ctx, ValidateStridedSliceOp( begin_is_constant ? &begin_tensor : nullptr, end_is_constant ? &end_tensor : nullptr, strides_tensor, input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides, - &output_to_sparse_mapping, &output_to_processing_mapping)); - - OP_REQUIRES( - ctx, partial_final_shape.AsTensorShape(&final_shape), - InvalidArgument("XLA can't deduce compile time constant output " - "shape for strided slice: ", - partial_final_shape.DebugString(), - ", output shape must be a compile-time constant")); + shrink_axis_mask_, &partial_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec)); xla::XlaOp slice = ctx->Input(0); + std::vector<bool> begins_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic)); + std::vector<bool> ends_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic)); if (begin_is_constant && end_is_constant) { + OP_REQUIRES( + ctx, partial_final_shape.AsTensorShape(&final_shape), + InvalidArgument("XLA can't deduce compile time constant output " + "shape for strided slice: ", + partial_final_shape.DebugString(), + ", output shape must be a compile-time constant")); absl::InlinedVector<int64, 4> dimensions_to_reverse; absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides; for (int i = 0; i < begin.size(); ++i) { @@ -129,12 +271,7 @@ class StridedSliceOp : public XlaOpKernel { auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.ValueOrDie(); - std::vector<bool> begins_are_dynamic; - OP_REQUIRES_OK( - ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic)); - std::vector<bool> ends_are_dynamic; - OP_REQUIRES_OK( - ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic)); + bool begins_are_static = absl::c_all_of( begins_are_dynamic, [](bool dynamic) { return !dynamic; }); OP_REQUIRES(ctx, begins_are_static, @@ -150,13 +287,13 @@ class StridedSliceOp : public XlaOpKernel { } for (int64 i = 0; i < final_shape.dims(); ++i) { - int64 input_index = output_to_processing_mapping[i]; + int64 input_index = shape_spec.output_to_processing_mapping[i]; if (input_index == -1) { continue; } bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index); - int64 sparse_index = output_to_sparse_mapping[i]; + int64 sparse_index = shape_spec.output_to_sparse_mapping[i]; bool end_is_dynamic = sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; bool backward_slice = sparse_index == -1 @@ -208,62 +345,9 @@ class StridedSliceOp : public XlaOpKernel { ctx->SetOutput(0, slice); return; } else { - // When output shape is fully defined, it must be a size one slice: - // - // 1. The number of output elements has to be equal to the number of input - // elements that are sliced. - // 2. The stride of the slice dimensions must be exact one. - int64 output_elements = final_shape.num_elements(); - - int64 input_elements_sliced = 1; - int64 slicing_dim_size = begin_shape.dim_size(0); - // We only support slicing major dimensions, so minor dimensions after - // slicing dimension are all sliced with their full sizes. - for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) { - input_elements_sliced *= input_shape.dim_size(d); - } - - OP_REQUIRES(ctx, output_elements == input_elements_sliced, - errors::InvalidArgument( - "Dynamic indices of strided_slice_op have to be leading " - "dimensions in the indices list.")); - - for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) { - OP_REQUIRES( - ctx, strides[i] == 1, - errors::InvalidArgument( - "Strides have to be one when inputs are not constant.")); - } - - // When inputs are not compile time constants, shape inference can only - // inference size 1 slice. - std::vector<int64> slice_sizes(slicing_dim_size, 1); - std::vector<xla::XlaOp> start_indices; - auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); - for (int64 d = 0; d < slicing_dim_size; ++d) { - auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1}); - // Convert index to scalar. - index = xla::Reshape(index, {}); - // Negative index: wrap it around with dimension size. - auto index_negative = xla::Lt(index, zero); - auto dim_size = xla::ConvertElementType( - xla::ConstantR0<int32>(ctx->builder(), input_shape.dim_size(d)), - ctx->InputXlaType("begin")); - auto wrapped_index = xla::Add(dim_size, index); - index = xla::Select(index_negative, wrapped_index, index); - start_indices.push_back(index); - } - - for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) { - // For non-slice dims, naturally we get the full slice starting from 0. - slice_sizes.push_back(input_shape.dim_size(d)); - start_indices.push_back(zero); - } - - std::vector<int64> output_shape_dim_sizes; - slice = xla::DynamicSlice(slice, start_indices, slice_sizes); - slice = xla::Reshape(slice, final_shape.dim_sizes()); - ctx->SetOutput(0, slice); + EmitDynamicSlice(ctx, strides, processing_shape, final_shape, + partial_processing_shape, partial_final_shape, + shape_spec, begins_are_dynamic, ends_are_dynamic); } } @@ -308,10 +392,7 @@ class StridedSliceGradOp : public XlaOpKernel { absl::InlinedVector<int64, 4> begin; absl::InlinedVector<int64, 4> end; absl::InlinedVector<int64, 4> strides; - - absl::InlinedVector<int64, 4> output_to_sparse_mapping; - absl::InlinedVector<int64, 4> output_to_processing_mapping; - + StridedSliceShapeSpec shape_spec; OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); OP_REQUIRES_OK( @@ -319,8 +400,7 @@ class StridedSliceGradOp : public XlaOpKernel { nullptr, nullptr, strides_tensor, input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, &dummy, - &begin, &end, &strides, &output_to_sparse_mapping, - &output_to_processing_mapping)); + &begin, &end, &strides, &shape_spec)); for (int64 i = 0; i < processing_shape.dims(); ++i) { OP_REQUIRES( ctx, strides[i] == 1, @@ -341,20 +421,20 @@ class StridedSliceGradOp : public XlaOpKernel { // Use grad shape, which is known, to update unknown processing shape. // Grad shape is the output of the ValidateStridedSliceOp function in // forward pass, thus we use output_to_processing_mapping. - if (output_to_processing_mapping[i] != -1) { - processing_shape.set_dim(output_to_processing_mapping[i], + if (shape_spec.output_to_processing_mapping[i] != -1) { + processing_shape.set_dim(shape_spec.output_to_processing_mapping[i], grad_shape.dimensions(i)); } // Similarly, use output_to_sparse_mapping to find out corresponding // begin dim of the output, as indices for dynamic update slice. - int64 begin_dim = output_to_sparse_mapping[i]; + int64 begin_dim = shape_spec.output_to_sparse_mapping[i]; if (begin_dim != -1) { auto begin_index = xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1}); auto begin_index_scalar = xla::Reshape( xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index); - begins[output_to_sparse_mapping[i]] = begin_index_scalar; + begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar; } } VLOG(1) << "processing_shape" << processing_shape.DebugString(); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index ef31735563d..35cd1c25b7d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3546,13 +3546,25 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { } case HloOpcode::kTuple: case HloOpcode::kTranspose: - case HloOpcode::kGetTupleElement: case HloOpcode::kSlice: case HloOpcode::kBroadcast: case HloOpcode::kConcatenate: case HloOpcode::kReshape: case HloOpcode::kPad: break; + case HloOpcode::kGetTupleElement: { + // Rewrite parameter followed by gte into constants to avoid + // rematerializing the tuple parameter (could be very large). + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode, + StringToHloOpcode(operand_proto->opcode())); + if (operand_opcode == HloOpcode::kParameter) { + SetInstructionAsConstant(new_instr, id, new_shape, true); + } + break; + } case HloOpcode::kGetDimensionSize: { int64 dimension = instr_proto->dimensions(0); int64 operand_handle = instr_proto->operand_ids(0); @@ -3646,6 +3658,18 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { should_visit_operand = false; } + if (opcode == HloOpcode::kGetTupleElement) { + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode, + StringToHloOpcode(operand_proto->opcode())); + if (operand_opcode == HloOpcode::kParameter) { + // Don't rematerialize the whole parameter if it's followed by a GTE. + should_visit_operand = false; + } + } + if (opcode == HloOpcode::kSelect) { TF_ASSIGN_OR_RETURN(bool constant_predicate, operand_is_constant(instr_proto, 0)); diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 1cf9a8cd013..126b684b8c7 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" #include <array> +#include <iterator> + #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/lib/core/status.h" @@ -64,6 +66,7 @@ struct StridedSliceDenseSpec { // index. A -1 in this vector means there the index is not from the sparse // input. gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse; + gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse; // The dense indexed shrink mask is which processing dimensions // should be shrunk. For example, if foo.shape = (10,10,10,10) // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and @@ -81,6 +84,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec( dense->begin.resize(dense->dims); dense->end.resize(dense->dims); dense->strides.resize(dense->dims); + dense->input_shape_gather_indices_sparse.resize(dense->dims); // What indices to get the final shape from. dense->begin_mask = 0; dense->end_mask = 0; @@ -114,6 +118,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec( dense->end_mask |= (1 << full_index); dense->final_shape_gather_indices.push_back(full_index); dense->final_shape_gather_indices_sparse.push_back(-1); + dense->input_shape_gather_indices_sparse[full_index] = i; } } else if ((1 << i) & sparse.new_axis_mask) { dense->final_shape_gather_indices.push_back(kNewAxis); @@ -153,6 +158,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec( // from. dense->final_shape_gather_indices_sparse.push_back(i); } + dense->input_shape_gather_indices_sparse[full_index] = i; full_index++; } } @@ -168,9 +174,7 @@ Status ValidateStridedSliceOp( PartialTensorShape* processing_shape, PartialTensorShape* final_shape, bool* is_identity, bool* is_simple_slice, bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, - gtl::InlinedVector<int64, 4>* strides, - gtl::InlinedVector<int64, 4>* output_to_sparse_mapping, - gtl::InlinedVector<int64, 4>* output_to_processing_mapping) { + gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) { const bool begin_is_wrong = begin_tensor != nullptr && !(TensorShapeUtils::IsVector(begin_tensor->shape()) && @@ -375,13 +379,18 @@ Status ValidateStridedSliceOp( // slices like foo[3,...] will reduce dimension by 1. // This cannot be done earlier, because it depends on Step 3. final_shape->Clear(); - if (output_to_sparse_mapping != nullptr) { - output_to_sparse_mapping->clear(); + if (shape_spec != nullptr) { + shape_spec->output_to_sparse_mapping.clear(); + shape_spec->output_to_processing_mapping.clear(); + shape_spec->processing_to_sparse_mapping.assign( + dense_spec.input_shape_gather_indices_sparse.begin(), + dense_spec.input_shape_gather_indices_sparse.end()); + + shape_spec->begin_dense_mask = dense_spec.begin_mask; + shape_spec->end_dense_mask = dense_spec.end_mask; + shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask; } - if (output_to_processing_mapping != nullptr) { - output_to_processing_mapping->clear(); - } for (int64 dense_dim = 0; dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) { int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim]; @@ -389,22 +398,19 @@ Status ValidateStridedSliceOp( dense_spec.final_shape_gather_indices_sparse[dense_dim]; if (gather_index >= 0) { final_shape->AddDim(processing_shape->dim_size(gather_index)); - if (output_to_sparse_mapping != nullptr) { - output_to_sparse_mapping->push_back(sparse_index); - } - if (output_to_processing_mapping != nullptr) { - output_to_processing_mapping->push_back(gather_index); + if (shape_spec != nullptr) { + shape_spec->output_to_sparse_mapping.push_back(sparse_index); + shape_spec->output_to_processing_mapping.push_back(gather_index); } } else if (gather_index == kNewAxis) { final_shape->AddDim(1); - if (output_to_sparse_mapping != nullptr) { - output_to_sparse_mapping->push_back(-1); - } - if (output_to_processing_mapping != nullptr) { - output_to_processing_mapping->push_back(-1); + if (shape_spec != nullptr) { + shape_spec->output_to_sparse_mapping.push_back(-1); + shape_spec->output_to_processing_mapping.push_back(-1); } } } + return Status::OK(); } @@ -416,16 +422,14 @@ Status ValidateStridedSliceOp( TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides, - gtl::InlinedVector<int64, 4>* output_to_sparse_mapping, - gtl::InlinedVector<int64, 4>* output_to_processing_mapping) { + StridedSliceShapeSpec* shape_spec) { // Validate with PartialTensorShape output PartialTensorShape partial_processing_shape, partial_final_shape; TF_RETURN_IF_ERROR(ValidateStridedSliceOp( begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, &partial_processing_shape, &partial_final_shape, is_identity, - is_simple_slice, slice_dim0, begin, end, strides, - output_to_sparse_mapping, output_to_processing_mapping)); + is_simple_slice, slice_dim0, begin, end, strides, shape_spec)); // Verify that the output shapes are fully known if (!partial_processing_shape.AsTensorShape(processing_shape) || diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h index 9e49477a9c3..dfb411a6065 100644 --- a/tensorflow/core/util/strided_slice_op.h +++ b/tensorflow/core/util/strided_slice_op.h @@ -23,6 +23,26 @@ limitations under the License. namespace tensorflow { +struct StridedSliceShapeSpec { + // Begin mask canonlized in dense form. + int32 begin_dense_mask; + // End mask canonlized in dense form. + int32 end_dense_mask; + // Shrink axis mask canonlized in dense form. + int32 shrink_axis_dense_mask; + // output_to_sparse_mapping[i] represents output[i]'s the corresponding dim + // index in the begin_tensor. If + // output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up + // in sparse_mapping. + gtl::InlinedVector<int64, 4> output_to_sparse_mapping; + // output_to_processing_mapping is similar to output_to_sparse_mapping, but + // for processing shape. + gtl::InlinedVector<int64, 4> output_to_processing_mapping; + // processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding + // dim index in the begin_tensor. + gtl::InlinedVector<int64, 4> processing_to_sparse_mapping; +}; + // Runs validation on the strided slice op parameters. // // Is a separate translation unit from the kernel so that: @@ -41,16 +61,6 @@ namespace tensorflow { // (-1). Any validation that can be done without complete information is // performed. // -// This function changes the orders of dimensions, output_to_sparse_mapping and -// output_to_processing_mapping are used to track the order change. -// -// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim -// index in the begin_tensor. If -// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in -// sparse_mapping. -// -// output_to_processing_mapping is similar to output_to_sparse_mapping, but for -// processing_shape. Status ValidateStridedSliceOp( const Tensor* begin_tensor, const Tensor* end_tensor, const Tensor& strides_tensor, const PartialTensorShape& input_shape, @@ -60,8 +70,7 @@ Status ValidateStridedSliceOp( bool* is_identity, bool* is_simple_slice, bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides, - gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr, - gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr); + StridedSliceShapeSpec* shape_spec = nullptr); // Same as above, but the outputs are TensorShape, not PartialTensorShape Status ValidateStridedSliceOp( @@ -72,8 +81,7 @@ Status ValidateStridedSliceOp( TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides, - gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr, - gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr); + StridedSliceShapeSpec* shape_spec = nullptr); } // namespace tensorflow