Support dynamic leading dimension for tensorlist.

PiperOrigin-RevId: 326118256
Change-Id: Icea13ee03e23aaac56f17dc5f58bd9182a4ba02a
This commit is contained in:
Yunxing Dai 2020-08-11 15:49:12 -07:00 committed by TensorFlower Gardener
parent c604d85364
commit 29784fade2
6 changed files with 124 additions and 38 deletions

View File

@ -45,22 +45,32 @@ namespace tensorflow {
namespace {
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
// dimension is static, a constant dimension is returned.
// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
// dimension is static, a constant dimension is returned. If a dim is dynamic, a
// dynamic XlaOp representing the dynamic size is returned.
xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
const xla::Shape& list_shape, int64 num_elements) {
std::vector<int64> dynamic_sizes;
ctx->set_dynamic_dimension_is_minus_one(true);
// The multiplier can be a dynamic value.
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
std::vector<bool> dims_are_dynamic;
TF_RETURN_IF_ERROR(
ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
bool leading_dim_is_dynamic;
TF_RETURN_IF_ERROR(
ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
// Set dynamic dimension size to 0 for initialization value.
std::vector<xla::XlaOp> dynamic_dims;
// Leading dim is a static dimension.
dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
if (leading_dim_is_dynamic) {
dynamic_dims.push_back(ctx->Input(1));
} else {
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), num_elements));
}
for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
if (dims_are_dynamic[dim]) {
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
@ -80,11 +90,12 @@ class TensorListLengthOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
int64 leading_dim;
OP_REQUIRES_OK(ctx,
GetLeadingDimForTensorList(ctx->Input(0), &leading_dim));
Tensor length_tensor(DT_INT32, {});
length_tensor.scalar<int32>()() = static_cast<int32>(leading_dim);
ctx->SetConstantOutput(0, length_tensor);
xla::XlaOp leading_dim_size;
bool leading_dim_is_dynamic;
OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
&leading_dim_is_dynamic,
&leading_dim_size));
ctx->SetOutput(0, leading_dim_size);
}
private:
@ -134,6 +145,9 @@ class TensorListReserveOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
int64 num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
OP_REQUIRES(
ctx, num_elements >= 0,
errors::InvalidArgument(
@ -156,7 +170,8 @@ class TensorListReserveOp : public XlaOpKernel {
if (got_shape) {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, num_elements, &list_shape));
element_shape, num_elements,
num_element_is_dynamic, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, num_elements);
@ -175,8 +190,8 @@ class TensorListReserveOp : public XlaOpKernel {
return;
}
xla::XlaOp result =
BuildUninitializedTensorList(ctx->builder(), num_elements);
xla::XlaOp result = BuildUninitializedTensorList(
ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
ctx->SetTensorListOutput(0, result);
}
@ -200,6 +215,9 @@ class EmptyTensorListOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
int64 max_num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
OP_REQUIRES(ctx, max_num_elements >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed tensor list size. Set "
@ -210,9 +228,9 @@ class EmptyTensorListOp : public XlaOpKernel {
if (dtype_ != DT_VARIANT) {
// We are creating a non-nested TensorList.
// If element shape is compile time constant and it's not "unknown rank"
// shape (-1), create an initialized TensorList. Otherwise create an
// uninitialized TensorList.
// If element shape is compile time constant and it's not "unknown
// rank" shape (-1), create an initialized TensorList. Otherwise
// create an uninitialized TensorList.
xla::XlaOp element_shape_handle = ctx->Input(0);
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
@ -224,7 +242,8 @@ class EmptyTensorListOp : public XlaOpKernel {
if (got_shape) {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, max_num_elements, &list_shape));
element_shape, max_num_elements,
num_element_is_dynamic, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, max_num_elements);
@ -243,7 +262,8 @@ class EmptyTensorListOp : public XlaOpKernel {
// We are creating a nested TensorList or a non-nested TensorList with
// unknown shape. Just create an uninitialized TensorList.
xla::XlaOp result =
BuildUninitializedTensorList(ctx->builder(), max_num_elements);
BuildUninitializedTensorList(ctx->builder(), max_num_elements,
num_element_is_dynamic, ctx->Input(1));
ctx->SetTensorListOutput(0, result);
}

View File

@ -189,28 +189,42 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
}
xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
int64 leading_dimension) {
int64 leading_dimension,
bool leading_size_is_dynamic,
xla::XlaOp leading_dim_size) {
auto zero =
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
return xla::Broadcast(zero, std::vector<int64>{leading_dimension});
auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension});
if (leading_size_is_dynamic) {
return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
} else {
return broadcast;
}
}
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) {
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
bool* leading_dim_is_dynamic,
xla::XlaOp* leading_dim_dynamic_size) {
bool is_initialized;
TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
if (is_initialized) {
auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
*leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
auto buffer = xla::GetTupleElement(list, 0);
*leading_dim = buffer_shape.dimensions(0);
*leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
} else {
*leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
*leading_dim = list_shape.dimensions(0);
*leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
}
return Status::OK();
}
Status GetTensorListShapeFromElementTensorListShape(
const xla::Shape& element_tensor_list_shape, int64 leading_dim,
xla::Shape* tensor_list_shape) {
bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
std::vector<xla::Shape> shapes;
int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
for (int i = 0; i < tuple_size; i++) {
@ -220,6 +234,9 @@ Status GetTensorListShapeFromElementTensorListShape(
dimensions.insert(dimensions.begin(), leading_dim);
shapes.push_back(
xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
if (leading_dim_is_dynamic) {
shapes.back().set_dynamic_dimension(0, true);
}
}
shapes.push_back(
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
@ -229,6 +246,7 @@ Status GetTensorListShapeFromElementTensorListShape(
Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
int64 leading_dim,
bool leading_dim_is_dynamic,
xla::Shape* tensor_list_shape) {
if (!element_shape.IsArray()) {
return errors::InvalidArgument(
@ -236,12 +254,12 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
"shape. But element shape is ",
element_shape.DebugString());
}
std::vector<xla::Shape> shapes;
std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
dimensions.insert(dimensions.begin(), leading_dim);
shapes.push_back(
xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
shapes.push_back(
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
*tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
@ -279,7 +297,10 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
bool element_is_tensor_list,
xla::XlaOp* initialized_list) {
int64 leading_dim;
TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim));
xla::XlaOp leading_dim_dynamic_size;
bool leading_dim_is_dynamic;
TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
xla::XlaBuilder* b = list.builder();
xla::Shape list_shape;
@ -287,12 +308,11 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
if (element_is_tensor_list) {
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
element_shape, leading_dim, &list_shape));
element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
} else {
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
element_shape, leading_dim, &list_shape));
element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
}
bool is_initialized;
TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
if (is_initialized) {
@ -312,8 +332,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
// Leading dim is a static dimension.
dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
dynamic_dims.push_back(leading_dim_dynamic_size);
xla::XlaOp sub_element;
if (element_is_tensor_list) {
sub_element = xla::GetTupleElement(element, i);

View File

@ -60,17 +60,22 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
// Returns an uninitialized TensorList.
xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
int64 leading_dimension);
int64 leading_dimension,
bool leading_size_is_dynamic,
xla::XlaOp leading_dim_size);
// Returns leading dimension for the TensorList.
// Input can be initialized or uninitialized TensorList.
// Non-nested and nested TensorLists are both supported.
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim);
// Returns leading dimension for the TensorList as well as a dynamic op
// representing the dynamic size. Input can be initialized or uninitialized
// TensorList. Non-nested and nested TensorLists are both supported.
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
bool* leading_dim_is_dynamic,
xla::XlaOp* leading_dim_dynamic_size);
// Returns TensorList shape for the element shape.
// Element shape must be a normal tensor shape.
Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
int64 leading_dim,
bool leading_dim_is_dynamic,
xla::Shape* tensor_list_shape);
// Returns a TensorList filled by zeros with the given shape.

