[XLA:GPU] Enforce collectives ordering to be their appearence in the module

The ordering is enforced using control edges. This avoids deadlocks in
multi-host launches which arise due to change in ordering due to
non-determinism (coming from autotuning).

PiperOrigin-RevId: 361069409
Change-Id: I499639aafc74f4128226bccf55fe6d48ecce3f67
This commit is contained in:
George Karpenkov 2021-03-04 21:33:28 -08:00 committed by TensorFlower Gardener
parent f13ee79e03
commit d379154ca0
6 changed files with 287 additions and 0 deletions

View File

@ -2214,6 +2214,53 @@ tf_cc_test(
],
)
cc_library(
name = "collectives_schedule_linearizer",
srcs = ["collectives_schedule_linearizer.cc"],
hdrs = ["collectives_schedule_linearizer.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_domain_map",
":hlo_pass",
":hlo_query",
":hlo_reachability",
":shape_inference",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "collectives_schedule_linearizer_test",
srcs = ["collectives_schedule_linearizer_test.cc"],
deps = [
":collectives_schedule_linearizer",
":hlo",
":hlo_graph_dumper",
":hlo_matchers",
":hlo_runner",
":pattern_matcher",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "all_reduce_combiner",
srcs = ["all_reduce_combiner.cc"],

View File

@ -0,0 +1,68 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// TODO(b/181653482): Fix for interprocedural collectives as well.
StatusOr<bool> CollectivesScheduleLinearizer::Run(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
std::unique_ptr<HloReachabilityMap> reachability =
HloReachabilityMap::Build(computation);
HloCollectiveInstruction* prev = nullptr;
for (HloInstruction* instruction : computation->instructions()) {
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
// We check for reachability as we don't want to form a cycle.
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
VLOG(1) << "Adding control dependency from " << prev->ToString()
<< " to " << next->ToString();
changed = true;
}
prev = next;
}
}
}
return changed;
}
} // namespace xla

View File

@ -0,0 +1,46 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Enforces a total order on all collectives present in the module, based on the
// order given to the instructions.
//
// Does not insert inter-computation dependencies, only linearizes the order
// within each computation.
class CollectivesScheduleLinearizer : public HloModulePass {
public:
absl::string_view name() const override {
return "collectives-schedule-linearizer";
}
CollectivesScheduleLinearizer() = default;
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace xla {
namespace {
namespace m = match;
int64 CountControlEdges(const HloComputation& computation) {
int64 count = 0;
for (const auto& instruction : computation.instructions()) {
count += instruction->control_successors().size();
}
return count;
}
class CollectivesScheduleLinearizerTest : public HloTestBase {
protected:
void InsertCollectivesSchedule(HloModule* module) {
CollectivesScheduleLinearizer collectives_schedule_linearizer;
ASSERT_IS_OK(collectives_schedule_linearizer.Run(module).status());
}
};
TEST_F(CollectivesScheduleLinearizerTest, FixOrdering) {
absl::string_view hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT out = f32[] add(a, b)
}
ENTRY entry {
p0 = f32[100] parameter(0), parameter_replication={false}
p1 = f32[100] parameter(1), parameter_replication={false}
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
c2 = f32[100] all-reduce(p1), replica_groups={}, to_apply=sum
ROOT out = f32[100] add(c1, c2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCollectivesSchedule(module.get());
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
HloInstruction *c1 = nullptr, *c2 = nullptr;
for (HloInstruction* instr : module->entry_computation()->instructions()) {
if (Match(instr, m::AllReduce(m::Parameter(0)))) {
c1 = instr;
}
if (Match(instr, m::AllReduce(m::Parameter(1)))) {
c2 = instr;
}
}
EXPECT_TRUE(c1 != nullptr && c2 != nullptr);
EXPECT_TRUE(absl::c_linear_search(c2->control_predecessors(), c1));
}
TEST_F(CollectivesScheduleLinearizerTest, NoFixRequired) {
absl::string_view hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT out = f32[] add(a, b)
}
ENTRY entry {
p0 = f32[100] parameter(0), parameter_replication={false}
p1 = f32[100] parameter(1), parameter_replication={false}
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
c2 = f32[100] all-reduce(p1), replica_groups={}, to_apply=sum, control-predecessors={c1}
ROOT out = f32[100] add(c1, c2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCollectivesSchedule(module.get());
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
}
} // namespace
} // namespace xla

View File

@ -1282,6 +1282,7 @@ cc_library(
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:collectives_schedule_linearizer",
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
"//tensorflow/compiler/xla/service:conditional_simplifier",

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
#include "tensorflow/compiler/xla/service/comparison_expander.h"
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
@ -376,6 +377,13 @@ Status GpuCompiler::OptimizeHloModule(
/*combine_threshold_count=*/256);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
HloPassPipeline pipeline("collectives_schedule_linearizer");
pipeline.AddPass<CollectivesScheduleLinearizer>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
// Now we allow to replace any transposes outside of fusions with bitcasts.
HloPassPipeline pipeline("final_algebraic_simplifier");