Set tf2_behavior to 1 to enable V2 for early loading cases
PiperOrigin-RevId: 303110867 Change-Id: I5ca4706a707e346581928b5a7b91af49c8b5d29e
This commit is contained in:
parent
f76195954f
commit
410852dbd2
@ -42,6 +42,7 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
|
||||
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
|
||||
_os.environ['TF2_BEHAVIOR'] = '1'
|
||||
from tensorflow.python import tf2 as _tf2
|
||||
_tf2.enable()
|
||||
|
||||
|
||||
@ -79,6 +79,18 @@ class ModuleTest(test.TestCase):
|
||||
tf.compat.v1.summary.FileWriter
|
||||
# pylint: enable=pointless-statement
|
||||
|
||||
def testInternalKerasImport(self):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras import layers
|
||||
normalization_parent = layers.Normalization.__module__.split('.')[-1]
|
||||
if tf._major_api_version == 2:
|
||||
self.assertEqual('normalization', normalization_parent)
|
||||
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||
else:
|
||||
self.assertEqual('normalization_v1', normalization_parent)
|
||||
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user