Update Keras code to use public TF file API.
1. Update to use v2 public API from their v1 version. 2. Use the Gfile API to read/write file content for save_model code. PiperOrigin-RevId: 328159642 Change-Id: I38373fb16449ab8d19f15f4e22ad99d3c598266b
This commit is contained in:
parent
a4219770e9
commit
ac4e209f8c
@ -192,7 +192,7 @@ def DenseNet(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -269,7 +269,7 @@ def EfficientNet(
|
|||||||
if blocks_args == 'default':
|
if blocks_args == 'default':
|
||||||
blocks_args = DEFAULT_BLOCKS_ARGS
|
blocks_args = DEFAULT_BLOCKS_ARGS
|
||||||
|
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -112,7 +112,7 @@ def InceptionResNetV2(include_top=True,
|
|||||||
layers = VersionAwareLayers()
|
layers = VersionAwareLayers()
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -108,7 +108,7 @@ def InceptionV3(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -164,7 +164,7 @@ def MobileNet(input_shape=None,
|
|||||||
layers = VersionAwareLayers()
|
layers = VersionAwareLayers()
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -180,7 +180,7 @@ def MobileNetV2(input_shape=None,
|
|||||||
layers = VersionAwareLayers()
|
layers = VersionAwareLayers()
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -158,7 +158,7 @@ def MobileNetV3(stack_fn,
|
|||||||
pooling=None,
|
pooling=None,
|
||||||
dropout_rate=0.2,
|
dropout_rate=0.2,
|
||||||
classifier_activation='softmax'):
|
classifier_activation='softmax'):
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -150,7 +150,7 @@ def NASNet(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -137,7 +137,7 @@ def ResNet(stack_fn,
|
|||||||
layers = VersionAwareLayers()
|
layers = VersionAwareLayers()
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
raise ValueError('Unknown argument(s): %s' % (kwargs,))
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -113,7 +113,7 @@ def VGG16(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -113,7 +113,7 @@ def VGG19(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -113,7 +113,7 @@ def Xception(
|
|||||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||||
using a pretrained top layer.
|
using a pretrained top layer.
|
||||||
"""
|
"""
|
||||||
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
|
if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
|
||||||
raise ValueError('The `weights` argument should be either '
|
raise ValueError('The `weights` argument should be either '
|
||||||
'`None` (random initialization), `imagenet` '
|
'`None` (random initialization), `imagenet` '
|
||||||
'(pre-training on ImageNet), '
|
'(pre-training on ImageNet), '
|
||||||
|
@ -1399,9 +1399,10 @@ class ModelCheckpoint(Callback):
|
|||||||
def _checkpoint_exists(self, filepath):
|
def _checkpoint_exists(self, filepath):
|
||||||
"""Returns whether the checkpoint `filepath` refers to exists."""
|
"""Returns whether the checkpoint `filepath` refers to exists."""
|
||||||
if filepath.endswith('.h5'):
|
if filepath.endswith('.h5'):
|
||||||
return file_io.file_exists(filepath)
|
return file_io.file_exists_v2(filepath)
|
||||||
tf_saved_model_exists = file_io.file_exists(filepath)
|
tf_saved_model_exists = file_io.file_exists_v2(filepath)
|
||||||
tf_weights_only_checkpoint_exists = file_io.file_exists(filepath + '.index')
|
tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
|
||||||
|
filepath + '.index')
|
||||||
return tf_saved_model_exists or tf_weights_only_checkpoint_exists
|
return tf_saved_model_exists or tf_weights_only_checkpoint_exists
|
||||||
|
|
||||||
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
||||||
@ -1466,7 +1467,7 @@ class ModelCheckpoint(Callback):
|
|||||||
n_file_with_latest_mod_time = 0
|
n_file_with_latest_mod_time = 0
|
||||||
file_path_with_largest_file_name = None
|
file_path_with_largest_file_name = None
|
||||||
|
|
||||||
if file_io.file_exists(dir_name):
|
if file_io.file_exists_v2(dir_name):
|
||||||
for file_name in os.listdir(dir_name):
|
for file_name in os.listdir(dir_name):
|
||||||
# Only consider if `file_name` matches the pattern.
|
# Only consider if `file_name` matches the pattern.
|
||||||
if re.match(base_name_regex, file_name):
|
if re.match(base_name_regex, file_name):
|
||||||
@ -2505,7 +2506,7 @@ class CSVLogger(Callback):
|
|||||||
|
|
||||||
def on_train_begin(self, logs=None):
|
def on_train_begin(self, logs=None):
|
||||||
if self.append:
|
if self.append:
|
||||||
if file_io.file_exists(self.filename):
|
if file_io.file_exists_v2(self.filename):
|
||||||
with open(self.filename, 'r' + self.file_flags) as f:
|
with open(self.filename, 'r' + self.file_flags) as f:
|
||||||
self.append_header = not bool(len(f.readline()))
|
self.append_header = not bool(len(f.readline()))
|
||||||
mode = 'a'
|
mode = 'a'
|
||||||
|
@ -37,9 +37,10 @@ from tensorflow.python.platform import test
|
|||||||
def checkpoint_exists(filepath):
|
def checkpoint_exists(filepath):
|
||||||
"""Returns whether the checkpoint `filepath` refers to exists."""
|
"""Returns whether the checkpoint `filepath` refers to exists."""
|
||||||
if filepath.endswith('.h5'):
|
if filepath.endswith('.h5'):
|
||||||
return file_io.file_exists(filepath)
|
return file_io.file_exists_v2(filepath)
|
||||||
tf_saved_model_exists = file_io.file_exists(filepath)
|
tf_saved_model_exists = file_io.file_exists_v2(filepath)
|
||||||
tf_weights_only_checkpoint_exists = file_io.file_exists(filepath + '.index')
|
tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
|
||||||
|
filepath + '.index')
|
||||||
return tf_saved_model_exists or tf_weights_only_checkpoint_exists
|
return tf_saved_model_exists or tf_weights_only_checkpoint_exists
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
num_epoch = 2
|
num_epoch = 2
|
||||||
|
|
||||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
@ -153,7 +154,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
steps_per_epoch=steps,
|
steps_per_epoch=steps,
|
||||||
callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)])
|
callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)])
|
||||||
|
|
||||||
test_obj.assertTrue(file_io.file_exists(saving_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
|
saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
|
||||||
|
|
||||||
@ -185,7 +186,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
num_epoch = 4
|
num_epoch = 4
|
||||||
|
|
||||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')
|
bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -204,8 +205,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
||||||
test_obj.assertTrue(file_io.file_exists(backup_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
|
||||||
test_obj.assertTrue(file_io.file_exists(saving_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
@ -217,8 +218,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
AssertCallback()
|
AssertCallback()
|
||||||
])
|
])
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
test_obj.assertFalse(file_io.file_exists(backup_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
|
||||||
test_obj.assertTrue(file_io.file_exists(saving_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
|
saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
|
||||||
|
|
||||||
@ -244,7 +245,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
'logfile_%s_%d' % (task_config['type'], task_config['index']))
|
'logfile_%s_%d' % (task_config['type'], task_config['index']))
|
||||||
|
|
||||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
@ -257,7 +258,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
# `file_io.list_directory()` since the directory may be created at this
|
# `file_io.list_directory()` since the directory may be created at this
|
||||||
# point.
|
# point.
|
||||||
test_obj.assertEqual(
|
test_obj.assertEqual(
|
||||||
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
bool(file_io.list_directory_v2(saving_filepath)),
|
||||||
|
test_base.is_chief())
|
||||||
|
|
||||||
multi_process_runner.run(
|
multi_process_runner.run(
|
||||||
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||||
@ -280,7 +282,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
# Verifies that even if `saving_filepath_for_temp` exists, tensorboard
|
# Verifies that even if `saving_filepath_for_temp` exists, tensorboard
|
||||||
# can still save to temporary directory.
|
# can still save to temporary directory.
|
||||||
test_obj.assertTrue(file_io.file_exists(saving_filepath_for_temp))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath_for_temp))
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
@ -301,7 +303,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
num_epoch = 2
|
num_epoch = 2
|
||||||
|
|
||||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
|
|
||||||
@ -313,7 +315,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
|
|
||||||
test_obj.assertTrue(file_io.list_directory(saving_filepath))
|
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))
|
||||||
|
|
||||||
saving_filepath = os.path.join(self.get_temp_dir(), 'logfile')
|
saving_filepath = os.path.join(self.get_temp_dir(), 'logfile')
|
||||||
|
|
||||||
|
@ -159,9 +159,10 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
|||||||
# Make sure chief finishes saving before non-chief's assertions.
|
# Make sure chief finishes saving before non-chief's assertions.
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
|
|
||||||
if not file_io.file_exists(model_path):
|
if not file_io.file_exists_v2(model_path):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
if file_io.file_exists(write_model_path) != _is_chief(task_type, task_id):
|
if file_io.file_exists_v2(write_model_path) != _is_chief(
|
||||||
|
task_type, task_id):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
loaded_model = keras.saving.save.load_model(model_path)
|
loaded_model = keras.saving.save.load_model(model_path)
|
||||||
@ -179,9 +180,9 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
|||||||
# Make sure chief finishes saving before non-chief's assertions.
|
# Make sure chief finishes saving before non-chief's assertions.
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.barrier().wait()
|
||||||
|
|
||||||
if not file_io.file_exists(checkpoint_dir):
|
if not file_io.file_exists_v2(checkpoint_dir):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
if file_io.file_exists(write_checkpoint_dir) != _is_chief(
|
if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
|
||||||
task_type, task_id):
|
task_type, task_id):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
@ -112,12 +112,12 @@ class WorkerTrainingState(object):
|
|||||||
successfully finishes.
|
successfully finishes.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
for pathname in file_io.get_matching_files(
|
for pathname in file_io.get_matching_files_v2(
|
||||||
self.write_checkpoint_manager._prefix + '*'):
|
self.write_checkpoint_manager._prefix + '*'):
|
||||||
file_io.delete_recursively(pathname)
|
file_io.delete_recursively_v2(pathname)
|
||||||
for pathname in file_io.get_matching_files(
|
for pathname in file_io.get_matching_files_v2(
|
||||||
os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')):
|
os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')):
|
||||||
file_io.delete_recursively(pathname)
|
file_io.delete_recursively_v2(pathname)
|
||||||
|
|
||||||
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
|
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
|
||||||
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
||||||
|
@ -48,7 +48,7 @@ class ModelCheckpointTest(test_base.IndependentWorkerTestBase,
|
|||||||
callbacks.ModelCheckpoint(
|
callbacks.ModelCheckpoint(
|
||||||
filepath=saving_filepath, save_weights_only=save_weights_only)
|
filepath=saving_filepath, save_weights_only=save_weights_only)
|
||||||
]
|
]
|
||||||
self.assertFalse(file_io.file_exists(saving_filepath))
|
self.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.fit(
|
model.fit(
|
||||||
@ -56,9 +56,9 @@ class ModelCheckpointTest(test_base.IndependentWorkerTestBase,
|
|||||||
except NotFoundError as e:
|
except NotFoundError as e:
|
||||||
if 'Failed to create a NewWriteableFile' in e.message:
|
if 'Failed to create a NewWriteableFile' in e.message:
|
||||||
self.skipTest('b/138941852, path not found error in Windows py35.')
|
self.skipTest('b/138941852, path not found error in Windows py35.')
|
||||||
tf_saved_model_exists = file_io.file_exists(saving_filepath)
|
tf_saved_model_exists = file_io.file_exists_v2(saving_filepath)
|
||||||
tf_weights_only_checkpoint_exists = file_io.file_exists(saving_filepath +
|
tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
|
||||||
'.index')
|
saving_filepath + '.index')
|
||||||
self.assertTrue(tf_saved_model_exists or tf_weights_only_checkpoint_exists)
|
self.assertTrue(tf_saved_model_exists or tf_weights_only_checkpoint_exists)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1276,7 +1276,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
|||||||
prefix = 'ackpt'
|
prefix = 'ackpt'
|
||||||
self.evaluate(v.assign(42.))
|
self.evaluate(v.assign(42.))
|
||||||
m.save_weights(prefix)
|
m.save_weights(prefix)
|
||||||
self.assertTrue(file_io.file_exists('ackpt.index'))
|
self.assertTrue(file_io.file_exists_v2('ackpt.index'))
|
||||||
self.evaluate(v.assign(1.))
|
self.evaluate(v.assign(1.))
|
||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
self.assertEqual(42., self.evaluate(v))
|
self.assertEqual(42., self.evaluate(v))
|
||||||
@ -1284,7 +1284,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
|||||||
prefix = 'subdir/ackpt'
|
prefix = 'subdir/ackpt'
|
||||||
self.evaluate(v.assign(43.))
|
self.evaluate(v.assign(43.))
|
||||||
m.save_weights(prefix)
|
m.save_weights(prefix)
|
||||||
self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
|
self.assertTrue(file_io.file_exists_v2('subdir/ackpt.index'))
|
||||||
self.evaluate(v.assign(2.))
|
self.evaluate(v.assign(2.))
|
||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
self.assertEqual(43., self.evaluate(v))
|
self.assertEqual(43., self.evaluate(v))
|
||||||
@ -1292,7 +1292,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
|
|||||||
prefix = 'ackpt/'
|
prefix = 'ackpt/'
|
||||||
self.evaluate(v.assign(44.))
|
self.evaluate(v.assign(44.))
|
||||||
m.save_weights(prefix)
|
m.save_weights(prefix)
|
||||||
self.assertTrue(file_io.file_exists('ackpt/.index'))
|
self.assertTrue(file_io.file_exists_v2('ackpt/.index'))
|
||||||
self.evaluate(v.assign(3.))
|
self.evaluate(v.assign(3.))
|
||||||
m.load_weights(prefix)
|
m.load_weights(prefix)
|
||||||
self.assertEqual(44., self.evaluate(v))
|
self.assertEqual(44., self.evaluate(v))
|
||||||
|
@ -30,8 +30,8 @@ from tensorflow.python.keras.saving import model_config
|
|||||||
from tensorflow.python.keras.saving import saving_utils
|
from tensorflow.python.keras.saving import saving_utils
|
||||||
from tensorflow.python.keras.utils import mode_keys
|
from tensorflow.python.keras.utils import mode_keys
|
||||||
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||||
from tensorflow.python.lib.io import file_io
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
@ -152,7 +152,8 @@ def _export_model_json(model, saved_model_path):
|
|||||||
model_json_filepath = os.path.join(
|
model_json_filepath = os.path.join(
|
||||||
saved_model_utils.get_or_create_assets_dir(saved_model_path),
|
saved_model_utils.get_or_create_assets_dir(saved_model_path),
|
||||||
compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
|
compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
|
||||||
file_io.write_string_to_file(model_json_filepath, model_json)
|
with gfile.Open(model_json_filepath, 'w') as f:
|
||||||
|
f.write(model_json)
|
||||||
|
|
||||||
|
|
||||||
def _export_model_variables(model, saved_model_path):
|
def _export_model_variables(model, saved_model_path):
|
||||||
@ -417,7 +418,8 @@ def load_from_saved_model(saved_model_path, custom_objects=None):
|
|||||||
compat.as_bytes(saved_model_path),
|
compat.as_bytes(saved_model_path),
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
|
||||||
model_json = file_io.read_file_to_string(model_json_filepath)
|
with gfile.Open(model_json_filepath, 'r') as f:
|
||||||
|
model_json = f.read()
|
||||||
model = model_config.model_from_json(
|
model = model_config.model_from_json(
|
||||||
model_json, custom_objects=custom_objects)
|
model_json, custom_objects=custom_objects)
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ class ModelToDotFormatTest(test.TestCase):
|
|||||||
try:
|
try:
|
||||||
vis_utils.plot_model(
|
vis_utils.plot_model(
|
||||||
model, to_file=dot_img_file, show_shapes=True, show_dtype=True)
|
model, to_file=dot_img_file, show_shapes=True, show_dtype=True)
|
||||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
self.assertTrue(file_io.file_exists_v2(dot_img_file))
|
||||||
file_io.delete_file(dot_img_file)
|
file_io.delete_file_v2(dot_img_file)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,8 +68,8 @@ class ModelToDotFormatTest(test.TestCase):
|
|||||||
show_shapes=True,
|
show_shapes=True,
|
||||||
show_dtype=True,
|
show_dtype=True,
|
||||||
expand_nested=True)
|
expand_nested=True)
|
||||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
self.assertTrue(file_io.file_exists_v2(dot_img_file))
|
||||||
file_io.delete_file(dot_img_file)
|
file_io.delete_file_v2(dot_img_file)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -86,8 +86,8 @@ class ModelToDotFormatTest(test.TestCase):
|
|||||||
show_shapes=True,
|
show_shapes=True,
|
||||||
show_dtype=True,
|
show_dtype=True,
|
||||||
expand_nested=True)
|
expand_nested=True)
|
||||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
self.assertTrue(file_io.file_exists_v2(dot_img_file))
|
||||||
file_io.delete_file(dot_img_file)
|
file_io.delete_file_v2(dot_img_file)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -102,8 +102,8 @@ class ModelToDotFormatTest(test.TestCase):
|
|||||||
show_shapes=True,
|
show_shapes=True,
|
||||||
show_dtype=True,
|
show_dtype=True,
|
||||||
expand_nested=True)
|
expand_nested=True)
|
||||||
self.assertTrue(file_io.file_exists(dot_img_file))
|
self.assertTrue(file_io.file_exists_v2(dot_img_file))
|
||||||
file_io.delete_file(dot_img_file)
|
file_io.delete_file_v2(dot_img_file)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user