Fix remote I/O handling in train
This commit is contained in:
parent
8f31072998
commit
a6322b384e
@ -35,7 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||
from .util.io import open_remote, remove_remote, listdir_remote
|
||||
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
@ -513,7 +513,8 @@ def train():
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open_remote(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
@ -813,11 +814,11 @@ def export():
|
||||
if FLAGS.remove_remote_export:
|
||||
if isdir_remote(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
remove_remote(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not isdir_remote(FLAGS.export_dir):
|
||||
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
|
Loading…
x
Reference in New Issue
Block a user