diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index c812df42355..cc195879a6b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -1156,7 +1156,7 @@ bool IsWhileBody(const HloComputation* computation, HloModule* module) { std::unique_ptr call_graph = CallGraph::Build(module); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); bool changed = false; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 916b556fd43..9db85bc788b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -49,7 +49,7 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants // in IR. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6d2a3aa5b53..30e32a46d7d 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -419,7 +419,7 @@ StatusOr> HloAliasAnalysis::Run( auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN( alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(module, /*ssa_form=*/true, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index ccbbe8f1966..934e43ba487 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -38,12 +38,12 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, +HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(module)) {} + call_graph_(CallGraph::Build(&module)) {} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() { } string HloDataflowAnalysis::ToString() const { - string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); + string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { @@ -592,7 +592,7 @@ void HloDataflowAnalysis::Propagate() { } }; - for (HloComputation* computation : module_->computations()) { + for (HloComputation* computation : module_.computations()) { for (HloInstruction* instruction : computation->instructions()) { add_to_worklist(instruction); } @@ -686,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. @@ -787,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - HloModule* module, bool ssa_form, bool bitcast_defines_value) { - VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name(); - XLA_VLOG_LINES(2, module->ToString()); + const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); auto dataflow_analysis = WrapUnique( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); @@ -806,7 +806,7 @@ StatusOr> HloDataflowAnalysis::Run( // lookup is faster. std::vector> value_positions( dataflow_analysis->next_value_id_); - for (const HloComputation* computation : module->computations()) { + for (const HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : dataflow_analysis->GetInstructionValueSet(instruction)) { @@ -858,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const { // For each value in each value set, verify that the value set's position // appears in the value's positions(). - for (const auto& computation : module_->computations()) { + for (const auto& computation : module_.computations()) { for (const auto& instruction : computation->instructions()) { for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 89d318188f0..7b8a74b096f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -60,7 +60,7 @@ class HloDataflowAnalysis { // a new HLO value in the analysis. If false then Bitcast forwards the // value of its operand. static StatusOr> Run( - HloModule* module, bool ssa_form = false, + const HloModule& module, bool ssa_form = false, bool bitcast_defines_value = false); // Returns true if 'instruction' defines an HLO value at the given shape index @@ -119,7 +119,7 @@ class HloDataflowAnalysis { string ToString() const; protected: - HloDataflowAnalysis(HloModule* module, bool ssa_form, + HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); // Returns a new HloValue defined at the given instruction and shape index. @@ -180,7 +180,7 @@ class HloDataflowAnalysis { // Verify various invariants of the dataflow analysis. Status Verify() const; - HloModule* const module_; + const HloModule& module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e714b2567fd..7bf3a1a0604 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase, bool bitcast_defines_value = false) { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = - HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index aba66114de6..a989fce6323 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { scalar_shape, HloOpcode::kAdd, constant, xla_while)); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); // Init value is defined before the while, but live range is not before the diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2c2a02f6375..f8b309488ee 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase { CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = - HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie(); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index b060fb13b14..0bc7df2a65b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -287,7 +287,7 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { StatusOr>> MakeFakeArguments( HloModule* const module) { - TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); std::minstd_rand0 engine; std::vector> arguments(params.size());