Do not cache variable scopes in legacy_tf_layers.base.Layer when Eager is enabled (regardless of the current context).

Caching the variable scope causes the layer to be "poisoned" when used within a tf.function, since if the layer is called for the first time inside a tf.function, then a FuncGraph scope is captured and then re-entered on every subsequent call. This caching was simply a graph-building (Python) performance optimization and can be skipped if Eager is enabled.

PiperOrigin-RevId: 351286068
Change-Id: Ia82958b8fca83bac36f6bc3ce4dd08a8e5011ca0
This commit is contained in:
RJ Skerry-Ryan 2021-01-11 19:36:37 -08:00 committed by TensorFlower Gardener
parent 8f9ba797cc
commit 6ae54f7c58
2 changed files with 11 additions and 2 deletions

View File

@ -527,12 +527,20 @@ class Layer(base_layer.Layer):
# rather than initializing to None we check for an AttributeError.
scope_context_manager = self._always_reuse_variable_scope
except AttributeError:
scope_context_manager = None
if scope_context_manager is None:
# From this point we will always set reuse=True, so create a "final"
# variable scope with this setting. We avoid re-creating variable scopes
# after this point as an optimization.
self._always_reuse_variable_scope = vs.variable_scope(
scope_context_manager = vs.variable_scope(
self._scope, reuse=True, auxiliary_name_scope=False)
scope_context_manager = self._always_reuse_variable_scope
# Do not cache variable scopes if Eager mode is enabled. If Eager mode
# is enabled then we don't want to reuse scopes because the cached scope
# might be from a FuncGraph or Eager scope we are no longer in.
if not ops.executing_eagerly_outside_functions():
self._always_reuse_variable_scope = scope_context_manager
else:
scope_context_manager = vs.variable_scope(
self._scope, reuse=self._reuse, auxiliary_name_scope=False)

View File

@ -112,6 +112,7 @@ do_pylint() {
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
"^tensorflow/python/keras/layers/preprocessing/.*\[E1102.*not-callable "\
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/keras/legacy_tf_layers/base\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\