diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 976ff91f6ce..1ea0e797675 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -45,22 +45,32 @@ namespace tensorflow { namespace { // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist -// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a -// dimension is static, a constant dimension is returned. +// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a +// dimension is static, a constant dimension is returned. If a dim is dynamic, a +// dynamic XlaOp representing the dynamic size is returned. xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims( XlaOpKernelContext* ctx, const xla::Shape& element_shape, const xla::Shape& list_shape, int64 num_elements) { std::vector<int64> dynamic_sizes; - ctx->set_dynamic_dimension_is_minus_one(true); // The multiplier can be a dynamic value. TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes)); + std::vector<bool> dims_are_dynamic; + TF_RETURN_IF_ERROR( + ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic)); + bool leading_dim_is_dynamic; + TF_RETURN_IF_ERROR( + ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic)); std::vector<std::vector<xla::XlaOp>> list_dynamic_dims; // Set dynamic dimension size to 0 for initialization value. std::vector<xla::XlaOp> dynamic_dims; - // Leading dim is a static dimension. - dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements)); + if (leading_dim_is_dynamic) { + dynamic_dims.push_back(ctx->Input(1)); + } else { + dynamic_dims.push_back( + xla::ConstantR0<int32>(ctx->builder(), num_elements)); + } for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) { - if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) { + if (dims_are_dynamic[dim]) { auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1}); dynamic_dim_size = xla::Reshape(dynamic_dim_size, {}); dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32); @@ -80,11 +90,12 @@ class TensorListLengthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 leading_dim; - OP_REQUIRES_OK(ctx, - GetLeadingDimForTensorList(ctx->Input(0), &leading_dim)); - Tensor length_tensor(DT_INT32, {}); - length_tensor.scalar<int32>()() = static_cast<int32>(leading_dim); - ctx->SetConstantOutput(0, length_tensor); + xla::XlaOp leading_dim_size; + bool leading_dim_is_dynamic; + OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim, + &leading_dim_is_dynamic, + &leading_dim_size)); + ctx->SetOutput(0, leading_dim_size); } private: @@ -134,6 +145,9 @@ class TensorListReserveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + bool num_element_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic)); OP_REQUIRES( ctx, num_elements >= 0, errors::InvalidArgument( @@ -156,7 +170,8 @@ class TensorListReserveOp : public XlaOpKernel { if (got_shape) { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( - element_shape, num_elements, &list_shape)); + element_shape, num_elements, + num_element_is_dynamic, &list_shape)); // Set up dynamic dimension sizes to create the zero tensor. auto list_dynamic_dims_or = GetTensorListDynamicDims( ctx, element_shape, list_shape, num_elements); @@ -175,8 +190,8 @@ class TensorListReserveOp : public XlaOpKernel { return; } - xla::XlaOp result = - BuildUninitializedTensorList(ctx->builder(), num_elements); + xla::XlaOp result = BuildUninitializedTensorList( + ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1)); ctx->SetTensorListOutput(0, result); } @@ -200,6 +215,9 @@ class EmptyTensorListOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); + bool num_element_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic)); OP_REQUIRES(ctx, max_num_elements >= 0, errors::InvalidArgument( "XLA compilation requires a fixed tensor list size. Set " @@ -210,9 +228,9 @@ class EmptyTensorListOp : public XlaOpKernel { if (dtype_ != DT_VARIANT) { // We are creating a non-nested TensorList. - // If element shape is compile time constant and it's not "unknown rank" - // shape (-1), create an initialized TensorList. Otherwise create an - // uninitialized TensorList. + // If element shape is compile time constant and it's not "unknown + // rank" shape (-1), create an initialized TensorList. Otherwise + // create an uninitialized TensorList. xla::XlaOp element_shape_handle = ctx->Input(0); xla::PrimitiveType type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type)); @@ -224,7 +242,8 @@ class EmptyTensorListOp : public XlaOpKernel { if (got_shape) { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( - element_shape, max_num_elements, &list_shape)); + element_shape, max_num_elements, + num_element_is_dynamic, &list_shape)); // Set up dynamic dimension sizes to create the zero tensor. auto list_dynamic_dims_or = GetTensorListDynamicDims( ctx, element_shape, list_shape, max_num_elements); @@ -243,7 +262,8 @@ class EmptyTensorListOp : public XlaOpKernel { // We are creating a nested TensorList or a non-nested TensorList with // unknown shape. Just create an uninitialized TensorList. xla::XlaOp result = - BuildUninitializedTensorList(ctx->builder(), max_num_elements); + BuildUninitializedTensorList(ctx->builder(), max_num_elements, + num_element_is_dynamic, ctx->Input(1)); ctx->SetTensorListOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 0e367e10ec4..156f9bfea40 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -189,28 +189,42 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, } xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, - int64 leading_dimension) { + int64 leading_dimension, + bool leading_size_is_dynamic, + xla::XlaOp leading_dim_size) { auto zero = xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); - return xla::Broadcast(zero, std::vector<int64>{leading_dimension}); + auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension}); + if (leading_size_is_dynamic) { + return xla::SetDimensionSize(broadcast, leading_dim_size, 0); + } else { + return broadcast; + } } -Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) { +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); if (is_initialized) { auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); + *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0); + auto buffer = xla::GetTupleElement(list, 0); *leading_dim = buffer_shape.dimensions(0); + *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0); } else { + *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0); *leading_dim = list_shape.dimensions(0); + *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0); } return Status::OK(); } Status GetTensorListShapeFromElementTensorListShape( const xla::Shape& element_tensor_list_shape, int64 leading_dim, - xla::Shape* tensor_list_shape) { + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) { std::vector<xla::Shape> shapes; int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape); for (int i = 0; i < tuple_size; i++) { @@ -220,6 +234,9 @@ Status GetTensorListShapeFromElementTensorListShape( dimensions.insert(dimensions.begin(), leading_dim); shapes.push_back( xla::ShapeUtil::MakeShape(shape.element_type(), dimensions)); + if (leading_dim_is_dynamic) { + shapes.back().set_dynamic_dimension(0, true); + } } shapes.push_back( xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{})); @@ -229,6 +246,7 @@ Status GetTensorListShapeFromElementTensorListShape( Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, int64 leading_dim, + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) { if (!element_shape.IsArray()) { return errors::InvalidArgument( @@ -236,12 +254,12 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, "shape. But element shape is ", element_shape.DebugString()); } - std::vector<xla::Shape> shapes; std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions()); dimensions.insert(dimensions.begin(), leading_dim); shapes.push_back( xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions)); + shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic); shapes.push_back( xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{})); *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); @@ -279,7 +297,10 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, bool element_is_tensor_list, xla::XlaOp* initialized_list) { int64 leading_dim; - TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim)); + xla::XlaOp leading_dim_dynamic_size; + bool leading_dim_is_dynamic; + TF_RETURN_IF_ERROR(GetLeadingDimForTensorList( + list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size)); xla::XlaBuilder* b = list.builder(); xla::Shape list_shape; @@ -287,12 +308,11 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, if (element_is_tensor_list) { TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape( - element_shape, leading_dim, &list_shape)); + element_shape, leading_dim, leading_dim_is_dynamic, &list_shape)); } else { TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape( - element_shape, leading_dim, &list_shape)); + element_shape, leading_dim, leading_dim_is_dynamic, &list_shape)); } - bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (is_initialized) { @@ -312,8 +332,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { std::vector<xla::XlaOp> dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); - // Leading dim is a static dimension. - dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim)); + dynamic_dims.push_back(leading_dim_dynamic_size); xla::XlaOp sub_element; if (element_is_tensor_list) { sub_element = xla::GetTupleElement(element, i); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index ef3c8badf71..549ccd5aece 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -60,17 +60,22 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, // Returns an uninitialized TensorList. xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, - int64 leading_dimension); + int64 leading_dimension, + bool leading_size_is_dynamic, + xla::XlaOp leading_dim_size); -// Returns leading dimension for the TensorList. -// Input can be initialized or uninitialized TensorList. -// Non-nested and nested TensorLists are both supported. -Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim); +// Returns leading dimension for the TensorList as well as a dynamic op +// representing the dynamic size. Input can be initialized or uninitialized +// TensorList. Non-nested and nested TensorLists are both supported. +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size); // Returns TensorList shape for the element shape. // Element shape must be a normal tensor shape. Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, int64 leading_dim, + bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape); // Returns a TensorList filled by zeros with the given shape. diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index fe7a5898011..a94411f1b30 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -513,10 +513,26 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Prepare dynamic dimensions for element shapes. std::vector<std::vector<xla::XlaOp>> list_dynamic_dims; for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { - // Set dynamic dimension size to 0 for initilization value. std::vector<xla::XlaOp> dynamic_dims; + const xla::Shape& shape = list_shape.tuple_shapes(i); - for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) { + + // We already have the dynamic size of leading dimension outside of + // the while loop without initializing the TensorList inside the while + // loop. + if (shape.is_dynamic_dimension(0)) { + xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0); + dynamic_dims.push_back(leading_dim_size); + } else { + int32 dim_size = shape.dimensions(0); + dynamic_dims.push_back( + xla::ConstantR0<int32>(ctx->builder(), dim_size)); + } + + // Set dynamic dimension size to 0 for element value. Inside the while + // loop, TensorlistSetItem will properly set the element shape's + // dynamic diemnsion. + for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) { int32 dim_size = shape.dimensions(dim); if (shape.is_dynamic_dimension(dim)) { dim_size = 0; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 07537546d52..c2d1906e47a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -259,6 +259,32 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal, return Status::OK(); } +Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { + xla::Literal literal; + XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; + xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client); + if (!dynamism_or_status.ok()) { + Status status = dynamism_or_status.status(); + errors::AppendToMessage(&status, "while evaluating input dynamism", index, + " of ", context_->op_kernel().type_string()); + return status; + } + Tensor dynamism = dynamism_or_status.ValueOrDie(); + + Tensor temp(dynamism.dtype()); + TensorShape tensor_shape({}); + if (!temp.CopyFrom(dynamism, tensor_shape)) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape); + } + + TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); + *out = literal.Get<bool>({}); + return Status::OK(); +} + Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( int index, std::vector<bool>* out) { xla::Literal literal; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 75c3e60171a..1ed343ba20f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -119,7 +119,7 @@ class XlaOpKernelContext { // Evaluates input and returns their dynamism vector in a vector of // predicates. Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out); - + Status ResolveInputDynamismIntoPred(int index, bool* out); // Helper methods for constant inputs. // Evaluates input `index` and stores it in `*constant_literal`. If the