[MLIR:TF/XLA] Use TF op registry's statefulness flag in side-effect analysis
for ops that are not yet defined in ODS. PiperOrigin-RevId: 286109214 Change-Id: I198c0bf22c8bd1792d90269c917b6e58c074464d
This commit is contained in:
parent
e1ae3ca451
commit
03341c4342
@ -310,7 +310,25 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
|
|||||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||||
return while_op.is_stateless();
|
return while_op.is_stateless();
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
|
// Try to get the statefulness flag from the registry.
|
||||||
|
//
|
||||||
|
// TODO(yuanzx): Remove this after all ops are defined in the dialect.
|
||||||
|
if (op->getName().getDialect() !=
|
||||||
|
TF::TensorFlowDialect::getDialectNamespace()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
StringRef op_name = op->getName().getStringRef();
|
||||||
|
// Drop the `tf.` prefix to query TF registry.
|
||||||
|
auto node_name =
|
||||||
|
op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||||
|
const tensorflow::OpRegistrationData* op_reg_data;
|
||||||
|
if (!tensorflow::OpRegistry::Global()
|
||||||
|
->LookUp(node_name.data(), &op_reg_data)
|
||||||
|
.ok()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return !op_reg_data->op_def.is_stateful();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -737,3 +737,43 @@ func @while_cond(
|
|||||||
// expected-remark@above {{ID: 6}}
|
// expected-remark@above {{ID: 6}}
|
||||||
// expected-remark@above {{Predecessors: {5}}}
|
// expected-remark@above {{Predecessors: {5}}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests that the pass tracks control dependencies based on TF op registry
|
||||||
|
// statefulness flag, for ops not yet defined in ODS.
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tf_registry_ops
|
||||||
|
func @tf_registry_ops(
|
||||||
|
// expected-remark@above {{ID: 8}}
|
||||||
|
%arg0: tensor<!tf.string>, %arg1: tensor<!tf.string>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
// expected-remark@above {{ID: 6}}
|
||||||
|
// expected-remark@above {{Successors: {7}}}
|
||||||
|
%island = tf_executor.island {
|
||||||
|
// expected-remark@above {{ID: 4}}
|
||||||
|
// expected-remark@above {{Successors: {5}}}
|
||||||
|
"tf.PrintV2"(%arg0) { output_stream = "stderr", end = "\n" }
|
||||||
|
// expected-remark@above {{ID: 0}}
|
||||||
|
// expected-remark@above {{Successors: {2}}}
|
||||||
|
: (tensor<!tf.string>) -> ()
|
||||||
|
%merge_summary = "tf.MergeSummary"(%arg0, %arg1) { N = 2 }
|
||||||
|
// expected-remark@above {{ID: 1}}
|
||||||
|
: (tensor<!tf.string>, tensor<!tf.string>) -> (tensor<!tf.string>)
|
||||||
|
"tf.PrintV2"(%merge_summary) { output_stream = "stderr", end = "\n" }
|
||||||
|
// expected-remark@above {{ID: 2}}
|
||||||
|
// expected-remark@above {{Predecessors: {0}}}
|
||||||
|
// expected-remark@above {{Successors: {3}}}
|
||||||
|
: (tensor<!tf.string>) -> ()
|
||||||
|
tf_executor.yield
|
||||||
|
// expected-remark@above {{ID: 3}}
|
||||||
|
// expected-remark@above {{Predecessors: {2}}}
|
||||||
|
}
|
||||||
|
tf_executor.fetch %island : !tf_executor.control
|
||||||
|
// expected-remark@above {{ID: 5}}
|
||||||
|
// expected-remark@above {{Predecessors: {4}}}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// expected-remark@above {{ID: 7}}
|
||||||
|
// expected-remark@above {{Predecessors: {6}}}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user