With other internal changes, extend the TF XLA interface to support running post-optimizations HLO modules. Also fix a few bugs in XLA related to saving and loading HLO from protobuf.

PiperOrigin-RevId: 340702861
Change-Id: Ide5278d164cc09b5bf72d38e87b3a77bb07c85d1
This commit is contained in:
Jinliang Wei 2020-11-04 12:13:48 -08:00 committed by TensorFlower Gardener
parent cd0ef5f0a2
commit 7effe6da6f
9 changed files with 79 additions and 34 deletions

View File

@ -104,6 +104,17 @@ class ExecutableBuildOptions {
alias_passthrough_params_ = alias_passthrough_params;
}
bool run_backend_only() const { return run_backend_only_; }
// By default, XLA builds an executable by invoking standard compilation, i.e,
// running Compiler::Compile, or both Compiler::RunHloPasses and
// Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
// executable by invoking only RunBackend and skip invoking RunHloPasses,
// which can be used to compile post-optimizations HLO modules.
ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) {
run_backend_only_ = run_backend_only;
return *this;
}
private:
int device_ordinal_ = -1;
Shape result_layout_;
@ -116,6 +127,7 @@ class ExecutableBuildOptions {
bool deduplicate_hlo_ = false;
absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false;
bool run_backend_only_ = false;
};
} // namespace xla

View File

@ -496,11 +496,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kReplicaId: {
instruction = CreateReplicaId();
instruction = CreateReplicaId(shape);
break;
}
case HloOpcode::kPartitionId: {
instruction = CreatePartitionId();
instruction = CreatePartitionId(shape);
break;
}
case HloOpcode::kConvolution: {
@ -1080,15 +1080,20 @@ HloInstruction::CreateCollectivePermuteStart(
channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
return absl::WrapUnique(
new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
const Shape& shape) {
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
<< "HloInstruction replica-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreatePartitionId() {
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId,
ShapeUtil::MakeShape(U32, {})));
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
const Shape& shape) {
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
<< "HloInstruction partition-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
@ -1799,13 +1804,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
case HloOpcode::kReplicaId:
CHECK_EQ(new_operands.size(), 0);
clone = CreateReplicaId();
*clone->mutable_shape() = shape;
clone = CreateReplicaId(shape);
break;
case HloOpcode::kPartitionId:
CHECK_EQ(new_operands.size(), 0);
clone = CreatePartitionId();
*clone->mutable_shape() = shape;
clone = CreatePartitionId(shape);
break;
}
// SetupDerivedInstruction will setup the precision_config_ field.

View File

@ -706,10 +706,12 @@ class HloInstruction {
const absl::optional<int64>& channel_id);
// Creates an instruction that returns a U32 replica ID.
static std::unique_ptr<HloInstruction> CreateReplicaId();
static std::unique_ptr<HloInstruction> CreateReplicaId(
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
// Creates an instruction that returns a U32 partition ID.
static std::unique_ptr<HloInstruction> CreatePartitionId();
static std::unique_ptr<HloInstruction> CreatePartitionId(
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.

View File

@ -646,6 +646,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl(
HloInstructionProto HloAllGatherInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
proto.add_dimensions(all_gather_dimension_);
proto.set_use_global_device_ids(use_global_device_ids_);
return proto;
}

View File

@ -38,12 +38,14 @@ HloProto MakeHloProto(const HloModule& module) {
}
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
const HloModuleProto& proto, const HloModuleConfig& module_config,
bool is_module_post_optimizations) {
VLOG(4) << proto.ShortDebugString();
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, module_config));
TF_RETURN_IF_ERROR(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
HloVerifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/is_module_post_optimizations)
.Run(module.get())
.status());
return std::move(module);

View File

@ -38,8 +38,12 @@ HloProto MakeHloProto(const HloModule& module);
// Create an HLO state from serialized representation. In addition to
// creating the proto with HloModule::CreateFromProto(...) it also
// uses HloVerifier to ensure basic invariants are held.
// The HLO module could be a pre-optimizations (default) or post-optimizations
// module, which affects how the HLO module is verified, e.g., mixed-precision
// is allowed in post-optimizations HLOs.
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config);
const HloModuleProto& proto, const HloModuleConfig& module_config,
bool is_module_post_optimizations = false);
// Returns the shapes of the parameters of the entry computation. Shape pointers
// refer to shapes inside of the given HloProto.

View File

@ -193,7 +193,8 @@ LocalService::CompileExecutables(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
executor, build_options.device_allocator()));
executor, build_options.device_allocator(),
build_options.run_backend_only()));
std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable));
return executables;
@ -207,7 +208,8 @@ LocalService::CompileExecutables(
return BuildExecutables({&proto}, std::move(module_configs),
execute_backend_.get(), {executors},
build_options.device_allocator());
build_options.device_allocator(),
build_options.run_backend_only());
}
}

View File

@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator) {
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set.
@ -379,15 +379,28 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
const HloModuleProto* proto = module_protos[i];
const HloModuleConfig& config = *module_configs[i];
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
TF_ASSIGN_OR_RETURN(
auto module, CreateModuleFromProto(*proto, config, run_backend_only));
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
module_group->push_back(std::move(module));
}
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<Executable>> executables,
backend->compiler()->Compile(std::move(module_group),
std::move(executors), device_allocator));
std::vector<std::unique_ptr<Executable>> executables;
if (!run_backend_only) {
TF_ASSIGN_OR_RETURN(
executables,
backend->compiler()->Compile(std::move(module_group),
std::move(executors), device_allocator));
} else {
auto modules = module_group->ConsumeModules();
for (std::unique_ptr<HloModule>& module : modules) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(std::move(module), executors[0][0],
device_allocator));
executables.push_back(std::move(executable));
}
}
for (size_t i = 0; i < module_protos.size(); ++i) {
const auto& debug_opts = module_configs[i]->debug_options();
@ -797,18 +810,22 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) {
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
bool run_backend_only) {
VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this,
module_proto.name());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
CreateModuleFromProto(module_proto, *module_config));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
CreateModuleFromProto(module_proto, *module_config, run_backend_only));
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
if (!run_backend_only) {
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(

View File

@ -236,7 +236,8 @@ class Service : public ServiceInterface {
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator = nullptr);
se::DeviceMemoryAllocator* device_allocator = nullptr,
bool run_backend_only = false);
// Same as BuildExecutable() above, but builds a list of Executables for the
// given computations that may interact with each other.
@ -244,7 +245,8 @@ class Service : public ServiceInterface {
const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator);
se::DeviceMemoryAllocator* device_allocator,
bool run_backend_only = false);
// Runs the given executable with the given arguments and register the result
// in the allocation tracker. The handle of the result from the tracker is