With other internal changes, extend the TF XLA interface to support running post-optimizations HLO modules. Also fix a few bugs in XLA related to saving and loading HLO from protobuf.
PiperOrigin-RevId: 340702861 Change-Id: Ide5278d164cc09b5bf72d38e87b3a77bb07c85d1
This commit is contained in:
parent
cd0ef5f0a2
commit
7effe6da6f
@ -104,6 +104,17 @@ class ExecutableBuildOptions {
|
|||||||
alias_passthrough_params_ = alias_passthrough_params;
|
alias_passthrough_params_ = alias_passthrough_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool run_backend_only() const { return run_backend_only_; }
|
||||||
|
// By default, XLA builds an executable by invoking standard compilation, i.e,
|
||||||
|
// running Compiler::Compile, or both Compiler::RunHloPasses and
|
||||||
|
// Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
|
||||||
|
// executable by invoking only RunBackend and skip invoking RunHloPasses,
|
||||||
|
// which can be used to compile post-optimizations HLO modules.
|
||||||
|
ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) {
|
||||||
|
run_backend_only_ = run_backend_only;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_ordinal_ = -1;
|
int device_ordinal_ = -1;
|
||||||
Shape result_layout_;
|
Shape result_layout_;
|
||||||
@ -116,6 +127,7 @@ class ExecutableBuildOptions {
|
|||||||
bool deduplicate_hlo_ = false;
|
bool deduplicate_hlo_ = false;
|
||||||
absl::optional<DeviceAssignment> device_assignment_;
|
absl::optional<DeviceAssignment> device_assignment_;
|
||||||
bool alias_passthrough_params_ = false;
|
bool alias_passthrough_params_ = false;
|
||||||
|
bool run_backend_only_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -496,11 +496,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kReplicaId: {
|
case HloOpcode::kReplicaId: {
|
||||||
instruction = CreateReplicaId();
|
instruction = CreateReplicaId(shape);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kPartitionId: {
|
case HloOpcode::kPartitionId: {
|
||||||
instruction = CreatePartitionId();
|
instruction = CreatePartitionId(shape);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kConvolution: {
|
case HloOpcode::kConvolution: {
|
||||||
@ -1080,15 +1080,20 @@ HloInstruction::CreateCollectivePermuteStart(
|
|||||||
channel_id);
|
channel_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
|
||||||
return absl::WrapUnique(
|
const Shape& shape) {
|
||||||
new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
|
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
|
||||||
|
<< "HloInstruction replica-id must have a shape of u32[], but "
|
||||||
|
<< shape.ToString() << " is specified";
|
||||||
|
return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction>
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
|
||||||
HloInstruction::CreatePartitionId() {
|
const Shape& shape) {
|
||||||
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId,
|
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
|
||||||
ShapeUtil::MakeShape(U32, {})));
|
<< "HloInstruction partition-id must have a shape of u32[], but "
|
||||||
|
<< shape.ToString() << " is specified";
|
||||||
|
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
|
||||||
@ -1799,13 +1804,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
break;
|
break;
|
||||||
case HloOpcode::kReplicaId:
|
case HloOpcode::kReplicaId:
|
||||||
CHECK_EQ(new_operands.size(), 0);
|
CHECK_EQ(new_operands.size(), 0);
|
||||||
clone = CreateReplicaId();
|
clone = CreateReplicaId(shape);
|
||||||
*clone->mutable_shape() = shape;
|
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kPartitionId:
|
case HloOpcode::kPartitionId:
|
||||||
CHECK_EQ(new_operands.size(), 0);
|
CHECK_EQ(new_operands.size(), 0);
|
||||||
clone = CreatePartitionId();
|
clone = CreatePartitionId(shape);
|
||||||
*clone->mutable_shape() = shape;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// SetupDerivedInstruction will setup the precision_config_ field.
|
// SetupDerivedInstruction will setup the precision_config_ field.
|
||||||
|
@ -706,10 +706,12 @@ class HloInstruction {
|
|||||||
const absl::optional<int64>& channel_id);
|
const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
// Creates an instruction that returns a U32 replica ID.
|
// Creates an instruction that returns a U32 replica ID.
|
||||||
static std::unique_ptr<HloInstruction> CreateReplicaId();
|
static std::unique_ptr<HloInstruction> CreateReplicaId(
|
||||||
|
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
|
||||||
|
|
||||||
// Creates an instruction that returns a U32 partition ID.
|
// Creates an instruction that returns a U32 partition ID.
|
||||||
static std::unique_ptr<HloInstruction> CreatePartitionId();
|
static std::unique_ptr<HloInstruction> CreatePartitionId(
|
||||||
|
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
|
||||||
|
|
||||||
// Creates a conversion instruction, where operand is the data to convert and
|
// Creates a conversion instruction, where operand is the data to convert and
|
||||||
// shape is the target shape for the conversion.
|
// shape is the target shape for the conversion.
|
||||||
|
@ -646,6 +646,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloInstructionProto HloAllGatherInstruction::ToProto() const {
|
HloInstructionProto HloAllGatherInstruction::ToProto() const {
|
||||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
||||||
proto.add_dimensions(all_gather_dimension_);
|
proto.add_dimensions(all_gather_dimension_);
|
||||||
|
proto.set_use_global_device_ids(use_global_device_ids_);
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,12 +38,14 @@ HloProto MakeHloProto(const HloModule& module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||||
|
bool is_module_post_optimizations) {
|
||||||
VLOG(4) << proto.ShortDebugString();
|
VLOG(4) << proto.ShortDebugString();
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||||
HloModule::CreateFromProto(proto, module_config));
|
HloModule::CreateFromProto(proto, module_config));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
|
HloVerifier(/*layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision=*/is_module_post_optimizations)
|
||||||
.Run(module.get())
|
.Run(module.get())
|
||||||
.status());
|
.status());
|
||||||
return std::move(module);
|
return std::move(module);
|
||||||
|
@ -38,8 +38,12 @@ HloProto MakeHloProto(const HloModule& module);
|
|||||||
// Create an HLO state from serialized representation. In addition to
|
// Create an HLO state from serialized representation. In addition to
|
||||||
// creating the proto with HloModule::CreateFromProto(...) it also
|
// creating the proto with HloModule::CreateFromProto(...) it also
|
||||||
// uses HloVerifier to ensure basic invariants are held.
|
// uses HloVerifier to ensure basic invariants are held.
|
||||||
|
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
||||||
|
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
||||||
|
// is allowed in post-optimizations HLOs.
|
||||||
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||||
const HloModuleProto& proto, const HloModuleConfig& module_config);
|
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||||
|
bool is_module_post_optimizations = false);
|
||||||
|
|
||||||
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
||||||
// refer to shapes inside of the given HloProto.
|
// refer to shapes inside of the given HloProto.
|
||||||
|
@ -193,7 +193,8 @@ LocalService::CompileExecutables(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<Executable> executable,
|
std::unique_ptr<Executable> executable,
|
||||||
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
||||||
executor, build_options.device_allocator()));
|
executor, build_options.device_allocator(),
|
||||||
|
build_options.run_backend_only()));
|
||||||
std::vector<std::unique_ptr<Executable>> executables;
|
std::vector<std::unique_ptr<Executable>> executables;
|
||||||
executables.push_back(std::move(executable));
|
executables.push_back(std::move(executable));
|
||||||
return executables;
|
return executables;
|
||||||
@ -207,7 +208,8 @@ LocalService::CompileExecutables(
|
|||||||
|
|
||||||
return BuildExecutables({&proto}, std::move(module_configs),
|
return BuildExecutables({&proto}, std::move(module_configs),
|
||||||
execute_backend_.get(), {executors},
|
execute_backend_.get(), {executors},
|
||||||
build_options.device_allocator());
|
build_options.device_allocator(),
|
||||||
|
build_options.run_backend_only());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
|||||||
const std::vector<const HloModuleProto*>& module_protos,
|
const std::vector<const HloModuleProto*>& module_protos,
|
||||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||||
se::DeviceMemoryAllocator* device_allocator) {
|
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
|
||||||
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
||||||
|
|
||||||
// Dump computation proto state if flag is set.
|
// Dump computation proto state if flag is set.
|
||||||
@ -379,15 +379,28 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
|||||||
for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
|
for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
|
||||||
const HloModuleProto* proto = module_protos[i];
|
const HloModuleProto* proto = module_protos[i];
|
||||||
const HloModuleConfig& config = *module_configs[i];
|
const HloModuleConfig& config = *module_configs[i];
|
||||||
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto module, CreateModuleFromProto(*proto, config, run_backend_only));
|
||||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||||
module_group->push_back(std::move(module));
|
module_group->push_back(std::move(module));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<Executable>> executables;
|
||||||
|
if (!run_backend_only) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<Executable>> executables,
|
executables,
|
||||||
backend->compiler()->Compile(std::move(module_group),
|
backend->compiler()->Compile(std::move(module_group),
|
||||||
std::move(executors), device_allocator));
|
std::move(executors), device_allocator));
|
||||||
|
} else {
|
||||||
|
auto modules = module_group->ConsumeModules();
|
||||||
|
for (std::unique_ptr<HloModule>& module : modules) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<Executable> executable,
|
||||||
|
backend->compiler()->RunBackend(std::move(module), executors[0][0],
|
||||||
|
device_allocator));
|
||||||
|
executables.push_back(std::move(executable));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < module_protos.size(); ++i) {
|
for (size_t i = 0; i < module_protos.size(); ++i) {
|
||||||
const auto& debug_opts = module_configs[i]->debug_options();
|
const auto& debug_opts = module_configs[i]->debug_options();
|
||||||
@ -797,18 +810,22 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
|||||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||||
const HloModuleProto& module_proto,
|
const HloModuleProto& module_proto,
|
||||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||||
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) {
|
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
|
||||||
|
bool run_backend_only) {
|
||||||
VLOG(1) << StrFormat(
|
VLOG(1) << StrFormat(
|
||||||
"BuildExecutable on service %p with serialized module proto: %s", this,
|
"BuildExecutable on service %p with serialized module proto: %s", this,
|
||||||
module_proto.name());
|
module_proto.name());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
TF_ASSIGN_OR_RETURN(
|
||||||
CreateModuleFromProto(module_proto, *module_config));
|
std::unique_ptr<HloModule> module,
|
||||||
|
CreateModuleFromProto(module_proto, *module_config, run_backend_only));
|
||||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||||
|
|
||||||
|
if (!run_backend_only) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
||||||
device_allocator));
|
device_allocator));
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
backend->compiler()->RunBackend(
|
backend->compiler()->RunBackend(
|
||||||
|
@ -236,7 +236,8 @@ class Service : public ServiceInterface {
|
|||||||
const HloModuleProto& module_proto,
|
const HloModuleProto& module_proto,
|
||||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||||
se::StreamExecutor* executor,
|
se::StreamExecutor* executor,
|
||||||
se::DeviceMemoryAllocator* device_allocator = nullptr);
|
se::DeviceMemoryAllocator* device_allocator = nullptr,
|
||||||
|
bool run_backend_only = false);
|
||||||
|
|
||||||
// Same as BuildExecutable() above, but builds a list of Executables for the
|
// Same as BuildExecutable() above, but builds a list of Executables for the
|
||||||
// given computations that may interact with each other.
|
// given computations that may interact with each other.
|
||||||
@ -244,7 +245,8 @@ class Service : public ServiceInterface {
|
|||||||
const std::vector<const HloModuleProto*>& module_protos,
|
const std::vector<const HloModuleProto*>& module_protos,
|
||||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||||
se::DeviceMemoryAllocator* device_allocator);
|
se::DeviceMemoryAllocator* device_allocator,
|
||||||
|
bool run_backend_only = false);
|
||||||
|
|
||||||
// Runs the given executable with the given arguments and register the result
|
// Runs the given executable with the given arguments and register the result
|
||||||
// in the allocation tracker. The handle of the result from the tracker is
|
// in the allocation tracker. The handle of the result from the tracker is
|
||||||
|
Loading…
x
Reference in New Issue
Block a user