Refactor xla_sharding to be more useful.
PiperOrigin-RevId: 295838039 Change-Id: Ia138c41a9e2739379ecf3e2222686a195b0fe56d
This commit is contained in:
parent
9189ce99fc
commit
1519ef5c6a
@ -26,22 +26,6 @@ const char kShardingAttribute[] = "_XlaSharding";
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def) {
|
||||
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
||||
return absl::optional<xla::OpSharding>();
|
||||
}
|
||||
string value;
|
||||
xla::OpSharding sharding;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
|
||||
if (!sharding.ParseFromString(value)) {
|
||||
return xla::InvalidArgument(
|
||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||
"xla::OpSharding proto.");
|
||||
}
|
||||
return absl::optional<xla::OpSharding>(sharding);
|
||||
}
|
||||
|
||||
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid replicated core id: ", core,
|
||||
@ -107,4 +91,19 @@ void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def) {
|
||||
if (!HasNodeAttr(node_def, kShardingAttribute)) {
|
||||
return absl::optional<xla::OpSharding>();
|
||||
}
|
||||
string value;
|
||||
xla::OpSharding sharding;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
|
||||
if (!sharding.ParseFromString(value)) {
|
||||
return xla::InvalidArgument(
|
||||
"Experimental _XlaSharding attribute was not a valid encoded "
|
||||
"xla::OpSharding proto.");
|
||||
}
|
||||
return absl::optional<xla::OpSharding>(sharding);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -45,6 +45,10 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
||||
|
||||
// Get sharding inforamtion from node.
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
||||
const NodeDef& node_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_
|
||||
|
Loading…
Reference in New Issue
Block a user