[MLIR:TF/XLA] Use TF op registry's statefulness flag in side-effect analysis

for ops that are not yet defined in ODS.

PiperOrigin-RevId: 286109214
Change-Id: I198c0bf22c8bd1792d90269c917b6e58c074464d
This commit is contained in:
Yuanzhong Xu 2019-12-17 19:22:35 -08:00 committed by TensorFlower Gardener
parent e1ae3ca451
commit 03341c4342
2 changed files with 59 additions and 1 deletions

View File

@ -310,7 +310,25 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
return while_op.is_stateless();
}
return false;
// Try to get the statefulness flag from the registry.
//
// TODO(yuanzx): Remove this after all ops are defined in the dialect.
if (op->getName().getDialect() !=
TF::TensorFlowDialect::getDialectNamespace()) {
return false;
}
StringRef op_name = op->getName().getStringRef();
// Drop the `tf.` prefix to query TF registry.
auto node_name =
op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1);
const tensorflow::OpRegistrationData* op_reg_data;
if (!tensorflow::OpRegistry::Global()
->LookUp(node_name.data(), &op_reg_data)
.ok()) {
return false;
}
return !op_reg_data->op_def.is_stateful();
}
} // namespace

View File

@ -737,3 +737,43 @@ func @while_cond(
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {5}}}
}
// -----
// Tests that the pass tracks control dependencies based on TF op registry
// statefulness flag, for ops not yet defined in ODS.
// CHECK-LABEL: func @tf_registry_ops
func @tf_registry_ops(
// expected-remark@above {{ID: 8}}
%arg0: tensor<!tf.string>, %arg1: tensor<!tf.string>) {
tf_executor.graph {
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Successors: {7}}}
%island = tf_executor.island {
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Successors: {5}}}
"tf.PrintV2"(%arg0) { output_stream = "stderr", end = "\n" }
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {2}}}
: (tensor<!tf.string>) -> ()
%merge_summary = "tf.MergeSummary"(%arg0, %arg1) { N = 2 }
// expected-remark@above {{ID: 1}}
: (tensor<!tf.string>, tensor<!tf.string>) -> (tensor<!tf.string>)
"tf.PrintV2"(%merge_summary) { output_stream = "stderr", end = "\n" }
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {3}}}
: (tensor<!tf.string>) -> ()
tf_executor.yield
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Predecessors: {2}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
}
return
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {6}}}
}