Update DistributedTPURewritePass to populate OpMetadata in OpSharding.

Arguments and results to a TPU computation will now have OpMetadata containing the source of the sharding attribute (TF op type and name).

PiperOrigin-RevId: 352596417
Change-Id: I75a7114f039b9938ac1479d61ffd37a3ed0f843d
This commit is contained in:
Andy Ly 2021-01-19 10:16:05 -08:00 committed by TensorFlower Gardener
parent bfc75faae7
commit 5b69bbfb41
8 changed files with 190 additions and 43 deletions

View File

@ -26,6 +26,26 @@ const char kShardingAttribute[] = "_XlaSharding";
} // namespace
namespace {
xla::OpMetadata CreateOpMetadata(const std::string& op_type,
const std::string& op_name) {
xla::OpMetadata metadata;
metadata.set_op_type(op_type);
metadata.set_op_name(op_name);
return metadata;
}
void AssignOpMetadataToSharding(xla::OpSharding& sharding,
const string& op_type, const string& op_name) {
auto metadata = CreateOpMetadata(op_type, op_name);
if (sharding.type() == xla::OpSharding::TUPLE) {
for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
*sharding_element.add_metadata() = metadata;
}
} else {
*sharding.add_metadata() = metadata;
}
}
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
@ -35,7 +55,8 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
absl::optional<xla::OpSharding> explicit_sharding) {
absl::optional<xla::OpSharding> explicit_sharding,
absl::optional<xla::OpMetadata> metadata) {
if (device_name.empty()) {
return explicit_sharding;
}
@ -56,39 +77,50 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
if (core < 0 || core >= num_cores_per_replica) {
return CoreOutOfRangeError(core, num_cores_per_replica);
}
return absl::optional<xla::OpSharding>(
xla::sharding_builder::AssignDevice(core));
auto sharding = xla::sharding_builder::AssignDevice(core);
if (metadata.has_value()) {
*sharding.add_metadata() = metadata.value();
}
return absl::optional<xla::OpSharding>(sharding);
}
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const NodeDef& node_def, int num_cores_per_replica) {
const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
const string& device_name = node_def.device();
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node_def));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
GetShardingFromNodeDef(node_def, add_metadata));
return ParseShardingFromDevice(
device_name, num_cores_per_replica, sharding,
add_metadata ? absl::optional<xla::OpMetadata>(
CreateOpMetadata(node_def.op(), node_def.name()))
: absl::nullopt);
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const Node& node, int num_cores_per_replica) {
const Node& node, int num_cores_per_replica, bool add_metadata) {
string device_name = node.assigned_device_name();
if (device_name.empty()) {
device_name = node.requested_device();
}
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node.def()));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
GetShardingFromNodeDef(node.def(), add_metadata));
return ParseShardingFromDevice(
device_name, num_cores_per_replica, sharding,
add_metadata ? absl::optional<xla::OpMetadata>(
CreateOpMetadata(node.type_string(), node.name()))
: absl::nullopt);
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
const Edge& edge, int num_cores_per_replica) {
const Edge& edge, int num_cores_per_replica, bool add_metadata) {
if (edge.src() == nullptr) {
return tensorflow::errors::InvalidArgument(
"Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
}
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(*edge.src(), num_cores_per_replica));
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(
*edge.src(), num_cores_per_replica, add_metadata));
if (sharding.has_value() &&
sharding.value().type() == xla::OpSharding::TUPLE) {
if (edge.src_output() < 0 ||
@ -116,7 +148,7 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
}
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def) {
const NodeDef& node_def, bool add_metadata) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return absl::optional<xla::OpSharding>();
}
@ -128,6 +160,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
if (add_metadata) {
AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
}
return absl::optional<xla::OpSharding>(sharding);
}
} // namespace tensorflow

View File

