From 88b90a0a37a8df9dd4ae97bba537917f08da2b25 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 29 Jan 2021 20:46:11 -0800 Subject: [PATCH] [XLA] Use array all to all in the client frontend PiperOrigin-RevId: 354660949 Change-Id: I201e941172ab949d2632f83efc5cd80f4ebdddca --- tensorflow/compiler/xla/client/xla_builder.cc | 76 ++++++++++ tensorflow/compiler/xla/client/xla_builder.h | 18 +++ .../compiler/xla/client/xla_builder_test.cc | 19 ++- tensorflow/compiler/xla/service/BUILD | 21 +++ .../xla/service/all_to_all_decomposer.cc | 131 ++++++++++++++++++ .../xla/service/all_to_all_decomposer.h | 49 +++++++ tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 46 +++--- .../compiler/xla/service/hlo_verifier.cc | 3 +- 11 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 tensorflow/compiler/xla/service/all_to_all_decomposer.cc create mode 100644 tensorflow/compiler/xla/service/all_to_all_decomposer.h diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 980200fcb15..7e62b237215 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" @@ -2847,6 +2848,72 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups, const absl::optional& layout) { + // Array all_to_all may need to violate layout constraint to be legal so use + // the tuple version. + if (layout.has_value()) { + return AllToAllTuple(operand, split_dimension, concat_dimension, + split_count, replica_groups, layout); + } + return AllToAllArray(operand, split_dimension, concat_dimension, split_count, + replica_groups); +} + +XlaOp XlaBuilder::AllToAllArray( + XlaOp operand, int64 split_dimension, int64 concat_dimension, + int64 split_count, const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN( + const Shape all_to_all_shape, + ShapeInference::InferAllToAllShape(*operand_shape, split_dimension, + concat_dimension, split_count)); + HloInstructionProto instr; + *instr.mutable_shape() = operand_shape->ToProto(); + if (replica_groups.empty()) { + auto* group = instr.add_replica_groups(); + for (int64 i = 0; i < split_count; ++i) { + group->add_replica_ids(i); + } + } else { + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + } + instr.add_dimensions(split_dimension); + TF_ASSIGN_OR_RETURN( + XlaOp all_to_all, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand})); + if (split_dimension == concat_dimension) { + return all_to_all; + } + DimensionVector sizes; + for (int64 i = 0; i < operand_shape->rank(); ++i) { + if (i != split_dimension) { + sizes.push_back(operand_shape->dimensions(i)); + continue; + } + sizes.push_back(split_count); + sizes.push_back(operand_shape->dimensions(i) / split_count); + } + all_to_all = Reshape(all_to_all, sizes); + + std::vector permutation; + for (int64 i = 0; i < operand_shape->rank(); ++i) { + int64 dim_after_reshape = i >= split_dimension ? i + 1 : i; + if (i == concat_dimension) { + permutation.push_back(split_dimension); + } + permutation.push_back(dim_after_reshape); + } + all_to_all = Transpose(all_to_all, permutation); + return Reshape(all_to_all_shape, all_to_all); + }); +} + +XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups, + const absl::optional& layout) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -4581,6 +4648,15 @@ XlaOp AllToAll(const XlaOp operand, int64 split_dimension, split_count, replica_groups, layout); } +XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups, + const absl::optional& layout) { + return operand.builder()->AllToAllTuple(operand, split_dimension, + concat_dimension, split_count, + replica_groups, layout); +} + XlaOp CollectivePermute( const XlaOp operand, const std::vector>& source_target_pairs) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index c23704b04fe..a459a3616f2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -745,6 +745,11 @@ class XlaBuilder { const std::vector& replica_groups, const absl::optional& layout = absl::nullopt); + XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups, + const absl::optional& layout); + XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs); @@ -1297,6 +1302,10 @@ class XlaBuilder { int64 concat_dimension, int64 split_count, const std::vector& replica_groups, const absl::optional& layout); + friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups, + const absl::optional& layout); friend XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs); @@ -1425,6 +1434,10 @@ class XlaBuilder { absl::Span branch_computations, absl::Span branch_operands); + XlaOp AllToAllArray(XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + // Creates an op with the given opcode and the output shape. virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, absl::Span operands); @@ -2203,6 +2216,11 @@ XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, const std::vector& replica_groups = {}, const absl::optional& layout = absl::nullopt); +XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}, + const absl::optional& layout = absl::nullopt); + // Enqueues an collective operation that sends and receives data cross replicas. // // - `source_target_pair`: a list of (source_replica_id, target_replica_id) diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index cd21c6dc414..08f4f2b2456 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -414,12 +414,27 @@ TEST_F(XlaBuilderTest, AllToAll) { auto root = module->entry_computation()->root_instruction(); // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. - EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); - EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_EQ(root->opcode(), HloOpcode::kReshape); + EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(), + HloOpcode::kAllToAll); EXPECT_TRUE( ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); } +TEST_F(XlaBuilderTest, AllToAllSpecial) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x"); + AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is converted into a single all-to-all HloInstruction. + EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8}))); +} + TEST_F(XlaBuilderTest, CollectivePermute) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a9cb87d44bd..fc57210359f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2649,6 +2649,27 @@ tf_cc_test( ], ) +cc_library( + name = "all_to_all_decomposer", + srcs = ["all_to_all_decomposer.cc"], + hdrs = ["all_to_all_decomposer.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "all_gather_decomposer", srcs = ["all_gather_decomposer.cc"], diff --git a/tensorflow/compiler/xla/service/all_to_all_decomposer.cc b/tensorflow/compiler/xla/service/all_to_all_decomposer.cc new file mode 100644 index 00000000000..adf05ddb19a --- /dev/null +++ b/tensorflow/compiler/xla/service/all_to_all_decomposer.cc @@ -0,0 +1,131 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +bool AllToAllDecomposer::InstructionMatchesPattern( + HloInstruction* instruction) { + auto* all_to_all = DynCast(instruction); + if (all_to_all == nullptr) { + return false; + } + // Do not attempt to change layout constrained collectives. + if (all_to_all->constrain_layout()) { + return false; + } + if (all_to_all->shape().IsTuple()) { + return false; + } + if (decompose_to_tuple_) { + return true; + } + return all_to_all->shape().rank() < min_array_rank_; +} +StatusOr AllToAllDecomposer::ExpandInstruction( + HloInstruction* instruction) { + auto* all_to_all = Cast(instruction); + int64 split_dim = *all_to_all->split_dimension(); + int64 all_to_all_group_size = + all_to_all->replica_groups().empty() + ? instruction->parent()->parent()->config().replica_count() + : all_to_all->replica_groups()[0].replica_ids_size(); + int64 split_size = + all_to_all->shape().dimensions(split_dim) / all_to_all_group_size; + if (!decompose_to_tuple_) { + Shape new_all_to_all_shape; + new_all_to_all_shape.set_element_type( + instruction->operand(0)->shape().element_type()); + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + if (i != split_dim) { + new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i)); + continue; + } + new_all_to_all_shape.add_dimensions(all_to_all_group_size); + new_all_to_all_shape.add_dimensions(split_size); + for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) { + new_all_to_all_shape.add_dimensions(1); + } + } + *(new_all_to_all_shape.mutable_layout()) = + LayoutUtil::GetDefaultLayoutForRank(min_array_rank_); + HloInstruction* operand_reshape = + instruction->parent()->AddInstruction(HloInstruction::CreateReshape( + new_all_to_all_shape, instruction->mutable_operand(0))); + instruction->SetupDerivedInstruction(operand_reshape); + HloInstruction* all_to_all = + instruction->parent()->AddInstruction(instruction->CloneWithNewOperands( + new_all_to_all_shape, {operand_reshape})); + HloInstruction* output_reshape = instruction->parent()->AddInstruction( + HloInstruction::CreateReshape(instruction->shape(), all_to_all)); + instruction->SetupDerivedInstruction(output_reshape); + return output_reshape; + } + DimensionVector slice_starts(all_to_all->shape().rank(), 0); + DimensionVector slice_strides(all_to_all->shape().rank(), 1); + DimensionVector slice_limits(all_to_all->shape().dimensions().begin(), + all_to_all->shape().dimensions().end()); + slice_limits[split_dim] = split_size; + Shape slice_shape = all_to_all->shape(); + slice_shape.set_dimensions(split_dim, split_size); + std::vector slices; + slices.reserve(all_to_all_group_size); + HloInstruction* operand = all_to_all->mutable_operand(0); + for (int64 i = 0; i < all_to_all_group_size; ++i) { + slices.push_back( + all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice( + slice_shape, operand, slice_starts, slice_limits, slice_strides))); + all_to_all->SetupDerivedInstruction(slices.back()); + slice_starts[split_dim] = slice_limits[split_dim]; + slice_limits[split_dim] += split_size; + } + Shape all_to_all_shape = ShapeUtil::MakeTupleShape( + std::vector(all_to_all_group_size, slice_shape)); + HloInstruction* new_all_to_all = + all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll( + all_to_all_shape, slices, all_to_all->replica_groups(), false, + all_to_all->channel_id(), absl::nullopt)); + std::vector gtes; + gtes.reserve(all_to_all_group_size); + for (int64 i = 0; i < all_to_all_group_size; ++i) { + gtes.push_back(all_to_all->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i))); + all_to_all->SetupDerivedInstruction(new_all_to_all); + } + HloInstruction* concat = all_to_all->parent()->AddInstruction( + HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim)); + all_to_all->SetupDerivedInstruction(concat); + return concat; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_to_all_decomposer.h b/tensorflow/compiler/xla/service/all_to_all_decomposer.h new file mode 100644 index 00000000000..1d804c9cd5a --- /dev/null +++ b/tensorflow/compiler/xla/service/all_to_all_decomposer.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +// AllToAllDecomposer is a pass which converts unsupported array all_to_all +// into tuple all_to_all or array all_to_all with a minimum rank where the split +// dimension is the size of the replica_groups. +class AllToAllDecomposer : public OpExpanderPass { + public: + explicit AllToAllDecomposer(bool decompose_to_tuple = true, + int64 min_array_rank = 0) + : decompose_to_tuple_(decompose_to_tuple), + min_array_rank_(min_array_rank) {} + ~AllToAllDecomposer() override = default; + absl::string_view name() const override { return "all_to_all_decomposer"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + bool decompose_to_tuple_; + int64 min_array_rank_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8a7b57810d6..41e2a3e08d2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -157,6 +157,7 @@ cc_library( "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:VectorOps", "//tensorflow/compiler/xla/service:all_gather_decomposer", + "//tensorflow/compiler/xla/service:all_to_all_decomposer", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:topk_rewriter", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 34d3fda49b1..fd8dd6a067d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -57,6 +57,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_gather_decomposer.h" +#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -292,6 +293,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); // Inline computations with a single call site. pipeline.AddPass(/*single_call_site=*/true); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a1dc98ece40..20961e29274 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1427,6 +1427,7 @@ cc_library( "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:all_gather_decomposer", "//tensorflow/compiler/xla/service:all_reduce_combiner", + "//tensorflow/compiler/xla/service:all_to_all_decomposer", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9de5a68e231..954131ebf29 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_gather_decomposer.h" #include "tensorflow/compiler/xla/service/all_reduce_combiner.h" +#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" @@ -157,6 +158,7 @@ Status GpuCompiler::OptimizeHloModule( [](const HloAllGatherInstruction& ag) { return !NcclAllGatherThunk::CanImplement(&ag); }); + pipeline.AddPass(); pipeline.AddPass(); @@ -804,30 +806,30 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, llvm_modules.size()); tensorflow::BlockingCounter counter(llvm_modules.size()); for (int i = 0; i < llvm_modules.size(); i++) { - thread_pool->Schedule([&compile_results, compile_single_module, i, - &llvm_modules, &counter] { - llvm::Module* original_module = llvm_modules[i].get(); - llvm::LLVMContext context; - std::string buffer; - llvm::raw_string_ostream error(buffer); + thread_pool->Schedule( + [&compile_results, compile_single_module, i, &llvm_modules, &counter] { + llvm::Module* original_module = llvm_modules[i].get(); + llvm::LLVMContext context; + std::string buffer; + llvm::raw_string_ostream error(buffer); - std::unique_ptr new_llvm_module; - // Switch to a new context by dumping and re-parsing LLVM IR. Each thread - // has its own context to avoid race conditions. - { - std::string ir; - { - llvm::raw_string_ostream os(ir); - original_module->print(os, nullptr); - } - llvm::SMDiagnostic err; - new_llvm_module = llvm::parseAssemblyString(ir, err, context); - } + std::unique_ptr new_llvm_module; + // Switch to a new context by dumping and re-parsing LLVM IR. Each + // thread has its own context to avoid race conditions. + { + std::string ir; + { + llvm::raw_string_ostream os(ir); + original_module->print(os, nullptr); + } + llvm::SMDiagnostic err; + new_llvm_module = llvm::parseAssemblyString(ir, err, context); + } - compile_results[i] = compile_single_module( - new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i); - counter.DecrementCount(); - }); + compile_results[i] = compile_single_module( + new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i); + counter.DecrementCount(); + }); } counter.Wait(); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 80d2cd3e7b7..2c8736ad42c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -229,7 +229,8 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { } int64 replica_count = hlo->GetModule()->config().replica_count(); - if (!replicas_seen.empty() && replicas_seen.size() != replica_count) { + if (replica_count != 1 && !replicas_seen.empty() && + replicas_seen.size() != replica_count) { return InternalError( "Replica count in HloModuleConfig is %d, but ReplicaGroup config " "contains %d replicas: %s",