From 1519ef5c6a92b0c397b3c95e3646f1d8e0b6a678 Mon Sep 17 00:00:00 2001 From: Youlong Cheng Date: Tue, 18 Feb 2020 15:34:25 -0800 Subject: [PATCH] Refactor xla_sharding to be more useful. PiperOrigin-RevId: 295838039 Change-Id: Ia138c41a9e2739379ecf3e2222686a195b0fe56d --- tensorflow/compiler/tf2xla/sharding_util.cc | 31 ++++++++++----------- tensorflow/compiler/tf2xla/sharding_util.h | 4 +++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 4d5bf0835e1..366e8d49228 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding"; } // namespace namespace { -xla::StatusOr> GetShardingFromNodeDef( - const NodeDef& node_def) { - if (!HasNodeAttr(node_def, kShardingAttribute)) { - return absl::optional(); - } - string value; - xla::OpSharding sharding; - TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); - if (!sharding.ParseFromString(value)) { - return xla::InvalidArgument( - "Experimental _XlaSharding attribute was not a valid encoded " - "xla::OpSharding proto."); - } - return absl::optional(sharding); -} - Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, @@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { } } +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return absl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return absl::optional(sharding); +} } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index ab67d4f1542..196434826f9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -45,6 +45,10 @@ xla::StatusOr> ParseShardingFromDevice( void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); +// Get sharding inforamtion from node. +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_