@ -35,22 +35,23 @@ namespace tensorflow {
// - a sharding set as per xla::sharding_builder::AssignDevice.
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt,
absl::optional<xla::OpMetadata> metadata = absl::nullopt);
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const Node& node, int num_cores_per_replica);
const Node& node, int num_cores_per_replica, bool add_metadata);
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const NodeDef& node_def, int num_cores_per_replica);
const NodeDef& node_def, int num_cores_per_replica, bool add_metadata);
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
const Edge& edge, int num_cores_per_replica);
const Edge& edge, int num_cores_per_replica, bool add_metadata);
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
// Get sharding inforamtion from node.
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def);
const NodeDef& node_def, bool add_metadata);
} // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include <functional>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -54,4 +56,86 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
}
class ShardingWithMetadataTest
: public ::testing::TestWithParam<xla::OpSharding> {};
TEST_P(ShardingWithMetadataTest, GetShardingFromNode) {
NodeDef node_def;
{
node_def.set_op("_Arg");
node_def.set_name("arg");
AttrValue xla_sharding;
xla_sharding.set_s("");
AttrValue index;
index.set_i(0);
AttrValue type;
type.set_type(DataType::DT_FLOAT);
node_def.mutable_attr()->insert(
{{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}});
}
auto check_metadata = [](const xla::OpSharding& sharding) {
ASSERT_EQ(sharding.metadata_size(), 1);
const auto& metadata = sharding.metadata(0);
EXPECT_EQ(metadata.op_type(), "_Arg");
EXPECT_EQ(metadata.op_name(), "arg");
};
auto test_sharding_metadata =
[&check_metadata](
const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>&
fn) {
auto status_or_sharding = fn();
TF_ASSERT_OK(status_or_sharding.status());
ASSERT_TRUE(status_or_sharding.ValueOrDie().has_value());
auto& sharding = status_or_sharding.ValueOrDie();
ASSERT_TRUE(sharding.has_value());
if (sharding->type() == xla::OpSharding::TUPLE) {
EXPECT_TRUE(sharding->metadata().empty());
for (const auto& sharding_element : sharding->tuple_shardings()) {
check_metadata(sharding_element);
}
} else {
check_metadata(sharding.value());
}
};
{
test_sharding_metadata([&node_def]() {
return GetShardingFromNodeDef(node_def, /*add_metadata=*/true);
});
}
{
test_sharding_metadata([&node_def]() {
return ParseShardingFromDevice(node_def, /*num_cores_per_replica=*/1,
/*add_metadata=*/true);
});
}
{
Graph graph(OpRegistry::Global());
Status status;
Node* node = graph.AddNode(node_def, &status);
TF_ASSERT_OK(status);
test_sharding_metadata([node]() {
return ParseShardingFromDevice(*node, /*num_cores_per_replica=*/1,
/*add_metadata=*/true);
});
}
}
xla::OpSharding CreateTupleSharding() {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
return sharding;
}
INSTANTIATE_TEST_SUITE_P(GetShardingFromNode, ShardingWithMetadataTest,
::testing::Values(xla::sharding_builder::Replicate(),
CreateTupleSharding()));
} // namespace tensorflow

View File

@ -506,7 +506,8 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(
*possible_match,
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
/*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
/*add_metadata=*/false));
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
const int core_annotation = sharding.value().tile_assignment_devices(0);
if (core == -1 || core > core_annotation) {

View File

@ -242,7 +242,8 @@ TEST(SetNodeShardingFromNeighbors, Basic) {
// Test where one input to c_node has a device.
a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica,
/*add_metadata=*/false);
TF_ASSERT_OK(parse_status.status());
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
@ -250,14 +251,16 @@ TEST(SetNodeShardingFromNeighbors, Basic) {
// Test where two inputs to c_node have a device.
b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica,
/*add_metadata=*/false);
TF_ASSERT_OK(parse_status.status());
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
// Test setting based on out edges.
TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica,
/*add_metadata=*/false);
TF_ASSERT_OK(parse_status.status());
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));

View File

@ -110,8 +110,9 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
AttachLocationToMetadata(metadata, op_kernel, xla_context);
b->SetOpMetadata(metadata);
auto sharding_parse_result = ParseShardingFromDevice(
op_kernel->def(), std::numeric_limits<int>::max());
auto sharding_parse_result =
ParseShardingFromDevice(op_kernel->def(), std::numeric_limits<int>::max(),
/*add_metadata=*/false);
OP_REQUIRES_OK(context, sharding_parse_result.status());
absl::optional<xla::OpSharding> op_sharding =
sharding_parse_result.ValueOrDie();

View File

