[XLA][TF2XLA] Support tensor list with dynamic dimension.

Previously we don't allow a dynamic dimension to change in a HLO while
loop. But this constrain breaks tensor list where the true dynamic
dimension is only known inside the loop body.

This CL:
- Add the feature in dynamic padder to be able to change a dynamic dimension's size in the loop.
- Add a nice test to demonstrate how tensor list / stack can be handled more elegantly in xla.
- Add necessary machinery to wire this feature into tf2xla.

PiperOrigin-RevId: 307901191
Change-Id: I4d39f1d8a8c944f1e9834c39599e6cfbc99f6807
This commit is contained in:
Yunxing Dai 2020-04-22 14:34:06 -07:00 committed by TensorFlower Gardener
parent e224bfeabb
commit f86b74e27e
15 changed files with 387 additions and 148 deletions

View File

@ -460,14 +460,18 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
// -----
// CHECK: HloModule
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
return %0 : tensor<i32>
func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
%1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
return %1 : tensor<i32>
}
// CHECK: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
// CHECK: [[SIZE:%.*]] = s32[] parameter(1)
// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size(f32[4,2] [[ARG]], s32[] [[SIZE]]), dimensions={1}
// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size(f32[4,<=2] [[DYNAMIC]]), dimensions={1}
// -----

View File

@ -274,10 +274,23 @@ class ZerosLikeOp : public XlaOpKernel {
auto list_shape_or = ctx->builder()->GetShape(list);
OP_REQUIRES_OK(ctx, list_shape_or.status());
const xla::Shape& list_shape = list_shape_or.ValueOrDie();
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1);
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
// Set dynamic dimension size to 0 for initialization value.
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
auto sub_element = xla::GetTupleElement(list, i);
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
}
list_dynamic_dims.push_back(dynamic_dims);
}
xla::XlaOp new_list;
OP_REQUIRES_OK(
ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape_or.ValueOrDie(), &new_list));
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
list_dynamic_dims, &new_list));
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
@ -287,10 +300,20 @@ class ZerosLikeOp : public XlaOpKernel {
SetTensorListPushIndex(new_list, push_index, &result));
ctx->SetTensorListOutput(0, result);
} else {
const TensorShape input_shape = ctx->InputShape(0);
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
xla::XlaOp input = ctx->Input(0);
auto input_shape = ctx->InputXlaShape(0).ValueOrDie();
auto result = xla::Broadcast(zero, input_shape.dimensions());
// Setting up dynamic dimensions of the broadcast.
for (int64 i = 0; i < input_shape.dimensions_size(); ++i) {
if (input_shape.is_dynamic_dimension(i)) {
xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i);
result = xla::SetDimensionSize(result, input_dynamic_dim, i);
}
}
ctx->SetOutput(0, result);
}
}
};

View File

