[XLA] Add implementation support for variadic reduce window, including HLO, cost analysis, etc.

PiperOrigin-RevId: 347742450
Change-Id: I06b02b9407013322f3e72865fec487385a47abec
This commit is contained in:
A. Unique TensorFlower 2020-12-15 19:55:55 -08:00 committed by TensorFlower Gardener
parent 152cd6c3b7
commit 2e61843949
12 changed files with 224 additions and 11 deletions

View File

@ -4696,6 +4696,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) {
// TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
if (reduce_window->shape().IsTuple()) {
return Status::OK();
}
if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,

View File

@ -93,6 +93,9 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
return inst->mutable_operand(init_value_index);
}
case HloOpcode::kReduceWindow: {
if (inst->shape().IsTuple()) {
return Unimplemented("Variadic reduce window not yet supported. ");
}
// Because of the way we do reduce, we already require the `init`
// operand of hlo reduce instruction to be identity value. Here we reuse
// the operand.
@ -1015,6 +1018,10 @@ StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
HloInstruction* hlo,
DynamicDimensionInference* dynamic_dimension_inference) {
if (hlo->shape().IsTuple()) {
// TODO (b/73062247) variadic reduce window is not yet supported here.
return Unimplemented("Variadic reduce window net yet supported.");
}
HloInstruction* input = hlo->mutable_operand(0);
HloInstruction* init = hlo->mutable_operand(1);
HloComputation* comp = hlo->parent();

View File

@ -396,8 +396,11 @@ Status HloCostAnalysis::HandleReduceWindow(
for (const auto& dimension : window.dimensions()) {
window_element_count *= dimension.size();
}
const int64 output_element_count =
ShapeUtil::ElementsIn(reduce_window->shape());
ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
? reduce_window->shape()
: reduce_window->shape().tuple_shapes(0));
const int64 reduction_count =
(window_element_count - 1) * output_element_count;
for (const auto& property : sub_properties) {

View File

@ -449,6 +449,41 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
}
TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) {
XlaBuilder builder("reduce_window_variadic");
auto elem_shape = ShapeUtil::MakeShape(F32, {});
auto p2 = Parameter(&builder, 0, elem_shape, "x0");
auto p3 = Parameter(&builder, 1, elem_shape, "x1");
auto p4 = Parameter(&builder, 2, elem_shape, "y0");
auto p5 = Parameter(&builder, 3, elem_shape, "y1");
absl::InlinedVector<XlaOp, 2> compute_vec = {Min(p2, p4), Min(p3, p5)};
Tuple(&builder, compute_vec);
TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build());
auto input1 =
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1");
auto input2 =
Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2");
auto init = ConstantR0<float>(&builder, 0);
ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5},
Padding::kValid);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
HloCostAnalysis analysis(ShapeSize);
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
// Each of [2x4] output elements are generated from reducing [4x5] elements.
EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1));
EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4);
}
TEST_F(HloCostAnalysisTest, SelectAndScatter) {
XlaBuilder builder("select_and_scatter");
auto operand =

View File

@ -2285,9 +2285,13 @@ std::unique_ptr<HloInstruction>
HloReduceWindowInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
CHECK_EQ(new_operands.size() % 2, 0);
int64 num_operands = new_operands.size() / 2;
return absl::make_unique<HloReduceWindowInstruction>(
shape, new_operands[0], new_operands[1], window(), to_apply());
shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
absl::MakeSpan(new_operands)
.subspan(num_operands, new_operands.size() / 2),
window(), to_apply());
}
HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(

View File

@ -1366,16 +1366,25 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
if (operands.size() % 2) {
auto loc = lexer_.GetLoc();
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
*reduce_computation));
shape, /*operands=*/
absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2),
*window, *reduce_computation));
break;
}
case HloOpcode::kConvolution: {
@ -3585,7 +3594,7 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
// The string looks like "dim_labels=0bf_0io->0bf".
// Thestring looks like "dim_labels=0bf_0io->0bf".
bool HloParserImpl::ParseConvolutionDimensionNumbers(
ConvolutionDimensionNumbers* dnums) {
if (lexer_.GetKind() != TokKind::kDimLabels) {

View File

@ -437,6 +437,29 @@ ENTRY %R4UnitWindowScalar () -> f32[] {
ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
}
)"
},
// reduce window on scalar
{
"ReduceWindowVariadic",
R"(HloModule reduce_window_variadic
%add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
%lhs1 = f32[] parameter(0)
%rhs1 = f32[] parameter(2)
%add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
%lhs2 = f32[] parameter(1)
%rhs2 = f32[] parameter(3)
%add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
}
ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
%constant = f32[] constant(42)
%constant.1 = f32[] constant(1)
ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
}
)"
},
// convolution

View File

