diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 647232300e2..2600437f115 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -83,6 +83,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_deduplicate_hlo( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_broadcast_replicated_params( + bool broadcast_replicated_params) { + broadcast_replicated_params_ = broadcast_replicated_params; + return *this; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( const DeviceAssignment& device_assignment) { device_assignment_ = device_assignment; diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 000d2adb648..21f8e1fdb05 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -85,6 +85,12 @@ class ExecutableBuildOptions { bool deduplicate_hlo() const { return deduplicate_hlo_; } ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo); + bool broadcast_replicated_params() const { + return broadcast_replicated_params_; + } + ExecutableBuildOptions& set_broadcast_replicated_params( + bool broadcast_replicated_params); + // If set, this specifies a static device assignment for the computation. // Otherwise, the computation will be compiled generically and can be run with // any device assignment compatible with the computation's replica and @@ -135,6 +141,7 @@ class ExecutableBuildOptions { int num_partitions_ = 1; bool use_spmd_partitioning_ = false; bool deduplicate_hlo_ = false; + bool broadcast_replicated_params_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; bool run_backend_only_ = false; diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index f8e4f591a5d..df3210d9e3d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -93,6 +93,8 @@ CompileOnlyService::CompileAheadOfTime( } execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning()); execution_options.set_deduplicate_hlo(options.deduplicate_hlo()); + execution_options.set_broadcast_replicated_parameters_via_collectives( + options.broadcast_replicated_params()); for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); *execution_options.mutable_shape_with_output_layout() = diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 623b8262178..2f6bc8e2e9e 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -77,6 +77,7 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } + virtual bool broadcast_replicated_params() const { return false; } virtual bool use_spmd_partitioning() const { return false; } virtual bool deduplicate_hlo() const { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index e96f6d75b04..320ba25b3c3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -447,6 +447,8 @@ StatusOr HloModule::CreateModuleConfigFromShape( module_config.set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); + module_config.set_broadcast_replicated_params( + execution_options->broadcast_replicated_parameters_via_collectives()); if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index ae0a8aae838..1167f6df140 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -133,6 +133,13 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } + void set_broadcast_replicated_params(bool broadcast_replicated_params) { + broadcast_replicated_params_ = broadcast_replicated_params; + } + bool broadcast_replicated_params() const { + return broadcast_replicated_params_; + } + void set_use_spmd_partitioning(bool use_spmd_partitioning) { use_spmd_partitioning_ = use_spmd_partitioning; } @@ -249,6 +256,9 @@ class HloModuleConfig { // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; + // Whether to use XLA collectives to broadcast params to all replicas. + bool broadcast_replicated_params_ = false; + // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA // needs to partition the module. bool use_spmd_partitioning_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_util.cc b/tensorflow/compiler/xla/service/hlo_module_util.cc index 106c50c6e8a..b1067a227d4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_util.cc @@ -96,6 +96,8 @@ StatusOr> CreateModuleConfig( config->set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); config->set_deduplicate_hlo(execution_options->deduplicate_hlo()); + config->set_broadcast_replicated_params( + execution_options->broadcast_replicated_parameters_via_collectives()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index eb67010a651..2d86b6efac8 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -374,6 +374,10 @@ message ExecutionOptions { // If set, deduplicate hlo into function calls to reduce binary size. Only // works on TPU. bool deduplicate_hlo = 12; + + // If set, broadcast replicated parameters to all replicas, using collectives. + // Only applicable to TPU. + bool broadcast_replicated_parameters_via_collectives = 13; } message GetDeviceHandlesRequest { diff --git a/tensorflow/core/protobuf/tpu/compile_metadata.proto b/tensorflow/core/protobuf/tpu/compile_metadata.proto index 2b29e8468b2..3d90cfb1cbf 100644 --- a/tensorflow/core/protobuf/tpu/compile_metadata.proto +++ b/tensorflow/core/protobuf/tpu/compile_metadata.proto @@ -115,4 +115,8 @@ message TPUCompileMetadataProto { // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is // requested. bool use_spmd_for_xla_partitioning = 15; + + // Enables use of XLA collectives for broadcast of replicated parameters to + // all replicas, instead of using TensorFlow Send/Recv. + bool broadcast_replicated_parameters_via_collectives = 16; } diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index f91234f1ab3..614a3ef82ed 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -2229,6 +2229,8 @@ Status DistributedTPURewritePass::BuildCompileNode( return s.type() == xla::OpSharding::MAXIMAL; }); proto.set_use_spmd_for_xla_partitioning(use_spmd); + proto.set_broadcast_replicated_parameters_via_collectives( + enable_xla_param_broadcast_); // Get and fill padding map. if (replicate_node != nullptr) { @@ -2578,6 +2580,83 @@ Status DistributedTPURewritePass::BuildVariableWrites( namespace { +// Creates nodes for zero-initialized dummy arguments for TPUExecute nodes. +xla::StatusOr CreatePerHostDummyArgs(const InferredShape& raw_var_shape, + const string& host_cpu_device, + Node* var_read, + absl::string_view name_prefix, + Graph* graph) { + Status status; + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype)); + + TensorShape var_shape; + if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) && + !raw_var_shape.shape.AsTensorShape(&var_shape)) { + return Status(error::FAILED_PRECONDITION, "Failed to read arg shape."); + } + + // Const - shape_as_tensor + NodeDef shape_tensor_def; + shape_tensor_def.set_op("Const"); + shape_tensor_def.set_name(graph->NewName( + strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor"))); + AddNodeAttr("dtype", DT_INT32, &shape_tensor_def); + TensorProto tensorshape_proto; + tensorshape_proto.set_dtype(DT_INT32); + for (int i = 0; i < var_shape.dims(); ++i) { + tensorshape_proto.add_int_val(var_shape.dim_size(i)); + } + TensorShape shape_shape({var_shape.dims()}); + shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape()); + AddNodeAttr("value", tensorshape_proto, &shape_tensor_def); + Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status); + TF_RETURN_IF_ERROR(status); + + // Const - initializer value + NodeDef init_val_def; + init_val_def.set_op("Const"); + init_val_def.set_name(graph->NewName( + strings::StrCat(name_prefix, "/Initializer/zeros/const_val"))); + TensorProto tensor_proto; + tensor_proto.set_dtype(dtype); + if (dtype == DT_FLOAT) { + tensor_proto.add_float_val(0.0f); + } else if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { + tensor_proto.add_half_val(0); + } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8) { + tensor_proto.add_int_val(0); + } else if (dtype == DT_INT64) { + tensor_proto.add_int64_val(0); + } else if (dtype == DT_BOOL) { + tensor_proto.add_bool_val(false); + } else { + return errors::Internal( + "Unable to create zero-init dummy arg tensor for type ", dtype); + } + TensorShape scalar_shape({}); + scalar_shape.AsProto(tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", tensor_proto, &init_val_def); + AddNodeAttr("dtype", dtype, &init_val_def); + Node* init_val_node = graph->AddNode(init_val_def, &status); + TF_RETURN_IF_ERROR(status); + + // Fill node + NodeDef fill_def; + fill_def.set_op("Fill"); + fill_def.set_device(host_cpu_device); + fill_def.set_name( + graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros"))); + AddNodeAttr("T", dtype, &fill_def); + AddNodeAttr("index_type", DT_INT32, &fill_def); + Node* fill_node = graph->AddNode(fill_def, &status); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0); + graph->AddEdge(init_val_node, 0, fill_node, 1); + + return fill_node; +} + // Helper that creates an IdentityN node containing all of the variables // values on CPU device 'device', except for those that will be split across // cores. (For split variables, this may cause additional cross-host data @@ -2592,6 +2671,11 @@ namespace { // simple, and most models use pure replication where all cores want all the // variables. // +// If enable_xla_param_broadcast is set to true, then per-host dummy +// tensor args are created on all hosts except for the primary host. In this +// scheme, the dummy args feed the IdentityN node on their local host. All +// are zero-initialized. +// // Returns the node and its output index to be consumed by TPUExecute for the // requested variable index. xla::StatusOr CreateOrGetPerHostVariableCopy( @@ -2599,7 +2683,9 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( const std::vector& variable_reads, const DistributedTPURewritePass::ParameterInfo& params_info, const std::vector& arg_shardings, - const Node& replicate_node, + const Node& replicate_node, const bool enable_xla_param_broadcast, + const int num_cores_per_replica, + const std::vector& arg_shapes, absl::flat_hash_map>* per_host_var_copies, Graph* graph) { auto it = per_host_var_copies->find(host_cpu_device); @@ -2607,6 +2693,12 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( return it->second[var_index]; } + // Variable replication relies on identification of a master. + DeviceNameUtils::ParsedName parsed_device; + TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device)); + TF_RET_CHECK(parsed_device.has_task); + VLOG(1) << "Creating per-host IdentityN node for task " << parsed_device.task; + DataTypeVector dtypes; // Per-variable data source for TPUExecute. std::vector index_mapping; @@ -2632,6 +2724,8 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( ndef.set_op("IdentityN"); ndef.set_device(host_cpu_device); AddNodeAttr("T", dtypes, &ndef); + // TF meta-optimizer should skip this node for constant folding. + AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef); Status s; Node* id_node = graph->AddNode(ndef, &s); TF_RETURN_IF_ERROR(s); @@ -2641,8 +2735,36 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( if (index_mapping[i].node == nullptr) { // Fill index_mapping with the actual IdentityN node. index_mapping[i].node = id_node; - // Add the edge to id_node. - graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); + if (parsed_device.task == 0 || !enable_xla_param_broadcast) { + // XLA broadcast mode is not enabled, so use the variable reads as args + // to TPUExecuteOp. For task 0, variable reads are always used + // regardless of XLA broadcast. + + // Add the variable read edge to id_node. + graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); + } else { + // XLA broadcast mode is enabled. Create zero-valued dummy tensors to + // use as variable args in the TPUExecuteOp. + int64 orig_arg_num = i + params_info.NumPerReplicaArgs() + + params_info.NumBroadcastArgs(); + if (num_cores_per_replica > 1) { + LOG(WARNING) << "XLA parameter broadcast is only supported for " + "replicated parameters. Falling back to " + "non-broadcast mode for the parameter associated " + "with the following variable read: " + << variable_reads[i]->name(); + graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); + continue; + } + string dummy_name = + strings::StrCat(variable_reads[i]->name(), + absl::StrFormat("/dummy_%d", parsed_device.task)); + TF_ASSIGN_OR_RETURN( + Node * var_read, + CreatePerHostDummyArgs(arg_shapes[orig_arg_num], host_cpu_device, + variable_reads[i], dummy_name, graph)); + graph->AddEdge(var_read, 0, id_node, index_mapping[i].index); + } } } @@ -3008,11 +3130,13 @@ Status DistributedTPURewritePass::BuildExecuteNodes( string device; TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( tpu_device_names[replica][core], &device)); - TF_ASSIGN_OR_RETURN(auto var_data, - CreateOrGetPerHostVariableCopy( - device, variable_num, variable_reads, - params_info, arg_shardings, replicate_node, - &per_host_var_copies, graph)); + TF_ASSIGN_OR_RETURN( + auto var_data, + CreateOrGetPerHostVariableCopy( + device, variable_num, variable_reads, params_info, + arg_shardings, replicate_node, enable_xla_param_broadcast_, + num_cores_per_replica, arg_shapes, &per_host_var_copies, + graph)); if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { const xla::OpSharding& sharding = arg_shardings[orig_arg_num]; @@ -4321,12 +4445,13 @@ bool DistributedTPURewritePass:: bool DistributedTPURewritePass:: enable_cross_replica_sharding_mirrored_variables_ = true; bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false; +bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false; /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions( bool distribute_vars, bool allow_xla_spmd_partition, bool replicate_inputs_outputs_by_default_for_xla_spmd, bool enable_cross_replica_sharding_mirrored_variables, - bool enable_automatic_model_parallelism) { + bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast) { distribute_vars_ = distribute_vars; allow_xla_spmd_partition_ = allow_xla_spmd_partition; replicate_inputs_outputs_by_default_for_xla_spmd_ = @@ -4334,6 +4459,7 @@ bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false; enable_cross_replica_sharding_mirrored_variables_ = enable_cross_replica_sharding_mirrored_variables; enable_automatic_model_parallelism_ = enable_automatic_model_parallelism; + enable_xla_param_broadcast_ = enable_xla_param_broadcast; } } // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h index a9692cc0edb..add806fb74f 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -132,7 +132,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass { bool distribute_vars, bool allow_xla_spmd_partition, bool replicate_inputs_outputs_by_default_for_xla_spmd, bool enable_cross_replica_sharding_mirrored_variables, - bool enable_automatic_model_parallelism); + bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast); Status Run(const GraphOptimizationPassOptions& options) override; @@ -588,6 +588,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass { static bool replicate_inputs_outputs_by_default_for_xla_spmd_; static bool enable_cross_replica_sharding_mirrored_variables_; static bool enable_automatic_model_parallelism_; + static bool enable_xla_param_broadcast_; }; } // namespace tensorflow