Support dynamic sized slice.
Dynamic sized slice is supported by using input size as upper-bound shape, then call xla::SetSimensionSize on it. PiperOrigin-RevId: 349487191 Change-Id: If71bcf8715598f3e6c5db62d7fa297a5f8006e91
This commit is contained in:
parent
919c8f2738
commit
1c4d31960b
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -56,9 +57,11 @@ class SliceOp : public XlaOpKernel {
|
||||
|
||||
std::vector<int64> begin;
|
||||
std::vector<int64> size;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
|
||||
std::vector<int64> wrapped_size(size.size());
|
||||
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
|
||||
const bool begin_is_constant =
|
||||
ctx->ConstantInputAsIntVector(1, &begin).ok();
|
||||
const bool size_is_constant = ctx->ConstantInputAsIntVector(2, &size).ok();
|
||||
if (begin_is_constant && size_is_constant) {
|
||||
std::vector<int64> wrapped_size(size.size());
|
||||
// `begin` is a compile-time constant.
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
if (size[i] == -1) {
|
||||
@ -116,26 +119,58 @@ class SliceOp : public XlaOpKernel {
|
||||
}
|
||||
ctx->SetOutput(0, slice);
|
||||
} else {
|
||||
// `begin` is not a compile-time constant.
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
OP_REQUIRES(ctx, 0 <= size[i],
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation of Slice operator with negative sizes "
|
||||
"requires that 'begin' is a compile-time constant."));
|
||||
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
|
||||
errors::InvalidArgument("Expected size[", i, "] in [0, ",
|
||||
input_shape.dim_size(i), "], but ",
|
||||
"got ", size[i]));
|
||||
// `begin` or `size` is not a compile-time constant.
|
||||
if (size_is_constant) {
|
||||
for (int i = 0; i < input_dims; ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, 0 <= size[i],
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation of Slice operator with negative sizes "
|
||||
"requires that 'begin' is a compile-time constant."));
|
||||
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
|
||||
errors::InvalidArgument("Expected size[", i, "] in [0, ",
|
||||
input_shape.dim_size(i),
|
||||
"], but ", "got ", size[i]));
|
||||
}
|
||||
}
|
||||
|
||||
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
|
||||
scalar_indices.reserve(input_dims);
|
||||
xla::XlaOp begin = ctx->Input("begin");
|
||||
for (int i = 0; i < input_dims; i++)
|
||||
for (int i = 0; i < input_dims; i++) {
|
||||
scalar_indices.push_back(
|
||||
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
|
||||
}
|
||||
if (size_is_constant) {
|
||||
ctx->SetOutput(0,
|
||||
xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
|
||||
} else {
|
||||
// Size is not constant, use input size as upperbound and then set
|
||||
// dimension size on it.
|
||||
|
||||
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
|
||||
// First pad input with input size to avoid OOB -- dynamic slice with
|
||||
// OOB slice produces undesired results.
|
||||
xla::PaddingConfig padding_config;
|
||||
for (xla::int64 i = 0; i < input_dims; ++i) {
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(0);
|
||||
dims->set_edge_padding_high(input_shape.dim_size(i));
|
||||
dims->set_interior_padding(0);
|
||||
}
|
||||
auto padded_input = xla::Pad(
|
||||
ctx->Input(0), xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
|
||||
padding_config);
|
||||
|
||||
// Slice full size out of the input starting from the offsets.
|
||||
auto sliced = xla::DynamicSlice(padded_input, scalar_indices,
|
||||
input_shape.dim_sizes());
|
||||
for (int i = 0; i < input_dims; i++) {
|
||||
auto dynamic_size =
|
||||
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
|
||||
sliced = xla::SetDimensionSize(sliced, dynamic_size, i);
|
||||
}
|
||||
ctx->SetOutput(0, sliced);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user