Update DistributedTPURewritePass to populate OpMetadata in OpSharding.
Arguments and results to a TPU computation will now have OpMetadata containing the source of the sharding attribute (TF op type and name). PiperOrigin-RevId: 352596417 Change-Id: I75a7114f039b9938ac1479d61ffd37a3ed0f843d
This commit is contained in:
parent
bfc75faae7
commit
5b69bbfb41
@ -26,6 +26,26 @@ const char kShardingAttribute[] = "_XlaSharding";
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
xla::OpMetadata CreateOpMetadata(const std::string& op_type,
|
||||||
|
const std::string& op_name) {
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type(op_type);
|
||||||
|
metadata.set_op_name(op_name);
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssignOpMetadataToSharding(xla::OpSharding& sharding,
|
||||||
|
const string& op_type, const string& op_name) {
|
||||||
|
auto metadata = CreateOpMetadata(op_type, op_name);
|
||||||
|
if (sharding.type() == xla::OpSharding::TUPLE) {
|
||||||
|
for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
|
||||||
|
*sharding_element.add_metadata() = metadata;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*sharding.add_metadata() = metadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid replicated core id: ", core,
|
"Invalid replicated core id: ", core,
|
||||||
@ -35,7 +55,8 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
|||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const string& device_name, int num_cores_per_replica,
|
const string& device_name, int num_cores_per_replica,
|
||||||
absl::optional<xla::OpSharding> explicit_sharding) {
|
absl::optional<xla::OpSharding> explicit_sharding,
|
||||||
|
absl::optional<xla::OpMetadata> metadata) {
|
||||||
if (device_name.empty()) {
|
if (device_name.empty()) {
|
||||||
return explicit_sharding;
|
return explicit_sharding;
|
||||||
}
|
}
|
||||||
@ -56,39 +77,50 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
|||||||
if (core < 0 || core >= num_cores_per_replica) {
|
if (core < 0 || core >= num_cores_per_replica) {
|
||||||
return CoreOutOfRangeError(core, num_cores_per_replica);
|
return CoreOutOfRangeError(core, num_cores_per_replica);
|
||||||
}
|
}
|
||||||
return absl::optional<xla::OpSharding>(
|
auto sharding = xla::sharding_builder::AssignDevice(core);
|
||||||
xla::sharding_builder::AssignDevice(core));
|
if (metadata.has_value()) {
|
||||||
|
*sharding.add_metadata() = metadata.value();
|
||||||
|
}
|
||||||
|
return absl::optional<xla::OpSharding>(sharding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const NodeDef& node_def, int num_cores_per_replica) {
|
const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
|
||||||
const string& device_name = node_def.device();
|
const string& device_name = node_def.device();
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
GetShardingFromNodeDef(node_def));
|
GetShardingFromNodeDef(node_def, add_metadata));
|
||||||
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
|
return ParseShardingFromDevice(
|
||||||
|
device_name, num_cores_per_replica, sharding,
|
||||||
|
add_metadata ? absl::optional<xla::OpMetadata>(
|
||||||
|
CreateOpMetadata(node_def.op(), node_def.name()))
|
||||||
|
: absl::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const Node& node, int num_cores_per_replica) {
|
const Node& node, int num_cores_per_replica, bool add_metadata) {
|
||||||
string device_name = node.assigned_device_name();
|
string device_name = node.assigned_device_name();
|
||||||
if (device_name.empty()) {
|
if (device_name.empty()) {
|
||||||
device_name = node.requested_device();
|
device_name = node.requested_device();
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
GetShardingFromNodeDef(node.def()));
|
GetShardingFromNodeDef(node.def(), add_metadata));
|
||||||
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
|
return ParseShardingFromDevice(
|
||||||
|
device_name, num_cores_per_replica, sharding,
|
||||||
|
add_metadata ? absl::optional<xla::OpMetadata>(
|
||||||
|
CreateOpMetadata(node.type_string(), node.name()))
|
||||||
|
: absl::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
||||||
const Edge& edge, int num_cores_per_replica) {
|
const Edge& edge, int num_cores_per_replica, bool add_metadata) {
|
||||||
if (edge.src() == nullptr) {
|
if (edge.src() == nullptr) {
|
||||||
return tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
|
"Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
absl::optional<xla::OpSharding> sharding,
|
ParseShardingFromDevice(
|
||||||
ParseShardingFromDevice(*edge.src(), num_cores_per_replica));
|
*edge.src(), num_cores_per_replica, add_metadata));
|
||||||
if (sharding.has_value() &&
|
if (sharding.has_value() &&
|
||||||
sharding.value().type() == xla::OpSharding::TUPLE) {
|
sharding.value().type() == xla::OpSharding::TUPLE) {
|
||||||
if (edge.src_output() < 0 ||
|
if (edge.src_output() < 0 ||
|
||||||
@ -116,7 +148,7 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||||
const NodeDef& node_def) {
|
const NodeDef& node_def, bool add_metadata) {
|
||||||
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
||||||
return absl::optional<xla::OpSharding>();
|
return absl::optional<xla::OpSharding>();
|
||||||
}
|
}
|
||||||
@ -128,6 +160,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
|||||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||||
"xla::OpSharding proto.");
|
"xla::OpSharding proto.");
|
||||||
}
|
}
|
||||||
|
if (add_metadata) {
|
||||||
|
AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
|
||||||
|
}
|
||||||
return absl::optional<xla::OpSharding>(sharding);
|
return absl::optional<xla::OpSharding>(sharding);
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,22 +35,23 @@ namespace tensorflow {
|
|||||||
// - a sharding set as per xla::sharding_builder::AssignDevice.
|
// - a sharding set as per xla::sharding_builder::AssignDevice.
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const string& device_name, int num_cores_per_replica,
|
const string& device_name, int num_cores_per_replica,
|
||||||
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
|
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt,
|
||||||
|
absl::optional<xla::OpMetadata> metadata = absl::nullopt);
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const Node& node, int num_cores_per_replica);
|
const Node& node, int num_cores_per_replica, bool add_metadata);
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const NodeDef& node_def, int num_cores_per_replica);
|
const NodeDef& node_def, int num_cores_per_replica, bool add_metadata);
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
||||||
const Edge& edge, int num_cores_per_replica);
|
const Edge& edge, int num_cores_per_replica, bool add_metadata);
|
||||||
|
|
||||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
||||||
|
|
||||||
// Get sharding inforamtion from node.
|
// Get sharding inforamtion from node.
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||||
const NodeDef& node_def);
|
const NodeDef& node_def, bool add_metadata);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -54,4 +56,86 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
|
|||||||
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ShardingWithMetadataTest
|
||||||
|
: public ::testing::TestWithParam<xla::OpSharding> {};
|
||||||
|
|
||||||
|
TEST_P(ShardingWithMetadataTest, GetShardingFromNode) {
|
||||||
|
NodeDef node_def;
|
||||||
|
{
|
||||||
|
node_def.set_op("_Arg");
|
||||||
|
node_def.set_name("arg");
|
||||||
|
AttrValue xla_sharding;
|
||||||
|
xla_sharding.set_s("");
|
||||||
|
AttrValue index;
|
||||||
|
index.set_i(0);
|
||||||
|
AttrValue type;
|
||||||
|
type.set_type(DataType::DT_FLOAT);
|
||||||
|
node_def.mutable_attr()->insert(
|
||||||
|
{{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto check_metadata = [](const xla::OpSharding& sharding) {
|
||||||
|
ASSERT_EQ(sharding.metadata_size(), 1);
|
||||||
|
const auto& metadata = sharding.metadata(0);
|
||||||
|
EXPECT_EQ(metadata.op_type(), "_Arg");
|
||||||
|
EXPECT_EQ(metadata.op_name(), "arg");
|
||||||
|
};
|
||||||
|
|
||||||
|
auto test_sharding_metadata =
|
||||||
|
[&check_metadata](
|
||||||
|
const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>&
|
||||||
|
fn) {
|
||||||
|
auto status_or_sharding = fn();
|
||||||
|
TF_ASSERT_OK(status_or_sharding.status());
|
||||||
|
ASSERT_TRUE(status_or_sharding.ValueOrDie().has_value());
|
||||||
|
auto& sharding = status_or_sharding.ValueOrDie();
|
||||||
|
ASSERT_TRUE(sharding.has_value());
|
||||||
|
if (sharding->type() == xla::OpSharding::TUPLE) {
|
||||||
|
EXPECT_TRUE(sharding->metadata().empty());
|
||||||
|
for (const auto& sharding_element : sharding->tuple_shardings()) {
|
||||||
|
check_metadata(sharding_element);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
check_metadata(sharding.value());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
test_sharding_metadata([&node_def]() {
|
||||||
|
return GetShardingFromNodeDef(node_def, /*add_metadata=*/true);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
test_sharding_metadata([&node_def]() {
|
||||||
|
return ParseShardingFromDevice(node_def, /*num_cores_per_replica=*/1,
|
||||||
|
/*add_metadata=*/true);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
Status status;
|
||||||
|
Node* node = graph.AddNode(node_def, &status);
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
|
||||||
|
test_sharding_metadata([node]() {
|
||||||
|
return ParseShardingFromDevice(*node, /*num_cores_per_replica=*/1,
|
||||||
|
/*add_metadata=*/true);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::OpSharding CreateTupleSharding() {
|
||||||
|
xla::OpSharding sharding;
|
||||||
|
sharding.set_type(xla::OpSharding::TUPLE);
|
||||||
|
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
|
||||||
|
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
|
||||||
|
return sharding;
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(GetShardingFromNode, ShardingWithMetadataTest,
|
||||||
|
::testing::Values(xla::sharding_builder::Replicate(),
|
||||||
|
CreateTupleSharding()));
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -506,7 +506,8 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
|
|||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseShardingFromDevice(
|
ParseShardingFromDevice(
|
||||||
*possible_match,
|
*possible_match,
|
||||||
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
|
/*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
|
||||||
|
/*add_metadata=*/false));
|
||||||
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
|
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
|
||||||
const int core_annotation = sharding.value().tile_assignment_devices(0);
|
const int core_annotation = sharding.value().tile_assignment_devices(0);
|
||||||
if (core == -1 || core > core_annotation) {
|
if (core == -1 || core > core_annotation) {
|
||||||
|
@ -242,7 +242,8 @@ TEST(SetNodeShardingFromNeighbors, Basic) {
|
|||||||
// Test where one input to c_node has a device.
|
// Test where one input to c_node has a device.
|
||||||
a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
|
a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
|
||||||
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
||||||
auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
|
auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/false);
|
||||||
TF_ASSERT_OK(parse_status.status());
|
TF_ASSERT_OK(parse_status.status());
|
||||||
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
||||||
EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
||||||
@ -250,14 +251,16 @@ TEST(SetNodeShardingFromNeighbors, Basic) {
|
|||||||
// Test where two inputs to c_node have a device.
|
// Test where two inputs to c_node have a device.
|
||||||
b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
|
b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
|
||||||
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
|
||||||
parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
|
parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/false);
|
||||||
TF_ASSERT_OK(parse_status.status());
|
TF_ASSERT_OK(parse_status.status());
|
||||||
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
||||||
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
||||||
|
|
||||||
// Test setting based on out edges.
|
// Test setting based on out edges.
|
||||||
TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
|
TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
|
||||||
parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
|
parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/false);
|
||||||
TF_ASSERT_OK(parse_status.status());
|
TF_ASSERT_OK(parse_status.status());
|
||||||
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
ASSERT_TRUE(parse_status.ValueOrDie().has_value());
|
||||||
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
|
||||||
|
@ -110,8 +110,9 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
|||||||
AttachLocationToMetadata(metadata, op_kernel, xla_context);
|
AttachLocationToMetadata(metadata, op_kernel, xla_context);
|
||||||
b->SetOpMetadata(metadata);
|
b->SetOpMetadata(metadata);
|
||||||
|
|
||||||
auto sharding_parse_result = ParseShardingFromDevice(
|
auto sharding_parse_result =
|
||||||
op_kernel->def(), std::numeric_limits<int>::max());
|
ParseShardingFromDevice(op_kernel->def(), std::numeric_limits<int>::max(),
|
||||||
|
/*add_metadata=*/false);
|
||||||
OP_REQUIRES_OK(context, sharding_parse_result.status());
|
OP_REQUIRES_OK(context, sharding_parse_result.status());
|
||||||
absl::optional<xla::OpSharding> op_sharding =
|
absl::optional<xla::OpSharding> op_sharding =
|
||||||
sharding_parse_result.ValueOrDie();
|
sharding_parse_result.ValueOrDie();
|
||||||
|
@ -93,7 +93,8 @@ ComputeArgAndRetvalShardings(const Graph& graph) {
|
|||||||
[](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> {
|
[](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto sharding,
|
auto sharding,
|
||||||
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
|
ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
|
||||||
|
/*add_metadata=*/false));
|
||||||
return sharding;
|
return sharding;
|
||||||
};
|
};
|
||||||
std::map<int, xla::OpSharding> arg_shardings;
|
std::map<int, xla::OpSharding> arg_shardings;
|
||||||
|
@ -1172,9 +1172,17 @@ bool PlaceOpsOnTPU(Node* node) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type(node.type_string());
|
||||||
|
metadata.set_op_name(node.name());
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
// Validate sharding configuration derived from XlaSharding attribute.
|
// Validate sharding configuration derived from XlaSharding attribute.
|
||||||
// Infer the core id from the OpSharding, if necessary.
|
// Infer the core id from the OpSharding, if necessary.
|
||||||
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
||||||
|
const Node& node,
|
||||||
const int num_cores_per_replica,
|
const int num_cores_per_replica,
|
||||||
int64* inferred_core_id,
|
int64* inferred_core_id,
|
||||||
absl::optional<xla::OpSharding>* result) {
|
absl::optional<xla::OpSharding>* result) {
|
||||||
@ -1203,7 +1211,9 @@ Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
|||||||
|
|
||||||
if (result_value_serialized != sharding_serialized) {
|
if (result_value_serialized != sharding_serialized) {
|
||||||
// We see different shardings, assign to core 0.
|
// We see different shardings, assign to core 0.
|
||||||
result->emplace(xla::sharding_builder::AssignDevice(0));
|
auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
|
||||||
|
*core_zero_sharding.add_metadata() = CreateOpMetadataFromNode(node);
|
||||||
|
result->emplace(core_zero_sharding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1232,7 +1242,8 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
|
|||||||
// If |node| has `device` attribute or is a XlaSharding op,
|
// If |node| has `device` attribute or is a XlaSharding op,
|
||||||
// return the parsed OpSharding.
|
// return the parsed OpSharding.
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
ParseShardingFromDevice(node, num_cores_per_replica));
|
ParseShardingFromDevice(node, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/true));
|
||||||
if (sharding.has_value()) return sharding;
|
if (sharding.has_value()) return sharding;
|
||||||
|
|
||||||
// XlaShardingOp may be followed by an identity or followed by identity
|
// XlaShardingOp may be followed by an identity or followed by identity
|
||||||
@ -1244,9 +1255,10 @@ ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
|
|||||||
potential_nodes_with_input_sharding) {
|
potential_nodes_with_input_sharding) {
|
||||||
if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
|
if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding_config,
|
TF_ASSIGN_OR_RETURN(
|
||||||
ParseShardingFromDevice(*maybe_node_with_sharding_info,
|
absl::optional<xla::OpSharding> sharding_config,
|
||||||
num_cores_per_replica));
|
ParseShardingFromDevice(*maybe_node_with_sharding_info,
|
||||||
|
num_cores_per_replica, /*add_metadata=*/true));
|
||||||
if (sharding_config.has_value()) return sharding_config;
|
if (sharding_config.has_value()) return sharding_config;
|
||||||
}
|
}
|
||||||
return sharding;
|
return sharding;
|
||||||
@ -1273,8 +1285,9 @@ Status ParseAndValidateShardingFromNeighbors(
|
|||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
|
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
|
||||||
if (sharding.has_value()) {
|
if (sharding.has_value()) {
|
||||||
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
|
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, neighbor_node,
|
||||||
*sharding, num_cores_per_replica, inferred_core_id, result));
|
num_cores_per_replica,
|
||||||
|
inferred_core_id, result));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1295,8 +1308,9 @@ Status ParseAndValidateShardingFromNeighbors(
|
|||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
|
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
|
||||||
if (sharding.has_value()) {
|
if (sharding.has_value()) {
|
||||||
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
|
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, *e->dst(),
|
||||||
*sharding, num_cores_per_replica, inferred_core_id, result));
|
num_cores_per_replica,
|
||||||
|
inferred_core_id, result));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1775,7 +1789,8 @@ static Status ValidateCoreNumbers(const Graph& graph,
|
|||||||
int num_cores_per_replica) {
|
int num_cores_per_replica) {
|
||||||
for (Node* n : graph.nodes()) {
|
for (Node* n : graph.nodes()) {
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
ParseShardingFromDevice(*n, num_cores_per_replica));
|
ParseShardingFromDevice(*n, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/true));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1930,8 +1945,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
Node* input_node;
|
Node* input_node;
|
||||||
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
|
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
|
||||||
if (input_node->type_string() == kTPUPartitionedInput) {
|
if (input_node->type_string() == kTPUPartitionedInput) {
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
|
TF_ASSIGN_OR_RETURN(
|
||||||
GetShardingFromNodeDef(input_node->def()));
|
absl::optional<xla::OpSharding> parsed_sharding,
|
||||||
|
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
|
||||||
if (!parsed_sharding.has_value())
|
if (!parsed_sharding.has_value())
|
||||||
return errors::InvalidArgument("Missing _XlaSharding attr from: ",
|
return errors::InvalidArgument("Missing _XlaSharding attr from: ",
|
||||||
input_node->DebugString());
|
input_node->DebugString());
|
||||||
@ -1946,8 +1962,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
Node* input_node;
|
Node* input_node;
|
||||||
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
|
TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
|
||||||
if (input_node->type_string() == kVarHandleOp) {
|
if (input_node->type_string() == kVarHandleOp) {
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
|
TF_ASSIGN_OR_RETURN(
|
||||||
GetShardingFromNodeDef(input_node->def()));
|
absl::optional<xla::OpSharding> parsed_sharding,
|
||||||
|
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
|
||||||
if (parsed_sharding.has_value()) {
|
if (parsed_sharding.has_value()) {
|
||||||
sharding = parsed_sharding;
|
sharding = parsed_sharding;
|
||||||
VLOG(1) << "Arg " << i << " parsed sharding information from "
|
VLOG(1) << "Arg " << i << " parsed sharding information from "
|
||||||
@ -1988,6 +2005,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
}
|
}
|
||||||
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
||||||
}
|
}
|
||||||
|
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
|
||||||
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
|
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
|
||||||
assigned_core = sharding->tile_assignment_devices(0);
|
assigned_core = sharding->tile_assignment_devices(0);
|
||||||
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
|
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
|
||||||
@ -2036,12 +2054,14 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseShardingFromEdgeSource(*edge, num_cores_per_replica));
|
ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
|
||||||
|
/*add_metadata=*/true));
|
||||||
|
|
||||||
if (partitioned_output_nodes.contains(i)) {
|
if (partitioned_output_nodes.contains(i)) {
|
||||||
Node* output_node = partitioned_output_nodes[i];
|
Node* output_node = partitioned_output_nodes[i];
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
|
TF_ASSIGN_OR_RETURN(
|
||||||
GetShardingFromNodeDef(output_node->def()));
|
absl::optional<xla::OpSharding> parsed_sharding,
|
||||||
|
GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
|
||||||
if (parsed_sharding.has_value()) {
|
if (parsed_sharding.has_value()) {
|
||||||
sharding = parsed_sharding;
|
sharding = parsed_sharding;
|
||||||
VLOG(1) << "Retval " << i << " parsed sharding information from "
|
VLOG(1) << "Retval " << i << " parsed sharding information from "
|
||||||
@ -2079,6 +2099,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
}
|
}
|
||||||
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
||||||
}
|
}
|
||||||
|
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
|
||||||
}
|
}
|
||||||
if (assigned_core.has_value()) {
|
if (assigned_core.has_value()) {
|
||||||
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
|
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user