Fix remote I/O handling in train

This commit is contained in:
CatalinVoss 2020-11-12 16:29:16 -08:00
parent 8f31072998
commit a6322b384e

View File

@ -35,7 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
from .util.io import open_remote, remove_remote, listdir_remote
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
check_ctcdecoder_version()
@ -513,7 +513,8 @@ def train():
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
if not is_remote_path(FLAGS.save_checkpoint_dir):
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open_remote(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
@ -813,11 +814,11 @@ def export():
if FLAGS.remove_remote_export:
if isdir_remote(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
remove_remote(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not isdir_remote(FLAGS.export_dir):
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(