diff --git a/util/flags.py b/util/flags.py index a4cb4979..468d8e50 100644 --- a/util/flags.py +++ b/util/flags.py @@ -2,10 +2,10 @@ from __future__ import absolute_import, division, print_function import tensorflow as tf +import os FLAGS = tf.app.flags.FLAGS - def create_flags(): # Importer # ======== @@ -113,3 +113,26 @@ def create_flags(): # Inference mode f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.') + + # Register validators for paths which require a file to be specified + + f.register_validator('lm_binary_path', + os.path.isfile, + message='The file pointed to by --lm_binary_path must exist and be readable.') + + f.register_validator('lm_trie_path', + os.path.isfile, + message='The file pointed to by --lm_trie_path must exist and be readable.') + + f.register_validator('alphabet_config_path', + os.path.isfile, + message='The file pointed to by --alphabet_config_path must exist and be readable.') + + f.register_validator('one_shot_infer', + lambda value: not value or os.path.isfile(value), + message='The file pointed to by --one_shot_infer must exist and be readable.') + + f.register_validator('export_dir', + lambda value: not value or os.path.isdir(value), + message='The path pointed to by --export_dir must exist and be a directory.') +