Add liveness_util functions which use dataflow analysis. Also make the analysis argument (TuplePointsToAnalysis or HloDataflowAnalysis) non-optional as all callers were passing in the analysis.
PiperOrigin-RevId: 169200824
This commit is contained in:
parent
f08ec5722b
commit
23da21150d
@ -720,6 +720,7 @@ cc_library(
|
||||
hdrs = ["liveness_util.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -838,6 +839,7 @@ cc_library(
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_proto",
|
||||
":hlo_value",
|
||||
":liveness_util",
|
||||
@ -1391,9 +1393,7 @@ cc_library(
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_ordering",
|
||||
":hlo_value",
|
||||
":liveness_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1412,6 +1412,7 @@ cc_test(
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":hlo_ordering",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1470,6 +1471,7 @@ cc_test(
|
||||
":hlo_alias_analysis",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":hlo_ordering",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -123,7 +123,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
if (b.instruction()->IsUserOf(alias.instruction()) &&
|
||||
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
|
||||
b.instruction(), b.index(),
|
||||
&points_to_analysis())) {
|
||||
points_to_analysis())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ Status HeapSimulator::RunComputation(
|
||||
buffer->instruction()->opcode() != HloOpcode::kCopy &&
|
||||
CanShareOperandBufferWithUser(
|
||||
operand_buffer->instruction(), operand_buffer->index(),
|
||||
buffer->instruction(), buffer->index(), &points_to_analysis)) {
|
||||
buffer->instruction(), buffer->index(), points_to_analysis)) {
|
||||
ShareBuffer(buffer, operand_buffer, instruction);
|
||||
shared = true;
|
||||
break;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -93,7 +94,8 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (*value_a != *value_b &&
|
||||
ordering.MayInterfere(*value_a, *value_b)) {
|
||||
ordering.MayInterfere(*value_a, *value_b,
|
||||
analysis_->dataflow_analysis())) {
|
||||
VLOG(1) << *value_a << " interferes with " << *value_b
|
||||
<< " in buffer: " << buffer;
|
||||
return true;
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -73,7 +74,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
|
||||
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
|
||||
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
|
||||
analysis_->GetValueDefinedAt(b));
|
||||
analysis_->GetValueDefinedAt(b), *analysis_);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
|
@ -123,8 +123,9 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
|
||||
}
|
||||
|
||||
/* static */
|
||||
bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
|
||||
const HloValue& value) const {
|
||||
bool HloOrdering::UseIsBeforeValueDefinition(
|
||||
const HloUse& use, const HloValue& value,
|
||||
const HloDataflowAnalysis& dataflow) const {
|
||||
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
|
||||
<< ", value=" << value.ToShortString() << ")";
|
||||
if (ExecutesBefore(use.instruction, value.defining_instruction())) {
|
||||
@ -139,7 +140,7 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
|
||||
CanShareOperandBufferWithUser(
|
||||
use.instruction->mutable_operand(use.operand_number),
|
||||
use.operand_index, value.defining_instruction(),
|
||||
value.defining_index())) {
|
||||
value.defining_index(), dataflow)) {
|
||||
VLOG(4) << " use is value def, and instruction can share use buffer";
|
||||
return true;
|
||||
}
|
||||
@ -172,12 +173,13 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
|
||||
return true;
|
||||
}
|
||||
}
|
||||
VLOG(4) << " use is not before while";
|
||||
VLOG(4) << " use is not before value";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
|
||||
const HloValue& b) const {
|
||||
bool HloOrdering::LiveRangeStrictlyBefore(
|
||||
const HloValue& a, const HloValue& b,
|
||||
const HloDataflowAnalysis& dataflow) const {
|
||||
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
|
||||
<< ", b = " << b.ToShortString() << ")";
|
||||
if (!IsDefinedBefore(a, b)) {
|
||||
@ -204,7 +206,7 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
|
||||
|
||||
// All uses of 'a' must be before 'b' is defined.
|
||||
for (const HloUse& use : a.uses()) {
|
||||
if (!UseIsBeforeValueDefinition(use, b)) {
|
||||
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
|
||||
VLOG(4) << "use of a (" << use << ") not before b is defined";
|
||||
return false;
|
||||
}
|
||||
@ -213,9 +215,11 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b) const {
|
||||
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloDataflowAnalysis& dataflow) const {
|
||||
// Buffers without disjoint liveness may interfere.
|
||||
return !LiveRangeStrictlyBefore(a, b) && !LiveRangeStrictlyBefore(b, a);
|
||||
return !LiveRangeStrictlyBefore(a, b, dataflow) &&
|
||||
!LiveRangeStrictlyBefore(b, a, dataflow);
|
||||
}
|
||||
|
||||
HloOrderingProto HloOrdering::ToProto() const {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||
@ -48,15 +49,17 @@ class HloOrdering {
|
||||
|
||||
// Returns whether the given use is before the given value definition under
|
||||
// the given ordering.
|
||||
bool UseIsBeforeValueDefinition(const HloUse& use,
|
||||
const HloValue& value) const;
|
||||
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
|
||||
const HloDataflowAnalysis& dataflow) const;
|
||||
// Returns whether the given values interfere. Two values interfere if they
|
||||
// may both be simultaneously live.
|
||||
bool MayInterfere(const HloValue& a, const HloValue& b) const;
|
||||
bool MayInterfere(const HloValue& a, const HloValue& b,
|
||||
const HloDataflowAnalysis& dataflow) const;
|
||||
|
||||
// Returns true if the live range of the given value 'a' is strictly before
|
||||
// the live range of value 'b' using the given HLO ordering.
|
||||
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b) const;
|
||||
bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
|
||||
const HloDataflowAnalysis& dataflow) const;
|
||||
|
||||
// Returns the sequential instruction order for the given computation, or
|
||||
// nullptr if the computation does not have a sequential ordering.
|
||||
|
@ -269,29 +269,32 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
// while because of the use of the init value in the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
|
||||
dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while), *dataflow));
|
||||
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
dataflow->GetValueDefinedAt(xla_while),
|
||||
*dataflow));
|
||||
|
||||
// Any value defined in the body or condition is defined before the while, and
|
||||
// has a live range strictly before the while.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
|
||||
dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while), *dataflow));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
dataflow->GetValueDefinedAt(xla_while),
|
||||
*dataflow));
|
||||
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
|
||||
dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while), *dataflow));
|
||||
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
|
||||
dataflow->GetValueDefinedAt(xla_while)));
|
||||
dataflow->GetValueDefinedAt(xla_while),
|
||||
*dataflow));
|
||||
|
||||
// The live range of the while should be before the add.
|
||||
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
@ -301,10 +304,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
|
||||
EXPECT_EQ(while_use.instruction, add);
|
||||
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
|
||||
while_use, dataflow->GetValueDefinedAt(add)));
|
||||
EXPECT_TRUE(
|
||||
ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
while_use, dataflow->GetValueDefinedAt(add), *dataflow));
|
||||
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
|
||||
dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
|
||||
*dataflow));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -69,6 +69,36 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user,
|
||||
const HloDataflowAnalysis& dataflow) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
// Iterate through all users of all uses of the fusion parameter value.
|
||||
// Return false if any uses are detected, returns true otherwise.
|
||||
const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index);
|
||||
return value.uses().empty();
|
||||
} else {
|
||||
// Return false if no value at 'operand' and 'index' is used at 'user'.
|
||||
for (const HloValue* value :
|
||||
dataflow.GetValueSet(operand, index).values()) {
|
||||
for (const HloUse& use : value->uses()) {
|
||||
if (use.instruction == user) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
|
||||
@ -153,7 +183,7 @@ bool HasUniqueFusedUseOfOperandAt(
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
const TuplePointsToAnalysis* points_to_analysis) {
|
||||
const TuplePointsToAnalysis& points_to_analysis) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
const Shape& operand_subshape =
|
||||
@ -164,7 +194,7 @@ bool CanShareOperandBufferWithUser(
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
return false;
|
||||
}
|
||||
if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) {
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
|
||||
user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
@ -174,7 +204,7 @@ bool CanShareOperandBufferWithUser(
|
||||
// 'operand_index', and this singleton use is the fused root at operand
|
||||
// index 0.
|
||||
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0,
|
||||
*points_to_analysis);
|
||||
points_to_analysis);
|
||||
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
|
||||
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
|
||||
// Output fusion with kAdd fused root.
|
||||
@ -202,7 +232,85 @@ bool CanShareOperandBufferWithUser(
|
||||
// index 'other_add_operand_index').
|
||||
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
|
||||
other_add_operand_index,
|
||||
*points_to_analysis);
|
||||
points_to_analysis);
|
||||
}
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
user->opcode() == HloOpcode::kWhile) {
|
||||
// We eliminated other users in BufferLiveness::live_range_strictly_before,
|
||||
// so here we just need to check that the use is at operand index 0.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
}
|
||||
|
||||
bool CanShareOperandBufferWithUser(HloInstruction* operand,
|
||||
const ShapeIndex& operand_index,
|
||||
HloInstruction* user,
|
||||
const ShapeIndex& user_index,
|
||||
const HloDataflowAnalysis& dataflow) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
const Shape& operand_subshape =
|
||||
ShapeUtil::GetSubshape(operand->shape(), operand_index);
|
||||
const Shape& user_subshape =
|
||||
ShapeUtil::GetSubshape(user->shape(), user_index);
|
||||
// Check that operand and user emit the same shape and layout.
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
// Get the parameter associated with 'operand';
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
|
||||
const HloValue& value =
|
||||
dataflow.GetValueDefinedAt(fusion_param, operand_index);
|
||||
if (value.uses().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
const HloUse& use = value.uses()[0];
|
||||
|
||||
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
|
||||
user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
// Loop fusion with kDynamicUpdateSlice fused root.
|
||||
//
|
||||
// Returns true iff there is exactly one use of 'operand' at shape index
|
||||
// 'operand_index', and this singleton use is the fused root at operand
|
||||
// index 0.
|
||||
return use.instruction == user->fused_expression_root() &&
|
||||
use.operand_number == 0;
|
||||
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
|
||||
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
|
||||
// Output fusion with kAdd fused root.
|
||||
|
||||
// Check if one operand of kAdd fused root is either kDot, or nested
|
||||
// kFusion of kind kTransposeDot.
|
||||
auto* add = user->fused_expression_root();
|
||||
auto add_operand_it =
|
||||
std::find_if(add->operands().begin(), add->operands().end(),
|
||||
[&](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kDot ||
|
||||
(operand->opcode() == HloOpcode::kFusion &&
|
||||
operand->fusion_kind() ==
|
||||
HloInstruction::FusionKind::kTransposeDot);
|
||||
});
|
||||
if (add_operand_it == add->operands().end()) {
|
||||
return false;
|
||||
}
|
||||
auto* matched_add_operand = *add_operand_it;
|
||||
// Calculate operand index of 'add' operand which was not matched above.
|
||||
const int64 other_add_operand_index =
|
||||
matched_add_operand == add->operand(0) ? 1 : 0;
|
||||
// Returns true iff there is exactly one use of 'operand' at shape index
|
||||
// 'operand_index', and this singleton use is the fused root (at operand
|
||||
// index 'other_add_operand_index').
|
||||
return use.instruction == user->fused_expression_root() &&
|
||||
use.operand_number == other_add_operand_index;
|
||||
}
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -29,21 +30,34 @@ namespace xla {
|
||||
// 'operand'. Returns false otherwise.
|
||||
//
|
||||
// REQUIRES: 'operand' is an operand of 'user'.
|
||||
//
|
||||
// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
|
||||
// moved over to the dataflow overload.
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user,
|
||||
const TuplePointsToAnalysis& points_to_analysis);
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user,
|
||||
const HloDataflowAnalysis& dataflow);
|
||||
|
||||
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
|
||||
// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a
|
||||
// points-to analysis argument. Without the analysis, the result is more
|
||||
// conservative (returns false more often).
|
||||
// 'operand' (at 'operand_index'). Returns false otherwise.
|
||||
//
|
||||
// REQUIRES: 'operand' is an operand of 'user'.
|
||||
//
|
||||
// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
|
||||
// moved over to the dataflow overload.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
const TuplePointsToAnalysis* points_to_analysis = nullptr);
|
||||
const TuplePointsToAnalysis& points_to_analysis);
|
||||
bool CanShareOperandBufferWithUser(HloInstruction* operand,
|
||||
const ShapeIndex& operand_index,
|
||||
HloInstruction* user,
|
||||
const ShapeIndex& user_index,
|
||||
const HloDataflowAnalysis& dataflow);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -35,6 +35,8 @@ class PointsToAnalysisTestBase : public HloTestBase {
|
||||
CHECK_NOTNULL(module_.get());
|
||||
points_to_analysis_ =
|
||||
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
dataflow_analysis_ =
|
||||
HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
|
||||
@ -45,6 +47,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
|
||||
std::unique_ptr<HloModule> module_;
|
||||
HloComputation* computation_ = nullptr;
|
||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
};
|
||||
|
||||
class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
|
||||
@ -70,6 +73,11 @@ TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_));
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_));
|
||||
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
|
||||
@ -105,6 +113,10 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
|
||||
DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
|
||||
EXPECT_FALSE(
|
||||
DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_));
|
||||
EXPECT_FALSE(
|
||||
DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
|
||||
@ -122,10 +134,15 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
||||
@ -143,9 +160,14 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
|
||||
*dataflow_analysis_));
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
|
||||
@ -161,10 +183,15 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
|
||||
|
||||
BuildModuleAndRunAnalysis(builder.Build());
|
||||
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
|
||||
@ -197,9 +224,14 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
|
||||
|
||||
// The fusion instruction can share with tuple element 1.
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
|
||||
*dataflow_analysis_));
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
|
||||
@ -221,12 +253,19 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
|
||||
|
||||
// The DynamicUpdateSlice instruction can share with the data operand, but not
|
||||
// with update or starts.
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
|
||||
EXPECT_FALSE(
|
||||
CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
|
||||
EXPECT_FALSE(
|
||||
CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_));
|
||||
EXPECT_FALSE(
|
||||
CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_));
|
||||
EXPECT_FALSE(
|
||||
CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
|
||||
@ -256,7 +295,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
|
||||
|
||||
// Output fused dot add should be able to share buffer with 'add_operand'.
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
|
||||
@ -292,7 +334,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
|
||||
|
||||
// Output fused transpose-dot-add should be share buffer with 'add_operand'.
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
|
||||
@ -320,7 +365,10 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
|
||||
|
||||
// Output fused operand->reverse->add cannot alias operand buffer 'operand'.
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
|
||||
points_to_analysis_.get()));
|
||||
*points_to_analysis_));
|
||||
|
||||
EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
|
||||
*dataflow_analysis_));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
@ -360,8 +408,11 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
||||
RunAnalysis();
|
||||
|
||||
// The While instruction can share with the data operand.
|
||||
EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {},
|
||||
points_to_analysis_.get()));
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
|
||||
|
||||
EXPECT_TRUE(
|
||||
CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user