[TF2XLA] Support dynamic slice size in strided slice op.

- Add two side outputs in ValidateStridedSliceOp to help analyze dynamic dimensions.
- Correctly set strided slice op's dynamic size if the slice size (slice end) is dynamic

PiperOrigin-RevId: 327893466
Change-Id: I7a5061ed1b92006c6f1f15c71661265b5bdfec25
This commit is contained in:
Yunxing Dai 2020-08-21 16:26:53 -07:00 committed by TensorFlower Gardener
parent f8066f9dc4
commit 1d654624e4
4 changed files with 164 additions and 72 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -28,13 +29,26 @@ class BroadcastToOp : public XlaOpKernel {
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
auto output_status_or =
BroadcastTo(context->Input(0), output_shape.dim_sizes());
OP_REQUIRES_OK(context, output_status_or.status());
auto output = output_status_or.ValueOrDie();
std::vector<bool> dynamic_dims;
OP_REQUIRES_OK(
context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
if (dynamic_dims[dim]) {
output = xla::SetDimensionSize(
output,
xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
{}),
dim);
}
}
auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
OP_REQUIRES_OK(context, output.status());
context->SetOutput(0, output.ValueOrDie());
context->SetOutput(0, output);
}
};

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/util/strided_slice_op.h"
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@ -23,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
@ -33,6 +37,7 @@ limitations under the License.
namespace tensorflow {
namespace {
using errors::InvalidArgument;
class StridedSliceOp : public XlaOpKernel {
public:
@ -48,7 +53,7 @@ class StridedSliceOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape begin_shape = ctx->InputShape("begin");
VLOG(0) << "strided slice";
OP_REQUIRES(
ctx, begin_shape.dims() == 1,
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
@ -78,20 +83,24 @@ class StridedSliceOp : public XlaOpKernel {
TensorShape final_shape;
PartialTensorShape dummy_processing_shape, partial_final_shape;
bool dummy = false;
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
begin_is_constant ? &begin_tensor : nullptr,
end_is_constant ? &end_tensor : nullptr,
strides_tensor, input_shape, begin_mask_, end_mask_,
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
absl::InlinedVector<int64, 4> output_to_processing_mapping;
OP_REQUIRES_OK(
ctx,
ValidateStridedSliceOp(
begin_is_constant ? &begin_tensor : nullptr,
end_is_constant ? &end_tensor : nullptr, strides_tensor,
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides,
&output_to_sparse_mapping, &output_to_processing_mapping));
OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
errors::InvalidArgument(
"XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
OP_REQUIRES(
ctx, partial_final_shape.AsTensorShape(&final_shape),
InvalidArgument("XLA can't deduce compile time constant output "
"shape for strided slice: ",
partial_final_shape.DebugString(),
", output shape must be a compile-time constant"));
xla::XlaOp slice = ctx->Input(0);
if (begin_is_constant && end_is_constant) {
@ -119,69 +128,84 @@ class StridedSliceOp : public XlaOpKernel {
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
OP_REQUIRES_OK(ctx, operand_shape_or.status());
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
if (xla_shape.is_static()) {
// Static output shape, return a static slice.
slice = xla::Reshape(slice, final_shape.dim_sizes());
std::vector<bool> begins_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
std::vector<bool> ends_are_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
bool begins_are_static = absl::c_all_of(
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
OP_REQUIRES(ctx, begins_are_static,
errors::InvalidArgument(
"XLA can't use dynamic begin values for slice."));
bool ends_are_static = absl::c_all_of(
ends_are_dynamic, [](bool dynamic) { return !dynamic; });
// Static output shape, return a static slice.
slice = xla::Reshape(slice, final_shape.dim_sizes());
if (xla_shape.is_static() && ends_are_static) {
ctx->SetOutput(0, slice);
return;
}
auto input_dim_sizes = input_shape.dim_sizes();
for (int64 i = 0; i < xla_shape.rank(); ++i) {
if (xla_shape.is_dynamic_dimension(i)) {
input_dim_sizes[i] = -1;
for (int64 i = 0; i < final_shape.dims(); ++i) {
int64 input_index = output_to_processing_mapping[i];
if (input_index == -1) {
continue;
}
}
PartialTensorShape input_partial_shape(input_dim_sizes);
partial_final_shape.Clear();
end.clear();
strides.clear();
begin.clear();
// Run shape inferenference again with partial shape.
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
&begin_tensor, &end_tensor, strides_tensor,
input_partial_shape, begin_mask_, end_mask_,
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&dummy_processing_shape, &partial_final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
if (partial_final_shape.AsTensorShape(&final_shape)) {
// Static output shape, return a static slice.
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
return;
}
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
// We consider slicing a dynamic tensor t with negative indices as a
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
bool backward_slice = end[i] < 0;
if (dynamic_dim && backward_slice) {
int64 sparse_index = output_to_sparse_mapping[i];
bool end_is_dynamic =
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
bool backward_slice = sparse_index == -1
? false
: end_literal.Get<int32>({sparse_index}) < 0;
if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
OP_REQUIRES(
ctx, strides[i] == 1,
ctx, strides[input_index] == 1,
errors::InvalidArgument("XLA has not implemented dynamic "
"sized slice with non-trival stride yet. "
"Please file a bug against XLA"));
OP_REQUIRES(ctx, begin[i] >= 0,
errors::InvalidArgument(
"XLA has not implemented dynamic "
"sized slice with negative begin index %lld. "
"Please file a bug against XLA",
begin[i]));
// If there is a dynamic dimension, properly set dimension size of
// the result.
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
operand_size = xla::Add(
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
if (backward_slice) {
// We consider slicing a dynamic tensor t with negative indices as
// a dynamic sized slice. E.g., t[: -n], the result length is
// shape(t) - n.
OP_REQUIRES(ctx, !end_is_dynamic,
errors::InvalidArgument(
"XLA has not implemented dynamic "
"sized slice with dynamic negative index %lld. "));
operand_size = xla::Add(
operand_size,
xla::ConstantR0<int32>(ctx->builder(),
end_literal.Get<int32>({sparse_index})));
} else {
// The end of slice with dynamic slice size is the min of operand
// shape and slice size. E.g., t[:end_size], result size is
// min(shape(t), end_size).
xla::XlaOp end_size;
if (end_is_dynamic) {
end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
{sparse_index + 1}, {1}),
{});
} else {
end_size =
xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
}
operand_size = xla::Min(operand_size, end_size);
}
slice = xla::SetDimensionSize(
slice,
xla::Sub(operand_size,
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
xla::Sub(operand_size, xla::ConstantR0<int32>(
ctx->builder(), begin[input_index])),
i);
}
}
ctx->SetOutput(0, slice);
return;
} else {
// When output shape is fully defined, it must be a size one slice:
//
@ -239,9 +263,9 @@ class StridedSliceOp : public XlaOpKernel {
std::vector<int64> output_shape_dim_sizes;
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
private:

View File

@ -59,6 +59,11 @@ struct StridedSliceDenseSpec {
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
// it will be 1. A shrunk dimension is skipped.
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
// This vector has the same size as final_shape_gather_indices, but it
// remembers the sparse index that a dimension comes from, instead of dense
// index. A -1 in this vector means there the index is not from the sparse
// input.
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
// The dense indexed shrink mask is which processing dimensions
// should be shrunk. For example, if foo.shape = (10,10,10,10)
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@ -108,9 +113,11 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
dense->begin_mask |= (1 << full_index);
dense->end_mask |= (1 << full_index);
dense->final_shape_gather_indices.push_back(full_index);
dense->final_shape_gather_indices_sparse.push_back(-1);
}
} else if ((1 << i) & sparse.new_axis_mask) {
dense->final_shape_gather_indices.push_back(kNewAxis);
dense->final_shape_gather_indices_sparse.push_back(-1);
} else {
if (full_index == dense->begin.size()) {
return errors::InvalidArgument("Index out of range using input dim ",
@ -138,9 +145,13 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
// axis (now in dense form) so we can ignore dense->end below.
if (sparse.shrink_axis_mask & (1 << i)) {
dense->final_shape_gather_indices.push_back(kShrinkAxis);
dense->final_shape_gather_indices_sparse.push_back(-1);
dense->shrink_axis_mask |= (1 << full_index);
} else {
dense->final_shape_gather_indices.push_back(full_index);
// Remember that where in the sparse shape the dense dim comes
// from.
dense->final_shape_gather_indices_sparse.push_back(i);
}
full_index++;
}
@ -157,7 +168,9 @@ Status ValidateStridedSliceOp(
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
gtl::InlinedVector<int64, 4>* strides) {
gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
const bool begin_is_wrong =
begin_tensor != nullptr &&
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@ -362,11 +375,34 @@ Status ValidateStridedSliceOp(
// slices like foo[3,...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
final_shape->Clear();
for (auto gather_index : dense_spec.final_shape_gather_indices) {
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->clear();
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->clear();
}
for (int64 dense_dim = 0;
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
int64 sparse_index =
dense_spec.final_shape_gather_indices_sparse[dense_dim];
if (gather_index >= 0) {
final_shape->AddDim(processing_shape->dim_size(gather_index));
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->push_back(sparse_index);
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->push_back(gather_index);
}
} else if (gather_index == kNewAxis) {
final_shape->AddDim(1);
if (output_to_sparse_mapping != nullptr) {
output_to_sparse_mapping->push_back(-1);
}
if (output_to_processing_mapping != nullptr) {
output_to_processing_mapping->push_back(-1);
}
}
}
return Status::OK();
@ -379,14 +415,17 @@ Status ValidateStridedSliceOp(
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
// Validate with PartialTensorShape output
PartialTensorShape partial_processing_shape, partial_final_shape;
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
&partial_processing_shape, &partial_final_shape, is_identity,
is_simple_slice, slice_dim0, begin, end, strides));
is_simple_slice, slice_dim0, begin, end, strides,
output_to_sparse_mapping, output_to_processing_mapping));
// Verify that the output shapes are fully known
if (!partial_processing_shape.AsTensorShape(processing_shape) ||

View File

@ -40,6 +40,17 @@ namespace tensorflow {
// some dimensions of <processing_shape> and/or <final_shape> may be unknown
// (-1). Any validation that can be done without complete information is
// performed.
//
// This function changes the orders of dimensions, output_to_sparse_mapping and
// output_to_processing_mapping are used to track the order change.
//
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
// index in the begin_tensor. If
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
// sparse_mapping.
//
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
// processing_shape.
Status ValidateStridedSliceOp(
const Tensor* begin_tensor, const Tensor* end_tensor,
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@ -48,7 +59,9 @@ Status ValidateStridedSliceOp(
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
gtl::InlinedVector<int64, 4>* strides);
gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
// Same as above, but the outputs are TensorShape, not PartialTensorShape
Status ValidateStridedSliceOp(
@ -58,7 +71,9 @@ Status ValidateStridedSliceOp(
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
} // namespace tensorflow