From 5b69bbfb413952ccaf484f2ea55b95cfd1bc076c Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 19 Jan 2021 10:16:05 -0800 Subject: [PATCH] Update DistributedTPURewritePass to populate OpMetadata in OpSharding. Arguments and results to a TPU computation will now have OpMetadata containing the source of the sharding attribute (TF op type and name). PiperOrigin-RevId: 352596417 Change-Id: I75a7114f039b9938ac1479d61ffd37a3ed0f843d --- tensorflow/compiler/tf2xla/sharding_util.cc | 63 ++++++++++---- tensorflow/compiler/tf2xla/sharding_util.h | 11 +-- .../compiler/tf2xla/sharding_util_test.cc | 84 +++++++++++++++++++ tensorflow/compiler/tf2xla/tf2xla_util.cc | 3 +- .../compiler/tf2xla/tf2xla_util_test.cc | 9 +- .../compiler/tf2xla/xla_compilation_device.cc | 5 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 3 +- .../distributed_tpu_rewrite_pass.cc | 55 ++++++++---- 8 files changed, 190 insertions(+), 43 deletions(-) diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 90585c9d98a..1806c93a497 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -26,6 +26,26 @@ const char kShardingAttribute[] = "_XlaSharding"; } // namespace namespace { +xla::OpMetadata CreateOpMetadata(const std::string& op_type, + const std::string& op_name) { + xla::OpMetadata metadata; + metadata.set_op_type(op_type); + metadata.set_op_name(op_name); + return metadata; +} + +void AssignOpMetadataToSharding(xla::OpSharding& sharding, + const string& op_type, const string& op_name) { + auto metadata = CreateOpMetadata(op_type, op_name); + if (sharding.type() == xla::OpSharding::TUPLE) { + for (auto& sharding_element : *sharding.mutable_tuple_shardings()) { + *sharding_element.add_metadata() = metadata; + } + } else { + *sharding.add_metadata() = metadata; + } +} + Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, @@ -35,7 +55,8 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) { xla::StatusOr> ParseShardingFromDevice( const string& device_name, int num_cores_per_replica, - absl::optional explicit_sharding) { + absl::optional explicit_sharding, + absl::optional metadata) { if (device_name.empty()) { return explicit_sharding; } @@ -56,39 +77,50 @@ xla::StatusOr> ParseShardingFromDevice( if (core < 0 || core >= num_cores_per_replica) { return CoreOutOfRangeError(core, num_cores_per_replica); } - return absl::optional( - xla::sharding_builder::AssignDevice(core)); + auto sharding = xla::sharding_builder::AssignDevice(core); + if (metadata.has_value()) { + *sharding.add_metadata() = metadata.value(); + } + return absl::optional(sharding); } } xla::StatusOr> ParseShardingFromDevice( - const NodeDef& node_def, int num_cores_per_replica) { + const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) { const string& device_name = node_def.device(); TF_ASSIGN_OR_RETURN(absl::optional sharding, - GetShardingFromNodeDef(node_def)); - return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); + GetShardingFromNodeDef(node_def, add_metadata)); + return ParseShardingFromDevice( + device_name, num_cores_per_replica, sharding, + add_metadata ? absl::optional( + CreateOpMetadata(node_def.op(), node_def.name())) + : absl::nullopt); } xla::StatusOr> ParseShardingFromDevice( - const Node& node, int num_cores_per_replica) { + const Node& node, int num_cores_per_replica, bool add_metadata) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } TF_ASSIGN_OR_RETURN(absl::optional sharding, - GetShardingFromNodeDef(node.def())); - return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); + GetShardingFromNodeDef(node.def(), add_metadata)); + return ParseShardingFromDevice( + device_name, num_cores_per_replica, sharding, + add_metadata ? absl::optional( + CreateOpMetadata(node.type_string(), node.name())) + : absl::nullopt); } xla::StatusOr> ParseShardingFromEdgeSource( - const Edge& edge, int num_cores_per_replica) { + const Edge& edge, int num_cores_per_replica, bool add_metadata) { if (edge.src() == nullptr) { return tensorflow::errors::InvalidArgument( "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString()); } - TF_ASSIGN_OR_RETURN( - absl::optional sharding, - ParseShardingFromDevice(*edge.src(), num_cores_per_replica)); + TF_ASSIGN_OR_RETURN(absl::optional sharding, + ParseShardingFromDevice( + *edge.src(), num_cores_per_replica, add_metadata)); if (sharding.has_value() && sharding.value().type() == xla::OpSharding::TUPLE) { if (edge.src_output() < 0 || @@ -116,7 +148,7 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { } xla::StatusOr> GetShardingFromNodeDef( - const NodeDef& node_def) { + const NodeDef& node_def, bool add_metadata) { if (!HasNodeAttr(node_def, kShardingAttribute)) { return absl::optional(); } @@ -128,6 +160,9 @@ xla::StatusOr> GetShardingFromNodeDef( "Experimental _XlaSharding attribute was not a valid encoded " "xla::OpSharding proto."); } + if (add_metadata) { + AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name()); + } return absl::optional(sharding); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index 07657c656d3..728991bb6f1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -35,22 +35,23 @@ namespace tensorflow { // - a sharding set as per xla::sharding_builder::AssignDevice. xla::StatusOr> ParseShardingFromDevice( const string& device_name, int num_cores_per_replica, - absl::optional explicit_sharding = absl::nullopt); + absl::optional explicit_sharding = absl::nullopt, + absl::optional metadata = absl::nullopt); xla::StatusOr> ParseShardingFromDevice( - const Node& node, int num_cores_per_replica); + const Node& node, int num_cores_per_replica, bool add_metadata); xla::StatusOr> ParseShardingFromDevice( - const NodeDef& node_def, int num_cores_per_replica); + const NodeDef& node_def, int num_cores_per_replica, bool add_metadata); xla::StatusOr> ParseShardingFromEdgeSource( - const Edge& edge, int num_cores_per_replica); + const Edge& edge, int num_cores_per_replica, bool add_metadata); void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); // Get sharding inforamtion from node. xla::StatusOr> GetShardingFromNodeDef( - const NodeDef& node_def); + const NodeDef& node_def, bool add_metadata); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index a9ba7e98ce1..133dd4d551f 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include + #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -54,4 +56,86 @@ TEST(CoreUtilTest, ParseShardingFromDevice) { EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); } +class ShardingWithMetadataTest + : public ::testing::TestWithParam {}; + +TEST_P(ShardingWithMetadataTest, GetShardingFromNode) { + NodeDef node_def; + { + node_def.set_op("_Arg"); + node_def.set_name("arg"); + AttrValue xla_sharding; + xla_sharding.set_s(""); + AttrValue index; + index.set_i(0); + AttrValue type; + type.set_type(DataType::DT_FLOAT); + node_def.mutable_attr()->insert( + {{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}}); + } + + auto check_metadata = [](const xla::OpSharding& sharding) { + ASSERT_EQ(sharding.metadata_size(), 1); + const auto& metadata = sharding.metadata(0); + EXPECT_EQ(metadata.op_type(), "_Arg"); + EXPECT_EQ(metadata.op_name(), "arg"); + }; + + auto test_sharding_metadata = + [&check_metadata]( + const std::function>()>& + fn) { + auto status_or_sharding = fn(); + TF_ASSERT_OK(status_or_sharding.status()); + ASSERT_TRUE(status_or_sharding.ValueOrDie().has_value()); + auto& sharding = status_or_sharding.ValueOrDie(); + ASSERT_TRUE(sharding.has_value()); + if (sharding->type() == xla::OpSharding::TUPLE) { + EXPECT_TRUE(sharding->metadata().empty()); + for (const auto& sharding_element : sharding->tuple_shardings()) { + check_metadata(sharding_element); + } + } else { + check_metadata(sharding.value()); + } + }; + + { + test_sharding_metadata([&node_def]() { + return GetShardingFromNodeDef(node_def, /*add_metadata=*/true); + }); + } + + { + test_sharding_metadata([&node_def]() { + return ParseShardingFromDevice(node_def, /*num_cores_per_replica=*/1, + /*add_metadata=*/true); + }); + } + + { + Graph graph(OpRegistry::Global()); + Status status; + Node* node = graph.AddNode(node_def, &status); + TF_ASSERT_OK(status); + + test_sharding_metadata([node]() { + return ParseShardingFromDevice(*node, /*num_cores_per_replica=*/1, + /*add_metadata=*/true); + }); + } +} + +xla::OpSharding CreateTupleSharding() { + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED); + sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED); + return sharding; +} + +INSTANTIATE_TEST_SUITE_P(GetShardingFromNode, ShardingWithMetadataTest, + ::testing::Values(xla::sharding_builder::Replicate(), + CreateTupleSharding())); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 53ff72a7d99..bab2f0913a7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -506,7 +506,8 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { absl::optional sharding, ParseShardingFromDevice( *possible_match, - /*num_cores_per_replica=*/std::numeric_limits::max())); + /*num_cores_per_replica=*/std::numeric_limits::max(), + /*add_metadata=*/false)); if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { const int core_annotation = sharding.value().tile_assignment_devices(0); if (core == -1 || core > core_annotation) { diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 827d0f389e2..a69637b61a8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -242,7 +242,8 @@ TEST(SetNodeShardingFromNeighbors, Basic) { // Test where one input to c_node has a device. a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2"); TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); - auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica, + /*add_metadata=*/false); TF_ASSERT_OK(parse_status.status()); ASSERT_TRUE(parse_status.ValueOrDie().has_value()); EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0)); @@ -250,14 +251,16 @@ TEST(SetNodeShardingFromNeighbors, Basic) { // Test where two inputs to c_node have a device. b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1"); TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); - parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica, + /*add_metadata=*/false); TF_ASSERT_OK(parse_status.status()); ASSERT_TRUE(parse_status.ValueOrDie().has_value()); EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); // Test setting based on out edges. TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true)); - parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica); + parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica, + /*add_metadata=*/false); TF_ASSERT_OK(parse_status.status()); ASSERT_TRUE(parse_status.ValueOrDie().has_value()); EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 0de00581a2f..010a9b3f075 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -110,8 +110,9 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, AttachLocationToMetadata(metadata, op_kernel, xla_context); b->SetOpMetadata(metadata); - auto sharding_parse_result = ParseShardingFromDevice( - op_kernel->def(), std::numeric_limits::max()); + auto sharding_parse_result = + ParseShardingFromDevice(op_kernel->def(), std::numeric_limits::max(), + /*add_metadata=*/false); OP_REQUIRES_OK(context, sharding_parse_result.status()); absl::optional op_sharding = sharding_parse_result.ValueOrDie(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8dc8da7efed..fe72698a07d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -93,7 +93,8 @@ ComputeArgAndRetvalShardings(const Graph& graph) { [](const Node* n) -> xla::StatusOr> { TF_ASSIGN_OR_RETURN( auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); + ParseShardingFromDevice(*n, std::numeric_limits::max(), + /*add_metadata=*/false)); return sharding; }; std::map arg_shardings; diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index 614a3ef82ed..36a2c4ace6d 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -1172,9 +1172,17 @@ bool PlaceOpsOnTPU(Node* node) { return true; } +xla::OpMetadata CreateOpMetadataFromNode(const Node& node) { + xla::OpMetadata metadata; + metadata.set_op_type(node.type_string()); + metadata.set_op_name(node.name()); + return metadata; +} + // Validate sharding configuration derived from XlaSharding attribute. // Infer the core id from the OpSharding, if necessary. Status ParseAndValidateSharding(const xla::OpSharding& sharding, + const Node& node, const int num_cores_per_replica, int64* inferred_core_id, absl::optional* result) { @@ -1203,7 +1211,9 @@ Status ParseAndValidateSharding(const xla::OpSharding& sharding, if (result_value_serialized != sharding_serialized) { // We see different shardings, assign to core 0. - result->emplace(xla::sharding_builder::AssignDevice(0)); + auto core_zero_sharding = xla::sharding_builder::AssignDevice(0); + *core_zero_sharding.add_metadata() = CreateOpMetadataFromNode(node); + result->emplace(core_zero_sharding); } } } @@ -1232,7 +1242,8 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica, // If |node| has `device` attribute or is a XlaSharding op, // return the parsed OpSharding. TF_ASSIGN_OR_RETURN(absl::optional sharding, - ParseShardingFromDevice(node, num_cores_per_replica)); + ParseShardingFromDevice(node, num_cores_per_replica, + /*add_metadata=*/true)); if (sharding.has_value()) return sharding; // XlaShardingOp may be followed by an identity or followed by identity @@ -1244,9 +1255,10 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica, potential_nodes_with_input_sharding) { if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue; - TF_ASSIGN_OR_RETURN(absl::optional sharding_config, - ParseShardingFromDevice(*maybe_node_with_sharding_info, - num_cores_per_replica)); + TF_ASSIGN_OR_RETURN( + absl::optional sharding_config, + ParseShardingFromDevice(*maybe_node_with_sharding_info, + num_cores_per_replica, /*add_metadata=*/true)); if (sharding_config.has_value()) return sharding_config; } return sharding; @@ -1273,8 +1285,9 @@ Status ParseAndValidateShardingFromNeighbors( absl::optional sharding, ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node)); if (sharding.has_value()) { - TF_RETURN_IF_ERROR(ParseAndValidateSharding( - *sharding, num_cores_per_replica, inferred_core_id, result)); + TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, neighbor_node, + num_cores_per_replica, + inferred_core_id, result)); return Status::OK(); } @@ -1295,8 +1308,9 @@ Status ParseAndValidateShardingFromNeighbors( absl::optional sharding, ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst())); if (sharding.has_value()) { - TF_RETURN_IF_ERROR(ParseAndValidateSharding( - *sharding, num_cores_per_replica, inferred_core_id, result)); + TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, *e->dst(), + num_cores_per_replica, + inferred_core_id, result)); return Status::OK(); } } @@ -1775,7 +1789,8 @@ static Status ValidateCoreNumbers(const Graph& graph, int num_cores_per_replica) { for (Node* n : graph.nodes()) { TF_ASSIGN_OR_RETURN(absl::optional sharding, - ParseShardingFromDevice(*n, num_cores_per_replica)); + ParseShardingFromDevice(*n, num_cores_per_replica, + /*add_metadata=*/true)); } return Status::OK(); } @@ -1930,8 +1945,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( Node* input_node; TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node)); if (input_node->type_string() == kTPUPartitionedInput) { - TF_ASSIGN_OR_RETURN(absl::optional parsed_sharding, - GetShardingFromNodeDef(input_node->def())); + TF_ASSIGN_OR_RETURN( + absl::optional parsed_sharding, + GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true)); if (!parsed_sharding.has_value()) return errors::InvalidArgument("Missing _XlaSharding attr from: ", input_node->DebugString()); @@ -1946,8 +1962,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( Node* input_node; TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node)); if (input_node->type_string() == kVarHandleOp) { - TF_ASSIGN_OR_RETURN(absl::optional parsed_sharding, - GetShardingFromNodeDef(input_node->def())); + TF_ASSIGN_OR_RETURN( + absl::optional parsed_sharding, + GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true)); if (parsed_sharding.has_value()) { sharding = parsed_sharding; VLOG(1) << "Arg " << i << " parsed sharding information from " @@ -1988,6 +2005,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( } sharding = xla::sharding_builder::AssignDevice(*assigned_core); } + *sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node); } else if (sharding->type() == xla::OpSharding::MAXIMAL) { assigned_core = sharding->tile_assignment_devices(0); } else if (sharding->type() != xla::OpSharding::REPLICATED && @@ -2036,12 +2054,14 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( TF_ASSIGN_OR_RETURN( absl::optional sharding, - ParseShardingFromEdgeSource(*edge, num_cores_per_replica)); + ParseShardingFromEdgeSource(*edge, num_cores_per_replica, + /*add_metadata=*/true)); if (partitioned_output_nodes.contains(i)) { Node* output_node = partitioned_output_nodes[i]; - TF_ASSIGN_OR_RETURN(absl::optional parsed_sharding, - GetShardingFromNodeDef(output_node->def())); + TF_ASSIGN_OR_RETURN( + absl::optional parsed_sharding, + GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true)); if (parsed_sharding.has_value()) { sharding = parsed_sharding; VLOG(1) << "Retval " << i << " parsed sharding information from " @@ -2079,6 +2099,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( } sharding = xla::sharding_builder::AssignDevice(*assigned_core); } + *sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node); } if (assigned_core.has_value()) { retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));