[XLA][TF2XLA] Support tensor list with dynamic dimension.
Previously we don't allow a dynamic dimension to change in a HLO while loop. But this constrain breaks tensor list where the true dynamic dimension is only known inside the loop body. This CL: - Add the feature in dynamic padder to be able to change a dynamic dimension's size in the loop. - Add a nice test to demonstrate how tensor list / stack can be handled more elegantly in xla. - Add necessary machinery to wire this feature into tf2xla. PiperOrigin-RevId: 307901191 Change-Id: I4d39f1d8a8c944f1e9834c39599e6cfbc99f6807
This commit is contained in:
parent
e224bfeabb
commit
f86b74e27e
@ -460,14 +460,18 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
|
||||
%1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
||||
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
||||
// CHECK: [[SIZE:%.*]] = s32[] parameter(1)
|
||||
// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size(f32[4,2] [[ARG]], s32[] [[SIZE]]), dimensions={1}
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size(f32[4,<=2] [[DYNAMIC]]), dimensions={1}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -274,10 +274,23 @@ class ZerosLikeOp : public XlaOpKernel {
|
||||
|
||||
auto list_shape_or = ctx->builder()->GetShape(list);
|
||||
OP_REQUIRES_OK(ctx, list_shape_or.status());
|
||||
const xla::Shape& list_shape = list_shape_or.ValueOrDie();
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1);
|
||||
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||
// Set dynamic dimension size to 0 for initialization value.
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||
auto sub_element = xla::GetTupleElement(list, i);
|
||||
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
|
||||
}
|
||||
list_dynamic_dims.push_back(dynamic_dims);
|
||||
}
|
||||
xla::XlaOp new_list;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CreateZerosTensorListWithShape(
|
||||
ctx->builder(), list_shape_or.ValueOrDie(), &new_list));
|
||||
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
|
||||
list_dynamic_dims, &new_list));
|
||||
|
||||
xla::XlaOp push_index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
|
||||
@ -287,10 +300,20 @@ class ZerosLikeOp : public XlaOpKernel {
|
||||
SetTensorListPushIndex(new_list, push_index, &result));
|
||||
ctx->SetTensorListOutput(0, result);
|
||||
} else {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
||||
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
auto input_shape = ctx->InputXlaShape(0).ValueOrDie();
|
||||
auto result = xla::Broadcast(zero, input_shape.dimensions());
|
||||
|
||||
// Setting up dynamic dimensions of the broadcast.
|
||||
for (int64 i = 0; i < input_shape.dimensions_size(); ++i) {
|
||||
if (input_shape.is_dynamic_dimension(i)) {
|
||||
xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i);
|
||||
result = xla::SetDimensionSize(result, input_dynamic_dim, i);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -44,6 +44,36 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
|
||||
// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
|
||||
// dimension is static, a constant dimension is returned.
|
||||
xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
|
||||
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
|
||||
const xla::Shape& list_shape, int64 num_elements) {
|
||||
std::vector<int64> dynamic_sizes;
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
// The multiplier can be a dynamic value.
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
// Set dynamic dimension size to 0 for initialization value.
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
// Leading dim is a static dimension.
|
||||
dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
|
||||
for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
|
||||
if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
|
||||
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
|
||||
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
|
||||
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
|
||||
dynamic_dims.push_back(dynamic_dim_size);
|
||||
} else {
|
||||
dynamic_dims.push_back(
|
||||
xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
|
||||
}
|
||||
}
|
||||
list_dynamic_dims.push_back(dynamic_dims);
|
||||
return list_dynamic_dims;
|
||||
}
|
||||
|
||||
class TensorListLengthOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
@ -124,10 +154,14 @@ class TensorListReserveOp : public XlaOpKernel {
|
||||
xla::Shape list_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||
element_shape, num_elements, &list_shape));
|
||||
|
||||
// Set up dynamic dimension sizes to create the zero tensor.
|
||||
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||
ctx, element_shape, list_shape, num_elements);
|
||||
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
|
||||
xla::XlaOp new_list;
|
||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
||||
ctx->builder(), list_shape, &new_list));
|
||||
ctx->builder(), list_shape,
|
||||
list_dynamic_dims_or.ValueOrDie(), &new_list));
|
||||
xla::XlaOp result;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
@ -185,10 +219,16 @@ class EmptyTensorListOp : public XlaOpKernel {
|
||||
xla::Shape list_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||
element_shape, max_num_elements, &list_shape));
|
||||
// Set up dynamic dimension sizes to create the zero tensor.
|
||||
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||
ctx, element_shape, list_shape, max_num_elements);
|
||||
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
|
||||
|
||||
xla::XlaOp result;
|
||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
||||
ctx->builder(), list_shape, &result));
|
||||
ctx->builder(), list_shape,
|
||||
list_dynamic_dims_or.ValueOrDie(), &result));
|
||||
|
||||
ctx->SetTensorListOutput(0, result);
|
||||
return;
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -247,19 +248,29 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
|
||||
const xla::Shape& list_shape,
|
||||
xla::XlaOp* list) {
|
||||
Status CreateZerosTensorListWithShape(
|
||||
xla::XlaBuilder* b, const xla::Shape& list_shape,
|
||||
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
|
||||
xla::XlaOp* list) {
|
||||
int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
|
||||
for (int i = 0; i < tuple_size - 1; i++) {
|
||||
const xla::Shape& shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(list_shape, i);
|
||||
xla::XlaOp zero =
|
||||
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
|
||||
xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
|
||||
TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
|
||||
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
|
||||
}
|
||||
elements.push_back(zeros);
|
||||
}
|
||||
// List size (last item) has to be S32.
|
||||
TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
|
||||
.element_type() == xla::S32);
|
||||
elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
|
||||
*list = xla::Tuple(b, elements);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -272,12 +283,12 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
||||
|
||||
xla::XlaBuilder* b = list.builder();
|
||||
xla::Shape list_shape;
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||
|
||||
if (element_is_tensor_list) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
|
||||
element_shape, leading_dim, &list_shape));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
|
||||
element_shape, leading_dim, &list_shape));
|
||||
}
|
||||
@ -295,7 +306,27 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
||||
*initialized_list = list;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return CreateZerosTensorListWithShape(b, list_shape, initialized_list);
|
||||
// Prepare dynamic dimension dimensions for zero tensor list. The dynamic
|
||||
// sizes are created by reading the dynamic dimension size of sub-elements.
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||
// Leading dim is a static dimension.
|
||||
dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
|
||||
xla::XlaOp sub_element;
|
||||
if (element_is_tensor_list) {
|
||||
sub_element = xla::GetTupleElement(element, i);
|
||||
} else {
|
||||
sub_element = element;
|
||||
}
|
||||
for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
|
||||
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
|
||||
}
|
||||
list_dynamic_dims.push_back(dynamic_dims);
|
||||
}
|
||||
return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
|
||||
initialized_list);
|
||||
}
|
||||
}
|
||||
|
||||
@ -473,7 +504,13 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
|
||||
|
||||
xla::XlaOp list_part = xla::GetTupleElement(list, 0);
|
||||
xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
|
||||
|
||||
for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) {
|
||||
if (buffer_shape.is_dynamic_dimension(i)) {
|
||||
auto buffer = xla::GetTupleElement(list, 0);
|
||||
auto gds = xla::GetDimensionSize(buffer, i);
|
||||
read = xla::SetDimensionSize(read, gds, i);
|
||||
}
|
||||
}
|
||||
slice_shape.erase(slice_shape.begin());
|
||||
*result = xla::Reshape(read, slice_shape);
|
||||
return Status::OK();
|
||||
|
@ -74,9 +74,9 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
||||
xla::Shape* tensor_list_shape);
|
||||
|
||||
// Returns a TensorList filled by zeros with the given shape.
|
||||
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
|
||||
const xla::Shape& list_shape,
|
||||
xla::XlaOp* list);
|
||||
Status CreateZerosTensorListWithShape(
|
||||
xla::XlaBuilder* b, const xla::Shape& list_shape,
|
||||
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims, xla::XlaOp* list);
|
||||
|
||||
// If the TensorList is initialized, check that its shape matches element shape;
|
||||
// If the TensorList is uninitialized, initialize it with the element shape.
|
||||
|
@ -510,8 +510,25 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
// first compilation and the body/cond was recompiled with the updated
|
||||
// shape/datatype of the list.
|
||||
if (input_shape != list_shape) {
|
||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
||||
ctx->builder(), list_shape, &inputs[i]));
|
||||
// Prepare dynamic dimensions for element shapes.
|
||||
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||
// Set dynamic dimension size to 0 for initilization value.
|
||||
std::vector<xla::XlaOp> dynamic_dims;
|
||||
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||
int32 dim_size = shape.dimensions(dim);
|
||||
if (shape.is_dynamic_dimension(dim)) {
|
||||
dim_size = 0;
|
||||
}
|
||||
dynamic_dims.push_back(
|
||||
xla::ConstantR0<int32>(ctx->builder(), dim_size));
|
||||
}
|
||||
list_dynamic_dims.push_back(dynamic_dims);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
|
||||
list_dynamic_dims, &inputs[i]));
|
||||
} else {
|
||||
inputs[i] = ctx->Input(input_num);
|
||||
}
|
||||
|
@ -217,6 +217,8 @@ class XlaOpKernelContext {
|
||||
return dynamic_dimension_is_minus_one_;
|
||||
}
|
||||
|
||||
bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; }
|
||||
|
||||
// Reads the current value of the resource variable referred to by input
|
||||
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
|
||||
// variable. Returns an error if the variable has not been initialized, or if
|
||||
|
@ -2646,6 +2646,11 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
|
||||
*operand_shape, dimension));
|
||||
// Calling GetDimensionSize on a static dimension returns a constant
|
||||
// instruction.
|
||||
if (!operand_shape->is_dynamic_dimension(dimension)) {
|
||||
return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
|
||||
}
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.add_dimensions(dimension);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
|
||||
@ -2657,8 +2662,20 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
|
||||
*operand_shape, dimension));
|
||||
// Setting an op's dynamic dimension to the static size is a noop.
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
|
||||
LookUpInstruction(val));
|
||||
if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
|
||||
HloOpcode::kConstant) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(val_proto->literal(), true));
|
||||
if (literal.Get<int32>({}) == shape.dimensions(dimension)) {
|
||||
return operand;
|
||||
}
|
||||
}
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.add_dimensions(dimension);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
|
||||
|
@ -407,13 +407,25 @@ TEST_F(XlaBuilderTest, CollectivePermute) {
|
||||
|
||||
TEST_F(XlaBuilderTest, GetDimensionSize) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||
auto x =
|
||||
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
|
||||
GetDimensionSize(x, 1);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, GetDimensionSizeConstant) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x =
|
||||
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
|
||||
// Get dimension size from a contant dimension gives us a constant.
|
||||
GetDimensionSize(x, 0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ReportError) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||
|
@ -1369,77 +1369,27 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
// While loop is handled by passing dynamic size hlos as parameters into the
|
||||
// hlo while loop. This is done by replacing the original while with a new
|
||||
// one.
|
||||
//
|
||||
// Before:
|
||||
//
|
||||
// op1 = ...
|
||||
// op2 = ...
|
||||
// op1_x = ... // dynamic dimension size of op1
|
||||
// while = while(op1, op2)
|
||||
//
|
||||
//
|
||||
// After:
|
||||
//
|
||||
// op1 = ...
|
||||
// op2 = ...
|
||||
// op1_x = ... // dynamic dimension size of op1
|
||||
// while = while(op1, op2, op1_x)
|
||||
//
|
||||
// In the above graph, op_x is the bound of the dynamic dimension size of op1
|
||||
// and is wired into the while loop as new parameter.
|
||||
//
|
||||
// TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic
|
||||
// bound can be propagated through native xla values instead of relying on
|
||||
// additional parameter.
|
||||
|
||||
// dynamic_size_to_operand_id_index_map keeps track of dynamic size operations
|
||||
// to their operand ids in the new while loop.
|
||||
absl::flat_hash_map<HloInstruction*, int64>
|
||||
dynamic_size_to_operand_id_index_map;
|
||||
|
||||
// operands_to_add collects dynamic sizes that need to be added to the while
|
||||
// loop as parameters. Note that a dynamic size is ignored if it is already
|
||||
// part of the parameter. i.e.:
|
||||
//
|
||||
// We don't do:
|
||||
//
|
||||
// op1 = ...
|
||||
// op2 = ...
|
||||
// op_x = ... // dynamic dimension size of both op1 and op2
|
||||
// while = while(op1, op2, op_x, op_x) // 4 parameters
|
||||
//
|
||||
// But we do:
|
||||
//
|
||||
// op1 = ...
|
||||
// op2 = ...
|
||||
// op_x = ... // dynamic dimension size of both op1 and op2
|
||||
// while = while(op1, op2, op_x)
|
||||
//
|
||||
// An alternative is to do this in a while loop CSE pass.
|
||||
//
|
||||
// If the output of the conditional contains dynamic dimension. We send
|
||||
// dynamic dimension size out by adding additional root element. A mapping
|
||||
// from the root instruction's dynamic dimension index (represented by a shape
|
||||
// index as output index and a int64 dimension number) to output index
|
||||
// (represented by an int64) is tracked for the conditional instruction (all
|
||||
// branches should have the same mapping).
|
||||
ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
|
||||
hlo->shape());
|
||||
std::vector<HloInstruction*> operands_to_add;
|
||||
int64 operand_count = hlo->shape().tuple_shapes_size();
|
||||
const int64 original_tuple_count = hlo->shape().tuple_shapes_size();
|
||||
int64 operand_count = original_tuple_count;
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction*, ShapeIndex, int64, int64,
|
||||
hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
const HloInstruction* tuple_operand = hlo->operand(0);
|
||||
for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
|
||||
if (dynamic_size == tuple_operand->operand(i)) {
|
||||
dynamic_size_to_operand_id_index_map[dynamic_size] = i;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
|
||||
if (iter == dynamic_size_to_operand_id_index_map.end()) {
|
||||
operands_to_add.push_back(dynamic_size);
|
||||
dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++;
|
||||
}
|
||||
operands_to_add.push_back(dynamic_size);
|
||||
dynamic_output_mapping.mutable_element(index)->emplace(dim,
|
||||
operand_count++);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
DynamicParameterBinding binding_for_while;
|
||||
if (!operands_to_add.empty()) {
|
||||
// Only replace the while loop if there are new parameters to add.
|
||||
HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
|
||||
@ -1453,37 +1403,78 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
parent_->CopyMapping(/*from=*/old_tuple_operand,
|
||||
/*to=*/new_tuple_operand);
|
||||
hlo = result.new_while_instr;
|
||||
// We have replaced the while loop, now set the dynamic dimensions for the
|
||||
// newly created while loop so that the hlos that consumes the while loop
|
||||
// can see the dynamic dimensions. Also sets the dynamic parameter binding
|
||||
// for running inference in the while loop.
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
[&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
TF_RET_CHECK(!operands_to_add.empty());
|
||||
const int64 output_dynamic_size_index =
|
||||
dynamic_output_mapping.element(index).at(dimension);
|
||||
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
||||
operand_index, {output_dynamic_size_index}};
|
||||
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
||||
operand_index, index, dimension};
|
||||
TF_RETURN_IF_ERROR(
|
||||
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
|
||||
// This is the updated output dynamic size coming out of hlo while
|
||||
// loop.
|
||||
HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::MakeScalarShape(S32), hlo,
|
||||
output_dynamic_size_index));
|
||||
parent_->SetDynamicSize(result.replacement_instr, index, dimension,
|
||||
output_dynamic_size, constraint);
|
||||
return Status::OK();
|
||||
}));
|
||||
// Set the replacement instruction as visited to avoid visiting it again.
|
||||
SetVisited(*result.replacement_instr);
|
||||
}
|
||||
|
||||
// We have replaced the while loop, now set the dynamic dimensions for the
|
||||
// newly created while loop so that the hlos that consumes the while loop can
|
||||
// see the dynamic dimensions. Also sets the dynamic parameter binding for
|
||||
// running inference in the while loop.
|
||||
DynamicParameterBinding binding_for_while;
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
||||
operand_index,
|
||||
{dynamic_size_to_operand_id_index_map[dynamic_size]}};
|
||||
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
||||
operand_index, index, dimension};
|
||||
TF_RETURN_IF_ERROR(
|
||||
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
||||
constraint);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Run inference in while body and condition.
|
||||
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
||||
hlo->while_body(), binding_for_while, parent_));
|
||||
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
||||
hlo->while_condition(), binding_for_while, parent_));
|
||||
|
||||
// Set the replacement while loop as visited to avoid visiting it again.
|
||||
SetVisited(*hlo);
|
||||
if (operands_to_add.empty()) {
|
||||
// No dynamic dimension in the inputs and outputs.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The dynamic dimension size could have been changed in the loop body (e.g, A
|
||||
// loop that inserts items in a stack, the stack size increases with each
|
||||
// iteration). Rewrite the dynamic dimension size at the root.
|
||||
HloInstruction* body_root = hlo->while_body()->root_instruction();
|
||||
std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
|
||||
nullptr);
|
||||
|
||||
// Original non-dynamic-dim operands of root are pass-through.
|
||||
for (int64 i = 0; i < original_tuple_count; ++i) {
|
||||
new_root_operands[i] =
|
||||
hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
body_root->shape().tuple_shapes(i), body_root, i));
|
||||
}
|
||||
// Add dynamic dimension size as new parameters.
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
|
||||
hlo->while_body()->root_instruction(),
|
||||
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
|
||||
DimensionConstraint) -> Status {
|
||||
const int64 output_index =
|
||||
dynamic_output_mapping.element(index).at(dim);
|
||||
new_root_operands[output_index] = dynamic_size;
|
||||
return Status::OK();
|
||||
}));
|
||||
for (auto operand : new_root_operands) {
|
||||
TF_RET_CHECK(operand != nullptr);
|
||||
}
|
||||
HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
|
||||
HloInstruction::CreateTuple(new_root_operands));
|
||||
hlo->while_body()->set_root_instruction(new_body_root);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -767,7 +767,7 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
|
||||
// While
|
||||
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, tuple_shape, "A"));
|
||||
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, scalar_shape_, "size_param"));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
|
||||
@ -782,37 +782,32 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
|
||||
DynamicParameterBinding::DynamicParameter{1, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
|
||||
|
||||
// Test that dynamic dimension inference does the right thing. A lambda is
|
||||
// used here since we want to test twice by running inference again
|
||||
// (idempotency).
|
||||
auto test_dynamic_dimension = [&]() {
|
||||
HloInstruction* while_hlo = nullptr;
|
||||
// The while hlo has been replaced, find the new one.
|
||||
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kWhile) {
|
||||
while_hlo = inst;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(while_hlo, nullptr);
|
||||
// The original while shape has 2 parameters. With dynamic size passed in
|
||||
// as an extra parameter, the tuple should have 3 elements.
|
||||
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3);
|
||||
HloInstruction* add = nullptr;
|
||||
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kAdd) {
|
||||
add = inst;
|
||||
}
|
||||
}
|
||||
EXPECT_NE(add, nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param);
|
||||
};
|
||||
|
||||
TF_ASSERT_OK(RunInference());
|
||||
test_dynamic_dimension();
|
||||
TF_ASSERT_OK(RunInference());
|
||||
test_dynamic_dimension();
|
||||
HloInstruction* while_hlo = nullptr;
|
||||
// The while hlo has been replaced, find the new one.
|
||||
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kWhile) {
|
||||
while_hlo = inst;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(while_hlo, nullptr);
|
||||
// The original while shape has 2 parameters. With dynamic size, the tuple
|
||||
// should have 4 elements (We don't deduplicate the arguments).
|
||||
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
|
||||
HloInstruction* add_inst = nullptr;
|
||||
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kAdd) {
|
||||
add_inst = inst;
|
||||
}
|
||||
}
|
||||
EXPECT_NE(add_inst, nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(
|
||||
module_->entry_computation()->root_instruction(), {0}, 0),
|
||||
nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(
|
||||
module_->entry_computation()->root_instruction(), {1}, 0),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
|
||||
|
@ -903,6 +903,90 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, WhileLoopStack) {
|
||||
// Push into a dynamic sized stack with iteration number:
|
||||
// init:
|
||||
// [[P, P],
|
||||
// [P, P],
|
||||
// [P, P],
|
||||
// [P, P]]
|
||||
// First iteration i = 0:
|
||||
// [[0, 0],
|
||||
// [P, P],
|
||||
// [P, P],
|
||||
// [P, P]]
|
||||
// Second iteration i = 1:
|
||||
// [[0, 0],
|
||||
// [1, 1],
|
||||
// [P, P],
|
||||
// [P, P]]
|
||||
// Third iteration i = 2:
|
||||
// [[0, 0],
|
||||
// [1, 1],
|
||||
// [2, 2],
|
||||
// [P, P]]
|
||||
|
||||
const string hlo_text = R"(
|
||||
HloModule module
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
body {
|
||||
stack = (s32[<=4,2]) parameter(0)
|
||||
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
|
||||
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
|
||||
zero = s32[] constant(0)
|
||||
one = s32[] constant(1)
|
||||
// content of the stack is the stack index broadcasted.
|
||||
new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
|
||||
new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
|
||||
new_stack_size = s32[] add(stack_size, one)
|
||||
new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
|
||||
ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
|
||||
}
|
||||
|
||||
condition {
|
||||
stack = (s32[<=4,2]) parameter(0)
|
||||
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
|
||||
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
|
||||
three = s32[] constant(3)
|
||||
ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
zero = s32[] constant(0)
|
||||
pad = s32[] constant(-1)
|
||||
stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={}
|
||||
stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0}
|
||||
input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic)
|
||||
while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
|
||||
stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
|
||||
ROOT reduce = s32[2] reduce(stack_buffer, zero),
|
||||
dimensions={0},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {});
|
||||
|
||||
// Stack has three valid items in it:
|
||||
// [[0, 0],
|
||||
// [1, 1],
|
||||
// [2, 2],
|
||||
// [P, P]]
|
||||
//
|
||||
// Reducing along major dimension gives us [3, 3]
|
||||
Literal expected = LiteralUtil::CreateR1<int32>({{3, 3}});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
@ -2596,7 +2596,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
|
||||
}
|
||||
|
||||
return operand_shape;
|
||||
auto result_shape = operand_shape;
|
||||
|
||||
// If any of the operand shape and update shape is dynamic, update the result
|
||||
// dimension to dynamic.
|
||||
for (int64 i = 0; i < update_shape.rank(); ++i) {
|
||||
if (update_shape.is_dynamic_dimension(i) ||
|
||||
operand_shape.is_dynamic_dimension(i)) {
|
||||
result_shape.set_dynamic_dimension(i, true);
|
||||
}
|
||||
}
|
||||
|
||||
return result_shape;
|
||||
}
|
||||
|
||||
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
|
||||
|
@ -125,8 +125,9 @@ WhileUtil::MakeInstructionsLiveIn(
|
||||
// We want to get rid of the old while instruction even if it has side
|
||||
// effecting operations so we do a manual HloComputation::RemoveInstruction
|
||||
// instead of relying on HloComputation::ReplaceInstruction.
|
||||
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
|
||||
new_while, while_instr->shape().tuple_shapes_size())));
|
||||
HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
|
||||
new_while, while_instr->shape().tuple_shapes_size());
|
||||
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
|
||||
TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
|
||||
|
||||
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
|
||||
@ -142,6 +143,7 @@ WhileUtil::MakeInstructionsLiveIn(
|
||||
WhileUtil::MakeInstructionsLiveInResult result;
|
||||
|
||||
result.new_while_instr = new_while;
|
||||
result.replacement_instr = replacement_instr;
|
||||
result.while_body_live_in_values = std::move(live_in_instructions);
|
||||
result.while_body_instruction_map = std::move(inlined_instructions_map);
|
||||
|
||||
|
@ -29,6 +29,10 @@ class WhileUtil {
|
||||
// The new while operation that has the requested values live in.
|
||||
HloInstruction* new_while_instr;
|
||||
|
||||
// The new tuple instruction that replaced the original while instruction
|
||||
// with the same shape.
|
||||
HloInstruction* replacement_instr;
|
||||
|
||||
// The i'th element of `while_body_live_in_values` is an instruction in the
|
||||
// while body that holds the i'th *newly added* live in value at runtime.
|
||||
std::vector<HloInstruction*> while_body_live_in_values;
|
||||
|
Loading…
Reference in New Issue
Block a user