[TF2XLA] Support unknown ranked tensor in fake param.
Unknown ranked tensor are expressed as xla shape T[0] as used in other places across the tf2xla bridge. I don't quite like it but fixing that is outside the scope of this cl. PiperOrigin-RevId: 306476269 Change-Id: I67eb6f93f38059003549a26f123dbcc299aa57d0
This commit is contained in:
		
							parent
							
								
									7c115c16e0
								
							
						
					
					
						commit
						3cfc29aa36
					
				@ -20,6 +20,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
 | 
			
		||||
#include "tensorflow/core/framework/kernel_def_builder.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_shape.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
 | 
			
		||||
    DataType dtype;
 | 
			
		||||
    TensorShape tensor_shape;
 | 
			
		||||
    // Tensor shape can be unknown.
 | 
			
		||||
    PartialTensorShape tensor_shape;
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype));
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape));
 | 
			
		||||
    OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_));
 | 
			
		||||
 | 
			
		||||
@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Convert a TensorShape into the equivalent XLA Shape proto.
 | 
			
		||||
Status TensorShapeToXLAShape(DataType dtype,
 | 
			
		||||
                             const PartialTensorShape& tensor_shape,
 | 
			
		||||
                             xla::Shape* shape) {
 | 
			
		||||
  xla::PrimitiveType type;
 | 
			
		||||
  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
 | 
			
		||||
  *shape = TensorShapeToXLAShape(type, tensor_shape);
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
 | 
			
		||||
                                 const PartialTensorShape& tensor_shape) {
 | 
			
		||||
  if (tensor_shape.unknown_rank()) {
 | 
			
		||||
    // For unknown shape, create a rank 1 size 0 tensor.
 | 
			
		||||
    return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0});
 | 
			
		||||
  }
 | 
			
		||||
  int rank = tensor_shape.dims();
 | 
			
		||||
  std::vector<int64> dimensions(rank);
 | 
			
		||||
  std::vector<bool> dynamic_dimensions(rank, false);
 | 
			
		||||
  std::vector<int64> layout(rank);
 | 
			
		||||
  for (int d = 0; d < rank; ++d) {
 | 
			
		||||
    dimensions[d] = tensor_shape.dim_size(d);
 | 
			
		||||
    if (dimensions[d] < 0) {
 | 
			
		||||
      dynamic_dimensions[d] = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // XLA uses minor-to-major; Tensorflow uses major-to-minor.
 | 
			
		||||
  std::iota(layout.rbegin(), layout.rend(), 0);
 | 
			
		||||
  xla::Shape result =
 | 
			
		||||
      xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
 | 
			
		||||
 | 
			
		||||
  for (int64 d = 0; d < rank; ++d) {
 | 
			
		||||
    result.set_dynamic_dimension(d, dynamic_dimensions[d]);
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Convert a TensorShape into the equivalent XLA Shape proto.
 | 
			
		||||
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
 | 
			
		||||
                             xla::Shape* shape) {
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
 | 
			
		||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
 | 
			
		||||
                                 const TensorShape& tensor_shape);
 | 
			
		||||
 | 
			
		||||
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
 | 
			
		||||
// with unknown rank is represented by an r1 with empty dimension.
 | 
			
		||||
Status TensorShapeToXLAShape(DataType dtype,
 | 
			
		||||
                             const PartialTensorShape& tensor_shape,
 | 
			
		||||
                             xla::Shape* shape);
 | 
			
		||||
 | 
			
		||||
// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape
 | 
			
		||||
// with unknown rank is represented by an r1 with empty dimension.
 | 
			
		||||
xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
 | 
			
		||||
                                 const PartialTensorShape& tensor_shape);
 | 
			
		||||
 | 
			
		||||
// Given an XLA shape with layouts, builds a layout vector in the form able to
 | 
			
		||||
// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/....
 | 
			
		||||
// THe returned vector is a linearized sequence of the minor-to-major values of
 | 
			
		||||
 | 
			
		||||
@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
 | 
			
		||||
            xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Unranked fake param returns a 0 shaped tensor.
 | 
			
		||||
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
 | 
			
		||||
  Scope scope = Scope::NewRootScope().ExitOnError();
 | 
			
		||||
  PartialTensorShape shape;
 | 
			
		||||
  auto a = ops::FakeParam(scope, DT_INT32, shape);
 | 
			
		||||
  auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
 | 
			
		||||
  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
 | 
			
		||||
  TF_ASSERT_OK(scope.ToGraph(graph.get()));
 | 
			
		||||
 | 
			
		||||
  // Compiles the graph.
 | 
			
		||||
  XlaCompiler compiler(DefaultOptions());
 | 
			
		||||
 | 
			
		||||
  XlaCompiler::CompilationResult result;
 | 
			
		||||
  TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
 | 
			
		||||
                                     std::move(graph), {}, &result));
 | 
			
		||||
  // Check that the return shapes are correctly tranposed.
 | 
			
		||||
  EXPECT_EQ(result.xla_output_shape,
 | 
			
		||||
            xla::ShapeUtil::MakeTupleShape(
 | 
			
		||||
                {xla::ShapeUtil::MakeShape(xla::S32, {0})}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests that the compiler doesn't reorder the parameters.
 | 
			
		||||
TEST_F(XlaCompilerTest, MixedOrderArguments) {
 | 
			
		||||
  for (bool swap_order : {false, true}) {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user