Use StatefulPartitionedCall instead of PartitionedCall in auto-clustering

Resolves https://github.com/tensorflow/tensorflow/issues/46419 -- we're not sure
if this is addressing the root cause of the issue, but this does fix *a* bug.

PiperOrigin-RevId: 357772529
Change-Id: I6f534b8505c92ef2a8b5e7903928653fff0450c6
This commit is contained in:
Sanjoy Das 2021-02-16 11:56:20 -08:00 committed by TensorFlower Gardener
parent 4feda2aa5e
commit 4853e35ae9
2 changed files with 18 additions and 13 deletions

View File

@ -309,9 +309,13 @@ xla::StatusOr<Node*> ReplaceFunctionCallWithPartitionedCall(
}
}
ops::PartitionedCall call(
root.WithOpName("partitioned_call"), args, n->output_types(), func,
ops::PartitionedCall::Attrs{}.ConfigProto(config_string));
// In theory we can use PartitionedCall if the XLA cluster does not have any
// stateful operations. However, for now we choose to be conservative since
// we don't have any evidence that choosing a stateless partitioned call helps
// for performance.
ops::StatefulPartitionedCall call(
root.WithOpName("stateful_partitioned_call"), args, n->output_types(),
func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string));
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {

View File

@ -194,7 +194,7 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
auto xla_run =
NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key)));
auto tf_call =
NodeWith(Op("PartitionedCall"),
NodeWith(Op("StatefulPartitionedCall"),
CtrlDeps(NodeWith(Op("Identity"),
Inputs(Out(0, predicated_compilation_key)))));
auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
@ -252,9 +252,10 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
Node* sink_node = graph->sink_node();
EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
NodeWith(Op("PartitionedCall")),
NodeWith(Op("NoOp")))));
EXPECT_THAT(sink_node,
NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
NodeWith(Op("StatefulPartitionedCall")),
NodeWith(Op("NoOp")))));
}
#ifdef GOOGLE_CUDA
@ -298,15 +299,15 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
Node* partitioned_call_op = nullptr;
Node* stateful_partitioned_call_op = nullptr;
for (Node* n : graph->op_nodes()) {
if (n->type_string() == "PartitionedCall") {
ASSERT_EQ(partitioned_call_op, nullptr);
partitioned_call_op = n;
if (n->type_string() == "StatefulPartitionedCall") {
ASSERT_EQ(stateful_partitioned_call_op, nullptr);
stateful_partitioned_call_op = n;
}
}
ASSERT_NE(partitioned_call_op, nullptr);
ASSERT_NE(stateful_partitioned_call_op, nullptr);
auto xla_compile = NodeWith(Op("_XlaCompile"));
auto switch_on_compilation_pred =
NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
@ -315,7 +316,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
// Check that we pipe int32 inputs through an IdentityN to avoid extra D2H
// copies.
EXPECT_THAT(
partitioned_call_op,
stateful_partitioned_call_op,
NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep))))));
}
#endif