Remove automatic control dep wrapping from layers in v2.

PiperOrigin-RevId: 316929712
Change-Id: Ic1a7d125776eeb0c7654e321dd6f2351c8656a16
This commit is contained in:
Pavithra Vijay 2020-06-17 11:32:39 -07:00 committed by TensorFlower Gardener
parent 9eed8b3cf5
commit 653131dd38

View File

@ -40,7 +40,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.eager import monitoring from tensorflow.python.eager import monitoring
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -1105,17 +1104,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
try: try:
with ops.enable_auto_cast_variables(self._compute_dtype_object): with ops.enable_auto_cast_variables(self._compute_dtype_object):
# Add auto_control_deps in V2 when they are not already added by outputs = call_fn(cast_inputs, *args, **kwargs)
# a `tf.function`.
if (ops.executing_eagerly_outside_functions() and
not base_layer_utils.is_in_eager_or_tf_function()):
with auto_control_deps.AutomaticControlDependencies() as acd:
outputs = call_fn(cast_inputs, *args, **kwargs)
# Wrap Tensors in `outputs` in `tf.identity` to avoid
# circular dependencies.
outputs = base_layer_utils.mark_as_return(outputs, acd)
else:
outputs = call_fn(cast_inputs, *args, **kwargs)
except errors.OperatorNotAllowedInGraphError as e: except errors.OperatorNotAllowedInGraphError as e:
raise TypeError('You are attempting to use Python control ' raise TypeError('You are attempting to use Python control '