Refactor xla_sharding to be more useful.

PiperOrigin-RevId: 295838039
Change-Id: Ia138c41a9e2739379ecf3e2222686a195b0fe56d
This commit is contained in:
Youlong Cheng 2020-02-18 15:34:25 -08:00 committed by TensorFlower Gardener
parent 9189ce99fc
commit 1519ef5c6a
2 changed files with 19 additions and 16 deletions

View File

@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding";
} // namespace
namespace {
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return absl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
if (!sharding.ParseFromString(value)) {
return xla::InvalidArgument(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
return absl::optional<xla::OpSharding>(sharding);
}
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
}
}
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return absl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
if (!sharding.ParseFromString(value)) {
return xla::InvalidArgument(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
return absl::optional<xla::OpSharding>(sharding);
}
} // namespace tensorflow

View File

@ -45,6 +45,10 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
// Get sharding inforamtion from node.
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
const NodeDef& node_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_