View File

@ -513,10 +513,26 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
// Prepare dynamic dimensions for element shapes.
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
// Set dynamic dimension size to 0 for initilization value.
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
// We already have the dynamic size of leading dimension outside of
// the while loop without initializing the TensorList inside the while
// loop.
if (shape.is_dynamic_dimension(0)) {
xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
dynamic_dims.push_back(leading_dim_size);
} else {
int32 dim_size = shape.dimensions(0);
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dim_size));
}
// Set dynamic dimension size to 0 for element value. Inside the while
// loop, TensorlistSetItem will properly set the element shape's
// dynamic diemnsion.
for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
int32 dim_size = shape.dimensions(dim);
if (shape.is_dynamic_dimension(dim)) {
dim_size = 0;

View File

@ -259,6 +259,32 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal,
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
xla::Literal literal;
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
if (!dynamism_or_status.ok()) {
Status status = dynamism_or_status.status();
errors::AppendToMessage(&status, "while evaluating input dynamism", index,
" of ", context_->op_kernel().type_string());
return status;
}
Tensor dynamism = dynamism_or_status.ValueOrDie();
Tensor temp(dynamism.dtype());
TensorShape tensor_shape({});
if (!temp.CopyFrom(dynamism, tensor_shape)) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
}
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
*out = literal.Get<bool>({});
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
int index, std::vector<bool>* out) {
xla::Literal literal;

View File

@ -119,7 +119,7 @@ class XlaOpKernelContext {
// Evaluates input and returns their dynamism vector in a vector of
// predicates.
Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
Status ResolveInputDynamismIntoPred(int index, bool* out);
// Helper methods for constant inputs.
// Evaluates input `index` and stores it in `*constant_literal`. If the