diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc index ec3463bd58f..ba9e406312d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel { public: explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { DataType dtype; - TensorShape tensor_shape; + // Tensor shape can be unknown. + PartialTensorShape tensor_shape; OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 8997b2f5c68..2fce6e7f0c7 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, return Status::OK(); } +// Convert a TensorShape into the equivalent XLA Shape proto. +Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const PartialTensorShape& tensor_shape) { + if (tensor_shape.unknown_rank()) { + // For unknown shape, create a rank 1 size 0 tensor. + return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0}); + } + int rank = tensor_shape.dims(); + std::vector dimensions(rank); + std::vector dynamic_dimensions(rank, false); + std::vector layout(rank); + for (int d = 0; d < rank; ++d) { + dimensions[d] = tensor_shape.dim_size(d); + if (dimensions[d] < 0) { + dynamic_dimensions[d] = true; + } + } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. + std::iota(layout.rbegin(), layout.rend(), 0); + xla::Shape result = + xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); + + for (int64 d = 0; d < rank; ++d) { + result.set_dynamic_dimension(d, dynamic_dimensions[d]); + } + return result; +} + // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 331cfa38c1d..438df7ecb18 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape); + +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const PartialTensorShape& tensor_shape); + // Given an XLA shape with layouts, builds a layout vector in the form able to // be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... // THe returned vector is a linearized sequence of the minor-to-major values of diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 76780167187..09f34bde64b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) { xla::ShapeUtil::MakeTupleShape({transposed, transposed})); } +// Unranked fake param returns a 0 shaped tensor. +TEST_F(XlaCompilerTest, UnrankedFakeParam) { + Scope scope = Scope::NewRootScope().ExitOnError(); + PartialTensorShape shape; + auto a = ops::FakeParam(scope, DT_INT32, shape); + auto ret = ops::_Retval(scope.WithOpName("D"), a, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile", + std::move(graph), {}, &result)); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {0})})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) {