[TF2XLA] Support unknown ranked tensor in fake param.
Unknown ranked tensor are expressed as xla shape T[0] as used in other places across the tf2xla bridge. I don't quite like it but fixing that is outside the scope of this cl. PiperOrigin-RevId: 306476269 Change-Id: I67eb6f93f38059003549a26f123dbcc299aa57d0
This commit is contained in:
parent
7c115c16e0
commit
3cfc29aa36
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel {
|
|||
public:
|
||||
explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
DataType dtype;
|
||||
TensorShape tensor_shape;
|
||||
// Tensor shape can be unknown.
|
||||
PartialTensorShape tensor_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape));
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_));
|
||||
|
|
|
@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convert a TensorShape into the equivalent XLA Shape proto.
|
||||
Status TensorShapeToXLAShape(DataType dtype,
|
||||
const PartialTensorShape& tensor_shape,
|
||||
xla::Shape* shape) {
|
||||
xla::PrimitiveType type;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
|
||||
*shape = TensorShapeToXLAShape(type, tensor_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||
const PartialTensorShape& tensor_shape) {
|
||||
if (tensor_shape.unknown_rank()) {
|
||||
// For unknown shape, create a rank 1 size 0 tensor.
|
||||
return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
|
||||
}
|
||||
int rank = tensor_shape.dims();
|
||||
std::vector<int64> dimensions(rank);
|
||||
std::vector<bool> dynamic_dimensions(rank, false);
|
||||
std::vector<int64> layout(rank);
|
||||
for (int d = 0; d < rank; ++d) {
|
||||
dimensions[d] = tensor_shape.dim_size(d);
|
||||
if (dimensions[d] < 0) {
|
||||
dynamic_dimensions[d] = true;
|
||||
}
|
||||
}
|
||||
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
|
||||
std::iota(layout.rbegin(), layout.rend(), 0);
|
||||
xla::Shape result =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
|
||||
|
||||
for (int64 d = 0; d < rank; ++d) {
|
||||
result.set_dynamic_dimension(d, dynamic_dimensions[d]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert a TensorShape into the equivalent XLA Shape proto.
|
||||
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
|
||||
xla::Shape* shape) {
|
||||
|
|
|
@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
|
|||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||
const TensorShape& tensor_shape);
|
||||
|
||||
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
|
||||
// with unknown rank is represented by an r1 with empty dimension.
|
||||
Status TensorShapeToXLAShape(DataType dtype,
|
||||
const PartialTensorShape& tensor_shape,
|
||||
xla::Shape* shape);
|
||||
|
||||
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
|
||||
// with unknown rank is represented by an r1 with empty dimension.
|
||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||
const PartialTensorShape& tensor_shape);
|
||||
|
||||
// Given an XLA shape with layouts, builds a layout vector in the form able to
|
||||
// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/....
|
||||
// THe returned vector is a linearized sequence of the minor-to-major values of
|
||||
|
|
|
@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
|
|||
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||
}
|
||||
|
||||
// Unranked fake param returns a 0 shaped tensor.
|
||||
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
PartialTensorShape shape;
|
||||
auto a = ops::FakeParam(scope, DT_INT32, shape);
|
||||
auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
|
||||
std::move(graph), {}, &result));
|
||||
// Check that the return shapes are correctly tranposed.
|
||||
EXPECT_EQ(result.xla_output_shape,
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {0})}));
|
||||
}
|
||||
|
||||
// Tests that the compiler doesn't reorder the parameters.
|
||||
TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
||||
for (bool swap_order : {false, true}) {
|
||||
|
|
Loading…
Reference in New Issue