[TF:XLA] Fix more int32/int64 literal bugs.
Make XlaOpKernelContext::ConstantInputReshaped private, and change its users to use other methods, since its use cases can be handled other ways. Move code to copy a tensor to a literal into common literal_util module and simplify it. PiperOrigin-RevId: 220502409
This commit is contained in:
parent
6ca494f8f8
commit
ee054e9826
tensorflow/compiler/tf2xla
@ -33,8 +33,8 @@ class FillOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// The output of this Op is a tensor of shape 'dims_shape' with each
|
||||
// element set to the scalar 'dims_literal'.
|
||||
const TensorShape dims_shape = ctx->InputShape(0);
|
||||
const TensorShape value_shape = ctx->InputShape(1);
|
||||
const TensorShape dims_shape = ctx->InputShape("dims");
|
||||
const TensorShape value_shape = ctx->InputShape("value");
|
||||
OP_REQUIRES(
|
||||
ctx, IsLegacyVector(dims_shape),
|
||||
errors::InvalidArgument("dims must be a vector of int32, got shape ",
|
||||
@ -42,29 +42,22 @@ class FillOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(value_shape),
|
||||
errors::InvalidArgument("value must be a scalar, got shape ",
|
||||
value_shape.DebugString()));
|
||||
|
||||
// Evaluate the 'dims' constant input, reshaping to a vector if it
|
||||
// was a 'legacy' vector (secretly a scalar).
|
||||
xla::Literal dims_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(
|
||||
0, {dims_shape.num_elements()}, &dims_literal));
|
||||
std::vector<int64> dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dims", &dims));
|
||||
|
||||
// Convert the dims literal into a vector that we can pass to
|
||||
// XlaBuilder.
|
||||
std::vector<int64> broadcast;
|
||||
broadcast.reserve(dims_literal.shape().dimensions(0));
|
||||
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
|
||||
broadcast.push_back(dims_literal.Get<int>({i}));
|
||||
}
|
||||
// Look up the value input, reshaping to a scalar if it was a
|
||||
// 'legacy' scalar (secretly a vector).
|
||||
xla::XlaOp data = ctx->Input(1);
|
||||
xla::XlaOp data = ctx->Input("value");
|
||||
if (value_shape.dims() > 0) {
|
||||
CHECK_EQ(value_shape.dims(), 1);
|
||||
data = xla::Reshape(data, {});
|
||||
}
|
||||
// Emit the actual computation, which broadcasts the scalar to the
|
||||
// desired shape.
|
||||
auto result = xla::Broadcast(data, broadcast);
|
||||
auto result = xla::Broadcast(data, dims);
|
||||
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
@ -65,8 +65,8 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape pad_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape pad_shape = ctx->InputShape("paddings");
|
||||
|
||||
MirrorPadMode mode;
|
||||
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
|
||||
@ -93,11 +93,10 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
|
||||
// Evaluate the 'padding' constant input, reshaping to a matrix.
|
||||
xla::Literal pad_literal;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput("paddings", &pad_literal));
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
auto in0 = ctx->Input(0);
|
||||
auto in0 = ctx->Input("input");
|
||||
xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
|
||||
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
|
||||
xla::StatusOr<xla::XlaOp> accum_status =
|
||||
|
@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel {
|
||||
}
|
||||
// XlaBuilder::Rev() requires concrete values for dimensions arg.
|
||||
xla::Literal lax;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
|
||||
std::vector<bool> revdims(x_shape.dims());
|
||||
std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
|
||||
revdims.begin());
|
||||
std::vector<int64> dimensions;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax));
|
||||
|
||||
std::vector<int64> dimensions;
|
||||
for (int d = 0; d < x_shape.dims(); ++d) {
|
||||
if (revdims[d]) {
|
||||
if (lax.Get<bool>({d})) {
|
||||
dimensions.push_back(d);
|
||||
}
|
||||
}
|
||||
|
@ -108,21 +108,20 @@ class ExpandDimsOp : public XlaOpKernel {
|
||||
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape dim_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("input");
|
||||
const TensorShape dim_shape = ctx->InputShape("dim");
|
||||
|
||||
std::vector<int64> dims;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
|
||||
// TODO(phawkins): the standard implementation of ExpandDimsOp seems to
|
||||
// accept legacy scalars, even when they should be forbidden by the graphdef
|
||||
// version.
|
||||
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
|
||||
OP_REQUIRES(ctx, dims.size() == 1,
|
||||
errors::InvalidArgument(absl::StrCat(
|
||||
"dim input to ExpandDims must be a scalar; got ",
|
||||
dim_shape.DebugString())));
|
||||
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
|
||||
|
||||
int dim = literal.data<int32>()[0];
|
||||
int dim = dims[0];
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
|
||||
@ -148,7 +147,7 @@ class ExpandDimsOp : public XlaOpKernel {
|
||||
dim = std::min<int32>(dim, existing_dims_size);
|
||||
new_shape.emplace(new_shape.begin() + dim, 1);
|
||||
|
||||
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
|
||||
ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),
|
||||
|
@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel {
|
||||
: XlaOpKernel(ctx), conjugate_(conjugate) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape perm_tensor_shape = ctx->InputShape(1);
|
||||
const TensorShape input_shape = ctx->InputShape("x");
|
||||
const TensorShape perm_tensor_shape = ctx->InputShape("perm");
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
|
||||
@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel {
|
||||
". But input(1) is a vector of size ",
|
||||
perm_tensor_shape.num_elements()));
|
||||
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
|
||||
|
||||
std::vector<int32> perm(dims);
|
||||
std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
|
||||
perm.begin());
|
||||
std::vector<int64> perm;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm));
|
||||
|
||||
std::vector<int64> transposed_order;
|
||||
// Check whether permutation is a permutation of integers of [0 .. dims).
|
||||
absl::InlinedVector<bool, 8> bits(dims);
|
||||
bool is_identity = true;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
const int32 d = perm[i];
|
||||
const int64 d = perm[i];
|
||||
OP_REQUIRES(
|
||||
ctx, 0 <= d && d < dims,
|
||||
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
|
||||
@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel {
|
||||
xla::XlaOp transposed;
|
||||
// 0-D, 1-D, and identity transposes do nothing.
|
||||
if (dims <= 1 || is_identity) {
|
||||
transposed = ctx->Input(0);
|
||||
transposed = ctx->Input("x");
|
||||
} else {
|
||||
transposed = xla::Transpose(ctx->Input(0), transposed_order);
|
||||
transposed = xla::Transpose(ctx->Input("x"), transposed_order);
|
||||
}
|
||||
|
||||
// Conjugate the transposed result if this is ConjugateTransposeOp.
|
||||
|
@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor) {
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal));
|
||||
return literal.Clone();
|
||||
}
|
||||
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
|
||||
xla::Shape xla_shape;
|
||||
|
@ -30,6 +30,11 @@ namespace tensorflow {
|
||||
// 'host_tensor'.
|
||||
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
|
||||
xla::BorrowingLiteral* literal);
|
||||
|
||||
// Returns a Literal with the contents of 'host_tensor', backed by its own
|
||||
// storage (i.e., not reusing 'host_tensor's buffers.)
|
||||
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor);
|
||||
|
||||
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
|
||||
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
|
||||
Status HostTensorToMutableBorrowingLiteral(
|
||||
|
@ -136,25 +136,6 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
}
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
|
||||
auto copy_tensor_to_literal = [](const Tensor& tensor,
|
||||
xla::Literal* literal) {
|
||||
xla::Shape literal_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
|
||||
|
||||
*literal = xla::Literal(literal_shape);
|
||||
|
||||
// memcpy over the payload ...
|
||||
// TODO(phawkins): handle string types.
|
||||
size_t total_bytes = tensor.TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
void* dst_ptr = literal->untyped_data();
|
||||
const void* src_ptr = DMAHelper::base(&tensor);
|
||||
memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// If the tensor has a known constant value, there is no need to invoke XLA.
|
||||
if (expression->has_constant_value()) {
|
||||
Tensor temp(tensor.dtype());
|
||||
@ -164,14 +145,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
|
||||
}
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Make sure we treat zero-element tensors as constant.
|
||||
if (new_shape.num_elements() == 0) {
|
||||
Tensor temp(tensor.dtype(), new_shape);
|
||||
|
||||
return copy_tensor_to_literal(temp, constant_literal);
|
||||
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaOp handle = expression->handle();
|
||||
@ -322,6 +304,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
|
||||
return LiteralToInt64Vector(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
|
||||
absl::string_view name, std::vector<int64>* out) {
|
||||
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
|
||||
xla::Literal literal;
|
||||
TF_RETURN_IF_ERROR(ConstantInputReshaped(
|
||||
index, {InputShape(index).num_elements()}, &literal));
|
||||
return LiteralToInt64Vector(literal, out);
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
|
||||
xla::Literal* out) {
|
||||
xla::Literal literal;
|
||||
|
@ -111,14 +111,6 @@ class XlaOpKernelContext {
|
||||
Status ConstantInput(int index, xla::Literal* constant_literal);
|
||||
Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
|
||||
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
// returns a non-Ok status. If InputShape(index).num_elements() !=
|
||||
// new_shape.num_elements(), returns an error status.
|
||||
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
// Converts a constant scalar int32 or int64 tensor into an int64.
|
||||
Status ConstantInputAsIntScalar(int index, int64* out);
|
||||
Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
|
||||
@ -134,6 +126,8 @@ class XlaOpKernelContext {
|
||||
// Reshapes and converts a constant int32 or int64 tensor into a vector of
|
||||
// int64s.
|
||||
Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
|
||||
Status ConstantInputReshapedToIntVector(absl::string_view name,
|
||||
std::vector<int64>* out);
|
||||
|
||||
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
|
||||
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
|
||||
@ -260,6 +254,14 @@ class XlaOpKernelContext {
|
||||
// type to allow mapping for variant to more generic types.
|
||||
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
|
||||
|
||||
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
|
||||
// InputShape(index), and stores it in `*constant_literal`. If the input
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
// returns a non-Ok status. If InputShape(index).num_elements() !=
|
||||
// new_shape.num_elements(), returns an error status.
|
||||
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
OpKernelContext* const context_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user