[XLA] Use gather, cumsum instead of sorts to implement dynamic reshape.

PiperOrigin-RevId: 305514267
Change-Id: I7d9ea5caec1140c4cce1f71e97e9078503c8b2d4
This commit is contained in:
Yunxing Dai 2020-04-08 11:10:16 -07:00 committed by TensorFlower Gardener
parent 3b7ba8030f
commit 742b24d564
2 changed files with 181 additions and 95 deletions

View File

@ -209,38 +209,35 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
// [[a,b,P]
// [c,d,P]]
//
// The way we do this is by a 6-steps double-sorting algorithm:
// The way we do this is by a 5-steps cumsum-gather algorithm:
//
// 1.First we use the output shape to generate a binary 0-1 masking, which masks
// out the padded area of the output:
// [[0,0,1]
// [0,0,1]]
// [[1,1,0]
// [1,1,0]]
//
// 2.Then we do an inverse reshape to reshape it from output shape back to input
// shape [2,3]->[6]:
// [0,0,1,0,0,1]
// [1,1,0,1,1,0]
//
// 3.We then generate an iota mask using the input shape:
// [0,1,2,3,4,5]
// 3.We then do a cumsum with the mask:
// [1,2,2,3,4,4] and subtract it with 1:
// [0,1,1,2,3,3]
//
// 4.Stable sort the iota mask using the binary mask as key:
// key [0,0,1,0,0,1]
// value[0,1,2,3,4,5]
// | Sort by key
// 4.Use the the result of cumsum as gather indicies to rearrange the original
// data. Feed the original input [a,b,c,d,P,P] and indices into gather.
//
// operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
// | |
// Gather-----------------+
// |
// v
// key [0,0,0,0,1,1]
// value[0,1,3,4,2,5]
// value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
// doesn't matter.
//
// 5.Sort the original input [a,b,c,d,P,P] using the sorted iota mask:
// key [0,1,3,4,2,5]
// value[a,b,c,d,P,P]
// | Sort by key
// v
// key [0,1,2,3,4,5]
// value[a,b,P,c,d,P]
//
// 6.Feed the sorted input to original reshape[6]->[2,3], we can get the correct
// reshape:
// 5.Feed the sorted input to original reshape[6]->[2,3], we can now get the
// correct result:
// [[a,b,P]
// [c,d,P]]
//
@ -248,27 +245,37 @@ Status RewriteDynamicReshapeSplitInput(
HloInstruction* reshape, int64 input_dim,
absl::Span<const int64> output_dims,
DynamicDimensionInference* dynamic_dimension_inference) {
VLOG(1) << "Reshaping input dim " << input_dim << "to "
<< VectorString(output_dims);
const Shape operand_shape = reshape->operand(0)->shape();
TF_RET_CHECK(output_dims.size() > 1);
HloComputation* comp = reshape->parent();
const Shape mask_input_shape =
ShapeUtil::ChangeElementType(operand_shape, xla::S32);
ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
std::vector<int64> reshaped_dims;
for (int64 output_dim : output_dims) {
reshaped_dims.push_back(reshape->shape().dimensions(output_dim));
}
const Shape mask_reshaped_shape =
ShapeUtil::ChangeElementType(reshape->shape(), xla::S32);
ShapeUtil::MakeShape(xla::S32, reshaped_dims);
HloInstruction* zero = comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
HloInstruction* one = comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(S32)));
// Step 1 -- generate binary mask.
// Mask starts with all zero, each dynamic dimension sets one dimension of the
// mask to partially one.
// Mask starts with all one, each dynamic dimension sets that dimension of the
// mask to partially zero in the end.
HloInstruction* binary_mask = comp->AddInstruction(
HloInstruction::CreateBroadcast(mask_reshaped_shape, zero, {}));
HloInstruction::CreateBroadcast(mask_reshaped_shape, one, {}));
bool need_rewrite = false;
// Pad the effective dimension with 1.
//
// Index starts from 1 since there is no need to rewrite a major output
// dimension.
for (int64 i = 1; i < output_dims.size(); ++i) {
@ -278,10 +285,10 @@ Status RewriteDynamicReshapeSplitInput(
if (dynamic_size == nullptr) {
continue;
}
// If there is dynamic dimension in the output, need rewrite the input.
// If there is dynamic dimension in the output, need to rewrite the input.
need_rewrite = true;
binary_mask = PadWithScalar(binary_mask, output_dim, dynamic_size, one);
binary_mask = PadWithScalar(binary_mask, i, dynamic_size, zero);
}
if (!need_rewrite) {
return Status::OK();
@ -292,90 +299,77 @@ Status RewriteDynamicReshapeSplitInput(
HloInstruction* input_shape_binary_mask = comp->AddInstruction(
HloInstruction::CreateReshape(mask_input_shape, binary_mask));
// Step 3. Generate iota mask.
HloInstruction* iota_mask = comp->AddInstruction(
HloInstruction::CreateIota(mask_input_shape, input_dim));
// Step 4. Sort iota.
// Use binary mark to sort iota mask, then use iota mask to reshape input.
HloComputation::Builder comp_builder("compare_binary_iota");
// Step 3. Do a cumsum on the binary mask.
auto embedded_builder = HloComputation::Builder("add");
{
HloInstruction* lhs_key =
comp_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "lhs_key_binary"));
HloInstruction* rhs_key =
comp_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(S32, {}), "rhs_key_binary"));
// Values for lhs and rhs
comp_builder.AddInstruction(HloInstruction::CreateParameter(
2, ShapeUtil::MakeShape(S32, {}), "lhs_iota"));
comp_builder.AddInstruction(HloInstruction::CreateParameter(
3, ShapeUtil::MakeShape(S32, {}), "rhs_iota"));
comp_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
rhs_key, ComparisonDirection::kLt));
auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "lhs"));
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(operand_shape.element_type(), {}), "rhs"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
}
HloComputation* compare_binary_iota =
comp->parent()->AddEmbeddedComputation(comp_builder.Build());
HloComputation* add =
reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
Window cumsum_window;
// First dimension is unchanged.
WindowDimension* dim = cumsum_window.add_dimensions();
dim->set_size(operand_shape.dimensions(input_dim));
dim->set_stride(1);
dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
dim->set_padding_high(0);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
HloInstruction* cumsum =
comp->AddInstruction(HloInstruction::CreateReduceWindow(
mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
HloInstruction* sorted_binary_iota =
comp->AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeTupleShape({mask_input_shape, mask_input_shape}),
input_dim, {input_shape_binary_mask, iota_mask}, compare_binary_iota,
/*is_stable=*/true));
HloInstruction* sorted_iota_mask =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
mask_input_shape, sorted_binary_iota, 1));
HloInstruction* broadcast_ones = comp->AddInstruction(
HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
cumsum = comp->AddInstruction(HloInstruction::CreateBinary(
mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
// Step 5. Sort original input using iota mask as key.
HloComputation::Builder comp_builder_iota("compare_binary_iota");
{
HloInstruction* lhs_key =
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "lhs_key_iota"));
HloInstruction* rhs_key =
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(S32, {}), "rhs_key_iota"));
// Values for lhs and rhs
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
2, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
"lhs_value"));
comp_builder_iota.AddInstruction(HloInstruction::CreateParameter(
3, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
"rhs_value"));
comp_builder_iota.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
rhs_key, ComparisonDirection::kLt));
GatherDimensionNumbers gather_dim_numbers;
// We use gather to rearrange the input dim dimension. However the current
// semantic of gather doesn't allow us to collapse dimension in this case so
// we keep it, which make the gather from shape [..., input_dim, ...] to
// [..., 1, input_dim, ...]
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
// Offset dim is every dimension including newly added size 1 dim, except
// for input_dim, which acts as a batch_dim.
if (i != input_dim) {
gather_dim_numbers.add_offset_dims(i);
}
}
// The dimension to rewrite is the index dim.
gather_dim_numbers.add_start_index_map(input_dim);
gather_dim_numbers.set_index_vector_dim(1);
gather_dim_numbers.add_collapsed_slice_dims(input_dim);
HloComputation* compare_iota_value =
comp->parent()->AddEmbeddedComputation(comp_builder_iota.Build());
// Step 4. Gather.
// Temporarily removes dynamic dimension before entering sort -- we want the
// sort to ignore dynamic dimension.
// Temporarily removes dynamic dimension before entering gather -- we want the
// gather to ignore dynamic dimension.
HloInstruction* operand_static_dim_size =
comp->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(operand_shape.dimensions(input_dim))));
HloInstruction* operand_static =
comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
input_dim));
HloInstruction* sorted_iota_value =
comp->AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeTupleShape({mask_input_shape, operand_shape}),
input_dim, {sorted_iota_mask, operand_static}, compare_iota_value,
/*is_stable=*/true));
// Step 6: Feed sorted input to original reshape.
HloInstruction* sorted_operand =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
operand_shape, sorted_iota_value, 1));
std::vector<int64> slice_sizes(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
slice_sizes[input_dim] = 1;
HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
ShapeUtil::MakeShape(operand_shape.element_type(),
operand_shape.dimensions()),
operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, sorted_operand));
// Step 6: Feed gather input to original reshape.
TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
HloInstruction* reshape_dynamic = reshape;

