Support more cases in strided slice ops.
- When ranges are dynamic, we can lower strided slice ops into xla dynamic slices. - Optimize xla buidler dynamism inference to not materialize parameter tuple. - Export some helper data structures from strided_slice_op. PiperOrigin-RevId: 356839276 Change-Id: I3b5d9fcb1289496029afbf97e49df49df26ce699
This commit is contained in:
parent
dcaa3c9ed9
commit
41b3e84fa9
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -52,6 +53,145 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||
}
|
||||
|
||||
void EmitDynamicSlice(XlaOpKernelContext* ctx,
|
||||
const absl::InlinedVector<int64, 4>& strides,
|
||||
TensorShape processing_shape, TensorShape final_shape,
|
||||
PartialTensorShape partial_processing_shape,
|
||||
PartialTensorShape partial_final_shape,
|
||||
const StridedSliceShapeSpec& shape_spec,
|
||||
const std::vector<bool>& begins_are_dynamic,
|
||||
const std::vector<bool>& ends_are_dynamic) {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
xla::XlaOp slice = ctx->Input(0);
|
||||
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
||||
OP_REQUIRES(ctx, strides[i] == 1,
|
||||
errors::InvalidArgument(
|
||||
"Strides have to be one when inputs are not constant."));
|
||||
}
|
||||
// Infer static output shape, reconsile unknown dimension with input dim
|
||||
// size.
|
||||
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
||||
if (partial_final_shape.dim_size(i) == -1) {
|
||||
// Use input shape shape_spec.
|
||||
partial_final_shape.set_dim(
|
||||
i,
|
||||
input_shape.dim_size(shape_spec.output_to_processing_mapping[i]));
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
InvalidArgument("XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
for (int64 i = 0; i < partial_processing_shape.dims(); ++i) {
|
||||
if (partial_processing_shape.dim_size(i) == -1) {
|
||||
// Use input shape shape_spec.
|
||||
partial_processing_shape.set_dim(i, input_shape.dim_size(i));
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(
|
||||
ctx, partial_processing_shape.AsTensorShape(&processing_shape),
|
||||
InvalidArgument("XLA can't deduce compile time constant processing "
|
||||
"shape for strided slice: ",
|
||||
partial_processing_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
// When inputs are not compile time constants, shape inference can only
|
||||
// inference size 1 slice.
|
||||
std::vector<int64> slice_sizes(input_shape.dims(), 1);
|
||||
// If there is dynamic begin/end (and if the dimension is not shrunk), we
|
||||
// need to use dynamic shape infrastructure -- we slice the output with
|
||||
// full size, then call SetDimensionSize on the output. However, if we
|
||||
// slice with the full size at a non-zero dimension we may get OOB access.
|
||||
// To avoid that, we first pad the input to 2x before calling slice.
|
||||
xla::PaddingConfig padding_config;
|
||||
bool need_padding = false;
|
||||
std::vector<bool> result_dims_are_dynamic;
|
||||
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
|
||||
bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask;
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(0);
|
||||
|
||||
dims->set_interior_padding(0);
|
||||
if ((begins_are_dynamic[sparse_index] ||
|
||||
ends_are_dynamic[sparse_index]) &&
|
||||
!shrink_axis_set) {
|
||||
// Need to slice this dimension so pad first.
|
||||
dims->set_edge_padding_high(input_shape.dim_size(i));
|
||||
need_padding = true;
|
||||
result_dims_are_dynamic.push_back(true);
|
||||
} else {
|
||||
dims->set_edge_padding_high(0);
|
||||
result_dims_are_dynamic.push_back(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (need_padding) {
|
||||
// Pad input to 2x to avoid OOB access.
|
||||
slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
|
||||
padding_config);
|
||||
}
|
||||
std::vector<xla::XlaOp> start_indices;
|
||||
std::vector<xla::XlaOp> slice_sizes_dynamic;
|
||||
xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie();
|
||||
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||
bool begin_mask = (1 << i) & shape_spec.begin_dense_mask;
|
||||
bool end_mask = (1 << i) & shape_spec.end_dense_mask;
|
||||
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
|
||||
xla::XlaOp begin_index, end_index;
|
||||
int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
|
||||
bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i);
|
||||
xla::XlaOp dim_size;
|
||||
if (xla_input_is_dynamic) {
|
||||
dim_size = xla::GetDimensionSize(ctx->Input(0), i);
|
||||
OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32,
|
||||
errors::InvalidArgument("'begin shape has to be int32 when "
|
||||
"indices to slice op are dynamic"));
|
||||
} else {
|
||||
dim_size =
|
||||
xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
|
||||
input_xla_shape.dimensions(i));
|
||||
}
|
||||
if (begin_mask) {
|
||||
begin_index = zero;
|
||||
} else {
|
||||
begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
|
||||
{sparse_index + 1}, {1});
|
||||
begin_index = xla::Reshape(begin_index, {});
|
||||
auto index_negative = xla::Lt(begin_index, zero);
|
||||
auto wrapped_index = xla::Add(dim_size, begin_index);
|
||||
// Wrap negative indices around.
|
||||
begin_index = xla::Select(index_negative, wrapped_index, begin_index);
|
||||
}
|
||||
start_indices.push_back(begin_index);
|
||||
if (end_mask) {
|
||||
end_index = dim_size;
|
||||
} else {
|
||||
end_index = xla::Slice(ctx->Input("end"), {sparse_index},
|
||||
{sparse_index + 1}, {1});
|
||||
end_index = xla::Reshape(end_index, {});
|
||||
auto index_negative = xla::Lt(end_index, zero);
|
||||
auto wrapped_index = xla::Add(dim_size, end_index);
|
||||
end_index = xla::Select(index_negative, wrapped_index, end_index);
|
||||
}
|
||||
slice_sizes_dynamic.push_back(
|
||||
xla::Max(xla::Sub(end_index, begin_index), zero));
|
||||
}
|
||||
|
||||
slice =
|
||||
xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
|
||||
|
||||
for (int64 i = 0; i < input_shape.dims(); ++i) {
|
||||
if (result_dims_are_dynamic[i]) {
|
||||
slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
|
||||
}
|
||||
}
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape begin_shape = ctx->InputShape("begin");
|
||||
@ -80,31 +220,33 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
|
||||
TensorShape final_shape;
|
||||
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
||||
TensorShape processing_shape, final_shape;
|
||||
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||
bool dummy = false;
|
||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||
StridedSliceShapeSpec shape_spec;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ValidateStridedSliceOp(
|
||||
begin_is_constant ? &begin_tensor : nullptr,
|
||||
end_is_constant ? &end_tensor : nullptr, strides_tensor,
|
||||
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides,
|
||||
&output_to_sparse_mapping, &output_to_processing_mapping));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
InvalidArgument("XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
shrink_axis_mask_, &partial_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec));
|
||||
|
||||
xla::XlaOp slice = ctx->Input(0);
|
||||
std::vector<bool> begins_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
||||
std::vector<bool> ends_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
||||
if (begin_is_constant && end_is_constant) {
|
||||
OP_REQUIRES(
|
||||
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
InvalidArgument("XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
absl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
|
||||
for (int i = 0; i < begin.size(); ++i) {
|
||||
@ -129,12 +271,7 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||
std::vector<bool> begins_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
||||
std::vector<bool> ends_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
||||
|
||||
bool begins_are_static = absl::c_all_of(
|
||||
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||
OP_REQUIRES(ctx, begins_are_static,
|
||||
@ -150,13 +287,13 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < final_shape.dims(); ++i) {
|
||||
int64 input_index = output_to_processing_mapping[i];
|
||||
int64 input_index = shape_spec.output_to_processing_mapping[i];
|
||||
if (input_index == -1) {
|
||||
continue;
|
||||
}
|
||||
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
|
||||
|
||||
int64 sparse_index = output_to_sparse_mapping[i];
|
||||
int64 sparse_index = shape_spec.output_to_sparse_mapping[i];
|
||||
bool end_is_dynamic =
|
||||
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
|
||||
bool backward_slice = sparse_index == -1
|
||||
@ -208,62 +345,9 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
} else {
|
||||
// When output shape is fully defined, it must be a size one slice:
|
||||
//
|
||||
// 1. The number of output elements has to be equal to the number of input
|
||||
// elements that are sliced.
|
||||
// 2. The stride of the slice dimensions must be exact one.
|
||||
int64 output_elements = final_shape.num_elements();
|
||||
|
||||
int64 input_elements_sliced = 1;
|
||||
int64 slicing_dim_size = begin_shape.dim_size(0);
|
||||
// We only support slicing major dimensions, so minor dimensions after
|
||||
// slicing dimension are all sliced with their full sizes.
|
||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
||||
input_elements_sliced *= input_shape.dim_size(d);
|
||||
}
|
||||
|
||||
OP_REQUIRES(ctx, output_elements == input_elements_sliced,
|
||||
errors::InvalidArgument(
|
||||
"Dynamic indices of strided_slice_op have to be leading "
|
||||
"dimensions in the indices list."));
|
||||
|
||||
for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
errors::InvalidArgument(
|
||||
"Strides have to be one when inputs are not constant."));
|
||||
}
|
||||
|
||||
// When inputs are not compile time constants, shape inference can only
|
||||
// inference size 1 slice.
|
||||
std::vector<int64> slice_sizes(slicing_dim_size, 1);
|
||||
std::vector<xla::XlaOp> start_indices;
|
||||
auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
|
||||
for (int64 d = 0; d < slicing_dim_size; ++d) {
|
||||
auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
|
||||
// Convert index to scalar.
|
||||
index = xla::Reshape(index, {});
|
||||
// Negative index: wrap it around with dimension size.
|
||||
auto index_negative = xla::Lt(index, zero);
|
||||
auto dim_size = xla::ConvertElementType(
|
||||
xla::ConstantR0<int32>(ctx->builder(), input_shape.dim_size(d)),
|
||||
ctx->InputXlaType("begin"));
|
||||
auto wrapped_index = xla::Add(dim_size, index);
|
||||
index = xla::Select(index_negative, wrapped_index, index);
|
||||
start_indices.push_back(index);
|
||||
}
|
||||
|
||||
for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
|
||||
// For non-slice dims, naturally we get the full slice starting from 0.
|
||||
slice_sizes.push_back(input_shape.dim_size(d));
|
||||
start_indices.push_back(zero);
|
||||
}
|
||||
|
||||
std::vector<int64> output_shape_dim_sizes;
|
||||
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
EmitDynamicSlice(ctx, strides, processing_shape, final_shape,
|
||||
partial_processing_shape, partial_final_shape,
|
||||
shape_spec, begins_are_dynamic, ends_are_dynamic);
|
||||
}
|
||||
}
|
||||
|
||||
@ -308,10 +392,7 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||
|
||||
StridedSliceShapeSpec shape_spec;
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
OP_REQUIRES_OK(
|
||||
@ -319,8 +400,7 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
|
||||
&begin, &end, &strides, &output_to_sparse_mapping,
|
||||
&output_to_processing_mapping));
|
||||
&begin, &end, &strides, &shape_spec));
|
||||
for (int64 i = 0; i < processing_shape.dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
@ -341,20 +421,20 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
// Use grad shape, which is known, to update unknown processing shape.
|
||||
// Grad shape is the output of the ValidateStridedSliceOp function in
|
||||
// forward pass, thus we use output_to_processing_mapping.
|
||||
if (output_to_processing_mapping[i] != -1) {
|
||||
processing_shape.set_dim(output_to_processing_mapping[i],
|
||||
if (shape_spec.output_to_processing_mapping[i] != -1) {
|
||||
processing_shape.set_dim(shape_spec.output_to_processing_mapping[i],
|
||||
grad_shape.dimensions(i));
|
||||
}
|
||||
|
||||
// Similarly, use output_to_sparse_mapping to find out corresponding
|
||||
// begin dim of the output, as indices for dynamic update slice.
|
||||
int64 begin_dim = output_to_sparse_mapping[i];
|
||||
int64 begin_dim = shape_spec.output_to_sparse_mapping[i];
|
||||
if (begin_dim != -1) {
|
||||
auto begin_index =
|
||||
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
|
||||
auto begin_index_scalar = xla::Reshape(
|
||||
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
|
||||
begins[output_to_sparse_mapping[i]] = begin_index_scalar;
|
||||
begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "processing_shape" << processing_shape.DebugString();
|
||||
|
@ -3546,13 +3546,25 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
}
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad:
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
// Rewrite parameter followed by gte into constants to avoid
|
||||
// rematerializing the tuple parameter (could be very large).
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||
StringToHloOpcode(operand_proto->opcode()));
|
||||
if (operand_opcode == HloOpcode::kParameter) {
|
||||
SetInstructionAsConstant(new_instr, id, new_shape, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension = instr_proto->dimensions(0);
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
@ -3646,6 +3658,18 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
should_visit_operand = false;
|
||||
}
|
||||
|
||||
if (opcode == HloOpcode::kGetTupleElement) {
|
||||
int64 operand_handle = instr_proto->operand_ids(0);
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
|
||||
StringToHloOpcode(operand_proto->opcode()));
|
||||
if (operand_opcode == HloOpcode::kParameter) {
|
||||
// Don't rematerialize the whole parameter if it's followed by a GTE.
|
||||
should_visit_operand = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (opcode == HloOpcode::kSelect) {
|
||||
TF_ASSIGN_OR_RETURN(bool constant_predicate,
|
||||
operand_is_constant(instr_proto, 0));
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -64,6 +66,7 @@ struct StridedSliceDenseSpec {
|
||||
// index. A -1 in this vector means there the index is not from the sparse
|
||||
// input.
|
||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
|
||||
gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
|
||||
// The dense indexed shrink mask is which processing dimensions
|
||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||
@ -81,6 +84,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
||||
dense->begin.resize(dense->dims);
|
||||
dense->end.resize(dense->dims);
|
||||
dense->strides.resize(dense->dims);
|
||||
dense->input_shape_gather_indices_sparse.resize(dense->dims);
|
||||
// What indices to get the final shape from.
|
||||
dense->begin_mask = 0;
|
||||
dense->end_mask = 0;
|
||||
@ -114,6 +118,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
||||
dense->end_mask |= (1 << full_index);
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||
dense->input_shape_gather_indices_sparse[full_index] = i;
|
||||
}
|
||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||
@ -153,6 +158,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
||||
// from.
|
||||
dense->final_shape_gather_indices_sparse.push_back(i);
|
||||
}
|
||||
dense->input_shape_gather_indices_sparse[full_index] = i;
|
||||
full_index++;
|
||||
}
|
||||
}
|
||||
@ -168,9 +174,7 @@ Status ValidateStridedSliceOp(
|
||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||
gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||
gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
|
||||
const bool begin_is_wrong =
|
||||
begin_tensor != nullptr &&
|
||||
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
||||
@ -375,13 +379,18 @@ Status ValidateStridedSliceOp(
|
||||
// slices like foo[3,...] will reduce dimension by 1.
|
||||
// This cannot be done earlier, because it depends on Step 3.
|
||||
final_shape->Clear();
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->clear();
|
||||
if (shape_spec != nullptr) {
|
||||
shape_spec->output_to_sparse_mapping.clear();
|
||||
shape_spec->output_to_processing_mapping.clear();
|
||||
shape_spec->processing_to_sparse_mapping.assign(
|
||||
dense_spec.input_shape_gather_indices_sparse.begin(),
|
||||
dense_spec.input_shape_gather_indices_sparse.end());
|
||||
|
||||
shape_spec->begin_dense_mask = dense_spec.begin_mask;
|
||||
shape_spec->end_dense_mask = dense_spec.end_mask;
|
||||
shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
|
||||
}
|
||||
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->clear();
|
||||
}
|
||||
for (int64 dense_dim = 0;
|
||||
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
|
||||
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
|
||||
@ -389,22 +398,19 @@ Status ValidateStridedSliceOp(
|
||||
dense_spec.final_shape_gather_indices_sparse[dense_dim];
|
||||
if (gather_index >= 0) {
|
||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->push_back(sparse_index);
|
||||
}
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->push_back(gather_index);
|
||||
if (shape_spec != nullptr) {
|
||||
shape_spec->output_to_sparse_mapping.push_back(sparse_index);
|
||||
shape_spec->output_to_processing_mapping.push_back(gather_index);
|
||||
}
|
||||
} else if (gather_index == kNewAxis) {
|
||||
final_shape->AddDim(1);
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->push_back(-1);
|
||||
}
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->push_back(-1);
|
||||
if (shape_spec != nullptr) {
|
||||
shape_spec->output_to_sparse_mapping.push_back(-1);
|
||||
shape_spec->output_to_processing_mapping.push_back(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -416,16 +422,14 @@ Status ValidateStridedSliceOp(
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||
StridedSliceShapeSpec* shape_spec) {
|
||||
// Validate with PartialTensorShape output
|
||||
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
||||
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
||||
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
&partial_processing_shape, &partial_final_shape, is_identity,
|
||||
is_simple_slice, slice_dim0, begin, end, strides,
|
||||
output_to_sparse_mapping, output_to_processing_mapping));
|
||||
is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
|
||||
|
||||
// Verify that the output shapes are fully known
|
||||
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
||||
|
@ -23,6 +23,26 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct StridedSliceShapeSpec {
|
||||
// Begin mask canonlized in dense form.
|
||||
int32 begin_dense_mask;
|
||||
// End mask canonlized in dense form.
|
||||
int32 end_dense_mask;
|
||||
// Shrink axis mask canonlized in dense form.
|
||||
int32 shrink_axis_dense_mask;
|
||||
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
||||
// index in the begin_tensor. If
|
||||
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
|
||||
// in sparse_mapping.
|
||||
gtl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||
// output_to_processing_mapping is similar to output_to_sparse_mapping, but
|
||||
// for processing shape.
|
||||
gtl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||
// processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
|
||||
// dim index in the begin_tensor.
|
||||
gtl::InlinedVector<int64, 4> processing_to_sparse_mapping;
|
||||
};
|
||||
|
||||
// Runs validation on the strided slice op parameters.
|
||||
//
|
||||
// Is a separate translation unit from the kernel so that:
|
||||
@ -41,16 +61,6 @@ namespace tensorflow {
|
||||
// (-1). Any validation that can be done without complete information is
|
||||
// performed.
|
||||
//
|
||||
// This function changes the orders of dimensions, output_to_sparse_mapping and
|
||||
// output_to_processing_mapping are used to track the order change.
|
||||
//
|
||||
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
||||
// index in the begin_tensor. If
|
||||
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
|
||||
// sparse_mapping.
|
||||
//
|
||||
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
|
||||
// processing_shape.
|
||||
Status ValidateStridedSliceOp(
|
||||
const Tensor* begin_tensor, const Tensor* end_tensor,
|
||||
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
||||
@ -60,8 +70,7 @@ Status ValidateStridedSliceOp(
|
||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||
gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||
StridedSliceShapeSpec* shape_spec = nullptr);
|
||||
|
||||
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
||||
Status ValidateStridedSliceOp(
|
||||
@ -72,8 +81,7 @@ Status ValidateStridedSliceOp(
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||
StridedSliceShapeSpec* shape_spec = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user