[XLA] Use gather, cumsum instead of sorts to implement dynamic reshape.
PiperOrigin-RevId: 305514267 Change-Id: I7d9ea5caec1140c4cce1f71e97e9078503c8b2d4
This commit is contained in:
parent
3b7ba8030f
commit
742b24d564
@ -209,38 +209,35 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
|
||||
// [[a,b,P]
|
||||
// [c,d,P]]
|
||||
//
|
||||
// The way we do this is by a 6-steps double-sorting algorithm:
|
||||
// The way we do this is by a 5-steps cumsum-gather algorithm:
|
||||
//
|
||||
// 1.First we use the output shape to generate a binary 0-1 masking, which masks
|
||||
// out the padded area of the output:
|
||||
// [[0,0,1]
|
||||
// [0,0,1]]
|
||||
// [[1,1,0]
|
||||
// [1,1,0]]
|
||||
//
|
||||
// 2.Then we do an inverse reshape to reshape it from output shape back to input
|
||||
// shape [2,3]->[6]:
|
||||
// [0,0,1,0,0,1]
|
||||
// [1,1,0,1,1,0]
|
||||
//
|
||||
// 3.We then generate an iota mask using the input shape:
|
||||
// [0,1,2,3,4,5]
|
||||
// 3.We then do a cumsum with the mask:
|
||||
// [1,2,2,3,4,4] and subtract it with 1:
|
||||
// [0,1,1,2,3,3]
|
||||
//
|
||||
// 4.Stable sort the iota mask using the binary mask as key:
|
||||
// key [0,0,1,0,0,1]
|
||||
// value[0,1,2,3,4,5]
|
||||
// | Sort by key
|
||||
// 4.Use the the result of cumsum as gather indicies to rearrange the original
|
||||
// data. Feed the original input [a,b,c,d,P,P] and indices into gather.
|
||||
//
|
||||
// operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
|
||||
// | |
|
||||
// Gather-----------------+
|
||||
// |
|
||||
// v
|
||||
// key [0,0,0,0,1,1]
|
||||
// value[0,1,3,4,2,5]
|
||||
// value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
|
||||
// doesn't matter.
|
||||
//
|
||||
// 5.Sort the original input [a,b,c,d,P,P] using the sorted iota mask:
|
||||
// key [0,1,3,4,2,5]
|
||||
// value[a,b,c,d,P,P]
|
||||
// | Sort by key
|
||||
// v
|
||||
// key [0,1,2,3,4,5]
|
||||
// value[a,b,P,c,d,P]
|
||||
//
|
||||
// 6.Feed the sorted input to original reshape[6]->[2,3], we can get the correct
|
||||
// reshape:
|
||||
// 5.Feed the sorted input to original reshape[6]->[2,3], we can now get the
|
||||
// correct result:
|
||||
// [[a,b,P]
|
||||
// [c,d,P]]
|
||||
//
|
||||
@ -248,27 +245,37 @@ Status RewriteDynamicReshapeSplitInput(
|
||||
HloInstruction* reshape, int64 input_dim,
|
||||
absl::Span<const int64> output_dims,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
VLOG(1) << "Reshaping input dim " << input_dim << "to "
|
||||
<< VectorString(output_dims);
|
||||
const Shape operand_shape = reshape->operand(0)->shape();
|
||||
TF_RET_CHECK(output_dims.size() > 1);
|
||||
|
||||
HloComputation* comp = reshape->parent();
|
||||
const Shape mask_input_shape =
|
||||
ShapeUtil::ChangeElementType(operand_shape, xla::S32);
|
||||
ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
|
||||
|
||||
std::vector<int64> reshaped_dims;
|
||||
for (int64 output_dim : output_dims) {
|
||||
reshaped_dims.push_back(reshape->shape().dimensions(output_dim));
|
||||
}
|
||||
|
||||
const Shape mask_reshaped_shape =
|
||||
ShapeUtil::ChangeElementType(reshape->shape(), xla::S32);
|
||||
ShapeUtil::MakeShape(xla::S32, reshaped_dims);
|
||||
|
||||
HloInstruction* zero = comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
|
||||
HloInstruction* one = comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::One(S32)));
|
||||
// Step 1 -- generate binary mask.
|
||||
// Mask starts with all zero, each dynamic dimension sets one dimension of the
|
||||
// mask to partially one.
|
||||
// Mask starts with all one, each dynamic dimension sets that dimension of the
|
||||
// mask to partially zero in the end.
|
||||
HloInstruction* binary_mask = comp->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(mask_reshaped_shape, zero, {}));
|
||||
HloInstruction::CreateBroadcast(mask_reshaped_shape, one, {}));
|
||||
|
||||
bool need_rewrite = false;
|
||||
|
||||
// Pad the effective dimension with 1.
|
||||
//
|
||||
// Index starts from 1 since there is no need to rewrite a major output
|
||||
// dimension.
|
||||
for (int64 i = 1; i < output_dims.size(); ++i) {
|
||||
@ -278,10 +285,10 @@ Status RewriteDynamicReshapeSplitInput(
|
||||
if (dynamic_size == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// If there is dynamic dimension in the output, need rewrite the input.
|
||||
// If there is dynamic dimension in the output, need to rewrite the input.
|
||||
need_rewrite = true;
|
||||
|
||||
binary_mask = PadWithScalar(binary_mask, output_dim, dynamic_size, one);
|
||||
binary_mask = PadWithScalar(binary_mask, i, dynamic_size, zero);
|
||||
}
|
||||
if (!need_rewrite) {
|
||||
return Status::OK();
|
||||
@ -292,90 +299,77 @@ Status RewriteDynamicReshapeSplitInput(
|
||||
HloInstruction* input_shape_binary_mask = comp->AddInstruction(
|
||||
HloInstruction::CreateReshape(mask_input_shape, binary_mask));
|
||||
|
||||
// Step 3. Generate iota mask.
|
||||
HloInstruction* iota_mask = comp->AddInstruction(
|
||||
HloInstruction::CreateIota(mask_input_shape, input_dim));
|
||||
|
||||
// Step 4. Sort iota.
|
||||
// Use binary mark to sort iota mask, then use iota mask to reshape input.
|
||||
HloComputation::Builder comp_builder("compare_binary_iota");
|
||||
// Step 3. Do a cumsum on the binary mask.
|
||||
auto embedded_builder = HloComputation::Builder("add");
|
||||
{
|
||||
HloInstruction* lhs_key =
|
||||
comp_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(S32, {}), "lhs_key_binary"));
|
||||
HloInstruction* rhs_key =
|
||||
comp_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(S32, {}), "rhs_key_binary"));
|
||||
|
||||
// Values for lhs and rhs
|
||||
comp_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
2, ShapeUtil::MakeShape(S32, {}), "lhs_iota"));
|
||||
comp_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
3, ShapeUtil::MakeShape(S32, {}), "rhs_iota"));
|
||||
comp_builder.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
|
||||
rhs_key, ComparisonDirection::kLt));
|
||||
auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "lhs"));
|
||||
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "rhs"));
|
||||
embedded_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
|
||||
}
|
||||
|
||||
HloComputation* compare_binary_iota =
|
||||
comp->parent()->AddEmbeddedComputation(comp_builder.Build());
|
||||
HloComputation* add =
|
||||
reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
|
||||
Window cumsum_window;
|
||||
// First dimension is unchanged.
|
||||
WindowDimension* dim = cumsum_window.add_dimensions();
|
||||
dim->set_size(operand_shape.dimensions(input_dim));
|
||||
dim->set_stride(1);
|
||||
dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
HloInstruction* cumsum =
|
||||
comp->AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
|
||||
|
||||
HloInstruction* sorted_binary_iota =
|
||||
comp->AddInstruction(HloInstruction::CreateSort(
|
||||
ShapeUtil::MakeTupleShape({mask_input_shape, mask_input_shape}),
|
||||
input_dim, {input_shape_binary_mask, iota_mask}, compare_binary_iota,
|
||||
/*is_stable=*/true));
|
||||
HloInstruction* sorted_iota_mask =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
mask_input_shape, sorted_binary_iota, 1));
|
||||
HloInstruction* broadcast_ones = comp->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
|
||||
cumsum = comp->AddInstruction(HloInstruction::CreateBinary(
|
||||
mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
|
||||
|
||||
// Step 5. Sort original input using iota mask as key.
|
||||
HloComputation::Builder comp_builder_iota("compare_binary_iota");
|
||||
{
|
||||
HloInstruction* lhs_key =
|
||||
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(S32, {}), "lhs_key_iota"));
|
||||
HloInstruction* rhs_key =
|
||||
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(S32, {}), "rhs_key_iota"));
|
||||
|
||||
// Values for lhs and rhs
|
||||
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
|
||||
2, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
|
||||
"lhs_value"));
|
||||
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
|
||||
3, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
|
||||
"rhs_value"));
|
||||
comp_builder_iota.AddInstruction(
|
||||
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
|
||||
rhs_key, ComparisonDirection::kLt));
|
||||
GatherDimensionNumbers gather_dim_numbers;
|
||||
// We use gather to rearrange the input dim dimension. However the current
|
||||
// semantic of gather doesn't allow us to collapse dimension in this case so
|
||||
// we keep it, which make the gather from shape [..., input_dim, ...] to
|
||||
// [..., 1, input_dim, ...]
|
||||
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
|
||||
// Offset dim is every dimension including newly added size 1 dim, except
|
||||
// for input_dim, which acts as a batch_dim.
|
||||
if (i != input_dim) {
|
||||
gather_dim_numbers.add_offset_dims(i);
|
||||
}
|
||||
}
|
||||
// The dimension to rewrite is the index dim.
|
||||
gather_dim_numbers.add_start_index_map(input_dim);
|
||||
gather_dim_numbers.set_index_vector_dim(1);
|
||||
gather_dim_numbers.add_collapsed_slice_dims(input_dim);
|
||||
|
||||
HloComputation* compare_iota_value =
|
||||
comp->parent()->AddEmbeddedComputation(comp_builder_iota.Build());
|
||||
// Step 4. Gather.
|
||||
|
||||
// Temporarily removes dynamic dimension before entering sort -- we want the
|
||||
// sort to ignore dynamic dimension.
|
||||
// Temporarily removes dynamic dimension before entering gather -- we want the
|
||||
// gather to ignore dynamic dimension.
|
||||
HloInstruction* operand_static_dim_size =
|
||||
comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(operand_shape.dimensions(input_dim))));
|
||||
|
||||
HloInstruction* operand_static =
|
||||
comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
|
||||
operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
|
||||
input_dim));
|
||||
|
||||
HloInstruction* sorted_iota_value =
|
||||
comp->AddInstruction(HloInstruction::CreateSort(
|
||||
ShapeUtil::MakeTupleShape({mask_input_shape, operand_shape}),
|
||||
input_dim, {sorted_iota_mask, operand_static}, compare_iota_value,
|
||||
/*is_stable=*/true));
|
||||
// Step 6: Feed sorted input to original reshape.
|
||||
HloInstruction* sorted_operand =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
operand_shape, sorted_iota_value, 1));
|
||||
std::vector<int64> slice_sizes(operand_shape.dimensions().begin(),
|
||||
operand_shape.dimensions().end());
|
||||
slice_sizes[input_dim] = 1;
|
||||
HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
|
||||
ShapeUtil::MakeShape(operand_shape.element_type(),
|
||||
operand_shape.dimensions()),
|
||||
operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
|
||||
|
||||
TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, sorted_operand));
|
||||
// Step 6: Feed gather input to original reshape.
|
||||
|
||||
TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
|
||||
|
||||
HloInstruction* reshape_dynamic = reshape;
|
||||
|
||||
|
@ -682,6 +682,98 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMajor) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = s32[2, 6] parameter(0)
|
||||
const = s32[] constant(4)
|
||||
param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1}
|
||||
// Third dimension is dynamic.
|
||||
reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[2, 2] reduce(reshaped, init),
|
||||
dimensions={2},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
// The third dimension has upper bound of 5, dynamic dimension is 3.
|
||||
Literal operand =
|
||||
LiteralUtil::CreateR2<int32>({{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 10, 11}});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand});
|
||||
|
||||
// After padding and reshape we have
|
||||
//
|
||||
// [[[0, 1, P],
|
||||
// [2, 3, P]],
|
||||
// [[6, 7, P],
|
||||
// [8, 9, P]]]
|
||||
// Reducing on the third dimension gives us
|
||||
// [0+1, 2+3]
|
||||
// [6+7, 8+9]
|
||||
//
|
||||
Literal expected = LiteralUtil::CreateR2<int32>({{1, 5}, {13, 17}});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMinor) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = s32[6, 2] parameter(0)
|
||||
const = s32[] constant(4)
|
||||
param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0}
|
||||
// Second dimension is dynamic.
|
||||
reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[2, 2] reduce(reshaped, init),
|
||||
dimensions={1},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
// The third dimension has upper bound of 5, dynamic dimension is 3.
|
||||
Literal operand = LiteralUtil::CreateR2<int32>(
|
||||
{{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand});
|
||||
|
||||
// After padding and reshape we have
|
||||
//
|
||||
// [[[0, 1],
|
||||
// [2, 3]
|
||||
// [P, P]],
|
||||
// [[4, 5],
|
||||
// [6, 7],
|
||||
// [P, P]]]
|
||||
// Reducing on the second dimension gives us
|
||||
// [0+2, 1+3]
|
||||
// [4+6, 5+7]
|
||||
//
|
||||
Literal expected = LiteralUtil::CreateR2<int32>({{2, 4}, {10, 12}});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
Loading…
x
Reference in New Issue
Block a user