View File

@ -682,6 +682,98 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMajor) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param = s32[2, 6] parameter(0)
const = s32[] constant(4)
param_padded = s32[2, 6] set-dimension-size(param, const), dimensions={1}
// Third dimension is dynamic.
reshaped = s32[2, 2, 3] reshape(param_padded), inferred_dimension=2
init = s32[] constant(0)
ROOT reduce = s32[2, 2] reduce(reshaped, init),
dimensions={2},
to_apply=update_s32
}
)";
// The third dimension has upper bound of 5, dynamic dimension is 3.
Literal operand =
LiteralUtil::CreateR2<int32>({{0, 1, 2, 3, 4, 5}, {6, 7, 8, 9, 10, 11}});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand});
// After padding and reshape we have
//
// [[[0, 1, P],
// [2, 3, P]],
// [[6, 7, P],
// [8, 9, P]]]
// Reducing on the third dimension gives us
// [0+1, 2+3]
// [6+7, 8+9]
//
Literal expected = LiteralUtil::CreateR2<int32>({{1, 5}, {13, 17}});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, OutputMinorDimensionReshapeWithUnchangedDimMinor) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param = s32[6, 2] parameter(0)
const = s32[] constant(4)
param_padded = s32[6, 2] set-dimension-size(param, const), dimensions={0}
// Second dimension is dynamic.
reshaped = s32[2, 3, 2] reshape(param_padded), inferred_dimension=1
init = s32[] constant(0)
ROOT reduce = s32[2, 2] reduce(reshaped, init),
dimensions={1},
to_apply=update_s32
}
)";
// The third dimension has upper bound of 5, dynamic dimension is 3.
Literal operand = LiteralUtil::CreateR2<int32>(
{{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}, {10, 11}});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand});
// After padding and reshape we have
//
// [[[0, 1],
// [2, 3]
// [P, P]],
// [[4, 5],
// [6, 7],
// [P, P]]]
// Reducing on the second dimension gives us
// [0+2, 1+3]
// [4+6, 5+7]
//
Literal expected = LiteralUtil::CreateR2<int32>({{2, 4}, {10, 12}});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicDimensionReshapeUnchanged) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1