Use StatefulPartitionedCall instead of PartitionedCall in auto-clustering
Resolves https://github.com/tensorflow/tensorflow/issues/46419 -- we're not sure if this is addressing the root cause of the issue, but this does fix *a* bug. PiperOrigin-RevId: 357772529 Change-Id: I6f534b8505c92ef2a8b5e7903928653fff0450c6
This commit is contained in:
parent
4feda2aa5e
commit
4853e35ae9
@ -309,9 +309,13 @@ xla::StatusOr<Node*> ReplaceFunctionCallWithPartitionedCall(
|
||||
}
|
||||
}
|
||||
|
||||
ops::PartitionedCall call(
|
||||
root.WithOpName("partitioned_call"), args, n->output_types(), func,
|
||||
ops::PartitionedCall::Attrs{}.ConfigProto(config_string));
|
||||
// In theory we can use PartitionedCall if the XLA cluster does not have any
|
||||
// stateful operations. However, for now we choose to be conservative since
|
||||
// we don't have any evidence that choosing a stateless partitioned call helps
|
||||
// for performance.
|
||||
ops::StatefulPartitionedCall call(
|
||||
root.WithOpName("stateful_partitioned_call"), args, n->output_types(),
|
||||
func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string));
|
||||
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
|
@ -194,7 +194,7 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
|
||||
auto xla_run =
|
||||
NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key)));
|
||||
auto tf_call =
|
||||
NodeWith(Op("PartitionedCall"),
|
||||
NodeWith(Op("StatefulPartitionedCall"),
|
||||
CtrlDeps(NodeWith(Op("Identity"),
|
||||
Inputs(Out(0, predicated_compilation_key)))));
|
||||
auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
|
||||
@ -252,9 +252,10 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
|
||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||
|
||||
Node* sink_node = graph->sink_node();
|
||||
EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
|
||||
NodeWith(Op("PartitionedCall")),
|
||||
NodeWith(Op("NoOp")))));
|
||||
EXPECT_THAT(sink_node,
|
||||
NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
|
||||
NodeWith(Op("StatefulPartitionedCall")),
|
||||
NodeWith(Op("NoOp")))));
|
||||
}
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
@ -298,15 +299,15 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||
|
||||
Node* partitioned_call_op = nullptr;
|
||||
Node* stateful_partitioned_call_op = nullptr;
|
||||
for (Node* n : graph->op_nodes()) {
|
||||
if (n->type_string() == "PartitionedCall") {
|
||||
ASSERT_EQ(partitioned_call_op, nullptr);
|
||||
partitioned_call_op = n;
|
||||
if (n->type_string() == "StatefulPartitionedCall") {
|
||||
ASSERT_EQ(stateful_partitioned_call_op, nullptr);
|
||||
stateful_partitioned_call_op = n;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_NE(partitioned_call_op, nullptr);
|
||||
ASSERT_NE(stateful_partitioned_call_op, nullptr);
|
||||
auto xla_compile = NodeWith(Op("_XlaCompile"));
|
||||
auto switch_on_compilation_pred =
|
||||
NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
|
||||
@ -315,7 +316,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
|
||||
// Check that we pipe int32 inputs through an IdentityN to avoid extra D2H
|
||||
// copies.
|
||||
EXPECT_THAT(
|
||||
partitioned_call_op,
|
||||
stateful_partitioned_call_op,
|
||||
NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep))))));
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user