From fc9c057b1ac7883551e72c833a5429f4bb6dc47a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 3 Aug 2020 15:36:31 -0700 Subject: [PATCH] Adds ParseShardingFromEdgeSource(). Makes DistributedTPURewritePass::AssignArgsAndRetvalsToCores() support TUPLE sharding for return values. PiperOrigin-RevId: 324697874 Change-Id: I3039da1731c9622ebeb0bf9c3b45185e220267af --- tensorflow/compiler/tf2xla/sharding_util.cc | 24 +++++++++++++++++++ tensorflow/compiler/tf2xla/sharding_util.h | 3 +++ .../distributed_tpu_rewrite_pass.cc | 20 ++++++++++------ .../distributed_tpu_rewrite_pass.h | 4 +++- tensorflow/core/tpu/kernels/BUILD | 1 - .../core/tpu/kernels/tpu_compile_op_common.cc | 2 +- 6 files changed, 44 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 366e8d49228..90585c9d98a 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -80,6 +80,30 @@ xla::StatusOr> ParseShardingFromDevice( return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } +xla::StatusOr> ParseShardingFromEdgeSource( + const Edge& edge, int num_cores_per_replica) { + if (edge.src() == nullptr) { + return tensorflow::errors::InvalidArgument( + "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString()); + } + TF_ASSIGN_OR_RETURN( + absl::optional sharding, + ParseShardingFromDevice(*edge.src(), num_cores_per_replica)); + if (sharding.has_value() && + sharding.value().type() == xla::OpSharding::TUPLE) { + if (edge.src_output() < 0 || + edge.src_output() >= sharding.value().tuple_shardings_size()) { + return tensorflow::errors::InvalidArgument( + "Tuple index out of bound: edge=", edge.DebugString(), + " sharding=", sharding->DebugString()); + } + absl::optional subsharding = + sharding.value().tuple_shardings(edge.src_output()); + return subsharding; + } + return sharding; +} + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index 196434826f9..07657c656d3 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -43,6 +43,9 @@ xla::StatusOr> ParseShardingFromDevice( xla::StatusOr> ParseShardingFromDevice( const NodeDef& node_def, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromEdgeSource( + const Edge& edge, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); // Get sharding inforamtion from node. diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index 075a1ec9069..5fdc74b79fc 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -1813,7 +1813,8 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( } else if (sharding->type() != xla::OpSharding::REPLICATED && sharding->type() != xla::OpSharding::OTHER) { return tensorflow::errors::InvalidArgument( - "Unsupported argument sharding: ", sharding->DebugString()); + "Unsupported argument sharding (for arg ", n->DebugString(), + "): ", sharding->DebugString()); } if (assigned_core.has_value()) { args_device_selector.ReportDeviceAssigned(*assigned_core, i); @@ -1855,7 +1856,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( TF_ASSIGN_OR_RETURN( absl::optional sharding, - ParseShardingFromDevice(*edge->src(), num_cores_per_replica)); + ParseShardingFromEdgeSource(*edge, num_cores_per_replica)); if (partitioned_output_nodes.contains(i)) { Node* output_node = partitioned_output_nodes[i]; @@ -1883,7 +1884,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( } else if (sharding.value().type() != xla::OpSharding::REPLICATED && sharding.value().type() != xla::OpSharding::OTHER) { return tensorflow::errors::InvalidArgument( - "Unsupported argument sharding: ", sharding->DebugString()); + "Unsupported argument sharding for retval ", + retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ", + sharding->DebugString()); } } else { if (use_spmd) { @@ -2472,7 +2475,8 @@ xla::StatusOr CreateOrGetPerHostVariableCopy( Status DistributedTPURewritePass::BuildExecuteNodes( const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica, - const Node& replicate_node, const DataTypeVector& arg_types, + const Node& replicate_node, const std::vector& arg_names, + const DataTypeVector& arg_types, const std::vector& arg_shapes, const DataTypeVector& retval_types, const std::vector& arg_shardings, @@ -2595,7 +2599,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes( } } else { return tensorflow::errors::InvalidArgument( - "Unsupported argument sharding: ", sharding.DebugString()); + "Unsupported argument sharding for arg=", arg_names[i], + " shape=", arg_shapes[i].shape.DebugString(), ": ", + sharding.DebugString()); } } std::vector> core_retval_nums(num_cores_per_replica); @@ -3922,8 +3928,8 @@ Status DistributedTPURewritePass::FingerprintFunctionLibrary( std::vector variable_writes; TF_RETURN_IF_ERROR(BuildExecuteNodes( - params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types, - arg_shapes, retval_types, arg_sharding, retval_sharding, + params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names, + arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding, tf_device_assignment, compile_node, variable_reads, control_after_compilation, control_after, &variable_writes, graph)); bool contains_resource_write_op = diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h index 1931b4ac80f..a9692cc0edb 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -413,9 +413,10 @@ class DistributedTPURewritePass : public GraphOptimizationPass { // * `num_cores_per_replica` is the number of cores which are dedicated to // each replica. // * `replicate_node` is the original TPUReplicate node. - // * `arg_types` are the types of the arguments to the computation function + // * `arg_names` are the names of the arguments to the computation function // passed as argument to TPUReplicate, including per-replica, // broadcast, and variable arguments. + // * `arg_types` are the corresponding types of the arguments. // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if // applicable). // * `arg_shardings` and `retval_shardings` are mappings from @@ -431,6 +432,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass { static Status BuildExecuteNodes( const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica, const Node& replicate_node, + const std::vector& arg_names, const DataTypeVector& arg_types, const std::vector& arg_shapes, const DataTypeVector& retval_types, diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 6f74123131f..1336f52ed34 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -71,7 +71,6 @@ cc_library( "//tensorflow/core/tpu:tpu_api", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/stream_executor/tpu:status_helper", "//tensorflow/stream_executor/tpu:tpu_platform_interface", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index 8bd45db2206..ce18e844e66 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -117,7 +117,7 @@ Status SetPerCoreArgShapes( } else { TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED) << "Unsupported argument sharding: " - << proto_arg.sharding().DebugString(); + << " proto_arg=" << proto_arg.DebugString(); for (int core = 0; core < per_core_arg_shapes->size(); ++core) { (*arg_core_mapping)[arg_index].indices.push_back( (*per_core_arg_shapes)[core].size());