[TF2XLA] Support unknown ranked tensor in fake param.
Unknown ranked tensor are expressed as xla shape T[0] as used in other places across the tf2xla bridge. I don't quite like it but fixing that is outside the scope of this cl. PiperOrigin-RevId: 306476269 Change-Id: I67eb6f93f38059003549a26f123dbcc299aa57d0
This commit is contained in:
parent
7c115c16e0
commit
3cfc29aa36
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TensorShape tensor_shape;
|
// Tensor shape can be unknown.
|
||||||
|
PartialTensorShape tensor_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape));
|
||||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_));
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_));
|
||||||
|
|
|
@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a TensorShape into the equivalent XLA Shape proto.
|
||||||
|
Status TensorShapeToXLAShape(DataType dtype,
|
||||||
|
const PartialTensorShape& tensor_shape,
|
||||||
|
xla::Shape* shape) {
|
||||||
|
xla::PrimitiveType type;
|
||||||
|
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
|
||||||
|
*shape = TensorShapeToXLAShape(type, tensor_shape);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||||
|
const PartialTensorShape& tensor_shape) {
|
||||||
|
if (tensor_shape.unknown_rank()) {
|
||||||
|
// For unknown shape, create a rank 1 size 0 tensor.
|
||||||
|
return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
|
||||||
|
}
|
||||||
|
int rank = tensor_shape.dims();
|
||||||
|
std::vector<int64> dimensions(rank);
|
||||||
|
std::vector<bool> dynamic_dimensions(rank, false);
|
||||||
|
std::vector<int64> layout(rank);
|
||||||
|
for (int d = 0; d < rank; ++d) {
|
||||||
|
dimensions[d] = tensor_shape.dim_size(d);
|
||||||
|
if (dimensions[d] < 0) {
|
||||||
|
dynamic_dimensions[d] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
|
||||||
|
std::iota(layout.rbegin(), layout.rend(), 0);
|
||||||
|
xla::Shape result =
|
||||||
|
xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
|
||||||
|
|
||||||
|
for (int64 d = 0; d < rank; ++d) {
|
||||||
|
result.set_dynamic_dimension(d, dynamic_dimensions[d]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Convert a TensorShape into the equivalent XLA Shape proto.
|
// Convert a TensorShape into the equivalent XLA Shape proto.
|
||||||
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
|
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
|
||||||
xla::Shape* shape) {
|
xla::Shape* shape) {
|
||||||
|
|
|
@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
|
||||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||||
const TensorShape& tensor_shape);
|
const TensorShape& tensor_shape);
|
||||||
|
|
||||||
|
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
|
||||||
|
// with unknown rank is represented by an r1 with empty dimension.
|
||||||
|
Status TensorShapeToXLAShape(DataType dtype,
|
||||||
|
const PartialTensorShape& tensor_shape,
|
||||||
|
xla::Shape* shape);
|
||||||
|
|
||||||
|
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
|
||||||
|
// with unknown rank is represented by an r1 with empty dimension.
|
||||||
|
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
|
||||||
|
const PartialTensorShape& tensor_shape);
|
||||||
|
|
||||||
// Given an XLA shape with layouts, builds a layout vector in the form able to
|
// Given an XLA shape with layouts, builds a layout vector in the form able to
|
||||||
// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/....
|
// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/....
|
||||||
// THe returned vector is a linearized sequence of the minor-to-major values of
|
// THe returned vector is a linearized sequence of the minor-to-major values of
|
||||||
|
|
|
@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
|
||||||
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unranked fake param returns a 0 shaped tensor.
|
||||||
|
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
PartialTensorShape shape;
|
||||||
|
auto a = ops::FakeParam(scope, DT_INT32, shape);
|
||||||
|
auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler compiler(DefaultOptions());
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
|
||||||
|
std::move(graph), {}, &result));
|
||||||
|
// Check that the return shapes are correctly tranposed.
|
||||||
|
EXPECT_EQ(result.xla_output_shape,
|
||||||
|
xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{xla::ShapeUtil::MakeShape(xla::S32, {0})}));
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that the compiler doesn't reorder the parameters.
|
// Tests that the compiler doesn't reorder the parameters.
|
||||||
TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
TEST_F(XlaCompilerTest, MixedOrderArguments) {
|
||||||
for (bool swap_order : {false, true}) {
|
for (bool swap_order : {false, true}) {
|
||||||
|
|
Loading…
Reference in New Issue