diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 36a2560b7c8..47d070c1572 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -310,7 +310,25 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { return while_op.is_stateless(); } - return false; + + // Try to get the statefulness flag from the registry. + // + // TODO(yuanzx): Remove this after all ops are defined in the dialect. + if (op->getName().getDialect() != + TF::TensorFlowDialect::getDialectNamespace()) { + return false; + } + StringRef op_name = op->getName().getStringRef(); + // Drop the `tf.` prefix to query TF registry. + auto node_name = + op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1); + const tensorflow::OpRegistrationData* op_reg_data; + if (!tensorflow::OpRegistry::Global() + ->LookUp(node_name.data(), &op_reg_data) + .ok()) { + return false; + } + return !op_reg_data->op_def.is_stateful(); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index 9b17956f399..5ff3212db65 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -737,3 +737,43 @@ func @while_cond( // expected-remark@above {{ID: 6}} // expected-remark@above {{Predecessors: {5}}} } + +// ----- + +// Tests that the pass tracks control dependencies based on TF op registry +// statefulness flag, for ops not yet defined in ODS. + +// CHECK-LABEL: func @tf_registry_ops +func @tf_registry_ops( + // expected-remark@above {{ID: 8}} + %arg0: tensor, %arg1: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Successors: {7}}} + %island = tf_executor.island { + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {5}}} + "tf.PrintV2"(%arg0) { output_stream = "stderr", end = "\n" } + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2}}} + : (tensor) -> () + %merge_summary = "tf.MergeSummary"(%arg0, %arg1) { N = 2 } + // expected-remark@above {{ID: 1}} + : (tensor, tensor) -> (tensor) + "tf.PrintV2"(%merge_summary) { output_stream = "stderr", end = "\n" } + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {3}}} + : (tensor) -> () + tf_executor.yield + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {2}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {4}}} + } + return + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} +}