With other internal changes, extend the TF XLA interface to support running post-optimizations HLO modules. Also fix a few bugs in XLA related to saving and loading HLO from protobuf.
PiperOrigin-RevId: 340702861 Change-Id: Ide5278d164cc09b5bf72d38e87b3a77bb07c85d1
This commit is contained in:
parent
cd0ef5f0a2
commit
7effe6da6f
@ -104,6 +104,17 @@ class ExecutableBuildOptions {
|
||||
alias_passthrough_params_ = alias_passthrough_params;
|
||||
}
|
||||
|
||||
bool run_backend_only() const { return run_backend_only_; }
|
||||
// By default, XLA builds an executable by invoking standard compilation, i.e,
|
||||
// running Compiler::Compile, or both Compiler::RunHloPasses and
|
||||
// Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
|
||||
// executable by invoking only RunBackend and skip invoking RunHloPasses,
|
||||
// which can be used to compile post-optimizations HLO modules.
|
||||
ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) {
|
||||
run_backend_only_ = run_backend_only;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_ordinal_ = -1;
|
||||
Shape result_layout_;
|
||||
@ -116,6 +127,7 @@ class ExecutableBuildOptions {
|
||||
bool deduplicate_hlo_ = false;
|
||||
absl::optional<DeviceAssignment> device_assignment_;
|
||||
bool alias_passthrough_params_ = false;
|
||||
bool run_backend_only_ = false;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -496,11 +496,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReplicaId: {
|
||||
instruction = CreateReplicaId();
|
||||
instruction = CreateReplicaId(shape);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kPartitionId: {
|
||||
instruction = CreatePartitionId();
|
||||
instruction = CreatePartitionId(shape);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
@ -1080,15 +1080,20 @@ HloInstruction::CreateCollectivePermuteStart(
|
||||
channel_id);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
|
||||
return absl::WrapUnique(
|
||||
new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
|
||||
const Shape& shape) {
|
||||
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
|
||||
<< "HloInstruction replica-id must have a shape of u32[], but "
|
||||
<< shape.ToString() << " is specified";
|
||||
return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreatePartitionId() {
|
||||
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId,
|
||||
ShapeUtil::MakeShape(U32, {})));
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
|
||||
const Shape& shape) {
|
||||
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
|
||||
<< "HloInstruction partition-id must have a shape of u32[], but "
|
||||
<< shape.ToString() << " is specified";
|
||||
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
|
||||
@ -1799,13 +1804,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
break;
|
||||
case HloOpcode::kReplicaId:
|
||||
CHECK_EQ(new_operands.size(), 0);
|
||||
clone = CreateReplicaId();
|
||||
*clone->mutable_shape() = shape;
|
||||
clone = CreateReplicaId(shape);
|
||||
break;
|
||||
case HloOpcode::kPartitionId:
|
||||
CHECK_EQ(new_operands.size(), 0);
|
||||
clone = CreatePartitionId();
|
||||
*clone->mutable_shape() = shape;
|
||||
clone = CreatePartitionId(shape);
|
||||
break;
|
||||
}
|
||||
// SetupDerivedInstruction will setup the precision_config_ field.
|
||||
|
@ -706,10 +706,12 @@ class HloInstruction {
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// Creates an instruction that returns a U32 replica ID.
|
||||
static std::unique_ptr<HloInstruction> CreateReplicaId();
|
||||
static std::unique_ptr<HloInstruction> CreateReplicaId(
|
||||
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
|
||||
|
||||
// Creates an instruction that returns a U32 partition ID.
|
||||
static std::unique_ptr<HloInstruction> CreatePartitionId();
|
||||
static std::unique_ptr<HloInstruction> CreatePartitionId(
|
||||
const Shape& shape = ShapeUtil::MakeShape(U32, {}));
|
||||
|
||||
// Creates a conversion instruction, where operand is the data to convert and
|
||||
// shape is the target shape for the conversion.
|
||||
|
@ -646,6 +646,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl(
|
||||
HloInstructionProto HloAllGatherInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
||||
proto.add_dimensions(all_gather_dimension_);
|
||||
proto.set_use_global_device_ids(use_global_device_ids_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
@ -38,12 +38,14 @@ HloProto MakeHloProto(const HloModule& module) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
bool is_module_post_optimizations) {
|
||||
VLOG(4) << proto.ShortDebugString();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(proto, module_config));
|
||||
TF_RETURN_IF_ERROR(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
|
||||
HloVerifier(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/is_module_post_optimizations)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
return std::move(module);
|
||||
|
@ -38,8 +38,12 @@ HloProto MakeHloProto(const HloModule& module);
|
||||
// Create an HLO state from serialized representation. In addition to
|
||||
// creating the proto with HloModule::CreateFromProto(...) it also
|
||||
// uses HloVerifier to ensure basic invariants are held.
|
||||
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
||||
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
||||
// is allowed in post-optimizations HLOs.
|
||||
StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config);
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
bool is_module_post_optimizations = false);
|
||||
|
||||
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
||||
// refer to shapes inside of the given HloProto.
|
||||
|
@ -193,7 +193,8 @@ LocalService::CompileExecutables(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
||||
executor, build_options.device_allocator()));
|
||||
executor, build_options.device_allocator(),
|
||||
build_options.run_backend_only()));
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
executables.push_back(std::move(executable));
|
||||
return executables;
|
||||
@ -207,7 +208,8 @@ LocalService::CompileExecutables(
|
||||
|
||||
return BuildExecutables({&proto}, std::move(module_configs),
|
||||
execute_backend_.get(), {executors},
|
||||
build_options.device_allocator());
|
||||
build_options.device_allocator(),
|
||||
build_options.run_backend_only());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
const std::vector<const HloModuleProto*>& module_protos,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
|
||||
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
||||
|
||||
// Dump computation proto state if flag is set.
|
||||
@ -379,15 +379,28 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
for (int64 i = 0, end = module_protos.size(); i < end; ++i) {
|
||||
const HloModuleProto* proto = module_protos[i];
|
||||
const HloModuleConfig& config = *module_configs[i];
|
||||
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, CreateModuleFromProto(*proto, config, run_backend_only));
|
||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||
module_group->push_back(std::move(module));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<Executable>> executables,
|
||||
backend->compiler()->Compile(std::move(module_group),
|
||||
std::move(executors), device_allocator));
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
if (!run_backend_only) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executables,
|
||||
backend->compiler()->Compile(std::move(module_group),
|
||||
std::move(executors), device_allocator));
|
||||
} else {
|
||||
auto modules = module_group->ConsumeModules();
|
||||
for (std::unique_ptr<HloModule>& module : modules) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(std::move(module), executors[0][0],
|
||||
device_allocator));
|
||||
executables.push_back(std::move(executable));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < module_protos.size(); ++i) {
|
||||
const auto& debug_opts = module_configs[i]->debug_options();
|
||||
@ -797,18 +810,22 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
const HloModuleProto& module_proto,
|
||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) {
|
||||
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
|
||||
bool run_backend_only) {
|
||||
VLOG(1) << StrFormat(
|
||||
"BuildExecutable on service %p with serialized module proto: %s", this,
|
||||
module_proto.name());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
CreateModuleFromProto(module_proto, *module_config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
CreateModuleFromProto(module_proto, *module_config, run_backend_only));
|
||||
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
||||
device_allocator));
|
||||
if (!run_backend_only) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
||||
device_allocator));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(
|
||||
|
@ -236,7 +236,8 @@ class Service : public ServiceInterface {
|
||||
const HloModuleProto& module_proto,
|
||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr);
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr,
|
||||
bool run_backend_only = false);
|
||||
|
||||
// Same as BuildExecutable() above, but builds a list of Executables for the
|
||||
// given computations that may interact with each other.
|
||||
@ -244,7 +245,8 @@ class Service : public ServiceInterface {
|
||||
const std::vector<const HloModuleProto*>& module_protos,
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
|
||||
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
|
||||
se::DeviceMemoryAllocator* device_allocator);
|
||||
se::DeviceMemoryAllocator* device_allocator,
|
||||
bool run_backend_only = false);
|
||||
|
||||
// Runs the given executable with the given arguments and register the result
|
||||
// in the allocation tracker. The handle of the result from the tracker is
|
||||
|
Loading…
x
Reference in New Issue
Block a user