[MLIR:TF/XLA] Use TF op registry's statefulness flag in side-effect analysis
for ops that are not yet defined in ODS. PiperOrigin-RevId: 286109214 Change-Id: I198c0bf22c8bd1792d90269c917b6e58c074464d
This commit is contained in:
parent
e1ae3ca451
commit
03341c4342
@ -310,7 +310,25 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
return while_op.is_stateless();
|
||||
}
|
||||
return false;
|
||||
|
||||
// Try to get the statefulness flag from the registry.
|
||||
//
|
||||
// TODO(yuanzx): Remove this after all ops are defined in the dialect.
|
||||
if (op->getName().getDialect() !=
|
||||
TF::TensorFlowDialect::getDialectNamespace()) {
|
||||
return false;
|
||||
}
|
||||
StringRef op_name = op->getName().getStringRef();
|
||||
// Drop the `tf.` prefix to query TF registry.
|
||||
auto node_name =
|
||||
op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
if (!tensorflow::OpRegistry::Global()
|
||||
->LookUp(node_name.data(), &op_reg_data)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
return !op_reg_data->op_def.is_stateful();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -737,3 +737,43 @@ func @while_cond(
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies based on TF op registry
|
||||
// statefulness flag, for ops not yet defined in ODS.
|
||||
|
||||
// CHECK-LABEL: func @tf_registry_ops
|
||||
func @tf_registry_ops(
|
||||
// expected-remark@above {{ID: 8}}
|
||||
%arg0: tensor<!tf.string>, %arg1: tensor<!tf.string>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
"tf.PrintV2"(%arg0) { output_stream = "stderr", end = "\n" }
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
: (tensor<!tf.string>) -> ()
|
||||
%merge_summary = "tf.MergeSummary"(%arg0, %arg1) { N = 2 }
|
||||
// expected-remark@above {{ID: 1}}
|
||||
: (tensor<!tf.string>, tensor<!tf.string>) -> (tensor<!tf.string>)
|
||||
"tf.PrintV2"(%merge_summary) { output_stream = "stderr", end = "\n" }
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {3}}}
|
||||
: (tensor<!tf.string>) -> ()
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {2}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user