Updates weight broadcast option to go through different proto.

PiperOrigin-RevId: 359914825
Change-Id: I6013f943574c0cb0ca1a379da6d26db2bc4cc284
This commit is contained in:
Tayo Oguntebi 2021-02-27 01:51:33 -08:00 committed by TensorFlower Gardener
parent 7a1884b232
commit 32b8a33edb
6 changed files with 101 additions and 56 deletions

View File

@ -77,7 +77,6 @@ class AotCompilationOptions {
virtual int64 replica_count() const { return 0; }
virtual int64 num_cores() const { return 0; }
virtual bool broadcast_replicated_params() const { return false; }
virtual bool use_spmd_partitioning() const { return false; }
virtual bool deduplicate_hlo() const { return false; }

View File

@ -133,11 +133,12 @@ class HloModuleConfig {
}
int64 num_partitions() const { return num_partitions_; }
void set_broadcast_replicated_params(bool broadcast_replicated_params) {
broadcast_replicated_params_ = broadcast_replicated_params;
const std::vector<bool> param_requires_broadcast_via_collectives() const {
return param_requires_broadcast_via_collectives_;
}
bool broadcast_replicated_params() const {
return broadcast_replicated_params_;
void set_param_requires_broadcast_via_collectives(
const std::vector<bool> require_broadcast) {
param_requires_broadcast_via_collectives_ = std::move(require_broadcast);
}
void set_use_spmd_partitioning(bool use_spmd_partitioning) {
@ -256,8 +257,8 @@ class HloModuleConfig {
// The number of partitions (model parallelism) to compile this binary for.
int64 num_partitions_ = 1;
// Whether to use XLA collectives to broadcast params to all replicas.
bool broadcast_replicated_params_ = false;
// Whether to broadcast args across all replicas. One entry per arg.
std::vector<bool> param_requires_broadcast_via_collectives_;
// Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
// needs to partition the module.

View File

@ -381,7 +381,7 @@ message ExecutionOptions {
// works on TPU.
bool deduplicate_hlo = 12;
reserved 13; // Was broadcast_replicated_parameters_via_collectives = 13;
reserved 13; // Was broadcast_replicated_parameters_via_collectives
}
message GetDeviceHandlesRequest {

View File

@ -62,6 +62,10 @@ message TPUCompileMetadataProto {
// Name of the node that the arg comes from.
string name = 10;
// Whether to use XLA collectives to broadcast this parameter to all
// replicas, instead of using TensorFlow Send/Recv among the tasks.
bool requires_xla_broadcast = 11;
}
repeated Arg args = 1;
@ -116,7 +120,5 @@ message TPUCompileMetadataProto {
// requested.
bool use_spmd_for_xla_partitioning = 15;
// Enables use of XLA collectives for broadcast of replicated parameters to
// all replicas, instead of using TensorFlow Send/Recv.
bool broadcast_replicated_parameters_via_collectives = 16;
reserved 16; // Was broadcast_replicated_parameters_via_collectives
}

View File

@ -2299,6 +2299,42 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
return Status::OK();
}
namespace {
bool XlaBroadcastTypeSupported(const DataType dtype) {
return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
dtype == DT_BOOL);
}
bool XlaBroadcastKindSupported(
const DistributedTPURewritePass::ParameterInfo& params_info,
int param_num) {
// NOTE: This is intended to cover non-sharded data parallel variables, for
// training only. . Is it correct to just check if the arg_type is
// DT_RESOURCE?
return params_info.IsVariableArg(param_num) &&
!(params_info.IsPerReplicaArg(param_num) ||
params_info.IsDistributedArg(param_num) ||
params_info.IsBroadcastArg(param_num) ||
params_info.IsConstantArg(param_num));
}
bool EnableXlaParamBroadcast(
bool enable_xla_param_broadcast,
const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
DataType dtype, int num_cores_per_replica) {
// Conditions necessary to use XLA collectives for arg broadcast:
// 1. Globally enabled via enable_xla_param_broadcast.
// 2. DataType must be supported.
// 3. Parameter must be a variable, and not distributed or broadcasted.
// 4. Model parallelism is not currently supported.
return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
XlaBroadcastKindSupported(params_info, param_num) &&
(num_cores_per_replica == 1);
}
} // namespace
// Builds a TPUCompile node that compiles the bodies of the function call
// `nodes`.
Status DistributedTPURewritePass::BuildCompileNode(
@ -2315,7 +2351,7 @@ Status DistributedTPURewritePass::BuildCompileNode(
int num_cores_per_replica, const string& compile_device,
const xla::DeviceAssignment* xla_device_assignment,
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
Node** compile_node, int64 autotuner_thresh) {
Node** compile_node, int64 autotuner_thresh, int num_tasks) {
VLOG(1) << "BuildCompileNode";
tpu::TPUCompileMetadataProto proto;
@ -2334,8 +2370,6 @@ Status DistributedTPURewritePass::BuildCompileNode(
return s.type() == xla::OpSharding::MAXIMAL;
});
proto.set_use_spmd_for_xla_partitioning(use_spmd);
proto.set_broadcast_replicated_parameters_via_collectives(
enable_xla_param_broadcast_);
// Get and fill padding map.
if (replicate_node != nullptr) {
@ -2383,6 +2417,15 @@ Status DistributedTPURewritePass::BuildCompileNode(
arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
}
}
// Use XLA collective primitives to distribute variables to all replicas,
// for multi-host systems.
arg->set_requires_xla_broadcast(
num_tasks > 1 &&
EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i,
arg_shape.handle_type /*arg.dtype?*/,
num_cores_per_replica));
// As long as the argument is not a per-replica one, it should have the same
// value for all replicas. For clarity, we keep the (redundant) checks for
// variable, broadcast and constant types, to prevent bugs in case new types
@ -2686,20 +2729,39 @@ Status DistributedTPURewritePass::BuildVariableWrites(
namespace {
// Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
const string& host_cpu_device,
Node* var_read,
absl::string_view name_prefix,
Graph* graph) {
xla::StatusOr<Node*> MaybeCreatePerHostDummyArgs(
const std::vector<InferredShape>& arg_shapes, const string& host_cpu_device,
const DistributedTPURewritePass::ParameterInfo& params_info, Node* var_read,
int var_num, int num_cores_per_replica, Graph* graph) {
Status status;
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
dtype == DT_BOOL)) {
if (num_cores_per_replica > 1) {
LOG_FIRST_N(WARNING, 1) << "XLA parameter broadcast is not supported for "
"model-partitioned parameters. Falling back to "
"non-broadcast mode for all parameters.";
return var_read;
}
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
DeviceNameUtils::ParsedName parsed_device;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
TF_RET_CHECK(parsed_device.has_task);
// Task 0 behaves as the primary task, where variables are assigned. Use the
// variable reads as arguments to TPUExecute.
// For other tasks, create dummies if the graph meets preconditions.
int64 orig_arg_num = var_num + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs();
if (parsed_device.task == 0 ||
!EnableXlaParamBroadcast(/*enable_xla_param_broadcast=*/true, params_info,
orig_arg_num, dtype, num_cores_per_replica)) {
return var_read;
}
auto raw_var_shape = arg_shapes[orig_arg_num];
TensorShape var_shape;
if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
!raw_var_shape.shape.AsTensorShape(&var_shape)) {
@ -2707,6 +2769,8 @@ xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
}
// Const - shape_as_tensor
const std::string name_prefix = strings::StrCat(
var_read->name(), absl::StrFormat("/dummy_%d", parsed_device.task));
NodeDef shape_tensor_def;
shape_tensor_def.set_op("Const");
shape_tensor_def.set_name(graph->NewName(
@ -2801,12 +2865,6 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
return it->second[var_index];
}
// Variable replication relies on identification of a master.
DeviceNameUtils::ParsedName parsed_device;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
TF_RET_CHECK(parsed_device.has_task);
VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task;
DataTypeVector dtypes;
// Per-variable data source for TPUExecute.
std::vector<NodeOut> index_mapping;
@ -2814,8 +2872,9 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
dtypes.reserve(variable_reads.size());
for (int64 i = 0; i < variable_reads.size(); ++i) {
Node* read = variable_reads[i];
int64 orig_arg_num =
i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs();
int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs();
if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
// We haven't built the IdentityN node yet, so temporarily use nullptr.
index_mapping.push_back(
@ -2843,34 +2902,18 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
if (index_mapping[i].node == nullptr) {
// Fill index_mapping with the actual IdentityN node.
index_mapping[i].node = id_node;
if (parsed_device.task == 0 || !enable_xla_param_broadcast) {
// XLA broadcast mode is not enabled, so use the variable reads as args
// to TPUExecuteOp. For task 0, variable reads are always used
// regardless of XLA broadcast.
if (!enable_xla_param_broadcast) {
// Add the variable read edge to id_node.
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
} else {
// XLA broadcast mode is enabled. Create zero-valued dummy tensors to
// use as variable args in the TPUExecuteOp.
int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
params_info.NumBroadcastArgs();
if (num_cores_per_replica > 1) {
LOG(WARNING) << "XLA parameter broadcast is only supported for "
"replicated parameters. Falling back to "
"non-broadcast mode for the parameter associated "
"with the following variable read: "
<< variable_reads[i]->name();
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
continue;
}
string dummy_name =
strings::StrCat(variable_reads[i]->name(),
absl::StrFormat("/dummy_%d", parsed_device.task));
// XLA param broadcast mode is enabled. Create zero-valued dummy
// tensors to use as variable args in the TPUExecuteOp, instead of
// original variable reads.
TF_ASSIGN_OR_RETURN(
Node * var_read,
CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device,
variable_reads[i], dummy_name, graph));
MaybeCreatePerHostDummyArgs(arg_shapes, host_cpu_device,
params_info, variable_reads[i], i,
num_cores_per_replica, graph));
graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
}
}
@ -4323,7 +4366,7 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
/*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
dynamic_shape_nodes, graph, &compile_node, autotuner_thresh, num_tasks));
// Compilation must be sequenced after the control node if the TPU computation
// in a control-flow construct, such as a loop.

View File

@ -362,7 +362,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
int num_cores_per_replica, const string& compile_device,
const xla::DeviceAssignment* xla_device_assignment,
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
Node** compile_node, int64 autotuner_thresh);
Node** compile_node, int64 autotuner_thresh, int num_tasks);
// Builds a TPUCompileSucceededAssert node that verifies that compilation
// succeeded and replaces the TPUCompilationStatus node in the graph.