Save the numbers of replicas and logical cores in HLO proto. This will allow us to reproduce the same compilation process when there is data and model parallelism.
PiperOrigin-RevId: 273001189
This commit is contained in:
parent
33820a9e52
commit
38a28a71c6
@ -67,24 +67,21 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
const AotCompilationOptions& options,
|
const AotCompilationOptions& options,
|
||||||
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
||||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||||
|
|
||||||
|
const DebugOptions& debug_options = options.debug_options();
|
||||||
|
ExecutionOptions execution_options;
|
||||||
|
*execution_options.mutable_debug_options() = debug_options;
|
||||||
|
|
||||||
for (const AotXlaComputationInstance& instance : computations) {
|
for (const AotXlaComputationInstance& instance : computations) {
|
||||||
TF_RET_CHECK(instance.computation.has_host_program_shape());
|
TF_RET_CHECK(instance.computation.has_host_program_shape());
|
||||||
|
|
||||||
const DebugOptions& debug_options = options.debug_options();
|
|
||||||
ExecutionOptions execution_options;
|
|
||||||
*execution_options.mutable_debug_options() = debug_options;
|
|
||||||
*execution_options.mutable_shape_with_output_layout() =
|
*execution_options.mutable_shape_with_output_layout() =
|
||||||
instance.result_layout->ToProto();
|
instance.result_layout->ToProto();
|
||||||
if (options.has_static_device_assignment()) {
|
|
||||||
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
|
|
||||||
execution_options.mutable_device_assignment()));
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(
|
CreateModuleConfig(
|
||||||
ProgramShape(instance.computation.host_program_shape()),
|
ProgramShape(instance.computation.host_program_shape()),
|
||||||
instance.argument_layouts, &execution_options,
|
instance.argument_layouts, &execution_options, &options));
|
||||||
options.fusion_config_collection(), options.fusion_config()));
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
@ -93,6 +90,29 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
hlo_modules.push_back(std::move(hlo_module));
|
hlo_modules.push_back(std::move(hlo_module));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture replica_count, num_cores, and device_assignment in ExecutionOptions
|
||||||
|
// to save in a proto dump.
|
||||||
|
if (options.replica_count() > 0) {
|
||||||
|
execution_options.set_num_replicas(options.replica_count());
|
||||||
|
if (options.has_static_device_assignment()) {
|
||||||
|
CHECK_EQ(options.replica_count(),
|
||||||
|
options.static_device_assignment().replica_count());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (options.num_cores() > 0) {
|
||||||
|
execution_options.set_num_partitions(options.num_cores());
|
||||||
|
if (options.has_static_device_assignment()) {
|
||||||
|
CHECK_EQ(options.num_cores(),
|
||||||
|
options.static_device_assignment().computation_count());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (options.has_static_device_assignment()) {
|
||||||
|
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
|
||||||
|
execution_options.mutable_device_assignment()));
|
||||||
|
}
|
||||||
|
execution_options.clear_shape_with_output_layout();
|
||||||
|
DumpExecutionOptions(execution_options, debug_options);
|
||||||
|
|
||||||
return compiler_->CompileAheadOfTime(
|
return compiler_->CompileAheadOfTime(
|
||||||
absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(),
|
absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(),
|
||||||
absl::MakeSpan(hlo_modules)),
|
absl::MakeSpan(hlo_modules)),
|
||||||
|
@ -73,6 +73,9 @@ class AotCompilationOptions {
|
|||||||
// Returns the ID of the platform to which these options apply.
|
// Returns the ID of the platform to which these options apply.
|
||||||
virtual se::Platform::Id PlatformId() const = 0;
|
virtual se::Platform::Id PlatformId() const = 0;
|
||||||
|
|
||||||
|
virtual int64 replica_count() const { return 0; }
|
||||||
|
virtual int64 num_cores() const { return 0; }
|
||||||
|
|
||||||
// Optional allocator that may be used for allocating temp space on the device
|
// Optional allocator that may be used for allocating temp space on the device
|
||||||
// during compilation.
|
// during compilation.
|
||||||
se::DeviceMemoryAllocator* device_allocator() const {
|
se::DeviceMemoryAllocator* device_allocator() const {
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
@ -276,6 +277,23 @@ void DumpToFileInDirOrStdout(const HloModule& module, string_view suffix,
|
|||||||
CanonicalDebugOptions(module.config().debug_options()));
|
CanonicalDebugOptions(module.config().debug_options()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DumpExecutionOptions(const ExecutionOptions& execution_options,
|
||||||
|
const DebugOptions& debug_options) {
|
||||||
|
CanonicalDebugOptions opts(debug_options);
|
||||||
|
tensorflow::Env* env = tensorflow::Env::Default();
|
||||||
|
const string& dir = opts.dump_to;
|
||||||
|
if (env->IsDirectory(dir).ok()) {
|
||||||
|
string filename = tensorflow::io::JoinPath(dir, "execution_options");
|
||||||
|
if (opts.dump_as_text) {
|
||||||
|
TF_CHECK_OK(tensorflow::WriteTextProto(
|
||||||
|
env, absl::StrCat(filename, ".txt"), execution_options));
|
||||||
|
} else {
|
||||||
|
TF_CHECK_OK(tensorflow::WriteBinaryProto(
|
||||||
|
env, absl::StrCat(filename, ".pb"), execution_options));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DumpHloModuleIfEnabled(const HloModule& module, string_view name) {
|
void DumpHloModuleIfEnabled(const HloModule& module, string_view name) {
|
||||||
CanonicalDebugOptions opts(module.config().debug_options());
|
CanonicalDebugOptions opts(module.config().debug_options());
|
||||||
if (opts.should_dump_module(module.name())) {
|
if (opts.should_dump_module(module.name())) {
|
||||||
|
@ -49,6 +49,11 @@ void DumpToFileInDirOrStdout(const HloModule& module,
|
|||||||
absl::string_view file_suffix,
|
absl::string_view file_suffix,
|
||||||
absl::string_view contents);
|
absl::string_view contents);
|
||||||
|
|
||||||
|
// Dumps the given execution options if dumping is enabled. Exactly
|
||||||
|
// where and in what formats it's dumped is determined by the debug options.
|
||||||
|
void DumpExecutionOptions(const ExecutionOptions& execution_options,
|
||||||
|
const DebugOptions& debug_options);
|
||||||
|
|
||||||
// Dumps the given HLO module if dumping is enabled for the module. Exactly
|
// Dumps the given HLO module if dumping is enabled for the module. Exactly
|
||||||
// where and in what formats it's dumped is determined by the module's config.
|
// where and in what formats it's dumped is determined by the module's config.
|
||||||
//
|
//
|
||||||
|
@ -1014,8 +1014,14 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
|
|||||||
<< operand->ToString() << ", used by " << instr->ToString();
|
<< operand->ToString() << ", used by " << instr->ToString();
|
||||||
new_operands.push_back(context->GetInstruction(replaced_operand));
|
new_operands.push_back(context->GetInstruction(replaced_operand));
|
||||||
}
|
}
|
||||||
instructions.push_back(
|
std::unique_ptr<HloInstruction> new_instr =
|
||||||
instr->CloneWithNewOperands(instr->shape(), new_operands, context));
|
instr->CloneWithNewOperands(instr->shape(), new_operands, context);
|
||||||
|
if (instr->opcode() == HloOpcode::kParameter &&
|
||||||
|
instr->parameter_replicated_at_leaf_buffers().has_value()) {
|
||||||
|
new_instr->set_parameter_replicated_at_leaf_buffers(
|
||||||
|
instr->parameter_replicated_at_leaf_buffers().value());
|
||||||
|
}
|
||||||
|
instructions.push_back(std::move(new_instr));
|
||||||
}
|
}
|
||||||
Builder builder(name() + "." + suffix);
|
Builder builder(name() + "." + suffix);
|
||||||
for (auto& instr : instructions) {
|
for (auto& instr : instructions) {
|
||||||
|
@ -383,14 +383,33 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
|
||||||
const HloModuleProto& module, const DebugOptions& debug_options) {
|
const ProgramShape& program_shape, const DebugOptions& debug_options,
|
||||||
TF_RET_CHECK(module.has_host_program_shape())
|
const ExecutionOptions* execution_options) {
|
||||||
<< "No program shape found in the proto";
|
|
||||||
ProgramShape program_shape(module.host_program_shape());
|
|
||||||
|
|
||||||
HloModuleConfig module_config(ProgramShape{program_shape});
|
HloModuleConfig module_config(ProgramShape{program_shape});
|
||||||
module_config.set_debug_options(debug_options);
|
module_config.set_debug_options(debug_options);
|
||||||
|
if (execution_options) {
|
||||||
|
if (execution_options->num_replicas() > 0) {
|
||||||
|
module_config.set_replica_count(execution_options->num_replicas());
|
||||||
|
}
|
||||||
|
if (execution_options->num_partitions() > 0) {
|
||||||
|
module_config.set_num_partitions(execution_options->num_partitions());
|
||||||
|
}
|
||||||
|
if (execution_options->has_device_assignment()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
|
||||||
|
DeviceAssignment::Deserialize(
|
||||||
|
execution_options->device_assignment()));
|
||||||
|
module_config.set_static_device_assignment(*device_assignment);
|
||||||
|
if (execution_options->num_replicas() > 0) {
|
||||||
|
CHECK_EQ(module_config.static_device_assignment().replica_count(),
|
||||||
|
module_config.replica_count());
|
||||||
|
}
|
||||||
|
if (execution_options->num_partitions() > 0) {
|
||||||
|
CHECK_EQ(module_config.static_device_assignment().computation_count(),
|
||||||
|
module_config.num_partitions());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// The module config is constructed with default layouts regardless of what is
|
// The module config is constructed with default layouts regardless of what is
|
||||||
// passed in via the ProgramShape. Set the layouts to the appropriate values.
|
// passed in via the ProgramShape. Set the layouts to the appropriate values.
|
||||||
@ -406,6 +425,17 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
|||||||
return module_config;
|
return module_config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||||
|
const HloModuleProto& module, const DebugOptions& debug_options,
|
||||||
|
const ExecutionOptions* execution_options) {
|
||||||
|
TF_RET_CHECK(module.has_host_program_shape())
|
||||||
|
<< "No program shape found in the proto";
|
||||||
|
ProgramShape program_shape(module.host_program_shape());
|
||||||
|
return CreateModuleConfigFromShape(program_shape, debug_options,
|
||||||
|
execution_options);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Returns whether `hlo` is used outside the given subcomputation.
|
// Returns whether `hlo` is used outside the given subcomputation.
|
||||||
// `instructions_in_subcomputation` is the instruction set of the given
|
// `instructions_in_subcomputation` is the instruction set of the given
|
||||||
|
@ -221,7 +221,14 @@ class HloModule {
|
|||||||
// Creates and returns an HloModuleConfig with an appropriate program shape
|
// Creates and returns an HloModuleConfig with an appropriate program shape
|
||||||
// for the HLO module in the given proto.
|
// for the HLO module in the given proto.
|
||||||
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
|
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
|
||||||
const HloModuleProto& module, const DebugOptions& debug_options);
|
const HloModuleProto& module, const DebugOptions& debug_options,
|
||||||
|
const ExecutionOptions* execution_options = nullptr);
|
||||||
|
|
||||||
|
// Creates and returns an HloModuleConfig with an appropriate program shape
|
||||||
|
// for the HLO module in the given proto.
|
||||||
|
static StatusOr<HloModuleConfig> CreateModuleConfigFromShape(
|
||||||
|
const ProgramShape& program_shape, const DebugOptions& debug_options,
|
||||||
|
const ExecutionOptions* execution_options = nullptr);
|
||||||
|
|
||||||
// Outlines the given expression from the given computation.
|
// Outlines the given expression from the given computation.
|
||||||
// instructions_to_outline contains the instructions that form the expression.
|
// instructions_to_outline contains the instructions that form the expression.
|
||||||
|
@ -113,6 +113,11 @@ class HloModuleConfig {
|
|||||||
}
|
}
|
||||||
int64 replica_count() const { return replica_count_; }
|
int64 replica_count() const { return replica_count_; }
|
||||||
|
|
||||||
|
void set_num_partitions(int64 num_partitions) {
|
||||||
|
num_partitions_ = num_partitions;
|
||||||
|
}
|
||||||
|
int64 num_partitions() const { return num_partitions_; }
|
||||||
|
|
||||||
// Return a string which unambiguously represents all the fields of this data
|
// Return a string which unambiguously represents all the fields of this data
|
||||||
// structure. Used for generating a cache key for storing the compiled
|
// structure. Used for generating a cache key for storing the compiled
|
||||||
// executable.
|
// executable.
|
||||||
@ -186,9 +191,12 @@ class HloModuleConfig {
|
|||||||
// Module/graph-level seed handle.
|
// Module/graph-level seed handle.
|
||||||
uint64 seed_ = 0;
|
uint64 seed_ = 0;
|
||||||
|
|
||||||
// The number of replicas to compile this binary for.
|
// The number of replicas (data parallelism) to compile this binary for.
|
||||||
int64 replica_count_ = 1;
|
int64 replica_count_ = 1;
|
||||||
|
|
||||||
|
// The number of partitions (model parallelism) to compile this binary for.
|
||||||
|
int64 num_partitions_ = 1;
|
||||||
|
|
||||||
// The target maximum parallelism at which to partition HLOs for parallel
|
// The target maximum parallelism at which to partition HLOs for parallel
|
||||||
// execution on the CPU backend.
|
// execution on the CPU backend.
|
||||||
int64 intra_op_parallelism_threads_ = -1;
|
int64 intra_op_parallelism_threads_ = -1;
|
||||||
|
@ -266,8 +266,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const Shape* const> argument_shapes,
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
const ExecutionOptions* execution_options,
|
const ExecutionOptions* execution_options,
|
||||||
FusionConfigCollection fusion_config_collection,
|
const AotCompilationOptions* aot_options) {
|
||||||
const std::vector<std::vector<bool>>& fusion_config) {
|
|
||||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||||
ComputationLayout* computation_layout =
|
ComputationLayout* computation_layout =
|
||||||
config->mutable_entry_computation_layout();
|
config->mutable_entry_computation_layout();
|
||||||
@ -311,6 +310,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
} else {
|
} else {
|
||||||
config->set_replica_count(options_.number_of_replicas());
|
config->set_replica_count(options_.number_of_replicas());
|
||||||
}
|
}
|
||||||
|
if (execution_options->num_partitions() > 0) {
|
||||||
|
config->set_num_partitions(execution_options->num_partitions());
|
||||||
|
}
|
||||||
config->set_seed(execution_options->seed());
|
config->set_seed(execution_options->seed());
|
||||||
config->set_debug_options(execution_options->debug_options());
|
config->set_debug_options(execution_options->debug_options());
|
||||||
} else {
|
} else {
|
||||||
@ -334,9 +336,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
config->set_alias_passthrough_params(
|
config->set_alias_passthrough_params(
|
||||||
execution_options->alias_passthrough_params());
|
execution_options->alias_passthrough_params());
|
||||||
|
|
||||||
if (fusion_config_collection != FusionConfigCollection::kOff) {
|
if (aot_options != nullptr &&
|
||||||
config->set_fusion_config_collection(fusion_config_collection);
|
aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
|
||||||
*config->mutable_fusion_config() = fusion_config;
|
config->set_fusion_config_collection(
|
||||||
|
aot_options->fusion_config_collection());
|
||||||
|
*config->mutable_fusion_config() = aot_options->fusion_config();
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::move(config);
|
return std::move(config);
|
||||||
@ -345,12 +349,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
const ExecutionOptions& execution_options) {
|
const ExecutionOptions& execution_options,
|
||||||
|
const AotCompilationOptions* aot_options) {
|
||||||
std::vector<const Shape*> argument_shapes;
|
std::vector<const Shape*> argument_shapes;
|
||||||
for (const auto* arg : arguments) {
|
for (const auto* arg : arguments) {
|
||||||
argument_shapes.push_back(&arg->on_host_shape());
|
argument_shapes.push_back(&arg->on_host_shape());
|
||||||
}
|
}
|
||||||
return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
|
return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
|
||||||
|
aot_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||||
|
@ -189,9 +189,7 @@ class Service : public ServiceInterface {
|
|||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const Shape* const> argument_shapes,
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
const ExecutionOptions* execution_options,
|
const ExecutionOptions* execution_options,
|
||||||
FusionConfigCollection fusion_config_collection =
|
const AotCompilationOptions* aot_options = nullptr);
|
||||||
FusionConfigCollection::kOff,
|
|
||||||
const std::vector<std::vector<bool>>& fusion_config = {});
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A private overload for Service itself, used by other methods within this
|
// A private overload for Service itself, used by other methods within this
|
||||||
@ -199,7 +197,8 @@ class Service : public ServiceInterface {
|
|||||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
absl::Span<const ShapedBuffer* const> arguments,
|
||||||
const ExecutionOptions& execution_options);
|
const ExecutionOptions& execution_options,
|
||||||
|
const AotCompilationOptions* aot_options = nullptr);
|
||||||
|
|
||||||
// Prepare the executors for executing parallel.
|
// Prepare the executors for executing parallel.
|
||||||
StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
|
StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
|
||||||
|
@ -339,6 +339,10 @@ message ExecutionOptions {
|
|||||||
// Alias input and output buffers for parameters that are passed-through XLA
|
// Alias input and output buffers for parameters that are passed-through XLA
|
||||||
// modules without being changed.
|
// modules without being changed.
|
||||||
bool alias_passthrough_params = 8;
|
bool alias_passthrough_params = 8;
|
||||||
|
|
||||||
|
// Number of partitions of the computation to run (model parallelism).
|
||||||
|
// If zero, uses the default number of partitions for the XLA service.
|
||||||
|
int32 num_partitions = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetDeviceHandlesRequest {
|
message GetDeviceHandlesRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user