[XLA] Pass the module to HloDataflowAnalysis by const reference.
PiperOrigin-RevId: 186072673
This commit is contained in:
parent
a189502cc3
commit
090bb9168c
@ -1156,7 +1156,7 @@ bool IsWhileBody(const HloComputation* computation,
|
||||
HloModule* module) {
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
|
||||
HloDataflowAnalysis::Run(module));
|
||||
HloDataflowAnalysis::Run(*module));
|
||||
|
||||
bool changed = false;
|
||||
|
||||
|
@ -49,7 +49,7 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
|
||||
HloDataflowAnalysis::Run(module));
|
||||
HloDataflowAnalysis::Run(*module));
|
||||
|
||||
// Make sure all operands of a library call are in memory instead of constants
|
||||
// in IR.
|
||||
|
@ -419,7 +419,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
alias_analysis->dataflow_analysis_,
|
||||
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
|
||||
/*bitcast_defines_value=*/false));
|
||||
|
||||
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
|
||||
|
@ -38,12 +38,12 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value)
|
||||
: module_(module),
|
||||
ssa_form_(ssa_form),
|
||||
bitcast_defines_value_(bitcast_defines_value),
|
||||
call_graph_(CallGraph::Build(module)) {}
|
||||
call_graph_(CallGraph::Build(&module)) {}
|
||||
|
||||
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() {
|
||||
}
|
||||
|
||||
string HloDataflowAnalysis::ToString() const {
|
||||
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
|
||||
string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
|
||||
StrAppend(&out, " Instruction value sets:\n");
|
||||
for (const HloComputation* computation : module_->computations()) {
|
||||
for (const HloComputation* computation : module_.computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
StrAppend(&out, " ", instruction->name(), ":\n");
|
||||
if (ShapeUtil::IsTuple(instruction->shape())) {
|
||||
@ -592,7 +592,7 @@ void HloDataflowAnalysis::Propagate() {
|
||||
}
|
||||
};
|
||||
|
||||
for (HloComputation* computation : module_->computations()) {
|
||||
for (HloComputation* computation : module_.computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
add_to_worklist(instruction);
|
||||
}
|
||||
@ -686,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
for (const HloComputation* computation : module_->computations()) {
|
||||
for (const HloComputation* computation : module_.computations()) {
|
||||
const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
// Create an empty shape tree.
|
||||
@ -787,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
|
||||
VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name();
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
|
||||
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
|
||||
XLA_VLOG_LINES(2, module.ToString());
|
||||
|
||||
auto dataflow_analysis = WrapUnique(
|
||||
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
|
||||
@ -806,7 +806,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
// lookup is faster.
|
||||
std::vector<std::vector<HloPosition>> value_positions(
|
||||
dataflow_analysis->next_value_id_);
|
||||
for (const HloComputation* computation : module->computations()) {
|
||||
for (const HloComputation* computation : module.computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
for (const auto& pair :
|
||||
dataflow_analysis->GetInstructionValueSet(instruction)) {
|
||||
@ -858,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const {
|
||||
|
||||
// For each value in each value set, verify that the value set's position
|
||||
// appears in the value's positions().
|
||||
for (const auto& computation : module_->computations()) {
|
||||
for (const auto& computation : module_.computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
for (const auto& pair : GetInstructionValueSet(instruction)) {
|
||||
const ShapeIndex& index = pair.first;
|
||||
|
@ -60,7 +60,7 @@ class HloDataflowAnalysis {
|
||||
// a new HLO value in the analysis. If false then Bitcast forwards the
|
||||
// value of its operand.
|
||||
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
|
||||
HloModule* module, bool ssa_form = false,
|
||||
const HloModule& module, bool ssa_form = false,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Returns true if 'instruction' defines an HLO value at the given shape index
|
||||
@ -119,7 +119,7 @@ class HloDataflowAnalysis {
|
||||
string ToString() const;
|
||||
|
||||
protected:
|
||||
HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Returns a new HloValue defined at the given instruction and shape index.
|
||||
@ -180,7 +180,7 @@ class HloDataflowAnalysis {
|
||||
// Verify various invariants of the dataflow analysis.
|
||||
Status Verify() const;
|
||||
|
||||
HloModule* const module_;
|
||||
const HloModule& module_;
|
||||
const bool ssa_form_;
|
||||
const bool bitcast_defines_value_;
|
||||
|
||||
|
@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
bool bitcast_defines_value = false) {
|
||||
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis");
|
||||
analysis_ =
|
||||
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
|
||||
HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
|
||||
.ConsumeValueOrDie();
|
||||
return *analysis_;
|
||||
}
|
||||
|
@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
scalar_shape, HloOpcode::kAdd, constant, xla_while));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
|
||||
DependencyHloOrdering ordering(module.get());
|
||||
|
||||
// Init value is defined before the while, but live range is not before the
|
||||
|
@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
|
||||
CHECK_NOTNULL(module_.get());
|
||||
points_to_analysis_ =
|
||||
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
dataflow_analysis_ =
|
||||
HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
|
||||
dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
|
||||
|
@ -287,7 +287,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
|
||||
HloModule* const module) {
|
||||
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module));
|
||||
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
|
||||
const auto params = module->entry_computation()->parameter_instructions();
|
||||
std::minstd_rand0 engine;
|
||||
std::vector<std::unique_ptr<Literal>> arguments(params.size());
|
||||
|
Loading…
Reference in New Issue
Block a user