[TF2XLA] Support dynamic slice size in strided slice op.
- Add two side outputs in ValidateStridedSliceOp to help analyze dynamic dimensions. - Correctly set strided slice op's dynamic size if the slice size (slice end) is dynamic PiperOrigin-RevId: 327893466 Change-Id: I7a5061ed1b92006c6f1f15c71661265b5bdfec25
This commit is contained in:
parent
f8066f9dc4
commit
1d654624e4
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -28,13 +29,26 @@ class BroadcastToOp : public XlaOpKernel {
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
TensorShape output_shape;
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
|
||||
auto output_status_or =
|
||||
BroadcastTo(context->Input(0), output_shape.dim_sizes());
|
||||
OP_REQUIRES_OK(context, output_status_or.status());
|
||||
auto output = output_status_or.ValueOrDie();
|
||||
std::vector<bool> dynamic_dims;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
|
||||
for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
|
||||
if (dynamic_dims[dim]) {
|
||||
output = xla::SetDimensionSize(
|
||||
output,
|
||||
xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
|
||||
{}),
|
||||
dim);
|
||||
}
|
||||
}
|
||||
|
||||
auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
|
||||
OP_REQUIRES_OK(context, output.status());
|
||||
context->SetOutput(0, output.ValueOrDie());
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -23,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -33,6 +37,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
using errors::InvalidArgument;
|
||||
|
||||
class StridedSliceOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -48,7 +53,7 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape begin_shape = ctx->InputShape("begin");
|
||||
|
||||
VLOG(0) << "strided slice";
|
||||
OP_REQUIRES(
|
||||
ctx, begin_shape.dims() == 1,
|
||||
errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
|
||||
@ -78,20 +83,24 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
TensorShape final_shape;
|
||||
PartialTensorShape dummy_processing_shape, partial_final_shape;
|
||||
bool dummy = false;
|
||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
||||
begin_is_constant ? &begin_tensor : nullptr,
|
||||
end_is_constant ? &end_tensor : nullptr,
|
||||
strides_tensor, input_shape, begin_mask_, end_mask_,
|
||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ValidateStridedSliceOp(
|
||||
begin_is_constant ? &begin_tensor : nullptr,
|
||||
end_is_constant ? &end_tensor : nullptr, strides_tensor,
|
||||
input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
|
||||
shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides,
|
||||
&output_to_sparse_mapping, &output_to_processing_mapping));
|
||||
|
||||
OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
errors::InvalidArgument(
|
||||
"XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
OP_REQUIRES(
|
||||
ctx, partial_final_shape.AsTensorShape(&final_shape),
|
||||
InvalidArgument("XLA can't deduce compile time constant output "
|
||||
"shape for strided slice: ",
|
||||
partial_final_shape.DebugString(),
|
||||
", output shape must be a compile-time constant"));
|
||||
|
||||
xla::XlaOp slice = ctx->Input(0);
|
||||
if (begin_is_constant && end_is_constant) {
|
||||
@ -119,69 +128,84 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
|
||||
OP_REQUIRES_OK(ctx, operand_shape_or.status());
|
||||
xla::Shape xla_shape = operand_shape_or.ValueOrDie();
|
||||
if (xla_shape.is_static()) {
|
||||
// Static output shape, return a static slice.
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
std::vector<bool> begins_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
|
||||
std::vector<bool> ends_are_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
|
||||
bool begins_are_static = absl::c_all_of(
|
||||
begins_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||
OP_REQUIRES(ctx, begins_are_static,
|
||||
errors::InvalidArgument(
|
||||
"XLA can't use dynamic begin values for slice."));
|
||||
bool ends_are_static = absl::c_all_of(
|
||||
ends_are_dynamic, [](bool dynamic) { return !dynamic; });
|
||||
// Static output shape, return a static slice.
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
if (xla_shape.is_static() && ends_are_static) {
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
}
|
||||
auto input_dim_sizes = input_shape.dim_sizes();
|
||||
|
||||
for (int64 i = 0; i < xla_shape.rank(); ++i) {
|
||||
if (xla_shape.is_dynamic_dimension(i)) {
|
||||
input_dim_sizes[i] = -1;
|
||||
for (int64 i = 0; i < final_shape.dims(); ++i) {
|
||||
int64 input_index = output_to_processing_mapping[i];
|
||||
if (input_index == -1) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
PartialTensorShape input_partial_shape(input_dim_sizes);
|
||||
partial_final_shape.Clear();
|
||||
end.clear();
|
||||
strides.clear();
|
||||
begin.clear();
|
||||
// Run shape inferenference again with partial shape.
|
||||
OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
|
||||
&begin_tensor, &end_tensor, strides_tensor,
|
||||
input_partial_shape, begin_mask_, end_mask_,
|
||||
ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&dummy_processing_shape, &partial_final_shape,
|
||||
&dummy, &dummy, &dummy, &begin, &end, &strides));
|
||||
if (partial_final_shape.AsTensorShape(&final_shape)) {
|
||||
// Static output shape, return a static slice.
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
}
|
||||
bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
|
||||
|
||||
// We consider slicing a dynamic tensor t with negative indices as a
|
||||
// dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
|
||||
for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
|
||||
bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
|
||||
bool backward_slice = end[i] < 0;
|
||||
if (dynamic_dim && backward_slice) {
|
||||
int64 sparse_index = output_to_sparse_mapping[i];
|
||||
bool end_is_dynamic =
|
||||
sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
|
||||
bool backward_slice = sparse_index == -1
|
||||
? false
|
||||
: end_literal.Get<int32>({sparse_index}) < 0;
|
||||
if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
ctx, strides[input_index] == 1,
|
||||
errors::InvalidArgument("XLA has not implemented dynamic "
|
||||
"sized slice with non-trival stride yet. "
|
||||
"Please file a bug against XLA"));
|
||||
|
||||
OP_REQUIRES(ctx, begin[i] >= 0,
|
||||
errors::InvalidArgument(
|
||||
"XLA has not implemented dynamic "
|
||||
"sized slice with negative begin index %lld. "
|
||||
"Please file a bug against XLA",
|
||||
begin[i]));
|
||||
// If there is a dynamic dimension, properly set dimension size of
|
||||
// the result.
|
||||
auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
|
||||
|
||||
operand_size = xla::Add(
|
||||
operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
|
||||
auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
|
||||
if (backward_slice) {
|
||||
// We consider slicing a dynamic tensor t with negative indices as
|
||||
// a dynamic sized slice. E.g., t[: -n], the result length is
|
||||
// shape(t) - n.
|
||||
OP_REQUIRES(ctx, !end_is_dynamic,
|
||||
errors::InvalidArgument(
|
||||
"XLA has not implemented dynamic "
|
||||
"sized slice with dynamic negative index %lld. "));
|
||||
operand_size = xla::Add(
|
||||
operand_size,
|
||||
xla::ConstantR0<int32>(ctx->builder(),
|
||||
end_literal.Get<int32>({sparse_index})));
|
||||
} else {
|
||||
// The end of slice with dynamic slice size is the min of operand
|
||||
// shape and slice size. E.g., t[:end_size], result size is
|
||||
// min(shape(t), end_size).
|
||||
xla::XlaOp end_size;
|
||||
if (end_is_dynamic) {
|
||||
end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
|
||||
{sparse_index + 1}, {1}),
|
||||
{});
|
||||
} else {
|
||||
end_size =
|
||||
xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
|
||||
}
|
||||
operand_size = xla::Min(operand_size, end_size);
|
||||
}
|
||||
slice = xla::SetDimensionSize(
|
||||
slice,
|
||||
xla::Sub(operand_size,
|
||||
xla::ConstantR0<int32>(ctx->builder(), begin[i])),
|
||||
xla::Sub(operand_size, xla::ConstantR0<int32>(
|
||||
ctx->builder(), begin[input_index])),
|
||||
i);
|
||||
}
|
||||
}
|
||||
ctx->SetOutput(0, slice);
|
||||
return;
|
||||
} else {
|
||||
// When output shape is fully defined, it must be a size one slice:
|
||||
//
|
||||
@ -239,9 +263,9 @@ class StridedSliceOp : public XlaOpKernel {
|
||||
|
||||
std::vector<int64> output_shape_dim_sizes;
|
||||
slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
}
|
||||
slice = xla::Reshape(slice, final_shape.dim_sizes());
|
||||
ctx->SetOutput(0, slice);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -59,6 +59,11 @@ struct StridedSliceDenseSpec {
|
||||
// is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
|
||||
// it will be 1. A shrunk dimension is skipped.
|
||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices;
|
||||
// This vector has the same size as final_shape_gather_indices, but it
|
||||
// remembers the sparse index that a dimension comes from, instead of dense
|
||||
// index. A -1 in this vector means there the index is not from the sparse
|
||||
// input.
|
||||
gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
|
||||
// The dense indexed shrink mask is which processing dimensions
|
||||
// should be shrunk. For example, if foo.shape = (10,10,10,10)
|
||||
// foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
|
||||
@ -108,9 +113,11 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
||||
dense->begin_mask |= (1 << full_index);
|
||||
dense->end_mask |= (1 << full_index);
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||
}
|
||||
} else if ((1 << i) & sparse.new_axis_mask) {
|
||||
dense->final_shape_gather_indices.push_back(kNewAxis);
|
||||
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||
} else {
|
||||
if (full_index == dense->begin.size()) {
|
||||
return errors::InvalidArgument("Index out of range using input dim ",
|
||||
@ -138,9 +145,13 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
|
||||
// axis (now in dense form) so we can ignore dense->end below.
|
||||
if (sparse.shrink_axis_mask & (1 << i)) {
|
||||
dense->final_shape_gather_indices.push_back(kShrinkAxis);
|
||||
dense->final_shape_gather_indices_sparse.push_back(-1);
|
||||
dense->shrink_axis_mask |= (1 << full_index);
|
||||
} else {
|
||||
dense->final_shape_gather_indices.push_back(full_index);
|
||||
// Remember that where in the sparse shape the dense dim comes
|
||||
// from.
|
||||
dense->final_shape_gather_indices_sparse.push_back(i);
|
||||
}
|
||||
full_index++;
|
||||
}
|
||||
@ -157,7 +168,9 @@ Status ValidateStridedSliceOp(
|
||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||
gtl::InlinedVector<int64, 4>* strides) {
|
||||
gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||
const bool begin_is_wrong =
|
||||
begin_tensor != nullptr &&
|
||||
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
|
||||
@ -362,11 +375,34 @@ Status ValidateStridedSliceOp(
|
||||
// slices like foo[3,...] will reduce dimension by 1.
|
||||
// This cannot be done earlier, because it depends on Step 3.
|
||||
final_shape->Clear();
|
||||
for (auto gather_index : dense_spec.final_shape_gather_indices) {
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->clear();
|
||||
}
|
||||
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->clear();
|
||||
}
|
||||
for (int64 dense_dim = 0;
|
||||
dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
|
||||
int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
|
||||
int64 sparse_index =
|
||||
dense_spec.final_shape_gather_indices_sparse[dense_dim];
|
||||
if (gather_index >= 0) {
|
||||
final_shape->AddDim(processing_shape->dim_size(gather_index));
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->push_back(sparse_index);
|
||||
}
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->push_back(gather_index);
|
||||
}
|
||||
} else if (gather_index == kNewAxis) {
|
||||
final_shape->AddDim(1);
|
||||
if (output_to_sparse_mapping != nullptr) {
|
||||
output_to_sparse_mapping->push_back(-1);
|
||||
}
|
||||
if (output_to_processing_mapping != nullptr) {
|
||||
output_to_processing_mapping->push_back(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -379,14 +415,17 @@ Status ValidateStridedSliceOp(
|
||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
|
||||
// Validate with PartialTensorShape output
|
||||
PartialTensorShape partial_processing_shape, partial_final_shape;
|
||||
TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
|
||||
begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
|
||||
end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
|
||||
&partial_processing_shape, &partial_final_shape, is_identity,
|
||||
is_simple_slice, slice_dim0, begin, end, strides));
|
||||
is_simple_slice, slice_dim0, begin, end, strides,
|
||||
output_to_sparse_mapping, output_to_processing_mapping));
|
||||
|
||||
// Verify that the output shapes are fully known
|
||||
if (!partial_processing_shape.AsTensorShape(processing_shape) ||
|
||||
|
@ -40,6 +40,17 @@ namespace tensorflow {
|
||||
// some dimensions of <processing_shape> and/or <final_shape> may be unknown
|
||||
// (-1). Any validation that can be done without complete information is
|
||||
// performed.
|
||||
//
|
||||
// This function changes the orders of dimensions, output_to_sparse_mapping and
|
||||
// output_to_processing_mapping are used to track the order change.
|
||||
//
|
||||
// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
|
||||
// index in the begin_tensor. If
|
||||
// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
|
||||
// sparse_mapping.
|
||||
//
|
||||
// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
|
||||
// processing_shape.
|
||||
Status ValidateStridedSliceOp(
|
||||
const Tensor* begin_tensor, const Tensor* end_tensor,
|
||||
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
|
||||
@ -48,7 +59,9 @@ Status ValidateStridedSliceOp(
|
||||
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
|
||||
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
|
||||
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
|
||||
gtl::InlinedVector<int64, 4>* strides);
|
||||
gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||
|
||||
// Same as above, but the outputs are TensorShape, not PartialTensorShape
|
||||
Status ValidateStridedSliceOp(
|
||||
@ -58,7 +71,9 @@ Status ValidateStridedSliceOp(
|
||||
int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
|
||||
TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
|
||||
bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
|
||||
gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
|
||||
gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
|
||||
gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user