@ -1086,6 +1086,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReducePrecision:
case HloOpcode::kReduceWindow:
case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:

View File

@ -666,12 +666,13 @@ bool InferShardingFromOperands(HloInstruction* instruction,
return false;
}
// Propagate manual sharding. Avoid tuple shaped HLOs that group independent
// together. Reduce and Sort can be tuples but the elements are correlated, so
// we propagate manual sharding through them.
// together. Reduce, ReduceWindow, and Sort can be tuples but the elements
// are correlated, so we propagate manual sharding through them.
if (!instruction->has_sharding() &&
(instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kSort) &&
instruction->opcode() == HloOpcode::kSort ||
instruction->opcode() == HloOpcode::kReduceWindow) &&
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
return op->has_sharding() && op->sharding().IsManual();
})) {
@ -868,6 +869,10 @@ bool InferShardingFromOperands(HloInstruction* instruction,
may_combine_partial_sharding);
}
case HloOpcode::kReduceWindow: {
if (instruction->shape().IsTuple()) {
// TODO (b/73062247) variadic reduce window is not yet supported here.
return false;
}
const HloInstruction* lhs = instruction->operand(0);
if (!IsSpatiallyPartitioned(lhs)) {
return false;
@ -1292,6 +1297,10 @@ absl::optional<HloSharding> GetShardingFromUser(
return user.sharding();
}
case HloOpcode::kReduceWindow: {
if (user.shape().IsTuple()) {
return user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
}
if (&instruction != user.operand(0)) {
return absl::nullopt;
}

View File

@ -922,6 +922,11 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
absl::c_linear_search(reduce_dims, space_dim);
}
if (consumer->opcode() == HloOpcode::kReduceWindow &&
consumer->shape().IsTuple()) {
// TODO (b/73062247) variadic reduce window is not yet supported.
return false;
}
if (consumer->opcode() == HloOpcode::kReduceWindow ||
consumer->opcode() == HloOpcode::kSelectAndScatter) {
auto first_operand = consumer->mutable_operand(0);

View File

@ -3415,6 +3415,10 @@ Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
}
Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
// TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
if (hlo->shape().IsTuple()) {
return DefaultAction(hlo);
}
auto& operand = GetPartitionedHlo(hlo->operand(0));
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);

View File

@ -1704,5 +1704,114 @@ ENTRY R4OnlyDilation {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
}
XLA_TEST_F(HloTestBase,
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport))) {
const char* const hlo_string = R"(
HloModule module
sum {
a0 = f32[] parameter(0)
a1 = f32[] parameter(1)
b0 = f32[] parameter(2)
b1 = f32[] parameter(3)
add0 = f32[] add(a0, b0)
add1 = f32[] add(a1, b1)
ROOT sum2 = (f32[], f32[]) tuple(add0, add1)
}
ENTRY entry {
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.1 = f32[] constant(0)
constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.3 = f32[] constant(0)
reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0})
reduce-window(constant, constant.2, constant.1, constant.3),
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window)
})";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
}
XLA_TEST_F(HloTestBase,
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport2))) {
const char* const hlo_string = R"(
HloModule module
sum {
a0 = f32[] parameter(0)
a1 = s32[] parameter(1)
b0 = f32[] parameter(2)
b1 = s32[] parameter(3)
add0 = f32[] add(a0, b0)
add1 = s32[] add(a1, b1)
ROOT sum2 = (f32[], s32[]) tuple(add0, add1)
}
ENTRY entry {
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.1 = f32[] constant(0)
constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.3 = s32[] constant(0)
ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0})
reduce-window(constant, constant.2, constant.1, constant.3),
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
})";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
}
XLA_TEST_F(HloTestBase,
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport3))) {
const char* const hlo_string = R"(
HloModule module
sum {
a0 = f32[] parameter(0)
a1 = bf16[] parameter(1)
b0 = f32[] parameter(2)
b1 = bf16[] parameter(3)
add0 = f32[] add(a0, b0)
add1 = bf16[] add(a1, b1)
ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
}
ENTRY entry {
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.1 = f32[] constant(0)
constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.3 = bf16[] constant(0)
ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
reduce-window(constant, constant.2, constant.1, constant.3),
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
})";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
}
XLA_TEST_F(HloTestBase,
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport4))) {
const char* const hlo_string = R"(
HloModule module
sum {
a0 = f32[] parameter(0)
a1 = bf16[] parameter(1)
b0 = f32[] parameter(2)
b1 = bf16[] parameter(3)
add0 = f32[] add(a0, b0)
add1 = bf16[] multiply(a1, b1)
ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
}
ENTRY entry {
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.1 = f32[] constant(0)
constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
constant.3 = bf16[] constant(1)
ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
reduce-window(constant, constant.2, constant.1, constant.3),
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
})";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
}
} // namespace
} // namespace xla