Adds ParseShardingFromEdgeSource().
Makes DistributedTPURewritePass::AssignArgsAndRetvalsToCores() support TUPLE sharding for return values. PiperOrigin-RevId: 324697874 Change-Id: I3039da1731c9622ebeb0bf9c3b45185e220267af
This commit is contained in:
parent
a0f7e214ae
commit
fc9c057b1a
tensorflow
compiler/tf2xla
core/tpu
@ -80,6 +80,30 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
||||
const Edge& edge, int num_cores_per_replica) {
|
||||
if (edge.src() == nullptr) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::optional<xla::OpSharding> sharding,
|
||||
ParseShardingFromDevice(*edge.src(), num_cores_per_replica));
|
||||
if (sharding.has_value() &&
|
||||
sharding.value().type() == xla::OpSharding::TUPLE) {
|
||||
if (edge.src_output() < 0 ||
|
||||
edge.src_output() >= sharding.value().tuple_shardings_size()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Tuple index out of bound: edge=", edge.DebugString(),
|
||||
" sharding=", sharding->DebugString());
|
||||
}
|
||||
absl::optional<xla::OpSharding> subsharding =
|
||||
sharding.value().tuple_shardings(edge.src_output());
|
||||
return subsharding;
|
||||
}
|
||||
return sharding;
|
||||
}
|
||||
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
|
||||
string device_name = src.assigned_device_name();
|
||||
if (device_name.empty()) {
|
||||
|
@ -43,6 +43,9 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
||||
const NodeDef& node_def, int num_cores_per_replica);
|
||||
|
||||
xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
|
||||
const Edge& edge, int num_cores_per_replica);
|
||||
|
||||
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
|
||||
|
||||
// Get sharding inforamtion from node.
|
||||
|
@ -1813,7 +1813,8 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
||||
} else if (sharding->type() != xla::OpSharding::REPLICATED &&
|
||||
sharding->type() != xla::OpSharding::OTHER) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported argument sharding: ", sharding->DebugString());
|
||||
"Unsupported argument sharding (for arg ", n->DebugString(),
|
||||
"): ", sharding->DebugString());
|
||||
}
|
||||
if (assigned_core.has_value()) {
|
||||
args_device_selector.ReportDeviceAssigned(*assigned_core, i);
|
||||
@ -1855,7 +1856,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::optional<xla::OpSharding> sharding,
|
||||
ParseShardingFromDevice(*edge->src(), num_cores_per_replica));
|
||||
ParseShardingFromEdgeSource(*edge, num_cores_per_replica));
|
||||
|
||||
if (partitioned_output_nodes.contains(i)) {
|
||||
Node* output_node = partitioned_output_nodes[i];
|
||||
@ -1883,7 +1884,9 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
|
||||
} else if (sharding.value().type() != xla::OpSharding::REPLICATED &&
|
||||
sharding.value().type() != xla::OpSharding::OTHER) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported argument sharding: ", sharding->DebugString());
|
||||
"Unsupported argument sharding for retval ",
|
||||
retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
|
||||
sharding->DebugString());
|
||||
}
|
||||
} else {
|
||||
if (use_spmd) {
|
||||
@ -2472,7 +2475,8 @@ xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
|
||||
|
||||
Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
|
||||
const Node& replicate_node, const DataTypeVector& arg_types,
|
||||
const Node& replicate_node, const std::vector<std::string>& arg_names,
|
||||
const DataTypeVector& arg_types,
|
||||
const std::vector<InferredShape>& arg_shapes,
|
||||
const DataTypeVector& retval_types,
|
||||
const std::vector<xla::OpSharding>& arg_shardings,
|
||||
@ -2595,7 +2599,9 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
}
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported argument sharding: ", sharding.DebugString());
|
||||
"Unsupported argument sharding for arg=", arg_names[i],
|
||||
" shape=", arg_shapes[i].shape.DebugString(), ": ",
|
||||
sharding.DebugString());
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
|
||||
@ -3922,8 +3928,8 @@ Status DistributedTPURewritePass::FingerprintFunctionLibrary(
|
||||
|
||||
std::vector<VariableWrite> variable_writes;
|
||||
TF_RETURN_IF_ERROR(BuildExecuteNodes(
|
||||
params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types,
|
||||
arg_shapes, retval_types, arg_sharding, retval_sharding,
|
||||
params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
|
||||
arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
|
||||
tf_device_assignment, compile_node, variable_reads,
|
||||
control_after_compilation, control_after, &variable_writes, graph));
|
||||
bool contains_resource_write_op =
|
||||
|
@ -413,9 +413,10 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
|
||||
// * `num_cores_per_replica` is the number of cores which are dedicated to
|
||||
// each replica.
|
||||
// * `replicate_node` is the original TPUReplicate node.
|
||||
// * `arg_types` are the types of the arguments to the computation function
|
||||
// * `arg_names` are the names of the arguments to the computation function
|
||||
// passed as argument to TPUReplicate, including per-replica,
|
||||
// broadcast, and variable arguments.
|
||||
// * `arg_types` are the corresponding types of the arguments.
|
||||
// * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
|
||||
// applicable).
|
||||
// * `arg_shardings` and `retval_shardings` are mappings from
|
||||
@ -431,6 +432,7 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
|
||||
static Status BuildExecuteNodes(
|
||||
const ParameterInfo& params_info, int num_tasks,
|
||||
int num_cores_per_replica, const Node& replicate_node,
|
||||
const std::vector<std::string>& arg_names,
|
||||
const DataTypeVector& arg_types,
|
||||
const std::vector<InferredShape>& arg_shapes,
|
||||
const DataTypeVector& retval_types,
|
||||
|
@ -71,7 +71,6 @@ cc_library(
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -117,7 +117,7 @@ Status SetPerCoreArgShapes(
|
||||
} else {
|
||||
TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
|
||||
<< "Unsupported argument sharding: "
|
||||
<< proto_arg.sharding().DebugString();
|
||||
<< " proto_arg=" << proto_arg.DebugString();
|
||||
for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
|
||||
(*arg_core_mapping)[arg_index].indices.push_back(
|
||||
(*per_core_arg_shapes)[core].size());
|
||||
|
Loading…
Reference in New Issue
Block a user