Exclude CudnnRNNV2 from automatic control dependencies.
PiperOrigin-RevId: 264736469
This commit is contained in:
parent
f78666b88e
commit
475e736eea
@ -815,35 +815,28 @@ bool MarkedForXlaCompilation(const Node* n) {
|
||||
}
|
||||
|
||||
const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||
static const auto* exemption = new absl::flat_hash_set<string>({
|
||||
// LINT.IfChange
|
||||
// Op types that should not run in program order, e.g. because they need
|
||||
// to run asynchronously to avoid deadlock.
|
||||
"CollectiveGather",
|
||||
"CollectiveReduce",
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastRecv",
|
||||
"NcclAllReduce",
|
||||
static const auto* exemption = new absl::flat_hash_set<string>(
|
||||
{// LINT.IfChange
|
||||
// Op types that should not run in program order, e.g. because they need
|
||||
// to run asynchronously to avoid deadlock.
|
||||
"CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
|
||||
"CollectiveBcastRecv", "NcclAllReduce",
|
||||
|
||||
// Legacy random ops.
|
||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"RandomStandardNormal",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"TruncatedNormal",
|
||||
"RandomShuffle",
|
||||
"Multinomial",
|
||||
"RandomGamma",
|
||||
"RandomGammaGrad",
|
||||
"RandomPoisson",
|
||||
"RandomPoissonV2",
|
||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||
// Legacy random ops.
|
||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||
"RandomUniform", "RandomUniformInt", "RandomStandardNormal",
|
||||
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
||||
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
||||
"RandomPoissonV2",
|
||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||
|
||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||
// but it can't generate any observable side-effect.
|
||||
"ReadVariableOp",
|
||||
});
|
||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||
// but it can't generate any observable side-effect.
|
||||
"ReadVariableOp",
|
||||
|
||||
// CudnnRNN ops are stateful but they can't generate any observable
|
||||
// side-effect.
|
||||
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3"});
|
||||
return exemption->contains(op);
|
||||
}
|
||||
|
||||
|
@ -86,9 +86,15 @@ LEGACY_RANDOM_OPS = [
|
||||
"RandomPoisson",
|
||||
"RandomPoissonV2",
|
||||
]
|
||||
|
||||
_ORDER_INSENSITIVE_STATEFUL_OPS = [
|
||||
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3"
|
||||
]
|
||||
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
|
||||
|
||||
_ALL_BLACKLISTED_OPS = set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
||||
_ALL_BLACKLISTED_OPS = (
|
||||
set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
||||
| set(_ORDER_INSENSITIVE_STATEFUL_OPS))
|
||||
|
||||
|
||||
def op_is_stateful(op):
|
||||
|
Loading…
Reference in New Issue
Block a user