Support dynamic sized slice.

Dynamic sized slice is supported by using input size as upper-bound shape, then call xla::SetSimensionSize on it.

PiperOrigin-RevId: 349487191
Change-Id: If71bcf8715598f3e6c5db62d7fa297a5f8006e91
This commit is contained in:
Yunxing Dai 2020-12-29 16:58:29 -08:00 committed by TensorFlower Gardener
parent 919c8f2738
commit 1c4d31960b

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -56,9 +57,11 @@ class SliceOp : public XlaOpKernel {
std::vector<int64> begin;
std::vector<int64> size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
std::vector<int64> wrapped_size(size.size());
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
const bool begin_is_constant =
ctx->ConstantInputAsIntVector(1, &begin).ok();
const bool size_is_constant = ctx->ConstantInputAsIntVector(2, &size).ok();
if (begin_is_constant && size_is_constant) {
std::vector<int64> wrapped_size(size.size());
// `begin` is a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
if (size[i] == -1) {
@ -116,26 +119,58 @@ class SliceOp : public XlaOpKernel {
}
ctx->SetOutput(0, slice);
} else {
// `begin` is not a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES(ctx, 0 <= size[i],
errors::InvalidArgument(
"XLA compilation of Slice operator with negative sizes "
"requires that 'begin' is a compile-time constant."));
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
errors::InvalidArgument("Expected size[", i, "] in [0, ",
input_shape.dim_size(i), "], but ",
"got ", size[i]));
// `begin` or `size` is not a compile-time constant.
if (size_is_constant) {
for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES(
ctx, 0 <= size[i],
errors::InvalidArgument(
"XLA compilation of Slice operator with negative sizes "
"requires that 'begin' is a compile-time constant."));
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
errors::InvalidArgument("Expected size[", i, "] in [0, ",
input_shape.dim_size(i),
"], but ", "got ", size[i]));
}
}
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
scalar_indices.reserve(input_dims);
xla::XlaOp begin = ctx->Input("begin");
for (int i = 0; i < input_dims; i++)
for (int i = 0; i < input_dims; i++) {
scalar_indices.push_back(
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
}
if (size_is_constant) {
ctx->SetOutput(0,
xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
} else {
// Size is not constant, use input size as upperbound and then set
// dimension size on it.
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
// First pad input with input size to avoid OOB -- dynamic slice with
// OOB slice produces undesired results.
xla::PaddingConfig padding_config;
for (xla::int64 i = 0; i < input_dims; ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(0);
dims->set_edge_padding_high(input_shape.dim_size(i));
dims->set_interior_padding(0);
}
auto padded_input = xla::Pad(
ctx->Input(0), xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
padding_config);
// Slice full size out of the input starting from the offsets.
auto sliced = xla::DynamicSlice(padded_input, scalar_indices,
input_shape.dim_sizes());
for (int i = 0; i < input_dims; i++) {
auto dynamic_size =
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
sliced = xla::SetDimensionSize(sliced, dynamic_size, i);
}
ctx->SetOutput(0, sliced);
}
}
}
};