Previously we don't allow a dynamic dimension to change in a HLO while loop. But this constrain breaks tensor list where the true dynamic dimension is only known inside the loop body. This CL: - Add the feature in dynamic padder to be able to change a dynamic dimension's size in the loop. - Add a nice test to demonstrate how tensor list / stack can be handled more elegantly in xla. - Add necessary machinery to wire this feature into tf2xla. PiperOrigin-RevId: 307901191 Change-Id: I4d39f1d8a8c944f1e9834c39599e6cfbc99f6807
339 lines
12 KiB
C++
339 lines
12 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// XLA-specific Shape Ops.
|
|
|
|
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/core/framework/bounds_check.h"
|
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
class ShapeOp : public XlaOpKernel {
|
|
public:
|
|
explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
std::vector<xla::XlaOp> operands;
|
|
const int rank = input_shape.dims();
|
|
if (rank != 0) {
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
operands.push_back(xla::Broadcast(
|
|
xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
|
|
ctx->output_xla_type(0)),
|
|
{1}));
|
|
}
|
|
|
|
ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), operands, 0));
|
|
} else {
|
|
// Rank 0 won't have dynamic size dimension, use constant output.
|
|
Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
|
|
OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
|
|
ctx->SetConstantOutput(0, shape_constant);
|
|
}
|
|
}
|
|
|
|
private:
|
|
DataType out_dtype_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
|
|
|
|
class ShapeNOp : public XlaOpKernel {
|
|
public:
|
|
explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
|
const TensorShape input_shape = ctx->InputShape(i);
|
|
std::vector<xla::XlaOp> operands;
|
|
|
|
const int rank = input_shape.dims();
|
|
if (rank != 0) {
|
|
// Each dimension can be dynamic, so use GetDimensionSize to get the
|
|
// runtime dimension.
|
|
for (int64 dim = 0; dim < rank; ++dim) {
|
|
operands.push_back(xla::Broadcast(
|
|
xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(i), dim),
|
|
ctx->output_xla_type(i)),
|
|
{1}));
|
|
}
|
|
|
|
ctx->SetOutput(i, xla::ConcatInDim(ctx->builder(), operands, 0));
|
|
} else {
|
|
// Rank 0 won't have dynamic size dimension, use constant output.
|
|
Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
|
|
OP_REQUIRES_OK(ctx,
|
|
TensorShapeToConstant(input_shape, &shape_constant));
|
|
ctx->SetConstantOutput(i, shape_constant);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool IsExpensive() override { return false; }
|
|
|
|
private:
|
|
DataType out_dtype_;
|
|
};
|
|
REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
|
|
|
|
class RankOp : public XlaOpKernel {
|
|
public:
|
|
explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
const int rank = input_shape.dims();
|
|
Tensor rank_constant(DT_INT32, TensorShape({}));
|
|
rank_constant.scalar<int32>()() = rank;
|
|
|
|
ctx->SetConstantOutput(0, rank_constant);
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
|
|
|
|
class SizeOp : public XlaOpKernel {
|
|
public:
|
|
explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
OP_REQUIRES(ctx,
|
|
FastBoundsCheck(input_shape.num_elements(),
|
|
std::numeric_limits<int32>::max()),
|
|
errors::InvalidArgument("Size does not work for tensors > "
|
|
"int32 max."));
|
|
Tensor size_constant(DT_INT32, TensorShape({}));
|
|
const int rank = input_shape.dims();
|
|
xla::XlaBuilder* builder = ctx->builder();
|
|
auto size = xla::One(builder, xla::U32);
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
size = xla::Mul(
|
|
size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
|
|
xla::U32));
|
|
}
|
|
size = xla::ConvertElementType(size, ctx->output_xla_type(0));
|
|
ctx->SetOutput(0, size);
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
|
|
|
|
class ExpandDimsOp : public XlaOpKernel {
|
|
public:
|
|
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape("input");
|
|
const TensorShape dim_shape = ctx->InputShape("dim");
|
|
|
|
std::vector<int64> dims;
|
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
|
|
OP_REQUIRES(ctx, dims.size() == 1,
|
|
errors::InvalidArgument(absl::StrCat(
|
|
"dim input to ExpandDims must be a scalar; got ",
|
|
dim_shape.DebugString())));
|
|
int dim = dims[0];
|
|
|
|
OP_REQUIRES(ctx,
|
|
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
|
|
errors::InvalidArgument("Tried to expand dim index ", dim,
|
|
" for tensor with ", input_shape.dims(),
|
|
" dimensions."));
|
|
|
|
auto existing_dims = input_shape.dim_sizes();
|
|
// Safe - # elements in tensor dims bounded.
|
|
const int existing_dims_size = static_cast<int>(existing_dims.size());
|
|
std::vector<int64> new_shape(existing_dims_size);
|
|
for (size_t i = 0; i < new_shape.size(); ++i) {
|
|
new_shape[i] = existing_dims[i];
|
|
}
|
|
|
|
// We emulate numpy's interpretation of the dim axis when
|
|
// -input.dims() >= dim <= input.dims().
|
|
if (dim < 0) {
|
|
dim += existing_dims.size() + 1;
|
|
}
|
|
|
|
// Clamp to the end if needed.
|
|
dim = std::min<int32>(dim, existing_dims_size);
|
|
new_shape.emplace(new_shape.begin() + dim, 1);
|
|
|
|
ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
|
|
}
|
|
};
|
|
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),
|
|
ExpandDimsOp);
|
|
|
|
class SqueezeOp : public XlaOpKernel {
|
|
public:
|
|
explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
std::vector<int32> squeeze_dims;
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
|
|
squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
auto existing_dims = input_shape.dim_sizes();
|
|
int existing_dims_size = input_shape.dims();
|
|
std::vector<int64> new_shape;
|
|
|
|
std::unordered_set<int32> wrapped_squeeze_dims;
|
|
wrapped_squeeze_dims.reserve(squeeze_dims_.size());
|
|
// Validate squeeze dims against the input.
|
|
for (int32 dim : squeeze_dims_) {
|
|
OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
|
|
errors::InvalidArgument("Tried to squeeze dim index ", dim,
|
|
" for tensor with ",
|
|
input_shape.dims(), " dimensions."));
|
|
// If dim is < 0, we wrap around (-1 means the last element).
|
|
if (dim < 0) {
|
|
dim = existing_dims_size + dim;
|
|
}
|
|
|
|
wrapped_squeeze_dims.insert(dim);
|
|
}
|
|
|
|
for (int i = 0; i < existing_dims_size; ++i) {
|
|
auto existing_dim = existing_dims[i];
|
|
|
|
// If squeeze_set is non-empty, only squeeze those dimensions.
|
|
if (!wrapped_squeeze_dims.empty()) {
|
|
if (wrapped_squeeze_dims.count(i) > 0) {
|
|
OP_REQUIRES(ctx, existing_dim == 1,
|
|
errors::InvalidArgument(
|
|
"Tried to explicitly squeeze dimension ", i,
|
|
" but dimension was not 1: ", existing_dim));
|
|
} else {
|
|
// This dimension is not being squeezed.
|
|
new_shape.push_back(existing_dim);
|
|
}
|
|
} else {
|
|
// Copy over all non-1-length dimensions.
|
|
if (existing_dim != 1) {
|
|
new_shape.push_back(existing_dim);
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
|
|
}
|
|
|
|
private:
|
|
std::unordered_set<int32> squeeze_dims_;
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
|
|
|
|
class ZerosLikeOp : public XlaOpKernel {
|
|
public:
|
|
explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
if (IsTensorListInput(ctx, 0)) {
|
|
// Input is a TensorList.
|
|
|
|
// Check the TensorList input is initialized.
|
|
xla::XlaOp list = ctx->Input(0);
|
|
bool is_initialized;
|
|
OP_REQUIRES_OK(ctx, IsTensorListInitialized(list, &is_initialized));
|
|
OP_REQUIRES(
|
|
ctx, is_initialized,
|
|
errors::InvalidArgument(
|
|
"TensorList input for ZerosLike op is an uninitialized list"));
|
|
|
|
auto list_shape_or = ctx->builder()->GetShape(list);
|
|
OP_REQUIRES_OK(ctx, list_shape_or.status());
|
|
const xla::Shape& list_shape = list_shape_or.ValueOrDie();
|
|
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
|
list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1);
|
|
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
|
// Set dynamic dimension size to 0 for initialization value.
|
|
std::vector<xla::XlaOp> dynamic_dims;
|
|
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
|
auto sub_element = xla::GetTupleElement(list, i);
|
|
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
|
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
|
|
}
|
|
list_dynamic_dims.push_back(dynamic_dims);
|
|
}
|
|
xla::XlaOp new_list;
|
|
OP_REQUIRES_OK(
|
|
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
|
|
list_dynamic_dims, &new_list));
|
|
|
|
xla::XlaOp push_index;
|
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
|
|
|
|
xla::XlaOp result;
|
|
OP_REQUIRES_OK(ctx,
|
|
SetTensorListPushIndex(new_list, push_index, &result));
|
|
ctx->SetTensorListOutput(0, result);
|
|
} else {
|
|
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
|
xla::XlaOp input = ctx->Input(0);
|
|
auto input_shape = ctx->InputXlaShape(0).ValueOrDie();
|
|
auto result = xla::Broadcast(zero, input_shape.dimensions());
|
|
|
|
// Setting up dynamic dimensions of the broadcast.
|
|
for (int64 i = 0; i < input_shape.dimensions_size(); ++i) {
|
|
if (input_shape.is_dynamic_dimension(i)) {
|
|
xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i);
|
|
result = xla::SetDimensionSize(result, input_dynamic_dim, i);
|
|
}
|
|
}
|
|
|
|
ctx->SetOutput(0, result);
|
|
}
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp);
|
|
|
|
class OnesLikeOp : public XlaOpKernel {
|
|
public:
|
|
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
const TensorShape input_shape = ctx->InputShape(0);
|
|
|
|
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
|
|
ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
|
|
}
|
|
};
|
|
|
|
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|