[TF:XLA] Fix more int32/int64 literal bugs.

Make XlaOpKernelContext::ConstantInputReshaped private, and change its users to use other methods, since its use cases can be handled other ways.

Move code to copy a tensor to a literal into common literal_util module and simplify it.

PiperOrigin-RevId: 220502409
This commit is contained in:
Peter Hawkins 2018-11-07 11:47:00 -08:00 committed by TensorFlower Gardener
parent 6ca494f8f8
commit ee054e9826
9 changed files with 62 additions and 74 deletions

View File

@ -33,8 +33,8 @@ class FillOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
// The output of this Op is a tensor of shape 'dims_shape' with each
// element set to the scalar 'dims_literal'.
const TensorShape dims_shape = ctx->InputShape(0);
const TensorShape value_shape = ctx->InputShape(1);
const TensorShape dims_shape = ctx->InputShape("dims");
const TensorShape value_shape = ctx->InputShape("value");
OP_REQUIRES(
ctx, IsLegacyVector(dims_shape),
errors::InvalidArgument("dims must be a vector of int32, got shape ",
@ -42,29 +42,22 @@ class FillOp : public XlaOpKernel {
OP_REQUIRES(ctx, IsLegacyScalar(value_shape),
errors::InvalidArgument("value must be a scalar, got shape ",
value_shape.DebugString()));
// Evaluate the 'dims' constant input, reshaping to a vector if it
// was a 'legacy' vector (secretly a scalar).
xla::Literal dims_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(
0, {dims_shape.num_elements()}, &dims_literal));
std::vector<int64> dims;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dims", &dims));
// Convert the dims literal into a vector that we can pass to
// XlaBuilder.
std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
broadcast.push_back(dims_literal.Get<int>({i}));
}
// Look up the value input, reshaping to a scalar if it was a
// 'legacy' scalar (secretly a vector).
xla::XlaOp data = ctx->Input(1);
xla::XlaOp data = ctx->Input("value");
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
data = xla::Reshape(data, {});
}
// Emit the actual computation, which broadcasts the scalar to the
// desired shape.
auto result = xla::Broadcast(data, broadcast);
auto result = xla::Broadcast(data, dims);
ctx->SetOutput(0, result);
}

View File

@ -65,8 +65,8 @@ class MirrorPadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape pad_shape = ctx->InputShape(1);
const TensorShape input_shape = ctx->InputShape("input");
const TensorShape pad_shape = ctx->InputShape("paddings");
MirrorPadMode mode;
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
@ -93,11 +93,10 @@ class MirrorPadOp : public XlaOpKernel {
// Evaluate the 'padding' constant input, reshaping to a matrix.
xla::Literal pad_literal;
OP_REQUIRES_OK(
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput("paddings", &pad_literal));
xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
auto in0 = ctx->Input("input");
xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
xla::StatusOr<xla::XlaOp> accum_status =

View File

@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel {
}
// XlaBuilder::Rev() requires concrete values for dimensions arg.
xla::Literal lax;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
std::vector<bool> revdims(x_shape.dims());
std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
revdims.begin());
std::vector<int64> dimensions;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax));
std::vector<int64> dimensions;
for (int d = 0; d < x_shape.dims(); ++d) {
if (revdims[d]) {
if (lax.Get<bool>({d})) {
dimensions.push_back(d);
}
}

View File

@ -108,21 +108,20 @@ class ExpandDimsOp : public XlaOpKernel {
explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape dim_shape = ctx->InputShape(1);
const TensorShape input_shape = ctx->InputShape("input");
const TensorShape dim_shape = ctx->InputShape("dim");
std::vector<int64> dims;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
// TODO(phawkins): the standard implementation of ExpandDimsOp seems to
// accept legacy scalars, even when they should be forbidden by the graphdef
// version.
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
OP_REQUIRES(ctx, dims.size() == 1,
errors::InvalidArgument(absl::StrCat(
"dim input to ExpandDims must be a scalar; got ",
dim_shape.DebugString())));
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
int dim = literal.data<int32>()[0];
int dim = dims[0];
OP_REQUIRES(ctx,
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
@ -148,7 +147,7 @@ class ExpandDimsOp : public XlaOpKernel {
dim = std::min<int32>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1);
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
}
};
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),

