Fix CI errors, address comments, update examples

This commit is contained in:
Reuben Morais 2020-01-29 11:53:33 +01:00
parent c512383aec
commit 3637f88c06
10 changed files with 21 additions and 45 deletions

View File

@ -876,11 +876,6 @@ def package_zip():
with open(os.path.join(export_dir, 'info.json'), 'w') as f: with open(os.path.join(export_dir, 'info.json'), 'w') as f:
json.dump({ json.dump({
'name': FLAGS.export_language, 'name': FLAGS.export_language,
'parameters': {
'beamWidth': FLAGS.export_beam_width,
'lmAlpha': FLAGS.lm_alpha,
'lmBeta': FLAGS.lm_beta
}
}, f) }, f)
shutil.copy(FLAGS.scorer_path, export_dir) shutil.copy(FLAGS.scorer_path, export_dir)

View File

@ -33,8 +33,6 @@ public class BasicTest {
public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer"; public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer";
public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav"; public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav";
public static final int BEAM_WIDTH = 50;
private char readLEChar(RandomAccessFile f) throws IOException { private char readLEChar(RandomAccessFile f) throws IOException {
byte b1 = f.readByte(); byte b1 = f.readByte();
byte b2 = f.readByte(); byte b2 = f.readByte();
@ -117,7 +115,6 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_stt_noLM() { public void loadDeepSpeech_stt_noLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, false); String decoded = doSTT(m, false);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);
@ -127,7 +124,6 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_stt_withLM() { public void loadDeepSpeech_stt_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, false); String decoded = doSTT(m, false);
@ -138,7 +134,6 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_sttWithMetadata_noLM() { public void loadDeepSpeech_sttWithMetadata_noLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, true); String decoded = doSTT(m, true);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);
@ -148,7 +143,6 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_sttWithMetadata_withLM() { public void loadDeepSpeech_sttWithMetadata_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, true); String decoded = doSTT(m, true);

View File

@ -33,7 +33,7 @@ public class DeepSpeechModel {
* *
* @return Beam width value used by the model. * @return Beam width value used by the model.
*/ */
public int beamWidth() { public long beamWidth() {
return impl.GetModelBeamWidth(this._msp); return impl.GetModelBeamWidth(this._msp);
} }
@ -45,7 +45,7 @@ public class DeepSpeechModel {
* *
* @return Zero on success, non-zero on failure. * @return Zero on success, non-zero on failure.
*/ */
public int setBeamWidth(int beamWidth) { public int setBeamWidth(long beamWidth) {
return impl.SetModelBeamWidth(this._msp, beamWidth); return impl.SetModelBeamWidth(this._msp, beamWidth);
} }

View File

@ -42,9 +42,8 @@ function Model(aModelPath) {
} }
/** /**
* Get beam width value used by the model. If {@link DS_SetModelBeamWidth} * Get beam width value used by the model. If :js:func:Model.setBeamWidth was
* was not called before, will return the default value loaded from the * not called before, will return the default value loaded from the model file.
* model file.
* *
* @return {number} Beam width value used by the model. * @return {number} Beam width value used by the model.
*/ */
@ -63,16 +62,6 @@ Model.prototype.setBeamWidth = function(aBeamWidth) {
return binding.SetModelBeamWidth(this._impl, aBeamWidth); return binding.SetModelBeamWidth(this._impl, aBeamWidth);
} }
/**
* Return the sample rate expected by the model.
*
* @return {number} Sample rate.
*/
Model.prototype.beamWidth = function() {
return binding.GetModelBeamWidth(this._impl);
}
/** /**
* Return the sample rate expected by the model. * Return the sample rate expected by the model.
* *

View File

@ -45,9 +45,8 @@ class Model(object):
def beamWidth(self): def beamWidth(self):
""" """
Get beam width value used by the model. If {@link DS_SetModelBeamWidth} Get beam width value used by the model. If setModelBeamWidth was not
was not called before, will return the default value loaded from the called before, will return the default value loaded from the model file.
model file.
:return: Beam width value used by the model. :return: Beam width value used by the model.
:type: int :type: int

View File

@ -128,16 +128,16 @@ TFLiteModelState::init(const char* model_path)
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate"); int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len"); int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step"); int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width"); int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width");
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
std::vector<int> metadata_exec_plan; std::vector<int> metadata_exec_plan;
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
for (int i = 0; i < metadata_exec_plan.size(); ++i) { for (int i = 0; i < metadata_exec_plan.size(); ++i) {
assert(metadata_exec_plan[i] > -1); assert(metadata_exec_plan[i] > -1);
@ -202,20 +202,20 @@ TFLiteModelState::init(const char* model_path)
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0); audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0); audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
int* const beam_width = interpreter_->typed_tensor<int>(metadata_beam_width_idx);
beam_width_ = (unsigned int)(*beam_width);
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0); tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len); err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
if (err != 0) { if (err != 0) {
return DS_ERR_INVALID_ALPHABET; return DS_ERR_INVALID_ALPHABET;
} }
int* const beam_width = interpreter_->typed_tensor<int>(metadata_beam_width_idx);
beam_width_ = (unsigned int)(*beam_width);
assert(sample_rate_ > 0); assert(sample_rate_ > 0);
assert(audio_win_len_ > 0); assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0); assert(audio_win_step_ > 0);
assert(alphabet_.GetSize() > 0);
assert(beam_width_ > 0); assert(beam_width_ > 0);
assert(alphabet_.GetSize() > 0);
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims; TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;

View File

@ -102,8 +102,8 @@ TFModelState::init(const char* model_path)
"metadata_sample_rate", "metadata_sample_rate",
"metadata_feature_win_len", "metadata_feature_win_len",
"metadata_feature_win_step", "metadata_feature_win_step",
"metadata_alphabet",
"metadata_beam_width", "metadata_beam_width",
"metadata_alphabet",
}, {}, &metadata_outputs); }, {}, &metadata_outputs);
if (!status.ok()) { if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl; std::cout << "Unable to fetch metadata: " << status << std::endl;
@ -115,21 +115,20 @@ TFModelState::init(const char* model_path)
int win_step_ms = metadata_outputs[2].scalar<int>()(); int win_step_ms = metadata_outputs[2].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0); audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0); audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
int beam_width = metadata_outputs[3].scalar<int>()();
beam_width_ = (unsigned int)(beam_width);
string serialized_alphabet = metadata_outputs[3].scalar<string>()(); string serialized_alphabet = metadata_outputs[4].scalar<string>()();
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size()); err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) { if (err != 0) {
return DS_ERR_INVALID_ALPHABET; return DS_ERR_INVALID_ALPHABET;
} }
int beam_width = metadata_outputs[4].scalar<int>()();
beam_width_ = (unsigned int)(beam_width);
assert(sample_rate_ > 0); assert(sample_rate_ > 0);
assert(audio_win_len_ > 0); assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0); assert(audio_win_step_ > 0);
assert(alphabet_.GetSize() > 0);
assert(beam_width_ > 0); assert(beam_width_ > 0);
assert(alphabet_.GetSize() > 0);
for (int i = 0; i < graph_def_.node_size(); ++i) { for (int i = 0; i < graph_def_.node_size(); ++i) {
NodeDef node = graph_def_.node(i); NodeDef node = graph_def_.node(i);

View File

@ -30,11 +30,11 @@ then:
image: ${build.docker_image} image: ${build.docker_image}
env: env:
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models.tar.gz" DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models_beam_width.tar.gz"
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz" DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
PIP_DEFAULT_TIMEOUT: "60" PIP_DEFAULT_TIMEOUT: "60"
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples" EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a" EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width"
command: command:
- "/bin/bash" - "/bin/bash"

View File

@ -44,7 +44,7 @@ payload:
MSYS: 'winsymlinks:nativestrict' MSYS: 'winsymlinks:nativestrict'
TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow} TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow}
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples" EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a" EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width"
command: command:
- >- - >-

View File

@ -111,7 +111,7 @@ def create_flags():
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
f.DEFINE_string('export_name', 'output_graph', 'name for the export model') f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
f.DEFINE_string('export_beam_width', 500, 'default beam width to embed into exported graph') f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph')
# Reporting # Reporting