From 7effe6da6f1a8857059d3e36ddde5b2f0bfbe8a0 Mon Sep 17 00:00:00 2001 From: Jinliang Wei Date: Wed, 4 Nov 2020 12:13:48 -0800 Subject: [PATCH] With other internal changes, extend the TF XLA interface to support running post-optimizations HLO modules. Also fix a few bugs in XLA related to saving and loading HLO from protobuf. PiperOrigin-RevId: 340702861 Change-Id: Ide5278d164cc09b5bf72d38e87b3a77bb07c85d1 --- .../xla/client/executable_build_options.h | 12 ++++++ .../compiler/xla/service/hlo_instruction.cc | 29 +++++++------ .../compiler/xla/service/hlo_instruction.h | 6 ++- .../compiler/xla/service/hlo_instructions.cc | 1 + .../compiler/xla/service/hlo_proto_util.cc | 6 ++- .../compiler/xla/service/hlo_proto_util.h | 6 ++- .../compiler/xla/service/local_service.cc | 6 ++- tensorflow/compiler/xla/service/service.cc | 41 +++++++++++++------ tensorflow/compiler/xla/service/service.h | 6 ++- 9 files changed, 79 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index d034eaa7fd6..d3f5dd3e662 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -104,6 +104,17 @@ class ExecutableBuildOptions { alias_passthrough_params_ = alias_passthrough_params; } + bool run_backend_only() const { return run_backend_only_; } + // By default, XLA builds an executable by invoking standard compilation, i.e, + // running Compiler::Compile, or both Compiler::RunHloPasses and + // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an + // executable by invoking only RunBackend and skip invoking RunHloPasses, + // which can be used to compile post-optimizations HLO modules. + ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) { + run_backend_only_ = run_backend_only; + return *this; + } + private: int device_ordinal_ = -1; Shape result_layout_; @@ -116,6 +127,7 @@ class ExecutableBuildOptions { bool deduplicate_hlo_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; + bool run_backend_only_ = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 149bf0158c9..9b3bbbff857 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -496,11 +496,11 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kReplicaId: { - instruction = CreateReplicaId(); + instruction = CreateReplicaId(shape); break; } case HloOpcode::kPartitionId: { - instruction = CreatePartitionId(); + instruction = CreatePartitionId(shape); break; } case HloOpcode::kConvolution: { @@ -1080,15 +1080,20 @@ HloInstruction::CreateCollectivePermuteStart( channel_id); } -/* static */ std::unique_ptr HloInstruction::CreateReplicaId() { - return absl::WrapUnique( - new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); +/* static */ std::unique_ptr HloInstruction::CreateReplicaId( + const Shape& shape) { + CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {}))) + << "HloInstruction replica-id must have a shape of u32[], but " + << shape.ToString() << " is specified"; + return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape)); } -/* static */ std::unique_ptr -HloInstruction::CreatePartitionId() { - return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, - ShapeUtil::MakeShape(U32, {}))); +/* static */ std::unique_ptr HloInstruction::CreatePartitionId( + const Shape& shape) { + CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {}))) + << "HloInstruction partition-id must have a shape of u32[], but " + << shape.ToString() << " is specified"; + return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape)); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -1799,13 +1804,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kReplicaId: CHECK_EQ(new_operands.size(), 0); - clone = CreateReplicaId(); - *clone->mutable_shape() = shape; + clone = CreateReplicaId(shape); break; case HloOpcode::kPartitionId: CHECK_EQ(new_operands.size(), 0); - clone = CreatePartitionId(); - *clone->mutable_shape() = shape; + clone = CreatePartitionId(shape); break; } // SetupDerivedInstruction will setup the precision_config_ field. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1dcaeb4e114..012b7d428f7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -706,10 +706,12 @@ class HloInstruction { const absl::optional& channel_id); // Creates an instruction that returns a U32 replica ID. - static std::unique_ptr CreateReplicaId(); + static std::unique_ptr CreateReplicaId( + const Shape& shape = ShapeUtil::MakeShape(U32, {})); // Creates an instruction that returns a U32 partition ID. - static std::unique_ptr CreatePartitionId(); + static std::unique_ptr CreatePartitionId( + const Shape& shape = ShapeUtil::MakeShape(U32, {})); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7b84e6e0700..f2ea03f063a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -646,6 +646,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl( HloInstructionProto HloAllGatherInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); proto.add_dimensions(all_gather_dimension_); + proto.set_use_global_device_ids(use_global_device_ids_); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3a9ee57e555..803fad57bf2 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -38,12 +38,14 @@ HloProto MakeHloProto(const HloModule& module) { } StatusOr> CreateModuleFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config) { + const HloModuleProto& proto, const HloModuleConfig& module_config, + bool is_module_post_optimizations) { VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + HloVerifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/is_module_post_optimizations) .Run(module.get()) .status()); return std::move(module); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 31ea2aaffd9..cc01306a19e 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -38,8 +38,12 @@ HloProto MakeHloProto(const HloModule& module); // Create an HLO state from serialized representation. In addition to // creating the proto with HloModule::CreateFromProto(...) it also // uses HloVerifier to ensure basic invariants are held. +// The HLO module could be a pre-optimizations (default) or post-optimizations +// module, which affects how the HLO module is verified, e.g., mixed-precision +// is allowed in post-optimizations HLOs. StatusOr> CreateModuleFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config); + const HloModuleProto& proto, const HloModuleConfig& module_config, + bool is_module_post_optimizations = false); // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5def5bbe9db..0eff81c9a0d 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -193,7 +193,8 @@ LocalService::CompileExecutables( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, BuildExecutable(proto, std::move(module_config), execute_backend_.get(), - executor, build_options.device_allocator())); + executor, build_options.device_allocator(), + build_options.run_backend_only())); std::vector> executables; executables.push_back(std::move(executable)); return executables; @@ -207,7 +208,8 @@ LocalService::CompileExecutables( return BuildExecutables({&proto}, std::move(module_configs), execute_backend_.get(), {executors}, - build_options.device_allocator()); + build_options.device_allocator(), + build_options.run_backend_only()); } } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index e72bce71a8e..a6d23c18797 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -357,7 +357,7 @@ StatusOr>> Service::BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - se::DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) { VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -379,15 +379,28 @@ StatusOr>> Service::BuildExecutables( for (int64 i = 0, end = module_protos.size(); i < end; ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + TF_ASSIGN_OR_RETURN( + auto module, CreateModuleFromProto(*proto, config, run_backend_only)); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); module_group->push_back(std::move(module)); } - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(module_group), - std::move(executors), device_allocator)); + std::vector> executables; + if (!run_backend_only) { + TF_ASSIGN_OR_RETURN( + executables, + backend->compiler()->Compile(std::move(module_group), + std::move(executors), device_allocator)); + } else { + auto modules = module_group->ConsumeModules(); + for (std::unique_ptr& module : modules) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend->compiler()->RunBackend(std::move(module), executors[0][0], + device_allocator)); + executables.push_back(std::move(executable)); + } + } for (size_t i = 0; i < module_protos.size(); ++i) { const auto& debug_opts = module_configs[i]->debug_options(); @@ -797,18 +810,22 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { + se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator, + bool run_backend_only) { VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - CreateModuleFromProto(module_proto, *module_config)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + CreateModuleFromProto(module_proto, *module_config, run_backend_only)); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); + if (!run_backend_only) { + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + } TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index d58020655de..712ccc44d91 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -236,7 +236,8 @@ class Service : public ServiceInterface { const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator = nullptr); + se::DeviceMemoryAllocator* device_allocator = nullptr, + bool run_backend_only = false); // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. @@ -244,7 +245,8 @@ class Service : public ServiceInterface { const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - se::DeviceMemoryAllocator* device_allocator); + se::DeviceMemoryAllocator* device_allocator, + bool run_backend_only = false); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is