@ -93,7 +93,8 @@ ComputeArgAndRetvalShardings(const Graph& graph) {
[](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> {
TF_ASSIGN_OR_RETURN(
auto sharding,
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
/*add_metadata=*/false));
return sharding;
};
std::map<int, xla::OpSharding> arg_shardings;

View File

@ -1172,9 +1172,17 @@ bool PlaceOpsOnTPU(Node* node) {
return true;
}
xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
xla::OpMetadata metadata;
metadata.set_op_type(node.type_string());
metadata.set_op_name(node.name());
return metadata;
}
// Validate sharding configuration derived from XlaSharding attribute.
// Infer the core id from the OpSharding, if necessary.
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
const Node& node,
const int num_cores_per_replica,
int64* inferred_core_id,
absl::optional<xla::OpSharding>* result) {
@ -1203,7 +1211,9 @@ Status ParseAndValidateSharding(const xla::OpSharding& sharding,
if (result_value_serialized != sharding_serialized) {
// We see different shardings, assign to core 0.
result->emplace(xla::sharding_builder::AssignDevice(0));
auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
*core_zero_sharding.add_metadata() = CreateOpMetadataFromNode(node);
result->emplace(core_zero_sharding);
}
}
}
@ -1232,7 +1242,8 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
// If |node| has `device` attribute or is a XlaSharding op,
// return the parsed OpSharding.
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(node, num_cores_per_replica));
ParseShardingFromDevice(node, num_cores_per_replica,
/*add_metadata=*/true));
if (sharding.has_value()) return sharding;
// XlaShardingOp may be followed by an identity or followed by identity
@ -1244,9 +1255,10 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
potential_nodes_with_input_sharding) {
if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding_config,
ParseShardingFromDevice(*maybe_node_with_sharding_info,
num_cores_per_replica));
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding_config,
ParseShardingFromDevice(*maybe_node_with_sharding_info,
num_cores_per_replica, /*add_metadata=*/true));
if (sharding_config.has_value()) return sharding_config;
}
return sharding;
@ -1273,8 +1285,9 @@ Status ParseAndValidateShardingFromNeighbors(
absl::optional<xla::OpSharding> sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
if (sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
*sharding, num_cores_per_replica, inferred_core_id, result));
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, neighbor_node,
num_cores_per_replica,
inferred_core_id, result));
return Status::OK();
}
@ -1295,8 +1308,9 @@ Status ParseAndValidateShardingFromNeighbors(
absl::optional<xla::OpSharding> sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
if (sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
*sharding, num_cores_per_replica, inferred_core_id, result));
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, *e->dst(),
num_cores_per_replica,
inferred_core_id, result));
return Status::OK();
}
}
@ -1775,7 +1789,8 @@ static Status ValidateCoreNumbers(const Graph& graph,
int num_cores_per_replica) {
for (Node* n : graph.nodes()) {
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(*n, num_cores_per_replica));
ParseShardingFromDevice(*n, num_cores_per_replica,
/*add_metadata=*/true));
}
return Status::OK();
}
@ -1930,8 +1945,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
Node* input_node;
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
if (input_node->type_string() == kTPUPartitionedInput) {
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def()));
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
if (!parsed_sharding.has_value())
return errors::InvalidArgument("Missing _XlaSharding attr from: ",
input_node->DebugString());
@ -1946,8 +1962,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
Node* input_node;
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
if (input_node->type_string() == kVarHandleOp) {
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def()));
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
if (parsed_sharding.has_value()) {
sharding = parsed_sharding;
VLOG(1) << "Arg " << i << " parsed sharding information from "
@ -1988,6 +2005,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
}
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
}
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
assigned_core = sharding->tile_assignment_devices(0);
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
@ -2036,12 +2054,14 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding,
ParseShardingFromEdgeSource(*edge, num_cores_per_replica));
ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
/*add_metadata=*/true));
if (partitioned_output_nodes.contains(i)) {
Node* output_node = partitioned_output_nodes[i];
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(output_node->def()));
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
if (parsed_sharding.has_value()) {
sharding = parsed_sharding;
VLOG(1) << "Retval " << i << " parsed sharding information from "
@ -2079,6 +2099,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
}
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
}
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
}
if (assigned_core.has_value()) {
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));