Do not cache variable scopes in legacy_tf_layers.base.Layer when Eager is enabled (regardless of the current context).
Caching the variable scope causes the layer to be "poisoned" when used within a tf.function, since if the layer is called for the first time inside a tf.function, then a FuncGraph scope is captured and then re-entered on every subsequent call. This caching was simply a graph-building (Python) performance optimization and can be skipped if Eager is enabled. PiperOrigin-RevId: 351286068 Change-Id: Ia82958b8fca83bac36f6bc3ce4dd08a8e5011ca0
This commit is contained in:
parent
8f9ba797cc
commit
6ae54f7c58
@ -527,12 +527,20 @@ class Layer(base_layer.Layer):
|
||||
# rather than initializing to None we check for an AttributeError.
|
||||
scope_context_manager = self._always_reuse_variable_scope
|
||||
except AttributeError:
|
||||
scope_context_manager = None
|
||||
|
||||
if scope_context_manager is None:
|
||||
# From this point we will always set reuse=True, so create a "final"
|
||||
# variable scope with this setting. We avoid re-creating variable scopes
|
||||
# after this point as an optimization.
|
||||
self._always_reuse_variable_scope = vs.variable_scope(
|
||||
scope_context_manager = vs.variable_scope(
|
||||
self._scope, reuse=True, auxiliary_name_scope=False)
|
||||
scope_context_manager = self._always_reuse_variable_scope
|
||||
|
||||
# Do not cache variable scopes if Eager mode is enabled. If Eager mode
|
||||
# is enabled then we don't want to reuse scopes because the cached scope
|
||||
# might be from a FuncGraph or Eager scope we are no longer in.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
self._always_reuse_variable_scope = scope_context_manager
|
||||
else:
|
||||
scope_context_manager = vs.variable_scope(
|
||||
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
|
||||
|
@ -112,6 +112,7 @@ do_pylint() {
|
||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
|
||||
"^tensorflow/python/keras/layers/preprocessing/.*\[E1102.*not-callable "\
|
||||
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/keras/legacy_tf_layers/base\.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
||||
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
|
||||
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\
|
||||
|
Loading…
Reference in New Issue
Block a user