Add liveness_util functions which use dataflow analysis. Also make the analysis argument (TuplePointsToAnalysis or HloDataflowAnalysis) non-optional as all callers were passing in the analysis.

PiperOrigin-RevId: 169200824
This commit is contained in:
Mark Heffernan 2017-09-18 23:34:22 -07:00 committed by TensorFlower Gardener
parent f08ec5722b
commit 23da21150d
13 changed files with 254 additions and 68 deletions

View File

@ -720,6 +720,7 @@ cc_library(
hdrs = ["liveness_util.h"],
deps = [
":hlo",
":hlo_dataflow_analysis",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@ -838,6 +839,7 @@ cc_library(
deps = [
":call_graph",
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
":hlo_value",
":liveness_util",
@ -1391,9 +1393,7 @@ cc_library(
deps = [
":call_graph",
":hlo",
":hlo_ordering",
":hlo_value",
":liveness_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
@ -1412,6 +1412,7 @@ cc_test(
":hlo_dataflow_analysis",
":hlo_graph_dumper",
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@ -1470,6 +1471,7 @@ cc_test(
":hlo_alias_analysis",
":hlo_graph_dumper",
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",

View File

@ -123,7 +123,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
b.instruction(), b.index(),
&points_to_analysis())) {
points_to_analysis())) {
return false;
}
}

View File

@ -204,7 +204,7 @@ Status HeapSimulator::RunComputation(
buffer->instruction()->opcode() != HloOpcode::kCopy &&
CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(),
buffer->instruction(), buffer->index(), &points_to_analysis)) {
buffer->instruction(), buffer->index(), points_to_analysis)) {
ShareBuffer(buffer, operand_buffer, instruction);
shared = true;
break;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@ -93,7 +94,8 @@ class HloAliasAnalysisTest : public HloTestBase {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
if (*value_a != *value_b &&
ordering.MayInterfere(*value_a, *value_b)) {
ordering.MayInterfere(*value_a, *value_b,
analysis_->dataflow_analysis())) {
VLOG(1) << *value_a << " interferes with " << *value_b
<< " in buffer: " << buffer;
return true;

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -73,7 +74,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
analysis_->GetValueDefinedAt(b));
analysis_->GetValueDefinedAt(b), *analysis_);
}
std::unique_ptr<HloModule> module_;

View File

@ -123,8 +123,9 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
}
/* static */
bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
const HloValue& value) const {
bool HloOrdering::UseIsBeforeValueDefinition(
const HloUse& use, const HloValue& value,
const HloDataflowAnalysis& dataflow) const {
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
<< ", value=" << value.ToShortString() << ")";
if (ExecutesBefore(use.instruction, value.defining_instruction())) {
@ -139,7 +140,7 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
CanShareOperandBufferWithUser(
use.instruction->mutable_operand(use.operand_number),
use.operand_index, value.defining_instruction(),
value.defining_index())) {
value.defining_index(), dataflow)) {
VLOG(4) << " use is value def, and instruction can share use buffer";
return true;
}
@ -172,12 +173,13 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
return true;
}
}
VLOG(4) << " use is not before while";
VLOG(4) << " use is not before value";
return false;
}
bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
const HloValue& b) const {
bool HloOrdering::LiveRangeStrictlyBefore(
const HloValue& a, const HloValue& b,
const HloDataflowAnalysis& dataflow) const {
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
<< ", b = " << b.ToShortString() << ")";
if (!IsDefinedBefore(a, b)) {
@ -204,7 +206,7 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (!UseIsBeforeValueDefinition(use, b)) {
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
VLOG(4) << "use of a (" << use << ") not before b is defined";
return false;
}
@ -213,9 +215,11 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
return true;
}
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b) const {
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
const HloDataflowAnalysis& dataflow) const {
// Buffers without disjoint liveness may interfere.
return !LiveRangeStrictlyBefore(a, b) && !LiveRangeStrictlyBefore(b, a);
return !LiveRangeStrictlyBefore(a, b, dataflow) &&
!LiveRangeStrictlyBefore(b, a, dataflow);
}
HloOrderingProto HloOrdering::ToProto() const {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
@ -48,15 +49,17 @@ class HloOrdering {
// Returns whether the given use is before the given value definition under
// the given ordering.
bool UseIsBeforeValueDefinition(const HloUse& use,
const HloValue& value) const;
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
const HloDataflowAnalysis& dataflow) const;
// Returns whether the given values interfere. Two values interfere if they
// may both be simultaneously live.
bool MayInterfere(const HloValue& a, const HloValue& b) const;
bool MayInterfere(const HloValue& a, const HloValue& b,
const HloDataflowAnalysis& dataflow) const;
// Returns true if the live range of the given value 'a' is strictly before
// the live range of value 'b' using the given HLO ordering.
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b) const;
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
const HloDataflowAnalysis& dataflow) const;
// Returns the sequential instruction order for the given computation, or
// nullptr if the computation does not have a sequential ordering.

View File

@ -269,29 +269,32 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// while because of the use of the init value in the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
dataflow->GetValueDefinedAt(xla_while),
*dataflow));
// Any value defined in the body or condition is defined before the while, and
// has a live range strictly before the while.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
dataflow->GetValueDefinedAt(xla_while),
*dataflow));
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
dataflow->GetValueDefinedAt(xla_while),
*dataflow));
// The live range of the while should be before the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
@ -301,10 +304,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
EXPECT_EQ(while_use.instruction, add);
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
while_use, dataflow->GetValueDefinedAt(add)));
EXPECT_TRUE(
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
dataflow->GetValueDefinedAt(add)));
while_use, dataflow->GetValueDefinedAt(add), *dataflow));
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
*dataflow));
}
} // namespace

