Updates weight broadcast option to go through different proto.
PiperOrigin-RevId: 359914825 Change-Id: I6013f943574c0cb0ca1a379da6d26db2bc4cc284
This commit is contained in:
parent
7a1884b232
commit
32b8a33edb
@ -77,7 +77,6 @@ class AotCompilationOptions {
|
||||
|
||||
virtual int64 replica_count() const { return 0; }
|
||||
virtual int64 num_cores() const { return 0; }
|
||||
virtual bool broadcast_replicated_params() const { return false; }
|
||||
virtual bool use_spmd_partitioning() const { return false; }
|
||||
virtual bool deduplicate_hlo() const { return false; }
|
||||
|
||||
|
||||
@ -133,11 +133,12 @@ class HloModuleConfig {
|
||||
}
|
||||
int64 num_partitions() const { return num_partitions_; }
|
||||
|
||||
void set_broadcast_replicated_params(bool broadcast_replicated_params) {
|
||||
broadcast_replicated_params_ = broadcast_replicated_params;
|
||||
const std::vector<bool> param_requires_broadcast_via_collectives() const {
|
||||
return param_requires_broadcast_via_collectives_;
|
||||
}
|
||||
bool broadcast_replicated_params() const {
|
||||
return broadcast_replicated_params_;
|
||||
void set_param_requires_broadcast_via_collectives(
|
||||
const std::vector<bool> require_broadcast) {
|
||||
param_requires_broadcast_via_collectives_ = std::move(require_broadcast);
|
||||
}
|
||||
|
||||
void set_use_spmd_partitioning(bool use_spmd_partitioning) {
|
||||
@ -256,8 +257,8 @@ class HloModuleConfig {
|
||||
// The number of partitions (model parallelism) to compile this binary for.
|
||||
int64 num_partitions_ = 1;
|
||||
|
||||
// Whether to use XLA collectives to broadcast params to all replicas.
|
||||
bool broadcast_replicated_params_ = false;
|
||||
// Whether to broadcast args across all replicas. One entry per arg.
|
||||
std::vector<bool> param_requires_broadcast_via_collectives_;
|
||||
|
||||
// Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
|
||||
// needs to partition the module.
|
||||
|
||||
@ -381,7 +381,7 @@ message ExecutionOptions {
|
||||
// works on TPU.
|
||||
bool deduplicate_hlo = 12;
|
||||
|
||||
reserved 13; // Was broadcast_replicated_parameters_via_collectives = 13;
|
||||
reserved 13; // Was broadcast_replicated_parameters_via_collectives
|
||||
}
|
||||
|
||||
message GetDeviceHandlesRequest {
|
||||
|
||||
@ -62,6 +62,10 @@ message TPUCompileMetadataProto {
|
||||
|
||||
// Name of the node that the arg comes from.
|
||||
string name = 10;
|
||||
|
||||
// Whether to use XLA collectives to broadcast this parameter to all
|
||||
// replicas, instead of using TensorFlow Send/Recv among the tasks.
|
||||
bool requires_xla_broadcast = 11;
|
||||
}
|
||||
repeated Arg args = 1;
|
||||
|
||||
@ -116,7 +120,5 @@ message TPUCompileMetadataProto {
|
||||
// requested.
|
||||
bool use_spmd_for_xla_partitioning = 15;
|
||||
|
||||
// Enables use of XLA collectives for broadcast of replicated parameters to
|
||||
// all replicas, instead of using TensorFlow Send/Recv.
|
||||
bool broadcast_replicated_parameters_via_collectives = 16;
|
||||
reserved 16; // Was broadcast_replicated_parameters_via_collectives
|
||||
}
|
||||
|
||||
@ -2299,6 +2299,42 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool XlaBroadcastTypeSupported(const DataType dtype) {
|
||||
return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
|
||||
dtype == DT_BOOL);
|
||||
}
|
||||
|
||||
bool XlaBroadcastKindSupported(
|
||||
const DistributedTPURewritePass::ParameterInfo& params_info,
|
||||
int param_num) {
|
||||
// NOTE: This is intended to cover non-sharded data parallel variables, for
|
||||
// training only. . Is it correct to just check if the arg_type is
|
||||
// DT_RESOURCE?
|
||||
return params_info.IsVariableArg(param_num) &&
|
||||
!(params_info.IsPerReplicaArg(param_num) ||
|
||||
params_info.IsDistributedArg(param_num) ||
|
||||
params_info.IsBroadcastArg(param_num) ||
|
||||
params_info.IsConstantArg(param_num));
|
||||
}
|
||||
|
||||
bool EnableXlaParamBroadcast(
|
||||
bool enable_xla_param_broadcast,
|
||||
const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
|
||||
DataType dtype, int num_cores_per_replica) {
|
||||
// Conditions necessary to use XLA collectives for arg broadcast:
|
||||
// 1. Globally enabled via enable_xla_param_broadcast.
|
||||
// 2. DataType must be supported.
|
||||
// 3. Parameter must be a variable, and not distributed or broadcasted.
|
||||
// 4. Model parallelism is not currently supported.
|
||||
return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
|
||||
XlaBroadcastKindSupported(params_info, param_num) &&
|
||||
(num_cores_per_replica == 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Builds a TPUCompile node that compiles the bodies of the function call
|
||||
// `nodes`.
|
||||
Status DistributedTPURewritePass::BuildCompileNode(
|
||||
@ -2315,7 +2351,7 @@ Status DistributedTPURewritePass::BuildCompileNode(
|
||||
int num_cores_per_replica, const string& compile_device,
|
||||
const xla::DeviceAssignment* xla_device_assignment,
|
||||
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
|
||||
Node** compile_node, int64 autotuner_thresh) {
|
||||
Node** compile_node, int64 autotuner_thresh, int num_tasks) {
|
||||
VLOG(1) << "BuildCompileNode";
|
||||
|
||||
tpu::TPUCompileMetadataProto proto;
|
||||
@ -2334,8 +2370,6 @@ Status DistributedTPURewritePass::BuildCompileNode(
|
||||
return s.type() == xla::OpSharding::MAXIMAL;
|
||||
});
|
||||
proto.set_use_spmd_for_xla_partitioning(use_spmd);
|
||||
proto.set_broadcast_replicated_parameters_via_collectives(
|
||||
enable_xla_param_broadcast_);
|
||||
|
||||
// Get and fill padding map.
|
||||
if (replicate_node != nullptr) {
|
||||
@ -2383,6 +2417,15 @@ Status DistributedTPURewritePass::BuildCompileNode(
|
||||
arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
|
||||
}
|
||||
}
|
||||
|
||||
// Use XLA collective primitives to distribute variables to all replicas,
|
||||
// for multi-host systems.
|
||||
arg->set_requires_xla_broadcast(
|
||||
num_tasks > 1 &&
|
||||
EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i,
|
||||
arg_shape.handle_type /*arg.dtype?*/,
|
||||
num_cores_per_replica));
|
||||
|
||||
// As long as the argument is not a per-replica one, it should have the same
|
||||
// value for all replicas. For clarity, we keep the (redundant) checks for
|
||||
// variable, broadcast and constant types, to prevent bugs in case new types
|
||||
@ -2686,20 +2729,39 @@ Status DistributedTPURewritePass::BuildVariableWrites(
|
||||
namespace {
|
||||
|
||||
// Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
|
||||
xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
|
||||
const string& host_cpu_device,
|
||||
Node* var_read,
|
||||
absl::string_view name_prefix,
|
||||
Graph* graph) {
|
||||
xla::StatusOr<Node*> MaybeCreatePerHostDummyArgs(
|
||||
const std::vector<InferredShape>& arg_shapes, const string& host_cpu_device,
|
||||
const DistributedTPURewritePass::ParameterInfo& params_info, Node* var_read,
|
||||
int var_num, int num_cores_per_replica, Graph* graph) {
|
||||
Status status;
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
|
||||
|
||||
if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
|
||||
dtype == DT_BOOL)) {
|
||||
if (num_cores_per_replica > 1) {
|
||||
LOG_FIRST_N(WARNING, 1) << "XLA parameter broadcast is not supported for "
|
||||
"model-partitioned parameters. Falling back to "
|
||||
"non-broadcast mode for all parameters.";
|
||||
return var_read;
|
||||
}
|
||||
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
|
||||
|
||||
DeviceNameUtils::ParsedName parsed_device;
|
||||
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
|
||||
TF_RET_CHECK(parsed_device.has_task);
|
||||
|
||||
// Task 0 behaves as the primary task, where variables are assigned. Use the
|
||||
// variable reads as arguments to TPUExecute.
|
||||
// For other tasks, create dummies if the graph meets preconditions.
|
||||
int64 orig_arg_num = var_num + params_info.NumPerReplicaArgs() +
|
||||
params_info.NumDistributedArgs() +
|
||||
params_info.NumBroadcastArgs();
|
||||
if (parsed_device.task == 0 ||
|
||||
!EnableXlaParamBroadcast(/*enable_xla_param_broadcast=*/true, params_info,
|
||||
orig_arg_num, dtype, num_cores_per_replica)) {
|
||||
return var_read;
|
||||
}
|
||||
|
||||
auto raw_var_shape = arg_shapes[orig_arg_num];
|
||||
TensorShape var_shape;
|
||||
if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
|
||||
!raw_var_shape.shape.AsTensorShape(&var_shape)) {
|
||||
@ -2707,6 +2769,8 @@ xla::StatusOr<Node*> CreatePerHostDummyArgs(const InferredShape& raw_var_shape,
|
||||
}
|
||||
|
||||
// Const - shape_as_tensor
|
||||
const std::string name_prefix = strings::StrCat(
|
||||
var_read->name(), absl::StrFormat("/dummy_%d", parsed_device.task));
|
||||
NodeDef shape_tensor_def;
|
||||
shape_tensor_def.set_op("Const");
|
||||
shape_tensor_def.set_name(graph->NewName(
|
||||
@ -2801,12 +2865,6 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
|
||||
return it->second[var_index];
|
||||
}
|
||||
|
||||
// Variable replication relies on identification of a master.
|
||||
DeviceNameUtils::ParsedName parsed_device;
|
||||
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
|
||||
TF_RET_CHECK(parsed_device.has_task);
|
||||
VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task;
|
||||
|
||||
DataTypeVector dtypes;
|
||||
// Per-variable data source for TPUExecute.
|
||||
std::vector<NodeOut> index_mapping;
|
||||
@ -2814,8 +2872,9 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
|
||||
dtypes.reserve(variable_reads.size());
|
||||
for (int64 i = 0; i < variable_reads.size(); ++i) {
|
||||
Node* read = variable_reads[i];
|
||||
int64 orig_arg_num =
|
||||
i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs();
|
||||
int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
|
||||
params_info.NumDistributedArgs() +
|
||||
params_info.NumBroadcastArgs();
|
||||
if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
|
||||
// We haven't built the IdentityN node yet, so temporarily use nullptr.
|
||||
index_mapping.push_back(
|
||||
@ -2843,34 +2902,18 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
|
||||
if (index_mapping[i].node == nullptr) {
|
||||
// Fill index_mapping with the actual IdentityN node.
|
||||
index_mapping[i].node = id_node;
|
||||
if (parsed_device.task == 0 || !enable_xla_param_broadcast) {
|
||||
// XLA broadcast mode is not enabled, so use the variable reads as args
|
||||
// to TPUExecuteOp. For task 0, variable reads are always used
|
||||
// regardless of XLA broadcast.
|
||||
|
||||
if (!enable_xla_param_broadcast) {
|
||||
// Add the variable read edge to id_node.
|
||||
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
|
||||
} else {
|
||||
// XLA broadcast mode is enabled. Create zero-valued dummy tensors to
|
||||
// use as variable args in the TPUExecuteOp.
|
||||
int64 orig_arg_num = i + params_info.NumPerReplicaArgs() +
|
||||
params_info.NumBroadcastArgs();
|
||||
if (num_cores_per_replica > 1) {
|
||||
LOG(WARNING) << "XLA parameter broadcast is only supported for "
|
||||
"replicated parameters. Falling back to "
|
||||
"non-broadcast mode for the parameter associated "
|
||||
"with the following variable read: "
|
||||
<< variable_reads[i]->name();
|
||||
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
|
||||
continue;
|
||||
}
|
||||
string dummy_name =
|
||||
strings::StrCat(variable_reads[i]->name(),
|
||||
absl::StrFormat("/dummy_%d", parsed_device.task));
|
||||
// XLA param broadcast mode is enabled. Create zero-valued dummy
|
||||
// tensors to use as variable args in the TPUExecuteOp, instead of
|
||||
// original variable reads.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * var_read,
|
||||
CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device,
|
||||
variable_reads[i], dummy_name, graph));
|
||||
MaybeCreatePerHostDummyArgs(arg_shapes, host_cpu_device,
|
||||
params_info, variable_reads[i], i,
|
||||
num_cores_per_replica, graph));
|
||||
graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
|
||||
}
|
||||
}
|
||||
@ -4323,7 +4366,7 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
|
||||
arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
|
||||
arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
|
||||
/*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
|
||||
dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
|
||||
dynamic_shape_nodes, graph, &compile_node, autotuner_thresh, num_tasks));
|
||||
|
||||
// Compilation must be sequenced after the control node if the TPU computation
|
||||
// in a control-flow construct, such as a loop.
|
||||
|
||||
@ -362,7 +362,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
|
||||
int num_cores_per_replica, const string& compile_device,
|
||||
const xla::DeviceAssignment* xla_device_assignment,
|
||||
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
|
||||
Node** compile_node, int64 autotuner_thresh);
|
||||
Node** compile_node, int64 autotuner_thresh, int num_tasks);
|
||||
|
||||
// Builds a TPUCompileSucceededAssert node that verifies that compilation
|
||||
// succeeded and replaces the TPUCompilationStatus node in the graph.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user