diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 42b1767c..6be4ef55 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -1116,7 +1116,10 @@ def export(): input_tensors=inputs.values(), output_tensors=outputs.values(), ) - converter.optimizations = [tf.lite.Optimize.DEFAULT] + + if Config.export_quantize: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite converter.allow_custom_ops = True tflite_model = converter.convert() @@ -1256,6 +1259,10 @@ def early_training_checks(): "for loading and saving." ) + if not Config.alphabet_config_path and not Config.bytes_output_mode: + log_error("Missing --alphabet_config_path flag, can't continue") + sys.exit(1) + def main(): initialize_globals_from_cli() diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index b426fb7a..0b072243 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -22,7 +22,7 @@ class _ConfigSingleton: _config = None def __getattr__(self, name): - if not _ConfigSingleton._config: + if _ConfigSingleton._config is None: raise RuntimeError("Global configuration not yet initialized.") if not hasattr(_ConfigSingleton._config, name): raise RuntimeError( @@ -478,6 +478,10 @@ class _SttConfig(Coqpit): export_tflite: bool = field( default=False, metadata=dict(help="export a graph ready for TF Lite engine") ) + export_quantize: bool = field( + default=True, + metadata=dict(help="export a quantized model (optimized for size)"), + ) n_steps: int = field( default=16, metadata=dict(