Enables per-host dummy args for TPUExecute (TF1) and adds XLA options.

Enabling this logic removes cross-worker send/recv dependencies required for TPUExecuteOp nodes to access a model's variables. This decreases overhead at the start of a training loop.

The approach used is to replace remote variable reads with zero tensors on each worker, except for the primary worker. The zero tensors feed TPUExecute nodes that are local to that worker.  For large distributed systems with large variables, this removes the need for the initial Send/Recv variable broadcast, which can be expensive.

PiperOrigin-RevId: 351904109
Change-Id: I9f1ed63c2401f227646010a94a70c04f1c96cb7e
This commit is contained in:
Tayo Oguntebi 2021-01-14 16:45:45 -08:00 committed by TensorFlower Gardener
parent a18d8e6276
commit 6983bacea1
11 changed files with 175 additions and 10 deletions

View File

@ -83,6 +83,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_deduplicate_hlo(
return *this;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_broadcast_replicated_params(
bool broadcast_replicated_params) {
broadcast_replicated_params_ = broadcast_replicated_params;
return *this;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
const DeviceAssignment& device_assignment) {
device_assignment_ = device_assignment;

View File

@ -85,6 +85,12 @@ class ExecutableBuildOptions {
bool deduplicate_hlo() const { return deduplicate_hlo_; }
ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo);
bool broadcast_replicated_params() const {
return broadcast_replicated_params_;
}
ExecutableBuildOptions& set_broadcast_replicated_params(
bool broadcast_replicated_params);
// If set, this specifies a static device assignment for the computation.
// Otherwise, the computation will be compiled generically and can be run with
// any device assignment compatible with the computation's replica and
@ -135,6 +141,7 @@ class ExecutableBuildOptions {
int num_partitions_ = 1;
bool use_spmd_partitioning_ = false;
bool deduplicate_hlo_ = false;
bool broadcast_replicated_params_ = false;
absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false;
bool run_backend_only_ = false;

View File

@ -93,6 +93,8 @@ CompileOnlyService::CompileAheadOfTime(
}
execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning());
execution_options.set_deduplicate_hlo(options.deduplicate_hlo());
execution_options.set_broadcast_replicated_parameters_via_collectives(
options.broadcast_replicated_params());
for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_host_program_shape());
*execution_options.mutable_shape_with_output_layout() =

View File

@ -77,6 +77,7 @@ class AotCompilationOptions {
virtual int64 replica_count() const { return 0; }
virtual int64 num_cores() const { return 0; }
virtual bool broadcast_replicated_params() const { return false; }
virtual bool use_spmd_partitioning() const { return false; }
virtual bool deduplicate_hlo() const { return false; }

View File

@ -447,6 +447,8 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
module_config.set_use_spmd_partitioning(
execution_options->use_spmd_partitioning());
module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
module_config.set_broadcast_replicated_params(
execution_options->broadcast_replicated_parameters_via_collectives());
if (execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
DeviceAssignment::Deserialize(

View File

@ -133,6 +133,13 @@ class HloModuleConfig {
}
int64 num_partitions() const { return num_partitions_; }
void set_broadcast_replicated_params(bool broadcast_replicated_params) {
broadcast_replicated_params_ = broadcast_replicated_params;
}
bool broadcast_replicated_params() const {
return broadcast_replicated_params_;
}
void set_use_spmd_partitioning(bool use_spmd_partitioning) {
use_spmd_partitioning_ = use_spmd_partitioning;
}
@ -249,6 +256,9 @@ class HloModuleConfig {
// The number of partitions (model parallelism) to compile this binary for.
int64 num_partitions_ = 1;
// Whether to use XLA collectives to broadcast params to all replicas.
bool broadcast_replicated_params_ = false;
// Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
// needs to partition the module.
bool use_spmd_partitioning_ = false;

View File

@ -96,6 +96,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
config->set_use_spmd_partitioning(
execution_options->use_spmd_partitioning());
config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
config->set_broadcast_replicated_params(
execution_options->broadcast_replicated_parameters_via_collectives());
config->set_seed(execution_options->seed());
config->set_launch_id(execution_options->launch_id());
config->set_debug_options(execution_options->debug_options());

View File

@ -374,6 +374,10 @@ message ExecutionOptions {
// If set, deduplicate hlo into function calls to reduce binary size. Only
// works on TPU.
bool deduplicate_hlo = 12;
// If set, broadcast replicated parameters to all replicas, using collectives.
// Only applicable to TPU.
bool broadcast_replicated_parameters_via_collectives = 13;
}
message GetDeviceHandlesRequest {

View File

@ -115,4 +115,8 @@ message TPUCompileMetadataProto {
// Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is
// requested.
bool use_spmd_for_xla_partitioning = 15;
// Enables use of XLA collectives for broadcast of replicated parameters to
// all replicas, instead of using TensorFlow Send/Recv.
bool broadcast_replicated_parameters_via_collectives = 16;
}

View File

@ -2229,6 +2229,8 @@ Status DistributedTPURewritePass::BuildCompileNode(
return s.type() == xla::OpSharding::MAXIMAL;
});
proto.set_use_spmd_for_xla_partitioning(use_spmd);
proto.set_broadcast_replicated_parameters_via_collectives(
enable_xla_param_broadcast_);
// Get and fill padding map.
if (replicate_node != nullptr) {
@ -2578,6 +2580,83 @@ Status DistributedTPURewritePass::BuildVariableWrites(
namespace {
// Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
const string& host_cpu_device,
Node* var_read,
absl::string_view name_prefix,
Graph* graph) {
Status status;
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
TensorShape var_shape;
if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
!raw_var_shape.shape.AsTensorShape(&var_shape)) {
return Status(error::FAILED_PRECONDITION, "Failed to read arg shape.");
}
// Const - shape_as_tensor
NodeDef shape_tensor_def;
shape_tensor_def.set_op("Const");
shape_tensor_def.set_name(graph->NewName(
strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
TensorProto tensorshape_proto;
tensorshape_proto.set_dtype(DT_INT32);
for (int i = 0; i < var_shape.dims(); ++i) {
tensorshape_proto.add_int_val(var_shape.dim_size(i));
}
TensorShape shape_shape({var_shape.dims()});
shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
TF_RETURN_IF_ERROR(status);
// Const - initializer value
NodeDef init_val_def;
init_val_def.set_op("Const");
init_val_def.set_name(graph->NewName(
strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
TensorProto tensor_proto;
tensor_proto.set_dtype(dtype);
if (dtype == DT_FLOAT) {
tensor_proto.add_float_val(0.0f);
} else if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
tensor_proto.add_half_val(0);
} else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8) {
tensor_proto.add_int_val(0);
} else if (dtype == DT_INT64) {
tensor_proto.add_int64_val(0);
} else if (dtype == DT_BOOL) {
tensor_proto.add_bool_val(false);
} else {
return errors::Internal(
"Unable to create zero-init dummy arg tensor for type ", dtype);
}
TensorShape scalar_shape({});
scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", tensor_proto, &init_val_def);
AddNodeAttr("dtype", dtype, &init_val_def);
Node* init_val_node = graph->AddNode(init_val_def, &status);
TF_RETURN_IF_ERROR(status);
// Fill node
NodeDef fill_def;
fill_def.set_op("Fill");
fill_def.set_device(host_cpu_device);
fill_def.set_name(
graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
AddNodeAttr("T", dtype, &fill_def);
AddNodeAttr("index_type", DT_INT32, &fill_def);
Node* fill_node = graph->AddNode(fill_def, &status);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
graph->AddEdge(init_val_node, 0, fill_node, 1);
return fill_node;
}
// Helper that creates an IdentityN node containing all of the variables
// values on CPU device 'device', except for those that will be split across
// cores. (For split variables, this may cause additional cross-host data
@ -2592,6 +2671,11 @@ namespace {
// simple, and most models use pure replication where all cores want all the
// variables.
//
// If enable_xla_param_broadcast is set to true, then per-host dummy
// tensor args are created on all hosts except for the primary host. In this
// scheme, the dummy args feed the IdentityN node on their local host. All
// are zero-initialized.
//
// Returns the node and its output index to be consumed by TPUExecute for the
// requested variable index.
xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
@ -2599,7 +2683,9 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
const std::vector<Node*>& variable_reads,
const DistributedTPURewritePass::ParameterInfo& params_info,
const std::vector<xla::OpSharding>& arg_shardings,
const Node& replicate_node,
const Node& replicate_node, const bool enable_xla_param_broadcast,
const int num_cores_per_replica,
const std::vector<InferredShape>& arg_shapes,
absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
Graph* graph) {
auto it = per_host_var_copies->find(host_cpu_device);
@ -2607,6 +2693,12 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
return it->second[var_index];
}
// Variable replication relies on identification of a master.
DeviceNameUtils::ParsedName parsed_device;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
TF_RET_CHECK(parsed_device.has_task);
VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task;
DataTypeVector dtypes;
// Per-variable data source for TPUExecute.
std::vector<NodeOut> index_mapping;
@ -2632,6 +2724,8 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
ndef.set_op("IdentityN");
ndef.set_device(host_cpu_device);
AddNodeAttr("T", dtypes, &ndef);
// TF meta-optimizer should skip this node for constant folding.
AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
Status s;
Node* id_node = graph->AddNode(ndef, &s);
TF_RETURN_IF_ERROR(s);
@ -2641,8 +2735,36 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
if (index_mapping[i].node == nullptr) {
// Fill index_mapping with the actual IdentityN node.
index_mapping[i].node = id_node;
// Add the edge to id_node.
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
if (parsed_device.task == 0 || !enable_xla_param_broadcast) {
// XLA broadcast mode is not enabled, so use the variable reads as args
// to TPUExecuteOp. For task 0, variable reads are always used
// regardless of XLA broadcast.
// Add the variable read edge to id_node.
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
} else {
// XLA broadcast mode is enabled. Create zero-valued dummy tensors to
// use as variable args in the TPUExecuteOp.
int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
params_info.NumBroadcastArgs();
if (num_cores_per_replica > 1) {
LOG(WARNING) << "XLA parameter broadcast is only supported for "
"replicated parameters. Falling back to "
"non-broadcast mode for the parameter associated "
"with the following variable read: "
<< variable_reads[i]->name();
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
continue;
}
string dummy_name =
strings::StrCat(variable_reads[i]->name(),
absl::StrFormat("/dummy_%d", parsed_device.task));
TF_ASSIGN_OR_RETURN(
Node * var_read,
CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device,
variable_reads[i], dummy_name, graph));
graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
}
}
}
@ -3008,11 +3130,13 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
string device;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
tpu_device_names[replica][core], &device));
TF_ASSIGN_OR_RETURN(auto var_data,
CreateOrGetPerHostVariableCopy(
device, variable_num, variable_reads,
params_info, arg_shardings, replicate_node,
&per_host_var_copies, graph));
TF_ASSIGN_OR_RETURN(
auto var_data,
CreateOrGetPerHostVariableCopy(
device, variable_num, variable_reads, params_info,
arg_shardings, replicate_node, enable_xla_param_broadcast_,
num_cores_per_replica, arg_shapes, &per_host_var_copies,
graph));
if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
@ -4321,12 +4445,13 @@ bool DistributedTPURewritePass::
bool DistributedTPURewritePass::
enable_cross_replica_sharding_mirrored_variables_ = true;
bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
/*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
bool distribute_vars, bool allow_xla_spmd_partition,
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism) {
bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast) {
distribute_vars_ = distribute_vars;
allow_xla_spmd_partition_ = allow_xla_spmd_partition;
replicate_inputs_outputs_by_default_for_xla_spmd_ =
@ -4334,6 +4459,7 @@ bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
enable_cross_replica_sharding_mirrored_variables_ =
enable_cross_replica_sharding_mirrored_variables;
enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
enable_xla_param_broadcast_ = enable_xla_param_broadcast;
}
} // namespace tensorflow

View File

@ -132,7 +132,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
bool distribute_vars, bool allow_xla_spmd_partition,
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism);
bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast);
Status Run(const GraphOptimizationPassOptions& options) override;
@ -588,6 +588,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
static bool enable_cross_replica_sharding_mirrored_variables_;
static bool enable_automatic_model_parallelism_;
static bool enable_xla_param_broadcast_;
};
} // namespace tensorflow