With other internal changes, extend the TF XLA interface to support running post-optimizations HLO modules. Also fix a few bugs in XLA related to saving and loading HLO from protobuf.

PiperOrigin-RevId: 340702861
Change-Id: Ide5278d164cc09b5bf72d38e87b3a77bb07c85d1
This commit is contained in:
Jinliang Wei 2020-11-04 12:13:48 -08:00 committed by TensorFlower Gardener
parent cd0ef5f0a2
commit 7effe6da6f
9 changed files with 79 additions and 34 deletions

View File

@ -104,6 +104,17 @@ class ExecutableBuildOptions {
alias_passthrough_params_ = alias_passthrough_params; alias_passthrough_params_ = alias_passthrough_params;
} }
bool run_backend_only() const { return run_backend_only_; }
// By default, XLA builds an executable by invoking standard compilation, i.e,
// running Compiler::Compile, or both Compiler::RunHloPasses and
// Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
// executable by invoking only RunBackend and skip invoking RunHloPasses,
// which can be used to compile post-optimizations HLO modules.
ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) {
run_backend_only_ = run_backend_only;
return *this;
}
private: private:
int device_ordinal_ = -1; int device_ordinal_ = -1;
Shape result_layout_; Shape result_layout_;
@ -116,6 +127,7 @@ class ExecutableBuildOptions {
bool deduplicate_hlo_ = false; bool deduplicate_hlo_ = false;
absl::optional<DeviceAssignment> device_assignment_; absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false; bool alias_passthrough_params_ = false;
bool run_backend_only_ = false;
}; };
} // namespace xla } // namespace xla

View File

@ -496,11 +496,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break; break;
} }
case HloOpcode::kReplicaId: { case HloOpcode::kReplicaId: {
instruction = CreateReplicaId(); instruction = CreateReplicaId(shape);
break; break;
} }
case HloOpcode::kPartitionId: { case HloOpcode::kPartitionId: {
instruction = CreatePartitionId(); instruction = CreatePartitionId(shape);
break; break;
} }
case HloOpcode::kConvolution: { case HloOpcode::kConvolution: {
@ -1080,15 +1080,20 @@ HloInstruction::CreateCollectivePermuteStart(
channel_id); channel_id);
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
return absl::WrapUnique( const Shape& shape) {
new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
<< "HloInstruction replica-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
} }
/* static */ std::unique_ptr<HloInstruction> /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
HloInstruction::CreatePartitionId() { const Shape& shape) {
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
ShapeUtil::MakeShape(U32, {}))); << "HloInstruction partition-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
@ -1799,13 +1804,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break; break;
case HloOpcode::kReplicaId: case HloOpcode::kReplicaId:
CHECK_EQ(new_operands.size(), 0); CHECK_EQ(new_operands.size(), 0);
clone = CreateReplicaId(); clone = CreateReplicaId(shape);
*clone->mutable_shape() = shape;
break; break;
case HloOpcode::kPartitionId: case HloOpcode::kPartitionId:
CHECK_EQ(new_operands.size(), 0); CHECK_EQ(new_operands.size(), 0);
clone = CreatePartitionId(); clone = CreatePartitionId(shape);
*clone->mutable_shape() = shape;
break; break;
} }
// SetupDerivedInstruction will setup the precision_config_ field. // SetupDerivedInstruction will setup the precision_config_ field.

View File

@ -706,10 +706,12 @@ class HloInstruction {
const absl::optional<int64>& channel_id); const absl::optional<int64>& channel_id);
// Creates an instruction that returns a U32 replica ID. // Creates an instruction that returns a U32 replica ID.
static std::unique_ptr<HloInstruction> CreateReplicaId(); static std::unique_ptr<HloInstruction> CreateReplicaId(
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
// Creates an instruction that returns a U32 partition ID. // Creates an instruction that returns a U32 partition ID.
static std::unique_ptr<HloInstruction> CreatePartitionId(); static std::unique_ptr<HloInstruction> CreatePartitionId(
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
// Creates a conversion instruction, where operand is the data to convert and // Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion. // shape is the target shape for the conversion.

View File

@ -646,6 +646,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl(
HloInstructionProto HloAllGatherInstruction::ToProto() const { HloInstructionProto HloAllGatherInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto(); HloInstructionProto proto = HloCollectiveInstruction::ToProto();
proto.add_dimensions(all_gather_dimension_); proto.add_dimensions(all_gather_dimension_);
proto.set_use_global_device_ids(use_global_device_ids_);
return proto; return proto;
} }

View File

@ -38,12 +38,14 @@ HloProto MakeHloProto(const HloModule& module) {
} }
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) { const HloModuleProto& proto, const HloModuleConfig& module_config,
bool is_module_post_optimizations) {
VLOG(4) << proto.ShortDebugString(); VLOG(4) << proto.ShortDebugString();
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, module_config)); HloModule::CreateFromProto(proto, module_config));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) HloVerifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/is_module_post_optimizations)
.Run(module.get()) .Run(module.get())
.status()); .status());
return std::move(module); return std::move(module);