@ -44,6 +44,36 @@ namespace tensorflow {
namespace {
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
// dimension is static, a constant dimension is returned.
xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
const xla::Shape& list_shape, int64 num_elements) {
std::vector<int64> dynamic_sizes;
ctx->set_dynamic_dimension_is_minus_one(true);
// The multiplier can be a dynamic value.
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
// Set dynamic dimension size to 0 for initialization value.
std::vector<xla::XlaOp> dynamic_dims;
// Leading dim is a static dimension.
dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
dynamic_dims.push_back(dynamic_dim_size);
} else {
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
}
}
list_dynamic_dims.push_back(dynamic_dims);
return list_dynamic_dims;
}
class TensorListLengthOp : public XlaOpKernel {
public:
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
@ -124,10 +154,14 @@ class TensorListReserveOp : public XlaOpKernel {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, num_elements, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, num_elements);
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
xla::XlaOp new_list;
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape, &new_list));
ctx->builder(), list_shape,
list_dynamic_dims_or.ValueOrDie(), &new_list));
xla::XlaOp result;
OP_REQUIRES_OK(
ctx,
@ -185,10 +219,16 @@ class EmptyTensorListOp : public XlaOpKernel {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, max_num_elements, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, max_num_elements);
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
xla::XlaOp result;
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape, &result));
ctx->builder(), list_shape,
list_dynamic_dims_or.ValueOrDie(), &result));
ctx->SetTensorListOutput(0, result);
return;
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -247,19 +248,29 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
return Status::OK();
}
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
const xla::Shape& list_shape,
xla::XlaOp* list) {
Status CreateZerosTensorListWithShape(
xla::XlaBuilder* b, const xla::Shape& list_shape,
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
xla::XlaOp* list) {
int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
std::vector<xla::XlaOp> elements;
for (int i = 0; i < tuple_size; i++) {
TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
for (int i = 0; i < tuple_size - 1; i++) {
const xla::Shape& shape =
xla::ShapeUtil::GetTupleElementShape(list_shape, i);
xla::XlaOp zero =
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
}
elements.push_back(zeros);
}
// List size (last item) has to be S32.
TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
.element_type() == xla::S32);
elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
*list = xla::Tuple(b, elements);
return Status::OK();
}
@ -272,12 +283,12 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
xla::XlaBuilder* b = list.builder();
xla::Shape list_shape;
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
if (element_is_tensor_list) {
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
element_shape, leading_dim, &list_shape));
} else {
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
element_shape, leading_dim, &list_shape));
}
@ -295,7 +306,27 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
*initialized_list = list;
return Status::OK();
} else {
return CreateZerosTensorListWithShape(b, list_shape, initialized_list);
// Prepare dynamic dimension dimensions for zero tensor list. The dynamic
// sizes are created by reading the dynamic dimension size of sub-elements.
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
// Leading dim is a static dimension.
dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
xla::XlaOp sub_element;
if (element_is_tensor_list) {
sub_element = xla::GetTupleElement(element, i);
} else {
sub_element = element;
}
for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
}
list_dynamic_dims.push_back(dynamic_dims);
}
return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
initialized_list);
}
}
@ -473,7 +504,13 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
xla::XlaOp list_part = xla::GetTupleElement(list, 0);
xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) {
if (buffer_shape.is_dynamic_dimension(i)) {
auto buffer = xla::GetTupleElement(list, 0);
auto gds = xla::GetDimensionSize(buffer, i);
read = xla::SetDimensionSize(read, gds, i);
}
}
slice_shape.erase(slice_shape.begin());
*result = xla::Reshape(read, slice_shape);
return Status::OK();

View File

@ -74,9 +74,9 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
xla::Shape* tensor_list_shape);
// Returns a TensorList filled by zeros with the given shape.
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
const xla::Shape& list_shape,
xla::XlaOp* list);
Status CreateZerosTensorListWithShape(
xla::XlaBuilder* b, const xla::Shape& list_shape,
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims, xla::XlaOp* list);
// If the TensorList is initialized, check that its shape matches element shape;
// If the TensorList is uninitialized, initialize it with the element shape.

View File

@ -510,8 +510,25 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
// first compilation and the body/cond was recompiled with the updated
// shape/datatype of the list.
if (input_shape != list_shape) {
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape, &inputs[i]));
// Prepare dynamic dimensions for element shapes.
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
// Set dynamic dimension size to 0 for initilization value.
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
int32 dim_size = shape.dimensions(dim);
if (shape.is_dynamic_dimension(dim)) {
dim_size = 0;
}
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dim_size));
}
list_dynamic_dims.push_back(dynamic_dims);
}
OP_REQUIRES_OK(
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
list_dynamic_dims, &inputs[i]));
} else {
inputs[i] = ctx->Input(input_num);
}

View File

@ -217,6 +217,8 @@ class XlaOpKernelContext {
return dynamic_dimension_is_minus_one_;
}
bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; }
// Reads the current value of the resource variable referred to by input
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if

View File

