Support dynamic sized slice.
Dynamic sized slice is supported by using input size as upper-bound shape, then call xla::SetSimensionSize on it. PiperOrigin-RevId: 349487191 Change-Id: If71bcf8715598f3e6c5db62d7fa297a5f8006e91
This commit is contained in:
parent
919c8f2738
commit
1c4d31960b
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -56,9 +57,11 @@ class SliceOp : public XlaOpKernel {
|
|||||||
|
|
||||||
std::vector<int64> begin;
|
std::vector<int64> begin;
|
||||||
std::vector<int64> size;
|
std::vector<int64> size;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
|
const bool begin_is_constant =
|
||||||
|
ctx->ConstantInputAsIntVector(1, &begin).ok();
|
||||||
|
const bool size_is_constant = ctx->ConstantInputAsIntVector(2, &size).ok();
|
||||||
|
if (begin_is_constant && size_is_constant) {
|
||||||
std::vector<int64> wrapped_size(size.size());
|
std::vector<int64> wrapped_size(size.size());
|
||||||
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
|
|
||||||
// `begin` is a compile-time constant.
|
// `begin` is a compile-time constant.
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
for (int i = 0; i < input_dims; ++i) {
|
||||||
if (size[i] == -1) {
|
if (size[i] == -1) {
|
||||||
@ -116,26 +119,58 @@ class SliceOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
ctx->SetOutput(0, slice);
|
ctx->SetOutput(0, slice);
|
||||||
} else {
|
} else {
|
||||||
// `begin` is not a compile-time constant.
|
// `begin` or `size` is not a compile-time constant.
|
||||||
|
if (size_is_constant) {
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
for (int i = 0; i < input_dims; ++i) {
|
||||||
OP_REQUIRES(ctx, 0 <= size[i],
|
OP_REQUIRES(
|
||||||
|
ctx, 0 <= size[i],
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"XLA compilation of Slice operator with negative sizes "
|
"XLA compilation of Slice operator with negative sizes "
|
||||||
"requires that 'begin' is a compile-time constant."));
|
"requires that 'begin' is a compile-time constant."));
|
||||||
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
|
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
|
||||||
errors::InvalidArgument("Expected size[", i, "] in [0, ",
|
errors::InvalidArgument("Expected size[", i, "] in [0, ",
|
||||||
input_shape.dim_size(i), "], but ",
|
input_shape.dim_size(i),
|
||||||
"got ", size[i]));
|
"], but ", "got ", size[i]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
|
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
|
||||||
scalar_indices.reserve(input_dims);
|
scalar_indices.reserve(input_dims);
|
||||||
xla::XlaOp begin = ctx->Input("begin");
|
xla::XlaOp begin = ctx->Input("begin");
|
||||||
for (int i = 0; i < input_dims; i++)
|
for (int i = 0; i < input_dims; i++) {
|
||||||
scalar_indices.push_back(
|
scalar_indices.push_back(
|
||||||
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
|
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
|
||||||
|
}
|
||||||
|
if (size_is_constant) {
|
||||||
|
ctx->SetOutput(0,
|
||||||
|
xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
|
||||||
|
} else {
|
||||||
|
// Size is not constant, use input size as upperbound and then set
|
||||||
|
// dimension size on it.
|
||||||
|
|
||||||
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
|
// First pad input with input size to avoid OOB -- dynamic slice with
|
||||||
|
// OOB slice produces undesired results.
|
||||||
|
xla::PaddingConfig padding_config;
|
||||||
|
for (xla::int64 i = 0; i < input_dims; ++i) {
|
||||||
|
auto* dims = padding_config.add_dimensions();
|
||||||
|
dims->set_edge_padding_low(0);
|
||||||
|
dims->set_edge_padding_high(input_shape.dim_size(i));
|
||||||
|
dims->set_interior_padding(0);
|
||||||
|
}
|
||||||
|
auto padded_input = xla::Pad(
|
||||||
|
ctx->Input(0), xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
|
||||||
|
padding_config);
|
||||||
|
|
||||||
|
// Slice full size out of the input starting from the offsets.
|
||||||
|
auto sliced = xla::DynamicSlice(padded_input, scalar_indices,
|
||||||
|
input_shape.dim_sizes());
|
||||||
|
for (int i = 0; i < input_dims; i++) {
|
||||||
|
auto dynamic_size =
|
||||||
|
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
|
||||||
|
sliced = xla::SetDimensionSize(sliced, dynamic_size, i);
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, sliced);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user