Adds ParseShardingFromEdgeSource().

Makes DistributedTPURewritePass::AssignArgsAndRetvalsToCores() support TUPLE sharding for return values.

PiperOrigin-RevId: 324697874
Change-Id: I3039da1731c9622ebeb0bf9c3b45185e220267af
This commit is contained in:
A. Unique TensorFlower 2020-08-03 15:36:31 -07:00 committed by TensorFlower Gardener
parent a0f7e214ae
commit fc9c057b1a
6 changed files with 44 additions and 10 deletions

View File

@ -80,6 +80,30 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
const Edge& edge, int num_cores_per_replica) {
if (edge.src() == nullptr) {
return tensorflow::errors::InvalidArgument(
"Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
}
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(*edge.src(), num_cores_per_replica));
if (sharding.has_value() &&
sharding.value().type() == xla::OpSharding::TUPLE) {
if (edge.src_output() < 0 ||
edge.src_output() >= sharding.value().tuple_shardings_size()) {
return tensorflow::errors::InvalidArgument(
"Tuple index out of bound: edge=", edge.DebugString(),
" sharding=", sharding->DebugString());
}
absl::optional<xla::OpSharding> subsharding =
sharding.value().tuple_shardings(edge.src_output());
return subsharding;
}
return sharding;
}
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
string device_name = src.assigned_device_name();
if (device_name.empty()) {

View File

@ -43,6 +43,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const NodeDef& node_def, int num_cores_per_replica);
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
const Edge& edge, int num_cores_per_replica);
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
// Get sharding inforamtion from node.

View File

@ -1813,7 +1813,8 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
sharding->type() != xla::OpSharding::OTHER) {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding: ", sharding->DebugString());
"Unsupported argument sharding (for arg ", n->DebugString(),
"): ", sharding->DebugString());
}
if (assigned_core.has_value()) {
args_device_selector.ReportDeviceAssigned(*assigned_core, i);
@ -1855,7 +1856,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(*edge->src(), num_cores_per_replica));
ParseShardingFromEdgeSource(*edge, num_cores_per_replica));
if (partitioned_output_nodes.contains(i)) {
Node* output_node = partitioned_output_nodes[i];
@ -1883,7 +1884,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
} else if (sharding.value().type() != xla::OpSharding::REPLICATED &&
sharding.value().type() != xla::OpSharding::OTHER) {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding: ", sharding->DebugString());
"Unsupported argument sharding for retval ",
retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
sharding->DebugString());
}
} else {
if (use_spmd) {
@ -2472,7 +2475,8 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
Status DistributedTPURewritePass::BuildExecuteNodes(
const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
const Node& replicate_node, const DataTypeVector& arg_types,
const Node& replicate_node, const std::vector<std::string>& arg_names,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,
const std::vector<xla::OpSharding>& arg_shardings,
@ -2595,7 +2599,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
}
} else {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding: ", sharding.DebugString());
"Unsupported argument sharding for arg=", arg_names[i],
" shape=", arg_shapes[i].shape.DebugString(), ": ",
sharding.DebugString());
}
}
std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
@ -3922,8 +3928,8 @@ Status DistributedTPURewritePass::FingerprintFunctionLibrary(
std::vector<VariableWrite> variable_writes;
TF_RETURN_IF_ERROR(BuildExecuteNodes(
params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types,
arg_shapes, retval_types, arg_sharding, retval_sharding,
params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
tf_device_assignment, compile_node, variable_reads,
control_after_compilation, control_after, &variable_writes, graph));
bool contains_resource_write_op =

View File

@ -413,9 +413,10 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
// * `num_cores_per_replica` is the number of cores which are dedicated to
// each replica.
// * `replicate_node` is the original TPUReplicate node.
// * `arg_types` are the types of the arguments to the computation function
// * `arg_names` are the names of the arguments to the computation function
// passed as argument to TPUReplicate, including per-replica,
// broadcast, and variable arguments.
// * `arg_types` are the corresponding types of the arguments.
// * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
// applicable).
// * `arg_shardings` and `retval_shardings` are mappings from
@ -431,6 +432,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
static Status BuildExecuteNodes(
const ParameterInfo& params_info, int num_tasks,
int num_cores_per_replica, const Node& replicate_node,
const std::vector<std::string>& arg_names,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,

View File

@ -71,7 +71,6 @@ cc_library(
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -117,7 +117,7 @@ Status SetPerCoreArgShapes(
} else {
TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
<< "Unsupported argument sharding: "
<< proto_arg.sharding().DebugString();
<< " proto_arg=" << proto_arg.DebugString();
for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
(*arg_core_mapping)[arg_index].indices.push_back(
(*per_core_arg_shapes)[core].size());