Merge pull request #46614 from trentlo:h-fusion-sharing-opnd-with-user-upstreaming
PiperOrigin-RevId: 357261369 Change-Id: Ifeb00b2f565f0946c1a347a5e5d5e8fa93e35aa9
This commit is contained in:
commit
ab9516dd9a
@ -2877,6 +2877,56 @@ ENTRY main {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
add0 = f32[10, 20] add(p0, p1)
|
||||
sub0 = f32[10, 10] subtract(p2, p3)
|
||||
reshape0 = f32[200] reshape(add0)
|
||||
reshape1 = f32[100] reshape(sub0)
|
||||
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
|
||||
slice0 = f32[200] slice(concat0), slice={[0:200]}
|
||||
slice1 = f32[100] slice(concat0), slice={[200:300]}
|
||||
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
|
||||
gte0 = f32[200] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[100] get-tuple-element(fusion), index=1
|
||||
bitcast0 = f32[10,20] bitcast(gte0)
|
||||
bitcast1 = f32[10,10] bitcast(gte1)
|
||||
ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/3,
|
||||
/*param_index=*/{}));
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
// There should be no copies inserted.
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TestModule
|
||||
|
@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We can emit DUS in-place, horizontally fusing it makes the emitter no
|
||||
// longer recognize that it can be done in-place. This creates much slower
|
||||
// code. This restriction could be lifted if buffer assignment would recognize
|
||||
// that the DUS can be done in-place even inside of a horizontal fusion.
|
||||
if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns whether any operand of `instr` is a parameter instruction that
|
||||
// is shared with `fusion_instrs`.
|
||||
bool AnyOpndIsParamSharedAmongFusions(
|
||||
const HloInstruction* instr,
|
||||
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
|
||||
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
|
||||
return opnd->opcode() == HloOpcode::kParameter &&
|
||||
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
|
||||
return user != instr && fusion_instrs.contains(user);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
|
||||
HloInstruction* consumer) {
|
||||
// First, find out all fusion instructions. We will filter out
|
||||
@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
|
||||
} else if (!HasOnlyRowMajorLayout(*instr)) {
|
||||
VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
|
||||
continue;
|
||||
} else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) {
|
||||
// Don't fuse fusions whose operands are parameter instructions that are
|
||||
// shared among fusions because we cannot i/o alias the produced
|
||||
// horizontal fusion due to the concat insertion.
|
||||
VLOG(2) << "Reject the fusion instr because it shares parameter with"
|
||||
<< " other fusion candidates, instr: ",
|
||||
instr->ToString();
|
||||
continue;
|
||||
} else {
|
||||
VLOG(2) << "Find a fusion candidate " << instr->ToString();
|
||||
fusion_instrs_.push_back(instr);
|
||||
|
@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
|
||||
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule NegativeTestForDynamicUpdateSlice
|
||||
|
||||
fusion.1 {
|
||||
p.0 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.1 = s32[1]{0} parameter(1)
|
||||
p.1 = s32[] parameter(1)
|
||||
p.2 = f16[1,9,10]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice =
|
||||
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
fusion.2 {
|
||||
p.0 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.1 = s32[1]{0} parameter(1)
|
||||
p.1 = s32[] parameter(1)
|
||||
p.2 = f16[1,9,10]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice =
|
||||
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p.00 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.01 = f16[5,9,10]{2,1,0} parameter(1)
|
||||
p.10 = s32[1]{0} parameter(2)
|
||||
p.11 = s32[1]{0} parameter(3)
|
||||
p.10 = s32[] parameter(2)
|
||||
p.11 = s32[] parameter(3)
|
||||
p.20 = f16[1,9,10]{2,1,0} parameter(4)
|
||||
p.21 = f16[1,9,10]{2,1,0} parameter(5)
|
||||
|
||||
@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
|
||||
|
||||
VLOG(2) << "Dump after horizontal fusion:";
|
||||
VLOG(2) << module->ToString();
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule BasicTest
|
||||
|
||||
fused_computation.1 {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
// arg.2 is shared by fusion.1 and fusion.2
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
arg.3 = f16[123]{0} parameter(2)
|
||||
fusion.1 = f16[123]{0}
|
||||
fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
|
||||
fusion.2 = f16[123]{0}
|
||||
fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
|
||||
ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
|
||||
tuple(fusion.1, fusion.2)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kSlice &&
|
||||
1 == instr->slice_starts().size() &&
|
||||
1 == instr->slice_limits().size() &&
|
||||
1 == instr->slice_strides().size() &&
|
||||
1 == instr->slice_strides().at(0);
|
||||
}
|
||||
|
||||
bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
|
||||
if (!unnested_hlo.IsInputFusion()) {
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* root = unnested_hlo.fused_expression_root();
|
||||
if (root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
|
||||
return Is1dSliceWithoutStrides(instr);
|
||||
});
|
||||
}
|
||||
|
||||
struct ConcatUsageInfo {
|
||||
// Pointer to a previously seen concat. nullptr if no previously seen concat.
|
||||
const HloInstruction* prev_concat;
|
||||
// The opnd id of the seen concat.
|
||||
int64 concat_opnd_idx;
|
||||
// The slice that recovers the opnd in the concat outputs.
|
||||
const HloInstruction* slice_to_recover_opnd;
|
||||
};
|
||||
|
||||
// Returns an optional concat usage info to denote whether the concat is used in
|
||||
// an elementwise manner. A concat followed by slices is considered effectively
|
||||
// elementwise if the slices combinedly is a reverse function of the concat.
|
||||
absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
|
||||
const HloInstruction& concat, const HloInstruction& operand,
|
||||
const ConcatUsageInfo& info) {
|
||||
// First, check if this concat is in the below pattern. Also, we check
|
||||
// that the slices combinedly are in effect a reverse function of the concat.
|
||||
//
|
||||
// Concat
|
||||
// | |
|
||||
// v v
|
||||
// Slice Slice
|
||||
//
|
||||
std::vector<HloInstruction*> users = concat.users();
|
||||
if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
|
||||
// Limit our supported cases to 1 dimensional slices.
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
// Verify that each operand to the concat is reversed by a slice.
|
||||
if (users.size() != concat.operand_count() ||
|
||||
concat.operand_count() != concat.unique_operands().size()) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
|
||||
return a->slice_starts().at(0) < b->slice_starts().at(0);
|
||||
});
|
||||
int64 prev_limit = 0;
|
||||
for (int64 i = 0; i < users.size(); ++i) {
|
||||
const HloInstruction* u = users[i];
|
||||
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
|
||||
if (u->slice_starts().at(0) != prev_limit ||
|
||||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
prev_limit = u->slice_limits().at(0);
|
||||
}
|
||||
|
||||
// If we have seen other concats, make sure they are identical. Multiple
|
||||
// concats exist because horizontal fusion inserts one concat for each output
|
||||
// of the fusion candidates. Check that all concats and operand ids are the
|
||||
// same to know that the "transitive use closure" will be computed in the same
|
||||
// iteration space.
|
||||
int64 operand_idx = concat.operand_index(&operand);
|
||||
if (info.prev_concat != nullptr) {
|
||||
bool is_concat_identical = info.prev_concat->Identical(
|
||||
concat,
|
||||
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
|
||||
// Operands don't need to be the same.
|
||||
return true;
|
||||
});
|
||||
if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
}
|
||||
|
||||
const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
|
||||
return absl::optional<ConcatUsageInfo>(
|
||||
ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
|
||||
}
|
||||
|
||||
// Returns whether we can prove the transitive uses of `param` are in effect
|
||||
// elementwise. In other words, we prove that the "transitive use closure" will
|
||||
// all be computed in the same iteration space without any reorder of elements.
|
||||
// In addition, we check that the "transitive use closure" includes the output
|
||||
// in the `root_tuple`.
|
||||
// Theoretically, We can prove more patterns but our primary use case is
|
||||
// SliceInputFusion.
|
||||
bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
|
||||
const HloInstruction* root_tuple,
|
||||
const ShapeIndex& out_shape_idx) {
|
||||
CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
|
||||
CHECK_EQ(out_shape_idx.size(), 1);
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
absl::InlinedVector<const HloInstruction*, 4> stack;
|
||||
stack.push_back(param);
|
||||
ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
|
||||
bool is_output_reachable = false;
|
||||
while (!stack.empty()) {
|
||||
const HloInstruction* current = stack.back();
|
||||
stack.pop_back();
|
||||
visited.insert(current);
|
||||
for (const HloInstruction* user : current->users()) {
|
||||
VLOG(3) << "Visiting: " << user->ToString();
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kTuple:
|
||||
if (user == root_tuple &&
|
||||
current == root_tuple->operand(out_shape_idx.back())) {
|
||||
// We need to know if the output is reachable by the `param` to make
|
||||
// sure that they will be computed in the same iteration space.
|
||||
is_output_reachable = true;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kReshape:
|
||||
if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kConcatenate: {
|
||||
absl::optional<ConcatUsageInfo> optional_concat_info =
|
||||
ConcatIsEffectivelyElementwise(*user, *current,
|
||||
concat_usage_info);
|
||||
if (!optional_concat_info) {
|
||||
return false;
|
||||
}
|
||||
concat_usage_info = *optional_concat_info;
|
||||
// Early continue as we only want to traverse through the slice that
|
||||
// recovers the operand. It is guaranteed that the operand to the
|
||||
// concat and the slice have the same iteration space. Insert the
|
||||
// slice instead of the concat.
|
||||
CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
|
||||
stack.push_back(concat_usage_info.slice_to_recover_opnd);
|
||||
continue;
|
||||
}
|
||||
default:
|
||||
for (const int64 use_index : user->OperandIndices(current)) {
|
||||
if (!user->IsElementwiseOnOperand(use_index)) {
|
||||
// Found a user that is non-elementwise on the current
|
||||
// instruction.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!LayoutUtil::Equal(current->shape().layout(),
|
||||
user->shape().layout())) {
|
||||
// Make sure the layout is not changed by the elementwise op.
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
} // end of switch
|
||||
if (!visited.contains(user)) {
|
||||
stack.push_back(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_output_reachable;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
const HloValueSet& value_set = GetValueSet(instruction, index);
|
||||
@ -1266,10 +1435,22 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Shape& operand_subshape =
|
||||
ShapeUtil::GetSubshape(operand->shape(), operand_index);
|
||||
const Shape& user_subshape =
|
||||
ShapeUtil::GetSubshape(user->shape(), user_index);
|
||||
if (IsSliceInputFusion(*user)) {
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
// We don't require the same dimensions but only the same number of elements
|
||||
// and type (to make sure the same buffer size).
|
||||
return ShapeUtil::ElementsIn(operand_subshape) ==
|
||||
ShapeUtil::ElementsIn(user_subshape) &&
|
||||
ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
|
||||
AreTransitiveUsesEffectivelyElementwise(
|
||||
fusion_param, user->fused_expression_root(), user_index);
|
||||
}
|
||||
|
||||
// Check that operand and user emit the same shape and layout.
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
|
@ -2795,5 +2795,150 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
add0 = f32[10, 20] add(p0, p1)
|
||||
sub0 = f32[10, 10] subtract(p2, p3)
|
||||
reshape0 = f32[200] reshape(add0)
|
||||
reshape1 = f32[100] reshape(sub0)
|
||||
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
|
||||
slice0 = f32[200] slice(concat0), slice={[0:200]}
|
||||
slice1 = f32[100] slice(concat0), slice={[200:300]}
|
||||
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
auto* param2 = module_->entry_computation()->parameter_instruction(2);
|
||||
auto* param3 = module_->entry_computation()->parameter_instruction(3);
|
||||
|
||||
RunAnalysis();
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param2, {},
|
||||
fusion, {1}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param3, {},
|
||||
fusion, {1}));
|
||||
// Tensors of different sizes cannot share buffer.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
// p0 has multiple transitive uses fed to concat. So, p0 cannot share
|
||||
// buffer with outputs because the aliased output could be written before
|
||||
// all the uses of p0 are finished.
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
add0 = f32[100] add(p0, p1)
|
||||
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
|
||||
slice0 = f32[100] slice(concat0), slice={[0:100]}
|
||||
slice1 = f32[100] slice(concat0), slice={[100:200]}
|
||||
ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
ROOT fusion = (f32[100], f32[100]) fusion(p0, p1),
|
||||
kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
|
||||
RunAnalysis();
|
||||
// p0 cannot share with either fusion{0} or fusion{1}.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
// p1 cannot share with fusion{0} because we're not sure about their
|
||||
// relationship.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
// p1 can share with fusion{1} because they will be executed in an
|
||||
// elementwise manner.
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {1}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
add0 = f32[100] add(p0, p1)
|
||||
sub0 = f32[100] subtract(p1, p1)
|
||||
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
|
||||
slice0 = f32[100] slice(concat0), slice={[0:100]}
|
||||
slice1 = f32[100] slice(concat0), slice={[100:200]}
|
||||
concat1 = f32[200] concatenate(p0, sub0), dimensions={0}
|
||||
slice2 = f32[100] slice(concat1), slice={[0:100]}
|
||||
slice3 = f32[100] slice(concat1), slice={[100:200]}
|
||||
ROOT tuple = (f32[100], f32[100], f32[100], f32[100])
|
||||
tuple(slice0, slice1, slice2, slice3)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
ROOT fusion = (f32[100], f32[100], f32[100], f32[100])
|
||||
fusion(p0, p1), kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
|
||||
RunAnalysis();
|
||||
// p0 cannot share.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {2}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {3}));
|
||||
// p1 can share with either fusion{1} or fusion{3}.
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {1}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {3}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user