Fix tests when tf._major_api_version
does not exist
This commit is contained in:
parent
90d1854a80
commit
d0106f72ea
@ -74,20 +74,22 @@ class ModuleTest(test.TestCase):
|
|||||||
tf.summary.image
|
tf.summary.image
|
||||||
# If we use v2 API, check for create_file_writer,
|
# If we use v2 API, check for create_file_writer,
|
||||||
# otherwise check for FileWriter.
|
# otherwise check for FileWriter.
|
||||||
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
|
if hasattr(tf, '_major_api_version'):
|
||||||
tf.summary.create_file_writer
|
if tf._major_api_version == 2:
|
||||||
else:
|
tf.summary.create_file_writer
|
||||||
tf.summary.FileWriter
|
else:
|
||||||
|
tf.summary.FileWriter
|
||||||
# pylint: enable=pointless-statement
|
# pylint: enable=pointless-statement
|
||||||
|
|
||||||
def testInternalKerasImport(self):
|
def testInternalKerasImport(self):
|
||||||
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
|
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
|
||||||
if tf._major_api_version == 2:
|
if hasattr(tf, '_major_api_version'):
|
||||||
self.assertEqual('normalization_v2', normalization_parent)
|
if tf._major_api_version == 2:
|
||||||
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
self.assertEqual('normalization_v2', normalization_parent)
|
||||||
else:
|
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||||
self.assertEqual('normalization', normalization_parent)
|
else:
|
||||||
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
self.assertEqual('normalization', normalization_parent)
|
||||||
|
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user