Fix tests when tf._major_api_version does not exist

This commit is contained in:
Mihai Maruseac 2020-05-13 18:33:18 -07:00
parent 90d1854a80
commit d0106f72ea

View File

@ -74,20 +74,22 @@ class ModuleTest(test.TestCase):
tf.summary.image tf.summary.image
# If we use v2 API, check for create_file_writer, # If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter. # otherwise check for FileWriter.
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2: if hasattr(tf, '_major_api_version'):
tf.summary.create_file_writer if tf._major_api_version == 2:
else: tf.summary.create_file_writer
tf.summary.FileWriter else:
tf.summary.FileWriter
# pylint: enable=pointless-statement # pylint: enable=pointless-statement
def testInternalKerasImport(self): def testInternalKerasImport(self):
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1] normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
if tf._major_api_version == 2: if hasattr(tf, '_major_api_version'):
self.assertEqual('normalization_v2', normalization_parent) if tf._major_api_version == 2:
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR) self.assertEqual('normalization_v2', normalization_parent)
else: self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
self.assertEqual('normalization', normalization_parent) else:
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR) self.assertEqual('normalization', normalization_parent)
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
if __name__ == '__main__': if __name__ == '__main__':