@ -2646,6 +2646,11 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
*operand_shape, dimension));
// Calling GetDimensionSize on a static dimension returns a constant
// instruction.
if (!operand_shape->is_dynamic_dimension(dimension)) {
return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
}
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
@ -2657,8 +2662,20 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
*operand_shape, dimension));
// Setting an op's dynamic dimension to the static size is a noop.
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
LookUpInstruction(val));
if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(val_proto->literal(), true));
if (literal.Get<int32>({}) == shape.dimensions(dimension)) {
return operand;
}
}
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,

View File

@ -407,13 +407,25 @@ TEST_F(XlaBuilderTest, CollectivePermute) {
TEST_F(XlaBuilderTest, GetDimensionSize) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
auto x =
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
GetDimensionSize(x, 1);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
}
TEST_F(XlaBuilderTest, GetDimensionSizeConstant) {
XlaBuilder b(TestName());
auto x =
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
// Get dimension size from a contant dimension gives us a constant.
GetDimensionSize(x, 0);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
}
TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");

View File

@ -1369,77 +1369,27 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
}
Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
// While loop is handled by passing dynamic size hlos as parameters into the
// hlo while loop. This is done by replacing the original while with a new
// one.
//
// Before:
//
// op1 = ...
// op2 = ...
// op1_x = ... // dynamic dimension size of op1
// while = while(op1, op2)
//
//
// After:
//
// op1 = ...
// op2 = ...
// op1_x = ... // dynamic dimension size of op1
// while = while(op1, op2, op1_x)
//
// In the above graph, op_x is the bound of the dynamic dimension size of op1
// and is wired into the while loop as new parameter.
//
// TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic
// bound can be propagated through native xla values instead of relying on
// additional parameter.
// dynamic_size_to_operand_id_index_map keeps track of dynamic size operations
// to their operand ids in the new while loop.
absl::flat_hash_map<HloInstruction*, int64>
dynamic_size_to_operand_id_index_map;
// operands_to_add collects dynamic sizes that need to be added to the while
// loop as parameters. Note that a dynamic size is ignored if it is already
// part of the parameter. i.e.:
//
// We don't do:
//
// op1 = ...
// op2 = ...
// op_x = ... // dynamic dimension size of both op1 and op2
// while = while(op1, op2, op_x, op_x) // 4 parameters
//
// But we do:
//
// op1 = ...
// op2 = ...
// op_x = ... // dynamic dimension size of both op1 and op2
// while = while(op1, op2, op_x)
//
// An alternative is to do this in a while loop CSE pass.
//
// If the output of the conditional contains dynamic dimension. We send
// dynamic dimension size out by adding additional root element. A mapping
// from the root instruction's dynamic dimension index (represented by a shape
// index as output index and a int64 dimension number) to output index
// (represented by an int64) is tracked for the conditional instruction (all
// branches should have the same mapping).
ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
hlo->shape());
std::vector<HloInstruction*> operands_to_add;
int64 operand_count = hlo->shape().tuple_shapes_size();
const int64 original_tuple_count = hlo->shape().tuple_shapes_size();
int64 operand_count = original_tuple_count;
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo, [&](HloInstruction*, ShapeIndex, int64, int64,
hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
const HloInstruction* tuple_operand = hlo->operand(0);
for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
if (dynamic_size == tuple_operand->operand(i)) {
dynamic_size_to_operand_id_index_map[dynamic_size] = i;
return Status::OK();
}
}
auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
if (iter == dynamic_size_to_operand_id_index_map.end()) {
operands_to_add.push_back(dynamic_size);
dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++;
}
operands_to_add.push_back(dynamic_size);
dynamic_output_mapping.mutable_element(index)->emplace(dim,
operand_count++);
return Status::OK();
}));
DynamicParameterBinding binding_for_while;
if (!operands_to_add.empty()) {
// Only replace the while loop if there are new parameters to add.
HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
@ -1453,37 +1403,78 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
parent_->CopyMapping(/*from=*/old_tuple_operand,
/*to=*/new_tuple_operand);
hlo = result.new_while_instr;
// We have replaced the while loop, now set the dynamic dimensions for the
// newly created while loop so that the hlos that consumes the while loop
// can see the dynamic dimensions. Also sets the dynamic parameter binding
// for running inference in the while loop.
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) -> Status {
TF_RET_CHECK(!operands_to_add.empty());
const int64 output_dynamic_size_index =
dynamic_output_mapping.element(index).at(dimension);
DynamicParameterBinding::DynamicParameter dynamic_parameter{
operand_index, {output_dynamic_size_index}};
DynamicParameterBinding::DynamicDimension dynamic_dimension{
operand_index, index, dimension};
TF_RETURN_IF_ERROR(
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
// This is the updated output dynamic size coming out of hlo while
// loop.
HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(
ShapeUtil::MakeScalarShape(S32), hlo,
output_dynamic_size_index));
parent_->SetDynamicSize(result.replacement_instr, index, dimension,
output_dynamic_size, constraint);
return Status::OK();
}));
// Set the replacement instruction as visited to avoid visiting it again.
SetVisited(*result.replacement_instr);
}
// We have replaced the while loop, now set the dynamic dimensions for the
// newly created while loop so that the hlos that consumes the while loop can
// see the dynamic dimensions. Also sets the dynamic parameter binding for
// running inference in the while loop.
DynamicParameterBinding binding_for_while;
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
DynamicParameterBinding::DynamicParameter dynamic_parameter{
operand_index,
{dynamic_size_to_operand_id_index_map[dynamic_size]}};
DynamicParameterBinding::DynamicDimension dynamic_dimension{
operand_index, index, dimension};
TF_RETURN_IF_ERROR(
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
constraint);
return Status::OK();
}));
// Run inference in while body and condition.
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
hlo->while_body(), binding_for_while, parent_));
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
hlo->while_condition(), binding_for_while, parent_));
// Set the replacement while loop as visited to avoid visiting it again.
SetVisited(*hlo);
if (operands_to_add.empty()) {
// No dynamic dimension in the inputs and outputs.
return Status::OK();
}
// The dynamic dimension size could have been changed in the loop body (e.g, A
// loop that inserts items in a stack, the stack size increases with each
// iteration). Rewrite the dynamic dimension size at the root.
HloInstruction* body_root = hlo->while_body()->root_instruction();
std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
nullptr);
// Original non-dynamic-dim operands of root are pass-through.
for (int64 i = 0; i < original_tuple_count; ++i) {
new_root_operands[i] =
hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
body_root->shape().tuple_shapes(i), body_root, i));
}
// Add dynamic dimension size as new parameters.
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
hlo->while_body()->root_instruction(),
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
DimensionConstraint) -> Status {
const int64 output_index =
dynamic_output_mapping.element(index).at(dim);
new_root_operands[output_index] = dynamic_size;
return Status::OK();
}));
for (auto operand : new_root_operands) {
TF_RET_CHECK(operand != nullptr);
}
HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
HloInstruction::CreateTuple(new_root_operands));
hlo->while_body()->set_root_instruction(new_body_root);
return Status::OK();
}

