[XLA] Add implementation support for variadic reduce window, including HLO, cost analysis, etc.
PiperOrigin-RevId: 347742450 Change-Id: I06b02b9407013322f3e72865fec487385a47abec
This commit is contained in:
parent
152cd6c3b7
commit
2e61843949
@ -4696,6 +4696,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
HloInstruction* reduce_window) {
|
||||
// TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
|
||||
if (reduce_window->shape().IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
reduce_window,
|
||||
|
@ -93,6 +93,9 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
|
||||
return inst->mutable_operand(init_value_index);
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
if (inst->shape().IsTuple()) {
|
||||
return Unimplemented("Variadic reduce window not yet supported. ");
|
||||
}
|
||||
// Because of the way we do reduce, we already require the `init`
|
||||
// operand of hlo reduce instruction to be identity value. Here we reuse
|
||||
// the operand.
|
||||
@ -1015,6 +1018,10 @@ StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
|
||||
StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
|
||||
HloInstruction* hlo,
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
if (hlo->shape().IsTuple()) {
|
||||
// TODO (b/73062247) variadic reduce window is not yet supported here.
|
||||
return Unimplemented("Variadic reduce window net yet supported.");
|
||||
}
|
||||
HloInstruction* input = hlo->mutable_operand(0);
|
||||
HloInstruction* init = hlo->mutable_operand(1);
|
||||
HloComputation* comp = hlo->parent();
|
||||
|
@ -396,8 +396,11 @@ Status HloCostAnalysis::HandleReduceWindow(
|
||||
for (const auto& dimension : window.dimensions()) {
|
||||
window_element_count *= dimension.size();
|
||||
}
|
||||
|
||||
const int64 output_element_count =
|
||||
ShapeUtil::ElementsIn(reduce_window->shape());
|
||||
ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
|
||||
? reduce_window->shape()
|
||||
: reduce_window->shape().tuple_shapes(0));
|
||||
const int64 reduction_count =
|
||||
(window_element_count - 1) * output_element_count;
|
||||
for (const auto& property : sub_properties) {
|
||||
|
@ -449,6 +449,41 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
|
||||
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
|
||||
}
|
||||
|
||||
TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) {
|
||||
XlaBuilder builder("reduce_window_variadic");
|
||||
auto elem_shape = ShapeUtil::MakeShape(F32, {});
|
||||
auto p2 = Parameter(&builder, 0, elem_shape, "x0");
|
||||
auto p3 = Parameter(&builder, 1, elem_shape, "x1");
|
||||
auto p4 = Parameter(&builder, 2, elem_shape, "y0");
|
||||
auto p5 = Parameter(&builder, 3, elem_shape, "y1");
|
||||
absl::InlinedVector<XlaOp, 2> compute_vec = {Min(p2, p4), Min(p3, p5)};
|
||||
Tuple(&builder, compute_vec);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build());
|
||||
auto input1 =
|
||||
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1");
|
||||
auto input2 =
|
||||
Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2");
|
||||
auto init = ConstantR0<float>(&builder, 0);
|
||||
ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5},
|
||||
Padding::kValid);
|
||||
|
||||
// Run HLO cost analysis.
|
||||
auto hlo_module = BuildHloGraph(&builder);
|
||||
HloCostAnalysis analysis(ShapeSize);
|
||||
ASSERT_IS_OK(
|
||||
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
|
||||
|
||||
// Each of [2x4] output elements are generated from reducing [4x5] elements.
|
||||
EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1));
|
||||
|
||||
EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3));
|
||||
|
||||
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20);
|
||||
EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
|
||||
EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4);
|
||||
}
|
||||
|
||||
TEST_F(HloCostAnalysisTest, SelectAndScatter) {
|
||||
XlaBuilder builder("select_and_scatter");
|
||||
auto operand =
|
||||
|
@ -2285,9 +2285,13 @@ std::unique_ptr<HloInstruction>
|
||||
HloReduceWindowInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 2);
|
||||
CHECK_EQ(new_operands.size() % 2, 0);
|
||||
int64 num_operands = new_operands.size() / 2;
|
||||
return absl::make_unique<HloReduceWindowInstruction>(
|
||||
shape, new_operands[0], new_operands[1], window(), to_apply());
|
||||
shape, absl::MakeSpan(new_operands).subspan(0, num_operands),
|
||||
absl::MakeSpan(new_operands)
|
||||
.subspan(num_operands, new_operands.size() / 2),
|
||||
window(), to_apply());
|
||||
}
|
||||
|
||||
HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
|
||||
|
@ -1366,16 +1366,25 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&reduce_computation};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
if (!window) {
|
||||
window.emplace();
|
||||
}
|
||||
if (operands.size() % 2) {
|
||||
auto loc = lexer_.GetLoc();
|
||||
return Error(loc, StrCat("expects an even number of operands, but has ",
|
||||
operands.size(), " operands"));
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
|
||||
*reduce_computation));
|
||||
shape, /*operands=*/
|
||||
absl::Span<HloInstruction* const>(operands).subspan(
|
||||
0, operands.size() / 2),
|
||||
/*init_values=*/
|
||||
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
|
||||
2),
|
||||
*window, *reduce_computation));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
@ -3585,7 +3594,7 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
|
||||
}
|
||||
|
||||
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
|
||||
// The string looks like "dim_labels=0bf_0io->0bf".
|
||||
// Thestring looks like "dim_labels=0bf_0io->0bf".
|
||||
bool HloParserImpl::ParseConvolutionDimensionNumbers(
|
||||
ConvolutionDimensionNumbers* dnums) {
|
||||
if (lexer_.GetKind() != TokKind::kDimLabels) {
|
||||
|
@ -437,6 +437,29 @@ ENTRY %R4UnitWindowScalar () -> f32[] {
|
||||
ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// reduce window on scalar
|
||||
{
|
||||
"ReduceWindowVariadic",
|
||||
R"(HloModule reduce_window_variadic
|
||||
|
||||
%add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
|
||||
%lhs1 = f32[] parameter(0)
|
||||
%rhs1 = f32[] parameter(2)
|
||||
%add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
|
||||
%lhs2 = f32[] parameter(1)
|
||||
%rhs2 = f32[] parameter(3)
|
||||
%add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
|
||||
ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
|
||||
}
|
||||
|
||||
ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
|
||||
%constant = f32[] constant(42)
|
||||
%constant.1 = f32[] constant(1)
|
||||
ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// convolution
|
||||
|
@ -1086,6 +1086,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kTupleSelect:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
|
@ -666,12 +666,13 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
return false;
|
||||
}
|
||||
// Propagate manual sharding. Avoid tuple shaped HLOs that group independent
|
||||
// together. Reduce and Sort can be tuples but the elements are correlated, so
|
||||
// we propagate manual sharding through them.
|
||||
// together. Reduce, ReduceWindow, and Sort can be tuples but the elements
|
||||
// are correlated, so we propagate manual sharding through them.
|
||||
if (!instruction->has_sharding() &&
|
||||
(instruction->shape().IsArray() ||
|
||||
instruction->opcode() == HloOpcode::kReduce ||
|
||||
instruction->opcode() == HloOpcode::kSort) &&
|
||||
instruction->opcode() == HloOpcode::kSort ||
|
||||
instruction->opcode() == HloOpcode::kReduceWindow) &&
|
||||
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
|
||||
return op->has_sharding() && op->sharding().IsManual();
|
||||
})) {
|
||||
@ -868,6 +869,10 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
may_combine_partial_sharding);
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
if (instruction->shape().IsTuple()) {
|
||||
// TODO (b/73062247) variadic reduce window is not yet supported here.
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* lhs = instruction->operand(0);
|
||||
if (!IsSpatiallyPartitioned(lhs)) {
|
||||
return false;
|
||||
@ -1292,6 +1297,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
return user.sharding();
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
if (user.shape().IsTuple()) {
|
||||
return user.sharding().GetSubSharding(
|
||||
user.shape(), {user.operand_index(&instruction)});
|
||||
}
|
||||
if (&instruction != user.operand(0)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
@ -922,6 +922,11 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
|
||||
absl::c_linear_search(reduce_dims, space_dim);
|
||||
}
|
||||
|
||||
if (consumer->opcode() == HloOpcode::kReduceWindow &&
|
||||
consumer->shape().IsTuple()) {
|
||||
// TODO (b/73062247) variadic reduce window is not yet supported.
|
||||
return false;
|
||||
}
|
||||
if (consumer->opcode() == HloOpcode::kReduceWindow ||
|
||||
consumer->opcode() == HloOpcode::kSelectAndScatter) {
|
||||
auto first_operand = consumer->mutable_operand(0);
|
||||
|
@ -3415,6 +3415,10 @@ Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
|
||||
// TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
|
||||
if (hlo->shape().IsTuple()) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
auto& operand = GetPartitionedHlo(hlo->operand(0));
|
||||
if (hlo->sharding().IsTileMaximal()) {
|
||||
return DefaultAction(hlo);
|
||||
|
@ -1704,5 +1704,114 @@ ENTRY R4OnlyDilation {
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloTestBase,
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport))) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a0 = f32[] parameter(0)
|
||||
a1 = f32[] parameter(1)
|
||||
b0 = f32[] parameter(2)
|
||||
b1 = f32[] parameter(3)
|
||||
add0 = f32[] add(a0, b0)
|
||||
add1 = f32[] add(a1, b1)
|
||||
ROOT sum2 = (f32[], f32[]) tuple(add0, add1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.1 = f32[] constant(0)
|
||||
constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.3 = f32[] constant(0)
|
||||
reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0})
|
||||
reduce-window(constant, constant.2, constant.1, constant.3),
|
||||
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
|
||||
ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window)
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloTestBase,
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport2))) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a0 = f32[] parameter(0)
|
||||
a1 = s32[] parameter(1)
|
||||
b0 = f32[] parameter(2)
|
||||
b1 = s32[] parameter(3)
|
||||
add0 = f32[] add(a0, b0)
|
||||
add1 = s32[] add(a1, b1)
|
||||
ROOT sum2 = (f32[], s32[]) tuple(add0, add1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.1 = f32[] constant(0)
|
||||
constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0})
|
||||
reduce-window(constant, constant.2, constant.1, constant.3),
|
||||
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloTestBase,
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport3))) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a0 = f32[] parameter(0)
|
||||
a1 = bf16[] parameter(1)
|
||||
b0 = f32[] parameter(2)
|
||||
b1 = bf16[] parameter(3)
|
||||
add0 = f32[] add(a0, b0)
|
||||
add1 = bf16[] add(a1, b1)
|
||||
ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.1 = f32[] constant(0)
|
||||
constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.3 = bf16[] constant(0)
|
||||
ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
|
||||
reduce-window(constant, constant.2, constant.1, constant.3),
|
||||
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(HloTestBase,
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport4))) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a0 = f32[] parameter(0)
|
||||
a1 = bf16[] parameter(1)
|
||||
b0 = f32[] parameter(2)
|
||||
b1 = bf16[] parameter(3)
|
||||
add0 = f32[] add(a0, b0)
|
||||
add1 = bf16[] multiply(a1, b1)
|
||||
ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.1 = f32[] constant(0)
|
||||
constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
|
||||
constant.3 = bf16[] constant(1)
|
||||
ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
|
||||
reduce-window(constant, constant.2, constant.1, constant.3),
|
||||
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user