Exclude CudnnRNNV2 from automatic control dependencies.
PiperOrigin-RevId: 264736469
This commit is contained in:
parent
f78666b88e
commit
475e736eea
@ -815,35 +815,28 @@ bool MarkedForXlaCompilation(const Node* n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||||
static const auto* exemption = new absl::flat_hash_set<string>({
|
static const auto* exemption = new absl::flat_hash_set<string>(
|
||||||
// LINT.IfChange
|
{// LINT.IfChange
|
||||||
// Op types that should not run in program order, e.g. because they need
|
// Op types that should not run in program order, e.g. because they need
|
||||||
// to run asynchronously to avoid deadlock.
|
// to run asynchronously to avoid deadlock.
|
||||||
"CollectiveGather",
|
"CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
|
||||||
"CollectiveReduce",
|
"CollectiveBcastRecv", "NcclAllReduce",
|
||||||
"CollectiveBcastSend",
|
|
||||||
"CollectiveBcastRecv",
|
|
||||||
"NcclAllReduce",
|
|
||||||
|
|
||||||
// Legacy random ops.
|
// Legacy random ops.
|
||||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||||
"RandomUniform",
|
"RandomUniform", "RandomUniformInt", "RandomStandardNormal",
|
||||||
"RandomUniformInt",
|
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
||||||
"RandomStandardNormal",
|
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
||||||
"ParameterizedTruncatedNormal",
|
|
||||||
"TruncatedNormal",
|
|
||||||
"RandomShuffle",
|
|
||||||
"Multinomial",
|
|
||||||
"RandomGamma",
|
|
||||||
"RandomGammaGrad",
|
|
||||||
"RandomPoisson",
|
|
||||||
"RandomPoissonV2",
|
"RandomPoissonV2",
|
||||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||||
|
|
||||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||||
// but it can't generate any observable side-effect.
|
// but it can't generate any observable side-effect.
|
||||||
"ReadVariableOp",
|
"ReadVariableOp",
|
||||||
});
|
|
||||||
|
// CudnnRNN ops are stateful but they can't generate any observable
|
||||||
|
// side-effect.
|
||||||
|
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3"});
|
||||||
return exemption->contains(op);
|
return exemption->contains(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,9 +86,15 @@ LEGACY_RANDOM_OPS = [
|
|||||||
"RandomPoisson",
|
"RandomPoisson",
|
||||||
"RandomPoissonV2",
|
"RandomPoissonV2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_ORDER_INSENSITIVE_STATEFUL_OPS = [
|
||||||
|
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3"
|
||||||
|
]
|
||||||
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
|
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
|
||||||
|
|
||||||
_ALL_BLACKLISTED_OPS = set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
_ALL_BLACKLISTED_OPS = (
|
||||||
|
set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
||||||
|
| set(_ORDER_INSENSITIVE_STATEFUL_OPS))
|
||||||
|
|
||||||
|
|
||||||
def op_is_stateful(op):
|
def op_is_stateful(op):
|
||||||
|
Loading…
Reference in New Issue
Block a user