View File

@ -767,7 +767,7 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
// While
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, tuple_shape, "A"));
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, scalar_shape_, "size_param"));
builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
@ -782,37 +782,32 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
DynamicParameterBinding::DynamicParameter{1, {}},
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
// Test that dynamic dimension inference does the right thing. A lambda is
// used here since we want to test twice by running inference again
// (idempotency).
auto test_dynamic_dimension = [&]() {
HloInstruction* while_hlo = nullptr;
// The while hlo has been replaced, find the new one.
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
if (inst->opcode() == HloOpcode::kWhile) {
while_hlo = inst;
}
}
ASSERT_NE(while_hlo, nullptr);
// The original while shape has 2 parameters. With dynamic size passed in
// as an extra parameter, the tuple should have 3 elements.
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3);
HloInstruction* add = nullptr;
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
if (inst->opcode() == HloOpcode::kAdd) {
add = inst;
}
}
EXPECT_NE(add, nullptr);
EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param);
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param);
};
TF_ASSERT_OK(RunInference());
test_dynamic_dimension();
TF_ASSERT_OK(RunInference());
test_dynamic_dimension();
HloInstruction* while_hlo = nullptr;
// The while hlo has been replaced, find the new one.
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
if (inst->opcode() == HloOpcode::kWhile) {
while_hlo = inst;
}
}
ASSERT_NE(while_hlo, nullptr);
// The original while shape has 2 parameters. With dynamic size, the tuple
// should have 4 elements (We don't deduplicate the arguments).
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
HloInstruction* add_inst = nullptr;
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
if (inst->opcode() == HloOpcode::kAdd) {
add_inst = inst;
}
}
EXPECT_NE(add_inst, nullptr);
EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
EXPECT_NE(inference_->GetDynamicSize(
module_->entry_computation()->root_instruction(), {0}, 0),
nullptr);
EXPECT_NE(inference_->GetDynamicSize(
module_->entry_computation()->root_instruction(), {1}, 0),
nullptr);
}
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {

View File

@ -903,6 +903,90 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, WhileLoopStack) {
// Push into a dynamic sized stack with iteration number:
// init:
// [[P, P],
// [P, P],
// [P, P],
// [P, P]]
// First iteration i = 0:
// [[0, 0],
// [P, P],
// [P, P],
// [P, P]]
// Second iteration i = 1:
// [[0, 0],
// [1, 1],
// [P, P],
// [P, P]]
// Third iteration i = 2:
// [[0, 0],
// [1, 1],
// [2, 2],
// [P, P]]
const string hlo_text = R"(
HloModule module
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
body {
stack = (s32[<=4,2]) parameter(0)
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
zero = s32[] constant(0)
one = s32[] constant(1)
// content of the stack is the stack index broadcasted.
new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
new_stack_size = s32[] add(stack_size, one)
new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
}
condition {
stack = (s32[<=4,2]) parameter(0)
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
three = s32[] constant(3)
ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT
}
ENTRY entry {
zero = s32[] constant(0)
pad = s32[] constant(-1)
stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={}
stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0}
input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic)
while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
ROOT reduce = s32[2] reduce(stack_buffer, zero),
dimensions={0},
to_apply=update_s32
}
)";
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {});
// Stack has three valid items in it:
// [[0, 0],
// [1, 1],
// [2, 2],
// [P, P]]
//
// Reducing along major dimension gives us [3, 3]
Literal expected = LiteralUtil::CreateR1<int32>({{3, 3}});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1

