diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 2f6bc8e2e9e..623b8262178 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -77,7 +77,6 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } - virtual bool broadcast_replicated_params() const { return false; } virtual bool use_spmd_partitioning() const { return false; } virtual bool deduplicate_hlo() const { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 1167f6df140..8d02df045d6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -133,11 +133,12 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } - void set_broadcast_replicated_params(bool broadcast_replicated_params) { - broadcast_replicated_params_ = broadcast_replicated_params; + const std::vector param_requires_broadcast_via_collectives() const { + return param_requires_broadcast_via_collectives_; } - bool broadcast_replicated_params() const { - return broadcast_replicated_params_; + void set_param_requires_broadcast_via_collectives( + const std::vector require_broadcast) { + param_requires_broadcast_via_collectives_ = std::move(require_broadcast); } void set_use_spmd_partitioning(bool use_spmd_partitioning) { @@ -256,8 +257,8 @@ class HloModuleConfig { // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; - // Whether to use XLA collectives to broadcast params to all replicas. - bool broadcast_replicated_params_ = false; + // Whether to broadcast args across all replicas. One entry per arg. + std::vector param_requires_broadcast_via_collectives_; // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA // needs to partition the module. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 1929b90a44a..b7cf13e084a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -381,7 +381,7 @@ message ExecutionOptions { // works on TPU. bool deduplicate_hlo = 12; - reserved 13; // Was broadcast_replicated_parameters_via_collectives = 13; + reserved 13; // Was broadcast_replicated_parameters_via_collectives } message GetDeviceHandlesRequest { diff --git a/tensorflow/core/protobuf/tpu/compile_metadata.proto b/tensorflow/core/protobuf/tpu/compile_metadata.proto index 3d90cfb1cbf..5c21f078299 100644 --- a/tensorflow/core/protobuf/tpu/compile_metadata.proto +++ b/tensorflow/core/protobuf/tpu/compile_metadata.proto @@ -62,6 +62,10 @@ message TPUCompileMetadataProto { // Name of the node that the arg comes from. string name = 10; + + // Whether to use XLA collectives to broadcast this parameter to all + // replicas, instead of using TensorFlow Send/Recv among the tasks. + bool requires_xla_broadcast = 11; } repeated Arg args = 1; @@ -116,7 +120,5 @@ message TPUCompileMetadataProto { // requested. bool use_spmd_for_xla_partitioning = 15; - // Enables use of XLA collectives for broadcast of replicated parameters to - // all replicas, instead of using TensorFlow Send/Recv. - bool broadcast_replicated_parameters_via_collectives = 16; + reserved 16; // Was broadcast_replicated_parameters_via_collectives } diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index a183c3dc522..13ba2ce0fc5 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -2299,6 +2299,42 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( return Status::OK(); } +namespace { + +bool XlaBroadcastTypeSupported(const DataType dtype) { + return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 || + dtype == DT_BOOL); +} + +bool XlaBroadcastKindSupported( + const DistributedTPURewritePass::ParameterInfo& params_info, + int param_num) { + // NOTE: This is intended to cover non-sharded data parallel variables, for + // training only. . Is it correct to just check if the arg_type is + // DT_RESOURCE? + return params_info.IsVariableArg(param_num) && + !(params_info.IsPerReplicaArg(param_num) || + params_info.IsDistributedArg(param_num) || + params_info.IsBroadcastArg(param_num) || + params_info.IsConstantArg(param_num)); +} + +bool EnableXlaParamBroadcast( + bool enable_xla_param_broadcast, + const DistributedTPURewritePass::ParameterInfo& params_info, int param_num, + DataType dtype, int num_cores_per_replica) { + // Conditions necessary to use XLA collectives for arg broadcast: + // 1. Globally enabled via enable_xla_param_broadcast. + // 2. DataType must be supported. + // 3. Parameter must be a variable, and not distributed or broadcasted. + // 4. Model parallelism is not currently supported. + return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) && + XlaBroadcastKindSupported(params_info, param_num) && + (num_cores_per_replica == 1); +} + +} // namespace + // Builds a TPUCompile node that compiles the bodies of the function call // `nodes`. Status DistributedTPURewritePass::BuildCompileNode( @@ -2315,7 +2351,7 @@ Status DistributedTPURewritePass::BuildCompileNode( int num_cores_per_replica, const string& compile_device, const xla::DeviceAssignment* xla_device_assignment, const std::vector& dynamic_shape_nodes, Graph* graph, - Node** compile_node, int64 autotuner_thresh) { + Node** compile_node, int64 autotuner_thresh, int num_tasks) { VLOG(1) << "BuildCompileNode"; tpu::TPUCompileMetadataProto proto; @@ -2334,8 +2370,6 @@ Status DistributedTPURewritePass::BuildCompileNode( return s.type() == xla::OpSharding::MAXIMAL; }); proto.set_use_spmd_for_xla_partitioning(use_spmd); - proto.set_broadcast_replicated_parameters_via_collectives( - enable_xla_param_broadcast_); // Get and fill padding map. if (replicate_node != nullptr) { @@ -2383,6 +2417,15 @@ Status DistributedTPURewritePass::BuildCompileNode( arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); } } + + // Use XLA collective primitives to distribute variables to all replicas, + // for multi-host systems. + arg->set_requires_xla_broadcast( + num_tasks > 1 && + EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i, + arg_shape.handle_type /*arg.dtype?*/, + num_cores_per_replica)); + // As long as the argument is not a per-replica one, it should have the same // value for all replicas. For clarity, we keep the (redundant) checks for // variable, broadcast and constant types, to prevent bugs in case new types @@ -2686,20 +2729,39 @@ Status DistributedTPURewritePass::BuildVariableWrites( namespace { // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes. -xla::StatusOr CreatePerHostDummyArgs(const InferredShape& raw_var_shape, - const string& host_cpu_device, - Node* var_read, - absl::string_view name_prefix, - Graph* graph) { +xla::StatusOr MaybeCreatePerHostDummyArgs( + const std::vector& arg_shapes, const string& host_cpu_device, + const DistributedTPURewritePass::ParameterInfo& params_info, Node* var_read, + int var_num, int num_cores_per_replica, Graph* graph) { Status status; - DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype)); - if (!(dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 || - dtype == DT_BOOL)) { + if (num_cores_per_replica > 1) { + LOG_FIRST_N(WARNING, 1) << "XLA parameter broadcast is not supported for " + "model-partitioned parameters. Falling back to " + "non-broadcast mode for all parameters."; return var_read; } + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype)); + + DeviceNameUtils::ParsedName parsed_device; + TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device)); + TF_RET_CHECK(parsed_device.has_task); + + // Task 0 behaves as the primary task, where variables are assigned. Use the + // variable reads as arguments to TPUExecute. + // For other tasks, create dummies if the graph meets preconditions. + int64 orig_arg_num = var_num + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs(); + if (parsed_device.task == 0 || + !EnableXlaParamBroadcast(/*enable_xla_param_broadcast=*/true, params_info, + orig_arg_num, dtype, num_cores_per_replica)) { + return var_read; + } + + auto raw_var_shape = arg_shapes[orig_arg_num]; TensorShape var_shape; if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) && !raw_var_shape.shape.AsTensorShape(&var_shape)) { @@ -2707,6 +2769,8 @@ xla::StatusOr CreatePerHostDummyArgs(const InferredShape& raw_var_shape, } // Const - shape_as_tensor + const std::string name_prefix = strings::StrCat( + var_read->name(), absl::StrFormat("/dummy_%d", parsed_device.task)); NodeDef shape_tensor_def; shape_tensor_def.set_op("Const"); shape_tensor_def.set_name(graph->NewName( @@ -2801,12 +2865,6 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( return it->second[var_index]; } - // Variable replication relies on identification of a master. - DeviceNameUtils::ParsedName parsed_device; - TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device)); - TF_RET_CHECK(parsed_device.has_task); - VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task; - DataTypeVector dtypes; // Per-variable data source for TPUExecute. std::vector index_mapping; @@ -2814,8 +2872,9 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( dtypes.reserve(variable_reads.size()); for (int64 i = 0; i < variable_reads.size(); ++i) { Node* read = variable_reads[i]; - int64 orig_arg_num = - i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs(); + int64 orig_arg_num = i + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs(); if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) { // We haven't built the IdentityN node yet, so temporarily use nullptr. index_mapping.push_back( @@ -2843,34 +2902,18 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( if (index_mapping[i].node == nullptr) { // Fill index_mapping with the actual IdentityN node. index_mapping[i].node = id_node; - if (parsed_device.task == 0 || !enable_xla_param_broadcast) { - // XLA broadcast mode is not enabled, so use the variable reads as args - // to TPUExecuteOp. For task 0, variable reads are always used - // regardless of XLA broadcast. - + if (!enable_xla_param_broadcast) { // Add the variable read edge to id_node. graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); } else { - // XLA broadcast mode is enabled. Create zero-valued dummy tensors to - // use as variable args in the TPUExecuteOp. - int64 orig_arg_num = i + params_info.NumPerReplicaArgs() + - params_info.NumBroadcastArgs(); - if (num_cores_per_replica > 1) { - LOG(WARNING) << "XLA parameter broadcast is only supported for " - "replicated parameters. Falling back to " - "non-broadcast mode for the parameter associated " - "with the following variable read: " - << variable_reads[i]->name(); - graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); - continue; - } - string dummy_name = - strings::StrCat(variable_reads[i]->name(), - absl::StrFormat("/dummy_%d", parsed_device.task)); + // XLA param broadcast mode is enabled. Create zero-valued dummy + // tensors to use as variable args in the TPUExecuteOp, instead of + // original variable reads. TF_ASSIGN_OR_RETURN( Node * var_read, - CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device, - variable_reads[i], dummy_name, graph)); + MaybeCreatePerHostDummyArgs(arg_shapes, host_cpu_device, + params_info, variable_reads[i], i, + num_cores_per_replica, graph)); graph->AddEdge(var_read, 0, id_node, index_mapping[i].index); } } @@ -4323,7 +4366,7 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes( arg_types, guaranteed_constant_nodes, session_handle, arg_sharding, arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica, /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(), - dynamic_shape_nodes, graph, &compile_node, autotuner_thresh)); + dynamic_shape_nodes, graph, &compile_node, autotuner_thresh, num_tasks)); // Compilation must be sequenced after the control node if the TPU computation // in a control-flow construct, such as a loop. diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h index fd755bcefbc..acbe4e00963 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -362,7 +362,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass { int num_cores_per_replica, const string& compile_device, const xla::DeviceAssignment* xla_device_assignment, const std::vector& dynamic_shape_nodes, Graph* graph, - Node** compile_node, int64 autotuner_thresh); + Node** compile_node, int64 autotuner_thresh, int num_tasks); // Builds a TPUCompileSucceededAssert node that verifies that compilation // succeeded and replaces the TPUCompilationStatus node in the graph.