diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 09fc57b9696..95c1b4e3e38 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -815,35 +815,28 @@ bool MarkedForXlaCompilation(const Node* n) { } const bool IsExemptFromSideEffectsExecutionValidation(const string& op) { - static const auto* exemption = new absl::flat_hash_set({ - // LINT.IfChange - // Op types that should not run in program order, e.g. because they need - // to run asynchronously to avoid deadlock. - "CollectiveGather", - "CollectiveReduce", - "CollectiveBcastSend", - "CollectiveBcastRecv", - "NcclAllReduce", + static const auto* exemption = new absl::flat_hash_set( + {// LINT.IfChange + // Op types that should not run in program order, e.g. because they need + // to run asynchronously to avoid deadlock. + "CollectiveGather", "CollectiveReduce", "CollectiveBcastSend", + "CollectiveBcastRecv", "NcclAllReduce", - // Legacy random ops. - // See details in tensorflow/python/framework/auto_control_deps.py. - "RandomUniform", - "RandomUniformInt", - "RandomStandardNormal", - "ParameterizedTruncatedNormal", - "TruncatedNormal", - "RandomShuffle", - "Multinomial", - "RandomGamma", - "RandomGammaGrad", - "RandomPoisson", - "RandomPoissonV2", - // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py) + // Legacy random ops. + // See details in tensorflow/python/framework/auto_control_deps.py. + "RandomUniform", "RandomUniformInt", "RandomStandardNormal", + "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle", + "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson", + "RandomPoissonV2", + // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py) - // ReadVariableOp marked as stateful because it consumes DT_RESOURCE, - // but it can't generate any observable side-effect. - "ReadVariableOp", - }); + // ReadVariableOp marked as stateful because it consumes DT_RESOURCE, + // but it can't generate any observable side-effect. + "ReadVariableOp", + + // CudnnRNN ops are stateful but they can't generate any observable + // side-effect. + "CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3"}); return exemption->contains(op); } diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index 1d2757bdacf..f50e8c8e3e8 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -86,9 +86,15 @@ LEGACY_RANDOM_OPS = [ "RandomPoisson", "RandomPoissonV2", ] + +_ORDER_INSENSITIVE_STATEFUL_OPS = [ + "CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3" +] # LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc) -_ALL_BLACKLISTED_OPS = set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) +_ALL_BLACKLISTED_OPS = ( + set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) + | set(_ORDER_INSENSITIVE_STATEFUL_OPS)) def op_is_stateful(op):