Update DistributedTPURewritePass and sharding util to populate OpMetadata in OpSharding.

Arguments and results to a TPU computation and TF -> HLO lowerings from ops with _XlaSharding attributes will now have OpMetadata containing the source of the sharding attribute (TF op type and name).

PiperOrigin-RevId: 351870054
Change-Id: I4e985e5b9d8851bcb187031d007e7ee39094184c
This commit is contained in:
A. Unique TensorFlower 2021-01-14 13:51:24 -08:00 committed by TensorFlower Gardener
parent 8b34c68f98
commit 2f171252d8
5 changed files with 12 additions and 199 deletions

View File

@ -26,26 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding";
} // namespace
namespace {
xla::OpMetadata CreateOpMetadata(const std::string& op_type,
const std::string& op_name) {
xla::OpMetadata metadata;
metadata.set_op_type(op_type);
metadata.set_op_name(op_name);
return metadata;
}
void AssignOpMetadataToSharding(xla::OpSharding& sharding,
const string& op_type, const string& op_name) {
auto metadata = CreateOpMetadata(op_type, op_name);
if (sharding.type() == xla::OpSharding::TUPLE) {
for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
*sharding_element.add_metadata() = metadata;
}
} else {
*sharding.add_metadata() = metadata;
}
}
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
@ -55,8 +35,7 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
absl::optional<xla::OpSharding> explicit_sharding,
absl::optional<xla::OpMetadata> metadata) {
absl::optional<xla::OpSharding> explicit_sharding) {
if (device_name.empty()) {
return explicit_sharding;
}
@ -77,11 +56,8 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
if (core < 0 || core >= num_cores_per_replica) {
return CoreOutOfRangeError(core, num_cores_per_replica);
}
auto sharding = xla::sharding_builder::AssignDevice(core);
if (metadata.has_value()) {
*sharding.add_metadata() = metadata.value();
}
return absl::optional<xla::OpSharding>(sharding);
return absl::optional<xla::OpSharding>(
xla::sharding_builder::AssignDevice(core));
}
}
@ -90,9 +66,7 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name = node_def.device();
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node_def));
return ParseShardingFromDevice(
device_name, num_cores_per_replica, sharding,
CreateOpMetadata(node_def.op(), node_def.name()));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
@ -103,9 +77,7 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
}
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node.def()));
return ParseShardingFromDevice(
device_name, num_cores_per_replica, sharding,
CreateOpMetadata(node.type_string(), node.name()));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
@ -156,7 +128,6 @@ xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
return absl::optional<xla::OpSharding>(sharding);
}
} // namespace tensorflow

View File

@ -35,8 +35,7 @@ namespace tensorflow {
// - a sharding set as per xla::sharding_builder::AssignDevice.
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt,
absl::optional<xla::OpMetadata> metadata = absl::nullopt);
absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const Node& node, int num_cores_per_replica);

View File

@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include <functional>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -56,83 +54,4 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie()));
}
class ShardingWithMetadataTest
: public ::testing::TestWithParam<xla::OpSharding> {};
TEST_P(ShardingWithMetadataTest, GetShardingFromNode) {
NodeDef node_def;
{
node_def.set_op("_Arg");
node_def.set_name("arg");
AttrValue xla_sharding;
xla_sharding.set_s("");
AttrValue index;
index.set_i(0);
AttrValue type;
type.set_type(DataType::DT_FLOAT);
node_def.mutable_attr()->insert(
{{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}});
}
auto check_metadata = [](const xla::OpSharding& sharding) {
ASSERT_EQ(sharding.metadata_size(), 1);
const auto& metadata = sharding.metadata(0);
EXPECT_EQ(metadata.op_type(), "_Arg");
EXPECT_EQ(metadata.op_name(), "arg");
};
auto test_sharding_metadata =
[&check_metadata](
const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>&
fn) {
auto status_or_sharding = fn();
TF_ASSERT_OK(status_or_sharding.status());
ASSERT_TRUE(status_or_sharding.ValueOrDie().has_value());
auto& sharding = status_or_sharding.ValueOrDie();
ASSERT_TRUE(sharding.has_value());
if (sharding->type() == xla::OpSharding::TUPLE) {
EXPECT_TRUE(sharding->metadata().empty());
for (const auto& sharding_element : sharding->tuple_shardings()) {
check_metadata(sharding_element);
}
} else {
check_metadata(sharding.value());
}
};
{
test_sharding_metadata(
[&node_def]() { return GetShardingFromNodeDef(node_def); });
}
{
test_sharding_metadata([&node_def]() {
return ParseShardingFromDevice(node_def, /*num_cores_per_replica=*/1);
});
}
{
Graph graph(OpRegistry::Global());
Status status;
Node* node = graph.AddNode(node_def, &status);
TF_ASSERT_OK(status);
test_sharding_metadata([node]() {
return ParseShardingFromDevice(*node, /*num_cores_per_replica=*/1);
});
}
}
xla::OpSharding CreateTupleSharding() {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED);
return sharding;
}
INSTANTIATE_TEST_SUITE_P(GetShardingFromNode, ShardingWithMetadataTest,
::testing::Values(xla::sharding_builder::Replicate(),
CreateTupleSharding()));
} // namespace tensorflow