View File

@ -69,6 +69,36 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand,
return false;
}
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user,
const HloDataflowAnalysis& dataflow) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Find fusion parameter associated with 'operand'.
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
// Iterate through all users of all uses of the fusion parameter value.
// Return false if any uses are detected, returns true otherwise.
const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index);
return value.uses().empty();
} else {
// Return false if no value at 'operand' and 'index' is used at 'user'.
for (const HloValue* value :
dataflow.GetValueSet(operand, index).values()) {
for (const HloUse& use : value->uses()) {
if (use.instruction == user) {
return false;
}
}
}
}
return true;
}
namespace {
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
@ -153,7 +183,7 @@ bool HasUniqueFusedUseOfOperandAt(
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
const TuplePointsToAnalysis* points_to_analysis) {
const TuplePointsToAnalysis& points_to_analysis) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
const Shape& operand_subshape =
@ -164,7 +194,7 @@ bool CanShareOperandBufferWithUser(
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
}
if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) {
if (user->opcode() == HloOpcode::kFusion) {
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
@ -174,7 +204,7 @@ bool CanShareOperandBufferWithUser(
// 'operand_index', and this singleton use is the fused root at operand
// index 0.
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0,
*points_to_analysis);
points_to_analysis);
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@ -202,7 +232,85 @@ bool CanShareOperandBufferWithUser(
// index 'other_add_operand_index').
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
other_add_operand_index,
*points_to_analysis);
points_to_analysis);
}
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
// Check if 'user' is element-wise.
return user->IsElementwise();
}
bool CanShareOperandBufferWithUser(HloInstruction* operand,
const ShapeIndex& operand_index,
HloInstruction* user,
const ShapeIndex& user_index,
const HloDataflowAnalysis& dataflow) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
const Shape& operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
}
if (user->opcode() == HloOpcode::kFusion) {
// Get the parameter associated with 'operand';
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
const HloValue& value =
dataflow.GetValueDefinedAt(fusion_param, operand_index);
if (value.uses().size() != 1) {
return false;
}
const HloUse& use = value.uses()[0];
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
// Loop fusion with kDynamicUpdateSlice fused root.
//
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root at operand
// index 0.
return use.instruction == user->fused_expression_root() &&
use.operand_number == 0;
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
// Check if one operand of kAdd fused root is either kDot, or nested
// kFusion of kind kTransposeDot.
auto* add = user->fused_expression_root();
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kDot ||
(operand->opcode() == HloOpcode::kFusion &&
operand->fusion_kind() ==
HloInstruction::FusionKind::kTransposeDot);
});
if (add_operand_it == add->operands().end()) {
return false;
}
auto* matched_add_operand = *add_operand_it;
// Calculate operand index of 'add' operand which was not matched above.
const int64 other_add_operand_index =
matched_add_operand == add->operand(0) ? 1 : 0;
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root (at operand
// index 'other_add_operand_index').
return use.instruction == user->fused_expression_root() &&
use.operand_number == other_add_operand_index;
}
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -29,21 +30,34 @@ namespace xla {
// 'operand'. Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
//
// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
// moved over to the dataflow overload.
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user,
const TuplePointsToAnalysis& points_to_analysis);
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user,
const HloDataflowAnalysis& dataflow);
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a
// points-to analysis argument. Without the analysis, the result is more
// conservative (returns false more often).
// 'operand' (at 'operand_index'). Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
//
// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
// moved over to the dataflow overload.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
const TuplePointsToAnalysis* points_to_analysis = nullptr);
const TuplePointsToAnalysis& points_to_analysis);
bool CanShareOperandBufferWithUser(HloInstruction* operand,
const ShapeIndex& operand_index,
HloInstruction* user,
const ShapeIndex& user_index,
const HloDataflowAnalysis& dataflow);
} // namespace xla

View File

@ -35,6 +35,8 @@ class PointsToAnalysisTestBase : public HloTestBase {
CHECK_NOTNULL(module_.get());
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
dataflow_analysis_ =
HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
@ -45,6 +47,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
std::unique_ptr<HloModule> module_;
HloComputation* computation_ = nullptr;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
};
class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
@ -70,6 +73,11 @@ TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_));
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_));
}
TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
@ -105,6 +113,10 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
EXPECT_FALSE(
DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_));
EXPECT_FALSE(
DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_));
}
class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
@ -122,10 +134,15 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
points_to_analysis_.get()));
EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {},
points_to_analysis_.get()));
EXPECT_TRUE(
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
@ -143,9 +160,14 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
*dataflow_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
*dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
@ -161,10 +183,15 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
points_to_analysis_.get()));
EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {},
points_to_analysis_.get()));
EXPECT_TRUE(
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
@ -197,9 +224,14 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// The fusion instruction can share with tuple element 1.
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
*dataflow_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
*dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
@ -221,12 +253,19 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
// The DynamicUpdateSlice instruction can share with the data operand, but not
// with update or starts.
EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {},
points_to_analysis_.get()));
EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {},
points_to_analysis_.get()));
EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {},
points_to_analysis_.get()));
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_));
EXPECT_FALSE(
CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
@ -256,7 +295,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
// Output fused dot add should be able to share buffer with 'add_operand'.
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
*dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
@ -292,7 +334,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
// Output fused transpose-dot-add should be share buffer with 'add_operand'.
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
*dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
@ -320,7 +365,10 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
// Output fused operand->reverse->add cannot alias operand buffer 'operand'.
EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
points_to_analysis_.get()));
*points_to_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
*dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
@ -360,8 +408,11 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
RunAnalysis();
// The While instruction can share with the data operand.
EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {},
points_to_analysis_.get()));
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
}
} // namespace