View File

@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel {
: XlaOpKernel(ctx), conjugate_(conjugate) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape perm_tensor_shape = ctx->InputShape(1);
const TensorShape input_shape = ctx->InputShape("x");
const TensorShape perm_tensor_shape = ctx->InputShape("perm");
// Preliminary validation of sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel {
". But input(1) is a vector of size ",
perm_tensor_shape.num_elements()));
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
std::vector<int32> perm(dims);
std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
perm.begin());
std::vector<int64> perm;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm));
std::vector<int64> transposed_order;
// Check whether permutation is a permutation of integers of [0 .. dims).
absl::InlinedVector<bool, 8> bits(dims);
bool is_identity = true;
for (int i = 0; i < dims; ++i) {
const int32 d = perm[i];
const int64 d = perm[i];
OP_REQUIRES(
ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel {
xla::XlaOp transposed;
// 0-D, 1-D, and identity transposes do nothing.
if (dims <= 1 || is_identity) {
transposed = ctx->Input(0);
transposed = ctx->Input("x");
} else {
transposed = xla::Transpose(ctx->Input(0), transposed_order);
transposed = xla::Transpose(ctx->Input("x"), transposed_order);
}
// Conjugate the transposed result if this is ConjugateTransposeOp.

View File

@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
return Status::OK();
}
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor) {
xla::BorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal));
return literal.Clone();
}
Status HostTensorToMutableBorrowingLiteral(
Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
xla::Shape xla_shape;

View File

@ -30,6 +30,11 @@ namespace tensorflow {
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
// Returns a Literal with the contents of 'host_tensor', backed by its own
// storage (i.e., not reusing 'host_tensor's buffers.)
xla::StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor);
// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
// owned by 'host_tensor', but is mutable via the xla::Literal methods.
Status HostTensorToMutableBorrowingLiteral(

View File

@ -136,25 +136,6 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
const XlaExpression* expression = CastExpressionFromTensor(tensor);
auto copy_tensor_to_literal = [](const Tensor& tensor,
xla::Literal* literal) {
xla::Shape literal_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
*literal = xla::Literal(literal_shape);
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = tensor.TotalBytes();
if (total_bytes > 0) {
void* dst_ptr = literal->untyped_data();
const void* src_ptr = DMAHelper::base(&tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
return Status::OK();
};
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
Tensor temp(tensor.dtype());
@ -164,14 +145,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
}
return copy_tensor_to_literal(temp, constant_literal);
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
// Make sure we treat zero-element tensors as constant.
if (new_shape.num_elements() == 0) {
Tensor temp(tensor.dtype(), new_shape);
return copy_tensor_to_literal(temp, constant_literal);
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
xla::XlaOp handle = expression->handle();
@ -322,6 +304,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
absl::string_view name, std::vector<int64>* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
xla::Literal* out) {
xla::Literal literal;

View File

@ -111,14 +111,6 @@ class XlaOpKernelContext {
Status ConstantInput(int index, xla::Literal* constant_literal);
Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
// Converts a constant scalar int32 or int64 tensor into an int64.
Status ConstantInputAsIntScalar(int index, int64* out);
Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
@ -134,6 +126,8 @@ class XlaOpKernelContext {
// Reshapes and converts a constant int32 or int64 tensor into a vector of
// int64s.
Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
Status ConstantInputReshapedToIntVector(absl::string_view name,
std::vector<int64>* out);
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
@ -260,6 +254,14 @@ class XlaOpKernelContext {
// type to allow mapping for variant to more generic types.
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
OpKernelContext* const context_;
};