View File

@ -38,8 +38,12 @@ HloProto MakeHloProto(const HloModule& module);
// Create an HLO state from serialized representation. In addition to // Create an HLO state from serialized representation. In addition to
// creating the proto with HloModule::CreateFromProto(...) it also // creating the proto with HloModule::CreateFromProto(...) it also
// uses HloVerifier to ensure basic invariants are held. // uses HloVerifier to ensure basic invariants are held.
// The HLO module could be a pre-optimizations (default) or post-optimizations
// module, which affects how the HLO module is verified, e.g., mixed-precision
// is allowed in post-optimizations HLOs.
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config); const HloModuleProto& proto, const HloModuleConfig& module_config,
bool is_module_post_optimizations = false);
// Returns the shapes of the parameters of the entry computation. Shape pointers // Returns the shapes of the parameters of the entry computation. Shape pointers
// refer to shapes inside of the given HloProto. // refer to shapes inside of the given HloProto.

View File

@ -193,7 +193,8 @@ LocalService::CompileExecutables(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable, std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config), execute_backend_.get(), BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
executor, build_options.device_allocator())); executor, build_options.device_allocator(),
build_options.run_backend_only()));
std::vector<std::unique_ptr<Executable>> executables; std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable)); executables.push_back(std::move(executable));
return executables; return executables;
@ -207,7 +208,8 @@ LocalService::CompileExecutables(
return BuildExecutables({&proto}, std::move(module_configs), return BuildExecutables({&proto}, std::move(module_configs),
execute_backend_.get(), {executors}, execute_backend_.get(), {executors},
build_options.device_allocator()); build_options.device_allocator(),
build_options.run_backend_only());
} }
} }

View File

@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
const std::vector<const HloModuleProto*>& module_protos, const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs, std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator) { se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
VLOG(1) << StrFormat("BuildExecutable on service %p", this); VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set. // Dump computation proto state if flag is set.
@ -379,15 +379,28 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
for (int64 i = 0, end = module_protos.size(); i < end; ++i) { for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
const HloModuleProto* proto = module_protos[i]; const HloModuleProto* proto = module_protos[i];
const HloModuleConfig& config = *module_configs[i]; const HloModuleConfig& config = *module_configs[i];
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); TF_ASSIGN_OR_RETURN(
auto module, CreateModuleFromProto(*proto, config, run_backend_only));
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
module_group->push_back(std::move(module)); module_group->push_back(std::move(module));
} }
TF_ASSIGN_OR_RETURN( std::vector<std::unique_ptr<Executable>> executables;
std::vector<std::unique_ptr<Executable>> executables, if (!run_backend_only) {
backend->compiler()->Compile(std::move(module_group), TF_ASSIGN_OR_RETURN(
std::move(executors), device_allocator)); executables,
backend->compiler()->Compile(std::move(module_group),
std::move(executors), device_allocator));
} else {
auto modules = module_group->ConsumeModules();
for (std::unique_ptr<HloModule>& module : modules) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(std::move(module), executors[0][0],
device_allocator));
executables.push_back(std::move(executable));
}
}
for (size_t i = 0; i < module_protos.size(); ++i) { for (size_t i = 0; i < module_protos.size(); ++i) {
const auto& debug_opts = module_configs[i]->debug_options(); const auto& debug_opts = module_configs[i]->debug_options();
@ -797,18 +810,22 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto, const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend, std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
bool run_backend_only) {
VLOG(1) << StrFormat( VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this, "BuildExecutable on service %p with serialized module proto: %s", this,
module_proto.name()); module_proto.name());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, TF_ASSIGN_OR_RETURN(
CreateModuleFromProto(module_proto, *module_config)); std::unique_ptr<HloModule> module,
CreateModuleFromProto(module_proto, *module_config, run_backend_only));
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
TF_ASSIGN_OR_RETURN( if (!run_backend_only) {
module, backend->compiler()->RunHloPasses(std::move(module), executor, TF_ASSIGN_OR_RETURN(
device_allocator)); module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend( backend->compiler()->RunBackend(

View File

@ -236,7 +236,8 @@ class Service : public ServiceInterface {
const HloModuleProto& module_proto, const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend, std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator = nullptr); se::DeviceMemoryAllocator* device_allocator = nullptr,
bool run_backend_only = false);
// Same as BuildExecutable() above, but builds a list of Executables for the // Same as BuildExecutable() above, but builds a list of Executables for the
// given computations that may interact with each other. // given computations that may interact with each other.
@ -244,7 +245,8 @@ class Service : public ServiceInterface {
const std::vector<const HloModuleProto*>& module_protos, const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs, std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator); se::DeviceMemoryAllocator* device_allocator,
bool run_backend_only = false);
// Runs the given executable with the given arguments and register the result // Runs the given executable with the given arguments and register the result
// in the allocation tracker. The handle of the result from the tracker is // in the allocation tracker. The handle of the result from the tracker is