Update DistributedTPURewritePass and sharding util to populate OpMetadata in OpSharding.
Arguments and results to a TPU computation and TF -> HLO lowerings from ops with _XlaSharding attributes will now have OpMetadata containing the source of the sharding attribute (TF op type and name). PiperOrigin-RevId: 351846058 Change-Id: I4cebdf6df466159717eb400dd89fac384bffdecf
This commit is contained in:
parent
8cdf1d3b4d
commit
58e1d18268
@ -26,6 +26,26 @@ const char kShardingAttribute[] = "_XlaSharding";
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
xla::OpMetadata CreateOpMetadata(const std::string& op_type,
|
||||||
|
const std::string& op_name) {
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type(op_type);
|
||||||
|
metadata.set_op_name(op_name);
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssignOpMetadataToSharding(xla::OpSharding& sharding,
|
||||||
|
const string& op_type, const string& op_name) {
|
||||||
|
auto metadata = CreateOpMetadata(op_type, op_name);
|
||||||
|
if (sharding.type() == xla::OpSharding::TUPLE) {
|
||||||
|
for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
|
||||||
|
*sharding_element.add_metadata() = metadata;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*sharding.add_metadata() = metadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid replicated core id: ", core,
|
"Invalid replicated core id: ", core,
|
||||||
@ -35,7 +55,8 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
|||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const string& device_name, int num_cores_per_replica,
|
const string& device_name, int num_cores_per_replica,
|
||||||
absl::optional<xla::OpSharding> explicit_sharding) {
|
absl::optional<xla::OpSharding> explicit_sharding,
|
||||||
|
absl::optional<xla::OpMetadata> metadata) {
|
||||||
if (device_name.empty()) {
|
if (device_name.empty()) {
|
||||||
return explicit_sharding;
|
return explicit_sharding;
|
||||||
}
|
}
|
||||||
@ -56,8 +77,11 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
|||||||
if (core < 0 || core >= num_cores_per_replica) {
|
if (core < 0 || core >= num_cores_per_replica) {
|
||||||
return CoreOutOfRangeError(core, num_cores_per_replica);
|
return CoreOutOfRangeError(core, num_cores_per_replica);
|
||||||
}
|
}
|
||||||
return absl::optional<xla::OpSharding>(
|
auto sharding = xla::sharding_builder::AssignDevice(core);
|
||||||
xla::sharding_builder::AssignDevice(core));
|
if (metadata.has_value()) {
|
||||||
|
*sharding.add_metadata() = metadata.value();
|
||||||
|
}
|
||||||
|
return absl::optional<xla::OpSharding>(sharding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +90,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
|||||||
const string& device_name = node_def.device();
|
const string& device_name = node_def.device();
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
GetShardingFromNodeDef(node_def));
|
GetShardingFromNodeDef(node_def));
|
||||||
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
|
return ParseShardingFromDevice(
|
||||||
|
device_name, num_cores_per_replica, sharding,
|
||||||
|
CreateOpMetadata(node_def.op(), node_def.name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
@ -77,7 +103,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
|||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
|
||||||
GetShardingFromNodeDef(node.def()));
|
GetShardingFromNodeDef(node.def()));
|
||||||
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
|
return ParseShardingFromDevice(
|
||||||
|
device_name, num_cores_per_replica, sharding,
|
||||||
|
CreateOpMetadata(node.type_string(), node.name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
||||||
@ -128,6 +156,7 @@ xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
|||||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||||
"xla::OpSharding proto.");
|
"xla::OpSharding proto.");
|
||||||
}
|
}
|
||||||
|
AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
|
||||||
return absl::optional<xla::OpSharding>(sharding);
|
return absl::optional<xla::OpSharding>(sharding);
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,7 +35,8 @@ namespace tensorflow {
|
|||||||
// - a sharding set as per xla::sharding_builder::AssignDevice.
|
// - a sharding set as per xla::sharding_builder::AssignDevice.
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const string& device_name, int num_cores_per_replica,
|
const string& device_name, int num_cores_per_replica,
|
||||||
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
|
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt,
|
||||||
|
absl::optional<xla::OpMetadata> metadata = absl::nullopt);
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||||
const Node& node, int num_cores_per_replica);
|
const Node& node, int num_cores_per_replica);
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -54,4 +56,83 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
|
|||||||
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ShardingWithMetadataTest
|
||||||
|
: public ::testing::TestWithParam<xla::OpSharding> {};
|
||||||
|
|
||||||
|
TEST_P(ShardingWithMetadataTest, GetShardingFromNode) {
|
||||||
|
NodeDef node_def;
|
||||||
|
{
|
||||||
|
node_def.set_op("_Arg");
|
||||||
|
node_def.set_name("arg");
|
||||||
|
AttrValue xla_sharding;
|
||||||
|
xla_sharding.set_s("");
|
||||||
|
AttrValue index;
|
||||||
|
index.set_i(0);
|
||||||
|
AttrValue type;
|
||||||
|
type.set_type(DataType::DT_FLOAT);
|
||||||
|
node_def.mutable_attr()->insert(
|
||||||
|
{{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto check_metadata = [](const xla::OpSharding& sharding) {
|
||||||
|
ASSERT_EQ(sharding.metadata_size(), 1);
|
||||||
|
const auto& metadata = sharding.metadata(0);
|
||||||
|
EXPECT_EQ(metadata.op_type(), "_Arg");
|
||||||
|
EXPECT_EQ(metadata.op_name(), "arg");
|
||||||
|
};
|
||||||
|
|
||||||
|
auto test_sharding_metadata =
|
||||||
|
[&check_metadata](
|
||||||
|
const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>&
|
||||||
|
fn) {
|
||||||
|
auto status_or_sharding = fn();
|
||||||
|
TF_ASSERT_OK(status_or_sharding.status());
|
||||||
|
ASSERT_TRUE(status_or_sharding.ValueOrDie().has_value());
|
||||||
|
auto& sharding = status_or_sharding.ValueOrDie();
|
||||||
|
ASSERT_TRUE(sharding.has_value());
|
||||||
|
if (sharding->type() == xla::OpSharding::TUPLE) {
|
||||||
|
EXPECT_TRUE(sharding->metadata().empty());
|
||||||
|
for (const auto& sharding_element : sharding->tuple_shardings()) {
|
||||||
|
check_metadata(sharding_element);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
check_metadata(sharding.value());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
test_sharding_metadata(
|
||||||
|
[&node_def]() { return GetShardingFromNodeDef(node_def); });
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
test_sharding_metadata([&node_def]() {
|
||||||
|
return ParseShardingFromDevice(node_def, /*num_cores_per_replica=*/1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
Status status;
|
||||||
|
Node* node = graph.AddNode(node_def, &status);
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
|
||||||
|
test_sharding_metadata([node]() {
|
||||||
|
return ParseShardingFromDevice(*node, /*num_cores_per_replica=*/1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::OpSharding CreateTupleSharding() {
|
||||||
|
xla::OpSharding sharding;
|
||||||
|
sharding.set_type(xla::OpSharding::TUPLE);
|
||||||
|
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
|
||||||
|
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
|
||||||
|
return sharding;
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(GetShardingFromNode, ShardingWithMetadataTest,
|
||||||
|
::testing::Values(xla::sharding_builder::Replicate(),
|
||||||
|
CreateTupleSharding()));
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1794,8 +1794,13 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
|
|||||||
ASSERT_TRUE(root_instruction_proto);
|
ASSERT_TRUE(root_instruction_proto);
|
||||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
|
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
|
||||||
{xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
{xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type("_Retval");
|
||||||
|
metadata.set_op_name("B");
|
||||||
|
xla::HloSharding sharding_with_metadata =
|
||||||
|
xla::HloSharding::Tile(tile_assignment, {metadata});
|
||||||
xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
|
xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
|
||||||
tuple_shape, std::vector<xla::HloSharding>{sharding});
|
tuple_shape, std::vector<xla::HloSharding>{sharding_with_metadata});
|
||||||
EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
|
EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
|
||||||
tuple_sharding.ToProto().SerializeAsString());
|
tuple_sharding.ToProto().SerializeAsString());
|
||||||
}
|
}
|
||||||
@ -1955,5 +1960,62 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
|
|||||||
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaCompilerTest, ShardingMetadata) {
|
||||||
|
// Builds a graph that returns its only argument.
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
|
||||||
|
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
|
||||||
|
auto add = ops::AddV2(scope.WithOpName("add"), arg0, arg1);
|
||||||
|
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Sets _XlaSharding attribute for the _Retval node.
|
||||||
|
auto node_name_index = graph->BuildNodeNameIndex();
|
||||||
|
auto add_node_it = node_name_index.find("add");
|
||||||
|
ASSERT_NE(add_node_it, node_name_index.end());
|
||||||
|
Node* add_node = add_node_it->second;
|
||||||
|
ASSERT_NE(add_node, nullptr);
|
||||||
|
|
||||||
|
xla::HloSharding sharding = xla::HloSharding::AssignDevice(0);
|
||||||
|
add_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString());
|
||||||
|
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({});
|
||||||
|
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({});
|
||||||
|
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler compiler(DefaultOptions());
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
|
||||||
|
std::move(graph), args, &result));
|
||||||
|
|
||||||
|
// Tests that we set sharding on the root TUPLE instruction.
|
||||||
|
const auto& hlo_module_proto = result.computation->proto();
|
||||||
|
ASSERT_EQ(hlo_module_proto.computations_size(), 1);
|
||||||
|
const auto& hlo_computation_proto = hlo_module_proto.computations(0);
|
||||||
|
absl::optional<xla::HloInstructionProto> add_instruction_proto;
|
||||||
|
for (const auto& inst : hlo_computation_proto.instructions()) {
|
||||||
|
if (inst.opcode() == "add") {
|
||||||
|
add_instruction_proto = inst;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(add_instruction_proto.has_value());
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type("AddV2");
|
||||||
|
metadata.set_op_name("add");
|
||||||
|
xla::HloSharding sharding_with_metadata =
|
||||||
|
xla::HloSharding::AssignDevice(0, {metadata});
|
||||||
|
EXPECT_EQ(add_instruction_proto->sharding().SerializeAsString(),
|
||||||
|
sharding_with_metadata.ToProto().SerializeAsString());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1172,9 +1172,17 @@ bool PlaceOpsOnTPU(Node* node) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
|
||||||
|
xla::OpMetadata metadata;
|
||||||
|
metadata.set_op_type(node.type_string());
|
||||||
|
metadata.set_op_name(node.name());
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
// Validate sharding configuration derived from XlaSharding attribute.
|
// Validate sharding configuration derived from XlaSharding attribute.
|
||||||
// Infer the core id from the OpSharding, if necessary.
|
// Infer the core id from the OpSharding, if necessary.
|
||||||
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
||||||
|
const Node& node,
|
||||||
const int num_cores_per_replica,
|
const int num_cores_per_replica,
|
||||||
int64* inferred_core_id,
|
int64* inferred_core_id,
|
||||||
absl::optional<xla::OpSharding>* result) {
|
absl::optional<xla::OpSharding>* result) {
|
||||||
@ -1203,7 +1211,9 @@ Status ParseAndValidateSharding(const xla::OpSharding& sharding,
|
|||||||
|
|
||||||
if (result_value_serialized != sharding_serialized) {
|
if (result_value_serialized != sharding_serialized) {
|
||||||
// We see different shardings, assign to core 0.
|
// We see different shardings, assign to core 0.
|
||||||
result->emplace(xla::sharding_builder::AssignDevice(0));
|
auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
|
||||||
|
*core_zero_sharding.add_metadata() = CreateOpMetadataFromNode(node);
|
||||||
|
result->emplace(core_zero_sharding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1273,8 +1283,9 @@ Status ParseAndValidateShardingFromNeighbors(
|
|||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
|
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
|
||||||
if (sharding.has_value()) {
|
if (sharding.has_value()) {
|
||||||
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
|
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, neighbor_node,
|
||||||
*sharding, num_cores_per_replica, inferred_core_id, result));
|
num_cores_per_replica,
|
||||||
|
inferred_core_id, result));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1295,8 +1306,9 @@ Status ParseAndValidateShardingFromNeighbors(
|
|||||||
absl::optional<xla::OpSharding> sharding,
|
absl::optional<xla::OpSharding> sharding,
|
||||||
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
|
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
|
||||||
if (sharding.has_value()) {
|
if (sharding.has_value()) {
|
||||||
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
|
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, *e->dst(),
|
||||||
*sharding, num_cores_per_replica, inferred_core_id, result));
|
num_cores_per_replica,
|
||||||
|
inferred_core_id, result));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1988,6 +2000,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
}
|
}
|
||||||
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
||||||
}
|
}
|
||||||
|
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
|
||||||
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
|
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
|
||||||
assigned_core = sharding->tile_assignment_devices(0);
|
assigned_core = sharding->tile_assignment_devices(0);
|
||||||
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
|
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
|
||||||
@ -2079,6 +2092,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
|||||||
}
|
}
|
||||||
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
|
||||||
}
|
}
|
||||||
|
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
|
||||||
}
|
}
|
||||||
if (assigned_core.has_value()) {
|
if (assigned_core.has_value()) {
|
||||||
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
|
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user