[TF2XLA] Support unknown ranked tensor in fake param.

Unknown ranked tensor are expressed as xla shape T[0] as used in other places across the tf2xla bridge. I don't quite like it but fixing that is outside the scope of this cl.

PiperOrigin-RevId: 306476269
Change-Id: I67eb6f93f38059003549a26f123dbcc299aa57d0
This commit is contained in:
Yunxing Dai 2020-04-14 11:12:08 -07:00 committed by TensorFlower Gardener
parent 7c115c16e0
commit 3cfc29aa36
4 changed files with 72 additions and 1 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel {
public:
explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
DataType dtype;
TensorShape tensor_shape;
// Tensor shape can be unknown.
PartialTensorShape tensor_shape;
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape));
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_));

View File

@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
return Status::OK();
}
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype,
const PartialTensorShape& tensor_shape,
xla::Shape* shape) {
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
*shape = TensorShapeToXLAShape(type, tensor_shape);
return Status::OK();
}
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
const PartialTensorShape& tensor_shape) {
if (tensor_shape.unknown_rank()) {
// For unknown shape, create a rank 1 size 0 tensor.
return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
}
int rank = tensor_shape.dims();
std::vector<int64> dimensions(rank);
std::vector<bool> dynamic_dimensions(rank, false);
std::vector<int64> layout(rank);
for (int d = 0; d < rank; ++d) {
dimensions[d] = tensor_shape.dim_size(d);
if (dimensions[d] < 0) {
dynamic_dimensions[d] = true;
}
}
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
xla::Shape result =
xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
for (int64 d = 0; d < rank; ++d) {
result.set_dynamic_dimension(d, dynamic_dimensions[d]);
}
return result;
}
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape) {

View File

@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
const TensorShape& tensor_shape);
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
// with unknown rank is represented by an r1 with empty dimension.
Status TensorShapeToXLAShape(DataType dtype,
const PartialTensorShape& tensor_shape,
xla::Shape* shape);
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
// with unknown rank is represented by an r1 with empty dimension.
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
const PartialTensorShape& tensor_shape);
// Given an XLA shape with layouts, builds a layout vector in the form able to
// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/....
// THe returned vector is a linearized sequence of the minor-to-major values of

View File

@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// Unranked fake param returns a 0 shaped tensor.
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
Scope scope = Scope::NewRootScope().ExitOnError();
PartialTensorShape shape;
auto a = ops::FakeParam(scope, DT_INT32, shape);
auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
std::move(graph), {}, &result));
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {0})}));
}
// Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest, MixedOrderArguments) {
for (bool swap_order : {false, true}) {