View File

@ -1794,13 +1794,8 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
ASSERT_TRUE(root_instruction_proto);
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2})});
xla::OpMetadata metadata;
metadata.set_op_type("_Retval");
metadata.set_op_name("B");
xla::HloSharding sharding_with_metadata =
xla::HloSharding::Tile(tile_assignment, {metadata});
xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
tuple_shape, std::vector<xla::HloSharding>{sharding_with_metadata});
tuple_shape, std::vector<xla::HloSharding>{sharding});
EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
tuple_sharding.ToProto().SerializeAsString());
}
@ -1960,62 +1955,5 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
}
TEST_F(XlaCompilerTest, ShardingMetadata) {
// Builds a graph that returns its only argument.
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
auto add = ops::AddV2(scope.WithOpName("add"), arg0, arg1);
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Sets _XlaSharding attribute for the _Retval node.
auto node_name_index = graph->BuildNodeNameIndex();
auto add_node_it = node_name_index.find("add");
ASSERT_NE(add_node_it, node_name_index.end());
Node* add_node = add_node_it->second;
ASSERT_NE(add_node, nullptr);
xla::HloSharding sharding = xla::HloSharding::AssignDevice(0);
add_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
std::move(graph), args, &result));
// Tests that we set sharding on the root TUPLE instruction.
const auto& hlo_module_proto = result.computation->proto();
ASSERT_EQ(hlo_module_proto.computations_size(), 1);
const auto& hlo_computation_proto = hlo_module_proto.computations(0);
absl::optional<xla::HloInstructionProto> add_instruction_proto;
for (const auto& inst : hlo_computation_proto.instructions()) {
if (inst.opcode() == "add") {
add_instruction_proto = inst;
break;
}
}
ASSERT_TRUE(add_instruction_proto.has_value());
xla::OpMetadata metadata;
metadata.set_op_type("AddV2");
metadata.set_op_name("add");
xla::HloSharding sharding_with_metadata =
xla::HloSharding::AssignDevice(0, {metadata});
EXPECT_EQ(add_instruction_proto->sharding().SerializeAsString(),
sharding_with_metadata.ToProto().SerializeAsString());
}
} // namespace
} // namespace tensorflow

View File

@ -1172,17 +1172,9 @@ bool PlaceOpsOnTPU(Node* node) {
return true;
}
xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
xla::OpMetadata metadata;
metadata.set_op_type(node.type_string());
metadata.set_op_name(node.name());
return metadata;
}
// Validate sharding configuration derived from XlaSharding attribute.
// Infer the core id from the OpSharding, if necessary.
Status ParseAndValidateSharding(const xla::OpSharding& sharding,
const Node& node,
const int num_cores_per_replica,
int64* inferred_core_id,
absl::optional<xla::OpSharding>* result) {
@ -1211,9 +1203,7 @@ Status ParseAndValidateSharding(const xla::OpSharding& sharding,
if (result_value_serialized != sharding_serialized) {
// We see different shardings, assign to core 0.
auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
*core_zero_sharding.add_metadata() = CreateOpMetadataFromNode(node);
result->emplace(core_zero_sharding);
result->emplace(xla::sharding_builder::AssignDevice(0));
}
}
}
@ -1283,9 +1273,8 @@ Status ParseAndValidateShardingFromNeighbors(
absl::optional<xla::OpSharding> sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
if (sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, neighbor_node,
num_cores_per_replica,
inferred_core_id, result));
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
*sharding, num_cores_per_replica, inferred_core_id, result));
return Status::OK();
}
@ -1306,9 +1295,8 @@ Status ParseAndValidateShardingFromNeighbors(
absl::optional<xla::OpSharding> sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
if (sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*sharding, *e->dst(),
num_cores_per_replica,
inferred_core_id, result));
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
*sharding, num_cores_per_replica, inferred_core_id, result));
return Status::OK();
}
}
@ -2000,7 +1988,6 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
}
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
}
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
} else if (sharding->type() == xla::OpSharding::MAXIMAL) {
assigned_core = sharding->tile_assignment_devices(0);
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
@ -2092,7 +2079,6 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
}
sharding = xla::sharding_builder::AssignDevice(*assigned_core);
}
*sharding->add_metadata() = CreateOpMetadataFromNode(*replicate_node);
}
if (assigned_core.has_value()) {
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));