STT-tensorflow/tensorflow/compiler/tf2xla/kernels/slice_op.cc
Smit Hinsu 296993a42c Remove deprecated variants of DynamicSlice and DynamicUpdateSlice builders
Upgraded existing users by converting 1d start_slices to a list of scalars. I am expecting this to be performance neutral as these tensors are expected to be small. I decided against having the XlaBuilder do this internally as I guess we want to discourage usage of vector indices.

PiperOrigin-RevId: 311261628
Change-Id: I4b779a58cfca1699bdf5104c236bc6453fd419bc
2020-05-12 21:33:39 -07:00

150 lines
6.0 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA-specific Slice Op.
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
namespace {
class SliceOp : public XlaOpKernel {
public:
explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape begin_tensor_shape = ctx->InputShape(1);
const TensorShape size_tensor_shape = ctx->InputShape(2);
const int input_dims = input_shape.dims();
OP_REQUIRES(
ctx,
TensorShapeUtils::IsVector(begin_tensor_shape) &&
TensorShapeUtils::IsVector(size_tensor_shape) &&
begin_tensor_shape.num_elements() == input_dims &&
size_tensor_shape.num_elements() == input_dims,
errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ",
input_dims, ", but got shapes ", begin_tensor_shape.DebugString(),
" and ", size_tensor_shape.DebugString(), " instead."));
std::vector<int64> begin;
std::vector<int64> size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size));
std::vector<int64> wrapped_size(size.size());
if (ctx->ConstantInputAsIntVector(1, &begin).ok()) {
// `begin` is a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
if (size[i] == -1) {
// A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
wrapped_size[i] = input_shape.dim_size(i) - begin[i];
} else {
wrapped_size[i] = size[i];
}
}
for (int i = 0; i < input_dims; ++i) {
int64 b = begin[i];
int64 s = wrapped_size[i];
if (input_shape.dim_size(i) == 0) {
OP_REQUIRES(ctx, b == 0 && s == 0,
errors::InvalidArgument(
"Expected begin[", i, "] == 0 (got ", b,
") and size[", i, "] == 0 ", "(got ", s, ") when ",
"input_shape.dim_size(", i, ") == 0"));
} else {
OP_REQUIRES(ctx, 0 <= b && b <= input_shape.dim_size(i),
errors::InvalidArgument("Expected begin[", i, "] in [0, ",
input_shape.dim_size(i),
"], but got ", b));
OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i),
errors::InvalidArgument("Expected size[", i, "] in [0, ",
input_shape.dim_size(i) - b,
"], but ", "got ", s));
}
}
std::vector<int64> limits;
limits.reserve(begin.size());
for (int i = 0; i < begin.size(); ++i) {
limits.push_back(begin[i] + wrapped_size[i]);
}
std::vector<int64> strides(begin.size(), 1);
auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
// Check for slice on dynamic dimensions.
ctx->set_dynamic_dimension_is_minus_one(true);
std::vector<int64> dynamic_size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
for (int64 i = 0; i < size.size(); ++i) {
if (dynamic_size[i] == -1) {
if (size[i] != -1) {
// If there is a dynamic dimension, properly set dimension size of
// the slice.
auto dynamic_size =
xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {});
slice = xla::SetDimensionSize(slice, dynamic_size, i);
}
}
}
ctx->SetOutput(0, slice);
} else {
// `begin` is not a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES(ctx, 0 <= size[i],
errors::InvalidArgument(
"XLA compilation of Slice operator with negative sizes "
"requires that 'begin' is a compile-time constant."));
OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i),
errors::InvalidArgument("Expected size[", i, "] in [0, ",
input_shape.dim_size(i), "], but ",
"got ", size[i]));
}
absl::InlinedVector<xla::XlaOp, 4> scalar_indices;
scalar_indices.reserve(input_dims);
xla::XlaOp begin = ctx->Input("begin");
for (int i = 0; i < input_dims; i++)
scalar_indices.push_back(
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size));
}
}
};
REGISTER_XLA_OP(Name("Slice")
.CompileTimeConstantInput("begin")
.CompileTimeConstantInput("size"),
SliceOp);
} // namespace
} // namespace tensorflow