View File

@ -2596,7 +2596,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
}
return operand_shape;
auto result_shape = operand_shape;
// If any of the operand shape and update shape is dynamic, update the result
// dimension to dynamic.
for (int64 i = 0; i < update_shape.rank(); ++i) {
if (update_shape.is_dynamic_dimension(i) ||
operand_shape.is_dynamic_dimension(i)) {
result_shape.set_dynamic_dimension(i, true);
}
}
return result_shape;
}
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(

View File

@ -125,8 +125,9 @@ WhileUtil::MakeInstructionsLiveIn(
// We want to get rid of the old while instruction even if it has side
// effecting operations so we do a manual HloComputation::RemoveInstruction
// instead of relying on HloComputation::ReplaceInstruction.
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
new_while, while_instr->shape().tuple_shapes_size())));
HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
new_while, while_instr->shape().tuple_shapes_size());
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
@ -142,6 +143,7 @@ WhileUtil::MakeInstructionsLiveIn(
WhileUtil::MakeInstructionsLiveInResult result;
result.new_while_instr = new_while;
result.replacement_instr = replacement_instr;
result.while_body_live_in_values = std::move(live_in_instructions);
result.while_body_instruction_map = std::move(inlined_instructions_map);

View File

@ -29,6 +29,10 @@ class WhileUtil {
// The new while operation that has the requested values live in.
HloInstruction* new_while_instr;
// The new tuple instruction that replaced the original while instruction
// with the same shape.
HloInstruction* replacement_instr;
// The i'th element of `while_body_live_in_values` is an instruction in the
// while body that holds the i'th *newly added* live in value at runtime.
std::vector<HloInstruction*> while_body_live_in_values;