[XLA][TF2XLA] Support tensor list with dynamic dimension.
Previously we don't allow a dynamic dimension to change in a HLO while loop. But this constrain breaks tensor list where the true dynamic dimension is only known inside the loop body. This CL: - Add the feature in dynamic padder to be able to change a dynamic dimension's size in the loop. - Add a nice test to demonstrate how tensor list / stack can be handled more elegantly in xla. - Add necessary machinery to wire this feature into tf2xla. PiperOrigin-RevId: 307901191 Change-Id: I4d39f1d8a8c944f1e9834c39599e6cfbc99f6807
This commit is contained in:
parent
e224bfeabb
commit
f86b74e27e
@ -460,14 +460,18 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK: HloModule
|
// CHECK: HloModule
|
||||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
|
||||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
|
||||||
return %0 : tensor<i32>
|
%1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||||
|
return %1 : tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: ENTRY
|
// CHECK: ENTRY
|
||||||
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
// CHECK: [[ARG:%.*]] = f32[4,2] parameter(0)
|
||||||
// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1}
|
// CHECK: [[SIZE:%.*]] = s32[] parameter(1)
|
||||||
|
// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size(f32[4,2] [[ARG]], s32[] [[SIZE]]), dimensions={1}
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size(f32[4,<=2] [[DYNAMIC]]), dimensions={1}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -274,10 +274,23 @@ class ZerosLikeOp : public XlaOpKernel {
|
|||||||
|
|
||||||
auto list_shape_or = ctx->builder()->GetShape(list);
|
auto list_shape_or = ctx->builder()->GetShape(list);
|
||||||
OP_REQUIRES_OK(ctx, list_shape_or.status());
|
OP_REQUIRES_OK(ctx, list_shape_or.status());
|
||||||
|
const xla::Shape& list_shape = list_shape_or.ValueOrDie();
|
||||||
|
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||||
|
list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1);
|
||||||
|
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||||
|
// Set dynamic dimension size to 0 for initialization value.
|
||||||
|
std::vector<xla::XlaOp> dynamic_dims;
|
||||||
|
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||||
|
auto sub_element = xla::GetTupleElement(list, i);
|
||||||
|
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||||
|
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
|
||||||
|
}
|
||||||
|
list_dynamic_dims.push_back(dynamic_dims);
|
||||||
|
}
|
||||||
xla::XlaOp new_list;
|
xla::XlaOp new_list;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, CreateZerosTensorListWithShape(
|
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
|
||||||
ctx->builder(), list_shape_or.ValueOrDie(), &new_list));
|
list_dynamic_dims, &new_list));
|
||||||
|
|
||||||
xla::XlaOp push_index;
|
xla::XlaOp push_index;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
|
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
|
||||||
@ -287,10 +300,20 @@ class ZerosLikeOp : public XlaOpKernel {
|
|||||||
SetTensorListPushIndex(new_list, push_index, &result));
|
SetTensorListPushIndex(new_list, push_index, &result));
|
||||||
ctx->SetTensorListOutput(0, result);
|
ctx->SetTensorListOutput(0, result);
|
||||||
} else {
|
} else {
|
||||||
const TensorShape input_shape = ctx->InputShape(0);
|
|
||||||
|
|
||||||
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
||||||
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
|
xla::XlaOp input = ctx->Input(0);
|
||||||
|
auto input_shape = ctx->InputXlaShape(0).ValueOrDie();
|
||||||
|
auto result = xla::Broadcast(zero, input_shape.dimensions());
|
||||||
|
|
||||||
|
// Setting up dynamic dimensions of the broadcast.
|
||||||
|
for (int64 i = 0; i < input_shape.dimensions_size(); ++i) {
|
||||||
|
if (input_shape.is_dynamic_dimension(i)) {
|
||||||
|
xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i);
|
||||||
|
result = xla::SetDimensionSize(result, input_dynamic_dim, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->SetOutput(0, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -44,6 +44,36 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
|
||||||
|
// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
|
||||||
|
// dimension is static, a constant dimension is returned.
|
||||||
|
xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
|
||||||
|
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
|
||||||
|
const xla::Shape& list_shape, int64 num_elements) {
|
||||||
|
std::vector<int64> dynamic_sizes;
|
||||||
|
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||||
|
// The multiplier can be a dynamic value.
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
|
||||||
|
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||||
|
// Set dynamic dimension size to 0 for initialization value.
|
||||||
|
std::vector<xla::XlaOp> dynamic_dims;
|
||||||
|
// Leading dim is a static dimension.
|
||||||
|
dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
|
||||||
|
for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
|
||||||
|
if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
|
||||||
|
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
|
||||||
|
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
|
||||||
|
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
|
||||||
|
dynamic_dims.push_back(dynamic_dim_size);
|
||||||
|
} else {
|
||||||
|
dynamic_dims.push_back(
|
||||||
|
xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list_dynamic_dims.push_back(dynamic_dims);
|
||||||
|
return list_dynamic_dims;
|
||||||
|
}
|
||||||
|
|
||||||
class TensorListLengthOp : public XlaOpKernel {
|
class TensorListLengthOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
@ -124,10 +154,14 @@ class TensorListReserveOp : public XlaOpKernel {
|
|||||||
xla::Shape list_shape;
|
xla::Shape list_shape;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||||
element_shape, num_elements, &list_shape));
|
element_shape, num_elements, &list_shape));
|
||||||
|
// Set up dynamic dimension sizes to create the zero tensor.
|
||||||
|
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||||
|
ctx, element_shape, list_shape, num_elements);
|
||||||
|
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
|
||||||
xla::XlaOp new_list;
|
xla::XlaOp new_list;
|
||||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
||||||
ctx->builder(), list_shape, &new_list));
|
ctx->builder(), list_shape,
|
||||||
|
list_dynamic_dims_or.ValueOrDie(), &new_list));
|
||||||
xla::XlaOp result;
|
xla::XlaOp result;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx,
|
ctx,
|
||||||
@ -185,10 +219,16 @@ class EmptyTensorListOp : public XlaOpKernel {
|
|||||||
xla::Shape list_shape;
|
xla::Shape list_shape;
|
||||||
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
|
||||||
element_shape, max_num_elements, &list_shape));
|
element_shape, max_num_elements, &list_shape));
|
||||||
|
// Set up dynamic dimension sizes to create the zero tensor.
|
||||||
|
auto list_dynamic_dims_or = GetTensorListDynamicDims(
|
||||||
|
ctx, element_shape, list_shape, max_num_elements);
|
||||||
|
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
|
||||||
|
|
||||||
xla::XlaOp result;
|
xla::XlaOp result;
|
||||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
||||||
ctx->builder(), list_shape, &result));
|
ctx->builder(), list_shape,
|
||||||
|
list_dynamic_dims_or.ValueOrDie(), &result));
|
||||||
|
|
||||||
ctx->SetTensorListOutput(0, result);
|
ctx->SetTensorListOutput(0, result);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -247,19 +248,29 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
|
Status CreateZerosTensorListWithShape(
|
||||||
const xla::Shape& list_shape,
|
xla::XlaBuilder* b, const xla::Shape& list_shape,
|
||||||
xla::XlaOp* list) {
|
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
|
||||||
|
xla::XlaOp* list) {
|
||||||
int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
|
int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
|
||||||
std::vector<xla::XlaOp> elements;
|
std::vector<xla::XlaOp> elements;
|
||||||
for (int i = 0; i < tuple_size; i++) {
|
TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
|
||||||
|
for (int i = 0; i < tuple_size - 1; i++) {
|
||||||
const xla::Shape& shape =
|
const xla::Shape& shape =
|
||||||
xla::ShapeUtil::GetTupleElementShape(list_shape, i);
|
xla::ShapeUtil::GetTupleElementShape(list_shape, i);
|
||||||
xla::XlaOp zero =
|
xla::XlaOp zero =
|
||||||
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
|
xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
|
||||||
xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
|
xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
|
||||||
|
TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
|
||||||
|
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||||
|
zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
|
||||||
|
}
|
||||||
elements.push_back(zeros);
|
elements.push_back(zeros);
|
||||||
}
|
}
|
||||||
|
// List size (last item) has to be S32.
|
||||||
|
TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
|
||||||
|
.element_type() == xla::S32);
|
||||||
|
elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
|
||||||
*list = xla::Tuple(b, elements);
|
*list = xla::Tuple(b, elements);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -272,12 +283,12 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
|||||||
|
|
||||||
xla::XlaBuilder* b = list.builder();
|
xla::XlaBuilder* b = list.builder();
|
||||||
xla::Shape list_shape;
|
xla::Shape list_shape;
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
||||||
|
|
||||||
if (element_is_tensor_list) {
|
if (element_is_tensor_list) {
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
|
||||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
|
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
|
||||||
element_shape, leading_dim, &list_shape));
|
element_shape, leading_dim, &list_shape));
|
||||||
} else {
|
} else {
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
|
|
||||||
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
|
TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
|
||||||
element_shape, leading_dim, &list_shape));
|
element_shape, leading_dim, &list_shape));
|
||||||
}
|
}
|
||||||
@ -295,7 +306,27 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
|
|||||||
*initialized_list = list;
|
*initialized_list = list;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return CreateZerosTensorListWithShape(b, list_shape, initialized_list);
|
// Prepare dynamic dimension dimensions for zero tensor list. The dynamic
|
||||||
|
// sizes are created by reading the dynamic dimension size of sub-elements.
|
||||||
|
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||||
|
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||||
|
std::vector<xla::XlaOp> dynamic_dims;
|
||||||
|
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||||
|
// Leading dim is a static dimension.
|
||||||
|
dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
|
||||||
|
xla::XlaOp sub_element;
|
||||||
|
if (element_is_tensor_list) {
|
||||||
|
sub_element = xla::GetTupleElement(element, i);
|
||||||
|
} else {
|
||||||
|
sub_element = element;
|
||||||
|
}
|
||||||
|
for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
|
||||||
|
dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
|
||||||
|
}
|
||||||
|
list_dynamic_dims.push_back(dynamic_dims);
|
||||||
|
}
|
||||||
|
return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
|
||||||
|
initialized_list);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,7 +504,13 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
|
|||||||
|
|
||||||
xla::XlaOp list_part = xla::GetTupleElement(list, 0);
|
xla::XlaOp list_part = xla::GetTupleElement(list, 0);
|
||||||
xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
|
xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
|
||||||
|
for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) {
|
||||||
|
if (buffer_shape.is_dynamic_dimension(i)) {
|
||||||
|
auto buffer = xla::GetTupleElement(list, 0);
|
||||||
|
auto gds = xla::GetDimensionSize(buffer, i);
|
||||||
|
read = xla::SetDimensionSize(read, gds, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
slice_shape.erase(slice_shape.begin());
|
slice_shape.erase(slice_shape.begin());
|
||||||
*result = xla::Reshape(read, slice_shape);
|
*result = xla::Reshape(read, slice_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -74,9 +74,9 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
|
|||||||
xla::Shape* tensor_list_shape);
|
xla::Shape* tensor_list_shape);
|
||||||
|
|
||||||
// Returns a TensorList filled by zeros with the given shape.
|
// Returns a TensorList filled by zeros with the given shape.
|
||||||
Status CreateZerosTensorListWithShape(xla::XlaBuilder* b,
|
Status CreateZerosTensorListWithShape(
|
||||||
const xla::Shape& list_shape,
|
xla::XlaBuilder* b, const xla::Shape& list_shape,
|
||||||
xla::XlaOp* list);
|
const std::vector<std::vector<xla::XlaOp>>& dynamic_dims, xla::XlaOp* list);
|
||||||
|
|
||||||
// If the TensorList is initialized, check that its shape matches element shape;
|
// If the TensorList is initialized, check that its shape matches element shape;
|
||||||
// If the TensorList is uninitialized, initialize it with the element shape.
|
// If the TensorList is uninitialized, initialize it with the element shape.
|
||||||
|
@ -510,8 +510,25 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
// first compilation and the body/cond was recompiled with the updated
|
// first compilation and the body/cond was recompiled with the updated
|
||||||
// shape/datatype of the list.
|
// shape/datatype of the list.
|
||||||
if (input_shape != list_shape) {
|
if (input_shape != list_shape) {
|
||||||
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
|
// Prepare dynamic dimensions for element shapes.
|
||||||
ctx->builder(), list_shape, &inputs[i]));
|
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
|
||||||
|
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
|
||||||
|
// Set dynamic dimension size to 0 for initilization value.
|
||||||
|
std::vector<xla::XlaOp> dynamic_dims;
|
||||||
|
const xla::Shape& shape = list_shape.tuple_shapes(i);
|
||||||
|
for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
|
||||||
|
int32 dim_size = shape.dimensions(dim);
|
||||||
|
if (shape.is_dynamic_dimension(dim)) {
|
||||||
|
dim_size = 0;
|
||||||
|
}
|
||||||
|
dynamic_dims.push_back(
|
||||||
|
xla::ConstantR0<int32>(ctx->builder(), dim_size));
|
||||||
|
}
|
||||||
|
list_dynamic_dims.push_back(dynamic_dims);
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
|
||||||
|
list_dynamic_dims, &inputs[i]));
|
||||||
} else {
|
} else {
|
||||||
inputs[i] = ctx->Input(input_num);
|
inputs[i] = ctx->Input(input_num);
|
||||||
}
|
}
|
||||||
|
@ -217,6 +217,8 @@ class XlaOpKernelContext {
|
|||||||
return dynamic_dimension_is_minus_one_;
|
return dynamic_dimension_is_minus_one_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; }
|
||||||
|
|
||||||
// Reads the current value of the resource variable referred to by input
|
// Reads the current value of the resource variable referred to by input
|
||||||
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
|
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
|
||||||
// variable. Returns an error if the variable has not been initialized, or if
|
// variable. Returns an error if the variable has not been initialized, or if
|
||||||
|
@ -2646,6 +2646,11 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
|
|||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
|
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
|
||||||
*operand_shape, dimension));
|
*operand_shape, dimension));
|
||||||
|
// Calling GetDimensionSize on a static dimension returns a constant
|
||||||
|
// instruction.
|
||||||
|
if (!operand_shape->is_dynamic_dimension(dimension)) {
|
||||||
|
return ConstantR0<int32>(this, operand_shape->dimensions(dimension));
|
||||||
|
}
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
instr.add_dimensions(dimension);
|
instr.add_dimensions(dimension);
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
|
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
|
||||||
@ -2657,8 +2662,20 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
|||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
|
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape(
|
||||||
*operand_shape, dimension));
|
*operand_shape, dimension));
|
||||||
|
// Setting an op's dynamic dimension to the static size is a noop.
|
||||||
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
|
||||||
|
LookUpInstruction(val));
|
||||||
|
if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
|
||||||
|
HloOpcode::kConstant) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto literal,
|
||||||
|
Literal::CreateFromProto(val_proto->literal(), true));
|
||||||
|
if (literal.Get<int32>({}) == shape.dimensions(dimension)) {
|
||||||
|
return operand;
|
||||||
|
}
|
||||||
|
}
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
instr.add_dimensions(dimension);
|
instr.add_dimensions(dimension);
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
|
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
|
||||||
|
@ -407,13 +407,25 @@ TEST_F(XlaBuilderTest, CollectivePermute) {
|
|||||||
|
|
||||||
TEST_F(XlaBuilderTest, GetDimensionSize) {
|
TEST_F(XlaBuilderTest, GetDimensionSize) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
auto x =
|
||||||
|
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
|
||||||
GetDimensionSize(x, 1);
|
GetDimensionSize(x, 1);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
|
EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, GetDimensionSizeConstant) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto x =
|
||||||
|
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
|
||||||
|
// Get dimension size from a contant dimension gives us a constant.
|
||||||
|
GetDimensionSize(x, 0);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, ReportError) {
|
TEST_F(XlaBuilderTest, ReportError) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||||
|
@ -1369,77 +1369,27 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||||
// While loop is handled by passing dynamic size hlos as parameters into the
|
// If the output of the conditional contains dynamic dimension. We send
|
||||||
// hlo while loop. This is done by replacing the original while with a new
|
// dynamic dimension size out by adding additional root element. A mapping
|
||||||
// one.
|
// from the root instruction's dynamic dimension index (represented by a shape
|
||||||
//
|
// index as output index and a int64 dimension number) to output index
|
||||||
// Before:
|
// (represented by an int64) is tracked for the conditional instruction (all
|
||||||
//
|
// branches should have the same mapping).
|
||||||
// op1 = ...
|
ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
|
||||||
// op2 = ...
|
hlo->shape());
|
||||||
// op1_x = ... // dynamic dimension size of op1
|
|
||||||
// while = while(op1, op2)
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// After:
|
|
||||||
//
|
|
||||||
// op1 = ...
|
|
||||||
// op2 = ...
|
|
||||||
// op1_x = ... // dynamic dimension size of op1
|
|
||||||
// while = while(op1, op2, op1_x)
|
|
||||||
//
|
|
||||||
// In the above graph, op_x is the bound of the dynamic dimension size of op1
|
|
||||||
// and is wired into the while loop as new parameter.
|
|
||||||
//
|
|
||||||
// TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic
|
|
||||||
// bound can be propagated through native xla values instead of relying on
|
|
||||||
// additional parameter.
|
|
||||||
|
|
||||||
// dynamic_size_to_operand_id_index_map keeps track of dynamic size operations
|
|
||||||
// to their operand ids in the new while loop.
|
|
||||||
absl::flat_hash_map<HloInstruction*, int64>
|
|
||||||
dynamic_size_to_operand_id_index_map;
|
|
||||||
|
|
||||||
// operands_to_add collects dynamic sizes that need to be added to the while
|
|
||||||
// loop as parameters. Note that a dynamic size is ignored if it is already
|
|
||||||
// part of the parameter. i.e.:
|
|
||||||
//
|
|
||||||
// We don't do:
|
|
||||||
//
|
|
||||||
// op1 = ...
|
|
||||||
// op2 = ...
|
|
||||||
// op_x = ... // dynamic dimension size of both op1 and op2
|
|
||||||
// while = while(op1, op2, op_x, op_x) // 4 parameters
|
|
||||||
//
|
|
||||||
// But we do:
|
|
||||||
//
|
|
||||||
// op1 = ...
|
|
||||||
// op2 = ...
|
|
||||||
// op_x = ... // dynamic dimension size of both op1 and op2
|
|
||||||
// while = while(op1, op2, op_x)
|
|
||||||
//
|
|
||||||
// An alternative is to do this in a while loop CSE pass.
|
|
||||||
//
|
|
||||||
std::vector<HloInstruction*> operands_to_add;
|
std::vector<HloInstruction*> operands_to_add;
|
||||||
int64 operand_count = hlo->shape().tuple_shapes_size();
|
const int64 original_tuple_count = hlo->shape().tuple_shapes_size();
|
||||||
|
int64 operand_count = original_tuple_count;
|
||||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||||
hlo, [&](HloInstruction*, ShapeIndex, int64, int64,
|
hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
|
||||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||||
const HloInstruction* tuple_operand = hlo->operand(0);
|
operands_to_add.push_back(dynamic_size);
|
||||||
for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
|
dynamic_output_mapping.mutable_element(index)->emplace(dim,
|
||||||
if (dynamic_size == tuple_operand->operand(i)) {
|
operand_count++);
|
||||||
dynamic_size_to_operand_id_index_map[dynamic_size] = i;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
|
|
||||||
if (iter == dynamic_size_to_operand_id_index_map.end()) {
|
|
||||||
operands_to_add.push_back(dynamic_size);
|
|
||||||
dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
DynamicParameterBinding binding_for_while;
|
||||||
if (!operands_to_add.empty()) {
|
if (!operands_to_add.empty()) {
|
||||||
// Only replace the while loop if there are new parameters to add.
|
// Only replace the while loop if there are new parameters to add.
|
||||||
HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
|
HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
|
||||||
@ -1453,37 +1403,78 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
|||||||
parent_->CopyMapping(/*from=*/old_tuple_operand,
|
parent_->CopyMapping(/*from=*/old_tuple_operand,
|
||||||
/*to=*/new_tuple_operand);
|
/*to=*/new_tuple_operand);
|
||||||
hlo = result.new_while_instr;
|
hlo = result.new_while_instr;
|
||||||
|
// We have replaced the while loop, now set the dynamic dimensions for the
|
||||||
|
// newly created while loop so that the hlos that consumes the while loop
|
||||||
|
// can see the dynamic dimensions. Also sets the dynamic parameter binding
|
||||||
|
// for running inference in the while loop.
|
||||||
|
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||||
|
hlo,
|
||||||
|
[&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||||
|
int64 operand_index, HloInstruction* dynamic_size,
|
||||||
|
DimensionConstraint constraint) -> Status {
|
||||||
|
TF_RET_CHECK(!operands_to_add.empty());
|
||||||
|
const int64 output_dynamic_size_index =
|
||||||
|
dynamic_output_mapping.element(index).at(dimension);
|
||||||
|
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
||||||
|
operand_index, {output_dynamic_size_index}};
|
||||||
|
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
||||||
|
operand_index, index, dimension};
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
|
||||||
|
// This is the updated output dynamic size coming out of hlo while
|
||||||
|
// loop.
|
||||||
|
HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(
|
||||||
|
ShapeUtil::MakeScalarShape(S32), hlo,
|
||||||
|
output_dynamic_size_index));
|
||||||
|
parent_->SetDynamicSize(result.replacement_instr, index, dimension,
|
||||||
|
output_dynamic_size, constraint);
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
// Set the replacement instruction as visited to avoid visiting it again.
|
||||||
|
SetVisited(*result.replacement_instr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have replaced the while loop, now set the dynamic dimensions for the
|
|
||||||
// newly created while loop so that the hlos that consumes the while loop can
|
|
||||||
// see the dynamic dimensions. Also sets the dynamic parameter binding for
|
|
||||||
// running inference in the while loop.
|
|
||||||
DynamicParameterBinding binding_for_while;
|
|
||||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
|
||||||
hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
|
|
||||||
int64 operand_index, HloInstruction* dynamic_size,
|
|
||||||
DimensionConstraint constraint) {
|
|
||||||
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
|
||||||
operand_index,
|
|
||||||
{dynamic_size_to_operand_id_index_map[dynamic_size]}};
|
|
||||||
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
|
||||||
operand_index, index, dimension};
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
|
|
||||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
|
||||||
constraint);
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Run inference in while body and condition.
|
// Run inference in while body and condition.
|
||||||
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
||||||
hlo->while_body(), binding_for_while, parent_));
|
hlo->while_body(), binding_for_while, parent_));
|
||||||
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
||||||
hlo->while_condition(), binding_for_while, parent_));
|
hlo->while_condition(), binding_for_while, parent_));
|
||||||
|
|
||||||
// Set the replacement while loop as visited to avoid visiting it again.
|
if (operands_to_add.empty()) {
|
||||||
SetVisited(*hlo);
|
// No dynamic dimension in the inputs and outputs.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The dynamic dimension size could have been changed in the loop body (e.g, A
|
||||||
|
// loop that inserts items in a stack, the stack size increases with each
|
||||||
|
// iteration). Rewrite the dynamic dimension size at the root.
|
||||||
|
HloInstruction* body_root = hlo->while_body()->root_instruction();
|
||||||
|
std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
// Original non-dynamic-dim operands of root are pass-through.
|
||||||
|
for (int64 i = 0; i < original_tuple_count; ++i) {
|
||||||
|
new_root_operands[i] =
|
||||||
|
hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
|
body_root->shape().tuple_shapes(i), body_root, i));
|
||||||
|
}
|
||||||
|
// Add dynamic dimension size as new parameters.
|
||||||
|
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
|
||||||
|
hlo->while_body()->root_instruction(),
|
||||||
|
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
|
||||||
|
DimensionConstraint) -> Status {
|
||||||
|
const int64 output_index =
|
||||||
|
dynamic_output_mapping.element(index).at(dim);
|
||||||
|
new_root_operands[output_index] = dynamic_size;
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
for (auto operand : new_root_operands) {
|
||||||
|
TF_RET_CHECK(operand != nullptr);
|
||||||
|
}
|
||||||
|
HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
|
||||||
|
HloInstruction::CreateTuple(new_root_operands));
|
||||||
|
hlo->while_body()->set_root_instruction(new_body_root);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -767,7 +767,7 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
|
|||||||
// While
|
// While
|
||||||
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
/*parameter_number=*/0, tuple_shape, "A"));
|
/*parameter_number=*/0, tuple_shape, "A"));
|
||||||
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
/*parameter_number=*/1, scalar_shape_, "size_param"));
|
/*parameter_number=*/1, scalar_shape_, "size_param"));
|
||||||
builder.AddInstruction(
|
builder.AddInstruction(
|
||||||
HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
|
HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
|
||||||
@ -782,37 +782,32 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
|
|||||||
DynamicParameterBinding::DynamicParameter{1, {}},
|
DynamicParameterBinding::DynamicParameter{1, {}},
|
||||||
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
|
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
|
||||||
|
|
||||||
// Test that dynamic dimension inference does the right thing. A lambda is
|
|
||||||
// used here since we want to test twice by running inference again
|
|
||||||
// (idempotency).
|
|
||||||
auto test_dynamic_dimension = [&]() {
|
|
||||||
HloInstruction* while_hlo = nullptr;
|
|
||||||
// The while hlo has been replaced, find the new one.
|
|
||||||
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
|
|
||||||
if (inst->opcode() == HloOpcode::kWhile) {
|
|
||||||
while_hlo = inst;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ASSERT_NE(while_hlo, nullptr);
|
|
||||||
// The original while shape has 2 parameters. With dynamic size passed in
|
|
||||||
// as an extra parameter, the tuple should have 3 elements.
|
|
||||||
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3);
|
|
||||||
HloInstruction* add = nullptr;
|
|
||||||
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
|
|
||||||
if (inst->opcode() == HloOpcode::kAdd) {
|
|
||||||
add = inst;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EXPECT_NE(add, nullptr);
|
|
||||||
EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr);
|
|
||||||
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param);
|
|
||||||
EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param);
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_ASSERT_OK(RunInference());
|
TF_ASSERT_OK(RunInference());
|
||||||
test_dynamic_dimension();
|
HloInstruction* while_hlo = nullptr;
|
||||||
TF_ASSERT_OK(RunInference());
|
// The while hlo has been replaced, find the new one.
|
||||||
test_dynamic_dimension();
|
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
|
||||||
|
if (inst->opcode() == HloOpcode::kWhile) {
|
||||||
|
while_hlo = inst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_NE(while_hlo, nullptr);
|
||||||
|
// The original while shape has 2 parameters. With dynamic size, the tuple
|
||||||
|
// should have 4 elements (We don't deduplicate the arguments).
|
||||||
|
EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
|
||||||
|
HloInstruction* add_inst = nullptr;
|
||||||
|
for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
|
||||||
|
if (inst->opcode() == HloOpcode::kAdd) {
|
||||||
|
add_inst = inst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_NE(add_inst, nullptr);
|
||||||
|
EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
|
||||||
|
EXPECT_NE(inference_->GetDynamicSize(
|
||||||
|
module_->entry_computation()->root_instruction(), {0}, 0),
|
||||||
|
nullptr);
|
||||||
|
EXPECT_NE(inference_->GetDynamicSize(
|
||||||
|
module_->entry_computation()->root_instruction(), {1}, 0),
|
||||||
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
|
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
|
||||||
|
@ -903,6 +903,90 @@ ENTRY main {
|
|||||||
EXPECT_EQ(result, expected);
|
EXPECT_EQ(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ExecutionTest, WhileLoopStack) {
|
||||||
|
// Push into a dynamic sized stack with iteration number:
|
||||||
|
// init:
|
||||||
|
// [[P, P],
|
||||||
|
// [P, P],
|
||||||
|
// [P, P],
|
||||||
|
// [P, P]]
|
||||||
|
// First iteration i = 0:
|
||||||
|
// [[0, 0],
|
||||||
|
// [P, P],
|
||||||
|
// [P, P],
|
||||||
|
// [P, P]]
|
||||||
|
// Second iteration i = 1:
|
||||||
|
// [[0, 0],
|
||||||
|
// [1, 1],
|
||||||
|
// [P, P],
|
||||||
|
// [P, P]]
|
||||||
|
// Third iteration i = 2:
|
||||||
|
// [[0, 0],
|
||||||
|
// [1, 1],
|
||||||
|
// [2, 2],
|
||||||
|
// [P, P]]
|
||||||
|
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||||
|
lhs = s32[] parameter(0)
|
||||||
|
rhs = s32[] parameter(1)
|
||||||
|
ROOT add = s32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
stack = (s32[<=4,2]) parameter(0)
|
||||||
|
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
|
||||||
|
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
one = s32[] constant(1)
|
||||||
|
// content of the stack is the stack index broadcasted.
|
||||||
|
new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={}
|
||||||
|
new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero)
|
||||||
|
new_stack_size = s32[] add(stack_size, one)
|
||||||
|
new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0}
|
||||||
|
ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic)
|
||||||
|
}
|
||||||
|
|
||||||
|
condition {
|
||||||
|
stack = (s32[<=4,2]) parameter(0)
|
||||||
|
stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0
|
||||||
|
stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0}
|
||||||
|
three = s32[] constant(3)
|
||||||
|
ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
zero = s32[] constant(0)
|
||||||
|
pad = s32[] constant(-1)
|
||||||
|
stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={}
|
||||||
|
stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0}
|
||||||
|
input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic)
|
||||||
|
while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition
|
||||||
|
stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0
|
||||||
|
ROOT reduce = s32[2] reduce(stack_buffer, zero),
|
||||||
|
dimensions={0},
|
||||||
|
to_apply=update_s32
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto module = GetHloModule(hlo_text);
|
||||||
|
|
||||||
|
Literal result = PadAndExecute(std::move(module), {});
|
||||||
|
|
||||||
|
// Stack has three valid items in it:
|
||||||
|
// [[0, 0],
|
||||||
|
// [1, 1],
|
||||||
|
// [2, 2],
|
||||||
|
// [P, P]]
|
||||||
|
//
|
||||||
|
// Reducing along major dimension gives us [3, 3]
|
||||||
|
Literal expected = LiteralUtil::CreateR1<int32>({{3, 3}});
|
||||||
|
|
||||||
|
EXPECT_EQ(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
|
XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) {
|
||||||
const string hlo_text = R"(
|
const string hlo_text = R"(
|
||||||
HloModule TensorFlowScatterV1
|
HloModule TensorFlowScatterV1
|
||||||
|
@ -2596,7 +2596,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
|
VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
return operand_shape;
|
auto result_shape = operand_shape;
|
||||||
|
|
||||||
|
// If any of the operand shape and update shape is dynamic, update the result
|
||||||
|
// dimension to dynamic.
|
||||||
|
for (int64 i = 0; i < update_shape.rank(); ++i) {
|
||||||
|
if (update_shape.is_dynamic_dimension(i) ||
|
||||||
|
operand_shape.is_dynamic_dimension(i)) {
|
||||||
|
result_shape.set_dynamic_dimension(i, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
|
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
|
||||||
|
@ -125,8 +125,9 @@ WhileUtil::MakeInstructionsLiveIn(
|
|||||||
// We want to get rid of the old while instruction even if it has side
|
// We want to get rid of the old while instruction even if it has side
|
||||||
// effecting operations so we do a manual HloComputation::RemoveInstruction
|
// effecting operations so we do a manual HloComputation::RemoveInstruction
|
||||||
// instead of relying on HloComputation::ReplaceInstruction.
|
// instead of relying on HloComputation::ReplaceInstruction.
|
||||||
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
|
HloInstruction* replacement_instr = TupleUtil::ExtractPrefix(
|
||||||
new_while, while_instr->shape().tuple_shapes_size())));
|
new_while, while_instr->shape().tuple_shapes_size());
|
||||||
|
TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr));
|
||||||
TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
|
TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
|
||||||
|
|
||||||
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
|
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
|
||||||
@ -142,6 +143,7 @@ WhileUtil::MakeInstructionsLiveIn(
|
|||||||
WhileUtil::MakeInstructionsLiveInResult result;
|
WhileUtil::MakeInstructionsLiveInResult result;
|
||||||
|
|
||||||
result.new_while_instr = new_while;
|
result.new_while_instr = new_while;
|
||||||
|
result.replacement_instr = replacement_instr;
|
||||||
result.while_body_live_in_values = std::move(live_in_instructions);
|
result.while_body_live_in_values = std::move(live_in_instructions);
|
||||||
result.while_body_instruction_map = std::move(inlined_instructions_map);
|
result.while_body_instruction_map = std::move(inlined_instructions_map);
|
||||||
|
|
||||||
|
@ -29,6 +29,10 @@ class WhileUtil {
|
|||||||
// The new while operation that has the requested values live in.
|
// The new while operation that has the requested values live in.
|
||||||
HloInstruction* new_while_instr;
|
HloInstruction* new_while_instr;
|
||||||
|
|
||||||
|
// The new tuple instruction that replaced the original while instruction
|
||||||
|
// with the same shape.
|
||||||
|
HloInstruction* replacement_instr;
|
||||||
|
|
||||||
// The i'th element of `while_body_live_in_values` is an instruction in the
|
// The i'th element of `while_body_live_in_values` is an instruction in the
|
||||||
// while body that holds the i'th *newly added* live in value at runtime.
|
// while body that holds the i'th *newly added* live in value at runtime.
|
||||||
std::vector<HloInstruction*> while_body_live_in_values;
|
std::vector<HloInstruction*> while_body_live_in_values;
|
||||||
|
Loading…
Reference in New Issue
Block a user