Save the numbers of replicas and logical cores in HLO proto. This will allow us to reproduce the same compilation process when there is data and model parallelism.

PiperOrigin-RevId: 273001189
This commit is contained in:
A. Unique TensorFlower 2019-10-04 19:07:41 -07:00 committed by TensorFlower Gardener
parent 33820a9e52
commit 38a28a71c6
11 changed files with 137 additions and 31 deletions

View File

@ -67,24 +67,21 @@ CompileOnlyService::CompileAheadOfTime(
const AotCompilationOptions& options, const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) { std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules; std::vector<std::unique_ptr<HloModule>> hlo_modules;
const DebugOptions& debug_options = options.debug_options();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
for (const AotXlaComputationInstance& instance : computations) { for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_host_program_shape()); TF_RET_CHECK(instance.computation.has_host_program_shape());
const DebugOptions& debug_options = options.debug_options();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() = *execution_options.mutable_shape_with_output_layout() =
instance.result_layout->ToProto(); instance.result_layout->ToProto();
if (options.has_static_device_assignment()) {
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config, std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig( CreateModuleConfig(
ProgramShape(instance.computation.host_program_shape()), ProgramShape(instance.computation.host_program_shape()),
instance.argument_layouts, &execution_options, instance.argument_layouts, &execution_options, &options));
options.fusion_config_collection(), options.fusion_config()));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module, std::unique_ptr<HloModule> hlo_module,
@ -93,6 +90,29 @@ CompileOnlyService::CompileAheadOfTime(
hlo_modules.push_back(std::move(hlo_module)); hlo_modules.push_back(std::move(hlo_module));
} }
// Capture replica_count, num_cores, and device_assignment in ExecutionOptions
// to save in a proto dump.
if (options.replica_count() > 0) {
execution_options.set_num_replicas(options.replica_count());
if (options.has_static_device_assignment()) {
CHECK_EQ(options.replica_count(),
options.static_device_assignment().replica_count());
}
}
if (options.num_cores() > 0) {
execution_options.set_num_partitions(options.num_cores());
if (options.has_static_device_assignment()) {
CHECK_EQ(options.num_cores(),
options.static_device_assignment().computation_count());
}
}
if (options.has_static_device_assignment()) {
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
execution_options.clear_shape_with_output_layout();
DumpExecutionOptions(execution_options, debug_options);
return compiler_->CompileAheadOfTime( return compiler_->CompileAheadOfTime(
absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(), absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(),
absl::MakeSpan(hlo_modules)), absl::MakeSpan(hlo_modules)),

View File

@ -73,6 +73,9 @@ class AotCompilationOptions {
// Returns the ID of the platform to which these options apply. // Returns the ID of the platform to which these options apply.
virtual se::Platform::Id PlatformId() const = 0; virtual se::Platform::Id PlatformId() const = 0;
virtual int64 replica_count() const { return 0; }
virtual int64 num_cores() const { return 0; }
// Optional allocator that may be used for allocating temp space on the device // Optional allocator that may be used for allocating temp space on the device
// during compilation. // during compilation.
se::DeviceMemoryAllocator* device_allocator() const { se::DeviceMemoryAllocator* device_allocator() const {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
@ -276,6 +277,23 @@ void DumpToFileInDirOrStdout(const HloModule& module, string_view suffix,
CanonicalDebugOptions(module.config().debug_options())); CanonicalDebugOptions(module.config().debug_options()));
} }
void DumpExecutionOptions(const ExecutionOptions& execution_options,
const DebugOptions& debug_options) {
CanonicalDebugOptions opts(debug_options);
tensorflow::Env* env = tensorflow::Env::Default();
const string& dir = opts.dump_to;
if (env->IsDirectory(dir).ok()) {
string filename = tensorflow::io::JoinPath(dir, "execution_options");
if (opts.dump_as_text) {
TF_CHECK_OK(tensorflow::WriteTextProto(
env, absl::StrCat(filename, ".txt"), execution_options));
} else {
TF_CHECK_OK(tensorflow::WriteBinaryProto(
env, absl::StrCat(filename, ".pb"), execution_options));
}
}
}
void DumpHloModuleIfEnabled(const HloModule& module, string_view name) { void DumpHloModuleIfEnabled(const HloModule& module, string_view name) {
CanonicalDebugOptions opts(module.config().debug_options()); CanonicalDebugOptions opts(module.config().debug_options());
if (opts.should_dump_module(module.name())) { if (opts.should_dump_module(module.name())) {

View File

@ -49,6 +49,11 @@ void DumpToFileInDirOrStdout(const HloModule& module,
absl::string_view file_suffix, absl::string_view file_suffix,
absl::string_view contents); absl::string_view contents);
// Dumps the given execution options if dumping is enabled. Exactly
// where and in what formats it's dumped is determined by the debug options.
void DumpExecutionOptions(const ExecutionOptions& execution_options,
const DebugOptions& debug_options);
// Dumps the given HLO module if dumping is enabled for the module. Exactly // Dumps the given HLO module if dumping is enabled for the module. Exactly
// where and in what formats it's dumped is determined by the module's config. // where and in what formats it's dumped is determined by the module's config.
// //

View File

@ -1014,8 +1014,14 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
<< operand->ToString() << ", used by " << instr->ToString(); << operand->ToString() << ", used by " << instr->ToString();
new_operands.push_back(context->GetInstruction(replaced_operand)); new_operands.push_back(context->GetInstruction(replaced_operand));
} }
instructions.push_back( std::unique_ptr<HloInstruction> new_instr =
instr->CloneWithNewOperands(instr->shape(), new_operands, context)); instr->CloneWithNewOperands(instr->shape(), new_operands, context);
if (instr->opcode() == HloOpcode::kParameter &&
instr->parameter_replicated_at_leaf_buffers().has_value()) {
new_instr->set_parameter_replicated_at_leaf_buffers(
instr->parameter_replicated_at_leaf_buffers().value());
}
instructions.push_back(std::move(new_instr));
} }
Builder builder(name() + "." + suffix); Builder builder(name() + "." + suffix);
for (auto& instr : instructions) { for (auto& instr : instructions) {

View File

@ -383,14 +383,33 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
} }
/* static */ /* static */
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto( StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
const HloModuleProto& module, const DebugOptions& debug_options) { const ProgramShape& program_shape, const DebugOptions& debug_options,
TF_RET_CHECK(module.has_host_program_shape()) const ExecutionOptions* execution_options) {
<< "No program shape found in the proto";
ProgramShape program_shape(module.host_program_shape());
HloModuleConfig module_config(ProgramShape{program_shape}); HloModuleConfig module_config(ProgramShape{program_shape});
module_config.set_debug_options(debug_options); module_config.set_debug_options(debug_options);
if (execution_options) {
if (execution_options->num_replicas() > 0) {
module_config.set_replica_count(execution_options->num_replicas());
}
if (execution_options->num_partitions() > 0) {
module_config.set_num_partitions(execution_options->num_partitions());
}
if (execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
DeviceAssignment::Deserialize(
execution_options->device_assignment()));
module_config.set_static_device_assignment(*device_assignment);
if (execution_options->num_replicas() > 0) {
CHECK_EQ(module_config.static_device_assignment().replica_count(),
module_config.replica_count());
}
if (execution_options->num_partitions() > 0) {
CHECK_EQ(module_config.static_device_assignment().computation_count(),
module_config.num_partitions());
}
}
}
// The module config is constructed with default layouts regardless of what is // The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values. // passed in via the ProgramShape. Set the layouts to the appropriate values.
@ -406,6 +425,17 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
return module_config; return module_config;
} }
/* static */
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
const HloModuleProto& module, const DebugOptions& debug_options,
const ExecutionOptions* execution_options) {
TF_RET_CHECK(module.has_host_program_shape())
<< "No program shape found in the proto";
ProgramShape program_shape(module.host_program_shape());
return CreateModuleConfigFromShape(program_shape, debug_options,
execution_options);
}
namespace { namespace {
// Returns whether `hlo` is used outside the given subcomputation. // Returns whether `hlo` is used outside the given subcomputation.
// `instructions_in_subcomputation` is the instruction set of the given // `instructions_in_subcomputation` is the instruction set of the given

View File

@ -221,7 +221,14 @@ class HloModule {
// Creates and returns an HloModuleConfig with an appropriate program shape // Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto. // for the HLO module in the given proto.
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
const HloModuleProto& module, const DebugOptions& debug_options); const HloModuleProto& module, const DebugOptions& debug_options,
const ExecutionOptions* execution_options = nullptr);
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.
static StatusOr<HloModuleConfig> CreateModuleConfigFromShape(
const ProgramShape& program_shape, const DebugOptions& debug_options,
const ExecutionOptions* execution_options = nullptr);
// Outlines the given expression from the given computation. // Outlines the given expression from the given computation.
// instructions_to_outline contains the instructions that form the expression. // instructions_to_outline contains the instructions that form the expression.

View File

@ -113,6 +113,11 @@ class HloModuleConfig {
} }
int64 replica_count() const { return replica_count_; } int64 replica_count() const { return replica_count_; }
void set_num_partitions(int64 num_partitions) {
num_partitions_ = num_partitions;
}
int64 num_partitions() const { return num_partitions_; }
// Return a string which unambiguously represents all the fields of this data // Return a string which unambiguously represents all the fields of this data
// structure. Used for generating a cache key for storing the compiled // structure. Used for generating a cache key for storing the compiled
// executable. // executable.
@ -186,9 +191,12 @@ class HloModuleConfig {
// Module/graph-level seed handle. // Module/graph-level seed handle.
uint64 seed_ = 0; uint64 seed_ = 0;
// The number of replicas to compile this binary for. // The number of replicas (data parallelism) to compile this binary for.
int64 replica_count_ = 1; int64 replica_count_ = 1;
// The number of partitions (model parallelism) to compile this binary for.
int64 num_partitions_ = 1;
// The target maximum parallelism at which to partition HLOs for parallel // The target maximum parallelism at which to partition HLOs for parallel
// execution on the CPU backend. // execution on the CPU backend.
int64 intra_op_parallelism_threads_ = -1; int64 intra_op_parallelism_threads_ = -1;

View File

@ -266,8 +266,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape, const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes, absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options, const ExecutionOptions* execution_options,
FusionConfigCollection fusion_config_collection, const AotCompilationOptions* aot_options) {
const std::vector<std::vector<bool>>& fusion_config) {
auto config = absl::make_unique<HloModuleConfig>(program_shape); auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout = ComputationLayout* computation_layout =
config->mutable_entry_computation_layout(); config->mutable_entry_computation_layout();
@ -311,6 +310,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
} else { } else {
config->set_replica_count(options_.number_of_replicas()); config->set_replica_count(options_.number_of_replicas());
} }
if (execution_options->num_partitions() > 0) {
config->set_num_partitions(execution_options->num_partitions());
}
config->set_seed(execution_options->seed()); config->set_seed(execution_options->seed());
config->set_debug_options(execution_options->debug_options()); config->set_debug_options(execution_options->debug_options());
} else { } else {
@ -334,9 +336,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
config->set_alias_passthrough_params( config->set_alias_passthrough_params(
execution_options->alias_passthrough_params()); execution_options->alias_passthrough_params());
if (fusion_config_collection != FusionConfigCollection::kOff) { if (aot_options != nullptr &&
config->set_fusion_config_collection(fusion_config_collection); aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
*config->mutable_fusion_config() = fusion_config; config->set_fusion_config_collection(
aot_options->fusion_config_collection());
*config->mutable_fusion_config() = aot_options->fusion_config();
} }
return std::move(config); return std::move(config);
@ -345,12 +349,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape, const ProgramShape& program_shape,
absl::Span<const ShapedBuffer* const> arguments, absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options) { const ExecutionOptions& execution_options,
const AotCompilationOptions* aot_options) {
std::vector<const Shape*> argument_shapes; std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) { for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape()); argument_shapes.push_back(&arg->on_host_shape());
} }
return CreateModuleConfig(program_shape, argument_shapes, &execution_options); return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
aot_options);
} }
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(

View File

@ -189,9 +189,7 @@ class Service : public ServiceInterface {
const ProgramShape& program_shape, const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes, absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options, const ExecutionOptions* execution_options,
FusionConfigCollection fusion_config_collection = const AotCompilationOptions* aot_options = nullptr);
FusionConfigCollection::kOff,
const std::vector<std::vector<bool>>& fusion_config = {});
private: private:
// A private overload for Service itself, used by other methods within this // A private overload for Service itself, used by other methods within this
@ -199,7 +197,8 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape, const ProgramShape& program_shape,
absl::Span<const ShapedBuffer* const> arguments, absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options); const ExecutionOptions& execution_options,
const AotCompilationOptions* aot_options = nullptr);
// Prepare the executors for executing parallel. // Prepare the executors for executing parallel.
StatusOr<std::vector<se::StreamExecutor*>> GetExecutors( StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(

View File

@ -339,6 +339,10 @@ message ExecutionOptions {
// Alias input and output buffers for parameters that are passed-through XLA // Alias input and output buffers for parameters that are passed-through XLA
// modules without being changed. // modules without being changed.
bool alias_passthrough_params = 8; bool alias_passthrough_params = 8;
// Number of partitions of the computation to run (model parallelism).
// If zero, uses the default number of partitions for the XLA service.
int32 num_partitions = 9;
} }
message GetDeviceHandlesRequest { message GetDeviceHandlesRequest {