Support dynamic leading dimension for tensorlist.
PiperOrigin-RevId: 326118256 Change-Id: Icea13ee03e23aaac56f17dc5f58bd9182a4ba02a
This commit is contained in:
parent
c604d85364
commit
29784fade2
@ -45,22 +45,32 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
|
||||
// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
|
||||
// dimension is static, a constant dimension is returned.
|
||||
// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
|
||||
// dimension is static, a constant dimension is returned. If a dim is dynamic, a
|
||||
// dynamic XlaOp representing the dynamic size is returned.
|
||||
xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
|
||||
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
|
||||
const xla::Shape& list_shape, int64 num_elements) {
|
||||
std::vector<int64> dynamic_sizes;
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
// The multiplier can be a dynamic value.
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
|
||||
std::vector<bool> dims_are_dynamic;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
|
||||
bool leading_dim_is_dynamic;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
// Set dynamic dimension size to 0 for initialization value.
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
// Leading dim is a static dimension.
|
||||
dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
|
||||
if (leading_dim_is_dynamic) {
|
||||
dynamic_dims.push_back(ctx->Input(1));
|
||||
} else {
|
||||
dynamic_dims.push_back(
|
||||
xla::ConstantR0<int32>(ctx->builder(), num_elements));
|
||||
}
|
||||
for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
|
||||
if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
|
||||
if (dims_are_dynamic[dim]) {
|
||||
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
|
||||
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
|
||||
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
|
||||
@ -80,11 +90,12 @@ class TensorListLengthOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 leading_dim;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetLeadingDimForTensorList(ctx->Input(0), &leading_dim));
|
||||
Tensor length_tensor(DT_INT32, {});
|
||||
length_tensor.scalar<int32>()() = static_cast<int32>(leading_dim);
|
||||
ctx->SetConstantOutput(0, length_tensor);
|
||||
xla::XlaOp leading_dim_size;
|
||||
bool leading_dim_is_dynamic;
|
||||
OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
|
||||
&leading_dim_is_dynamic,
|
||||
&leading_dim_size));
|
||||
ctx->SetOutput(0, leading_dim_size);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -134,6 +145,9 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 num_elements;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||
bool num_element_is_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
|
||||
OP_REQUIRES(
|
||||
ctx, num_elements >= 0,
|
||||
errors::InvalidArgument(
|
||||
@ -156,7 +170,8 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
if (got_shape) {
|
||||
xla::Shape list_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||
element_shape, num_elements, &list_shape));
|
||||
element_shape, num_elements,
|
||||
num_element_is_dynamic, &list_shape));
|
||||
// Set up dynamic dimension sizes to create the zero tensor.
|
||||
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||
ctx, element_shape, list_shape, num_elements);
|
||||
@ -175,8 +190,8 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
xla::XlaOp result =
|
||||
BuildUninitializedTensorList(ctx->builder(), num_elements);
|
||||
xla::XlaOp result = BuildUninitializedTensorList(
|
||||
ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
|
||||
ctx->SetTensorListOutput(0, result);
|
||||
}
|
||||
|
||||
@ -200,6 +215,9 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
int64 max_num_elements;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
|
||||
bool num_element_is_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
|
||||
OP_REQUIRES(ctx, max_num_elements >= 0,
|
||||
errors::InvalidArgument(
|
||||
"XLA compilation requires a fixed tensor list size. Set "
|
||||
@ -210,9 +228,9 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
|
||||
if (dtype_ != DT_VARIANT) {
|
||||
// We are creating a non-nested TensorList.
|
||||
// If element shape is compile time constant and it's not "unknown rank"
|
||||
// shape (-1), create an initialized TensorList. Otherwise create an
|
||||
// uninitialized TensorList.
|
||||
// If element shape is compile time constant and it's not "unknown
|
||||
// rank" shape (-1), create an initialized TensorList. Otherwise
|
||||
// create an uninitialized TensorList.
|
||||
xla::XlaOp element_shape_handle = ctx->Input(0);
|
||||
xla::PrimitiveType type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
|
||||
@ -224,7 +242,8 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
if (got_shape) {
|
||||
xla::Shape list_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||
element_shape, max_num_elements, &list_shape));
|
||||
element_shape, max_num_elements,
|
||||
num_element_is_dynamic, &list_shape));
|
||||
// Set up dynamic dimension sizes to create the zero tensor.
|
||||
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||
ctx, element_shape, list_shape, max_num_elements);
|
||||
@ -243,7 +262,8 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
// We are creating a nested TensorList or a non-nested TensorList with
|
||||
// unknown shape. Just create an uninitialized TensorList.
|
||||
xla::XlaOp result =
|
||||
BuildUninitializedTensorList(ctx->builder(), max_num_elements);
|
||||
BuildUninitializedTensorList(ctx->builder(), max_num_elements,
|
||||
num_element_is_dynamic, ctx->Input(1));
|
||||
ctx->SetTensorListOutput(0, result);
|
||||
}
|
||||
|
||||
|
@ -189,28 +189,42 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
|
||||
}
|
||||
|
||||
xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
|
||||
int64 leading_dimension) {
|
||||
int64 leading_dimension,
|
||||
bool leading_size_is_dynamic,
|
||||
xla::XlaOp leading_dim_size) {
|
||||
auto zero =
|
||||
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
|
||||
return xla::Broadcast(zero, std::vector<int64>{leading_dimension});
|
||||
auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension});
|
||||
if (leading_size_is_dynamic) {
|
||||
return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
|
||||
} else {
|
||||
return broadcast;
|
||||
}
|
||||
}
|
||||
|
||||
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) {
|
||||
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
|
||||
bool* leading_dim_is_dynamic,
|
||||
xla::XlaOp* leading_dim_dynamic_size) {
|
||||
bool is_initialized;
|
||||
TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
|
||||
if (is_initialized) {
|
||||
auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
|
||||
*leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
|
||||
auto buffer = xla::GetTupleElement(list, 0);
|
||||
*leading_dim = buffer_shape.dimensions(0);
|
||||
*leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
|
||||
} else {
|
||||
*leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
|
||||
*leading_dim = list_shape.dimensions(0);
|
||||
*leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListShapeFromElementTensorListShape(
|
||||
const xla::Shape& element_tensor_list_shape, int64 leading_dim,
|
||||
xla::Shape* tensor_list_shape) {
|
||||
bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
|
||||
std::vector<xla::Shape> shapes;
|
||||
int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
@ -220,6 +234,9 @@ Status GetTensorListShapeFromElementTensorListShape(
|
||||
dimensions.insert(dimensions.begin(), leading_dim);
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
|
||||
if (leading_dim_is_dynamic) {
|
||||
shapes.back().set_dynamic_dimension(0, true);
|
||||
}
|
||||
}
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
|
||||
@ -229,6 +246,7 @@ Status GetTensorListShapeFromElementTensorListShape(
|
||||
|
||||
Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
int64 leading_dim,
|
||||
bool leading_dim_is_dynamic,
|
||||
xla::Shape* tensor_list_shape) {
|
||||
if (!element_shape.IsArray()) {
|
||||
return errors::InvalidArgument(
|
||||
@ -236,12 +254,12 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
"shape. But element shape is ",
|
||||
element_shape.DebugString());
|
||||
}
|
||||
|
||||
std::vector<xla::Shape> shapes;
|
||||
std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
|
||||
dimensions.insert(dimensions.begin(), leading_dim);
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
|
||||
shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
|
||||
shapes.push_back(
|
||||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
|
||||
*tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
||||
@ -279,7 +297,10 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
||||
bool element_is_tensor_list,
|
||||
xla::XlaOp* initialized_list) {
|
||||
int64 leading_dim;
|
||||
TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim));
|
||||
xla::XlaOp leading_dim_dynamic_size;
|
||||
bool leading_dim_is_dynamic;
|
||||
TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
|
||||
list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
|
||||
|
||||
xla::XlaBuilder* b = list.builder();
|
||||
xla::Shape list_shape;
|
||||
@ -287,12 +308,11 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
||||
|
||||
if (element_is_tensor_list) {
|
||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
|
||||
element_shape, leading_dim, &list_shape));
|
||||
element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
|
||||
element_shape, leading_dim, &list_shape));
|
||||
element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
|
||||
}
|
||||
|
||||
bool is_initialized;
|
||||
TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
|
||||
if (is_initialized) {
|
||||
@ -312,8 +332,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
||||
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||
// Leading dim is a static dimension.
|
||||
dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
|
||||
dynamic_dims.push_back(leading_dim_dynamic_size);
|
||||
xla::XlaOp sub_element;
|
||||
if (element_is_tensor_list) {
|
||||
sub_element = xla::GetTupleElement(element, i);
|
||||
|
@ -60,17 +60,22 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
|
||||
|
||||
// Returns an uninitialized TensorList.
|
||||
xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
|
||||
int64 leading_dimension);
|
||||
int64 leading_dimension,
|
||||
bool leading_size_is_dynamic,
|
||||
xla::XlaOp leading_dim_size);
|
||||
|
||||
// Returns leading dimension for the TensorList.
|
||||
// Input can be initialized or uninitialized TensorList.
|
||||
// Non-nested and nested TensorLists are both supported.
|
||||
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim);
|
||||
// Returns leading dimension for the TensorList as well as a dynamic op
|
||||
// representing the dynamic size. Input can be initialized or uninitialized
|
||||
// TensorList. Non-nested and nested TensorLists are both supported.
|
||||
Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
|
||||
bool* leading_dim_is_dynamic,
|
||||
xla::XlaOp* leading_dim_dynamic_size);
|
||||
|
||||
// Returns TensorList shape for the element shape.
|
||||
// Element shape must be a normal tensor shape.
|
||||
Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
int64 leading_dim,
|
||||
bool leading_dim_is_dynamic,
|
||||
xla::Shape* tensor_list_shape);
|
||||
|
||||
// Returns a TensorList filled by zeros with the given shape.
|
||||
|
@ -513,10 +513,26 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
// Prepare dynamic dimensions for element shapes.
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||
// Set dynamic dimension size to 0 for initilization value.
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
|
||||
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
|
||||
// We already have the dynamic size of leading dimension outside of
|
||||
// the while loop without initializing the TensorList inside the while
|
||||
// loop.
|
||||
if (shape.is_dynamic_dimension(0)) {
|
||||
xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
|
||||
dynamic_dims.push_back(leading_dim_size);
|
||||
} else {
|
||||
int32 dim_size = shape.dimensions(0);
|
||||
dynamic_dims.push_back(
|
||||
xla::ConstantR0<int32>(ctx->builder(), dim_size));
|
||||
}
|
||||
|
||||
// Set dynamic dimension size to 0 for element value. Inside the while
|
||||
// loop, TensorlistSetItem will properly set the element shape's
|
||||
// dynamic diemnsion.
|
||||
for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
|
||||
int32 dim_size = shape.dimensions(dim);
|
||||
if (shape.is_dynamic_dimension(dim)) {
|
||||
dim_size = 0;
|
||||
|
@ -259,6 +259,32 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
|
||||
xla::Literal literal;
|
||||
XlaExpression e = InputExpression(index);
|
||||
auto* client = compiler() ? compiler()->client() : nullptr;
|
||||
xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
|
||||
if (!dynamism_or_status.ok()) {
|
||||
Status status = dynamism_or_status.status();
|
||||
errors::AppendToMessage(&status, "while evaluating input dynamism", index,
|
||||
" of ", context_->op_kernel().type_string());
|
||||
return status;
|
||||
}
|
||||
Tensor dynamism = dynamism_or_status.ValueOrDie();
|
||||
|
||||
Tensor temp(dynamism.dtype());
|
||||
TensorShape tensor_shape({});
|
||||
if (!temp.CopyFrom(dynamism, tensor_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
context_->op_kernel().name(), " input ", index, " has shape ",
|
||||
dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
|
||||
*out = literal.Get<bool>({});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
|
||||
int index, std::vector<bool>* out) {
|
||||
xla::Literal literal;
|
||||
|
@ -119,7 +119,7 @@ class XlaOpKernelContext {
|
||||
// Evaluates input and returns their dynamism vector in a vector of
|
||||
// predicates.
|
||||
Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
|
||||
|
||||
Status ResolveInputDynamismIntoPred(int index, bool* out);
|
||||
// Helper methods for constant inputs.
|
||||
|
||||
// Evaluates input `index` and stores it in `*constant_literal`. If the
|
||||
|
Loading…
Reference in New Issue
Block a user