Merge pull request #1690 from lissyx/convert-toco

Convert toco
This commit is contained in:
lissyx 2018-11-01 10:54:01 +01:00 committed by GitHub
commit dc85977b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 84 additions and 57 deletions

View File

@ -3,7 +3,7 @@
virtualenv -p python3 ../tmp/venv virtualenv -p python3 ../tmp/venv
source ../tmp/venv/bin/activate source ../tmp/venv/bin/activate
pip install -r <(grep -v tensorflow requirements.txt) pip install -r <(grep -v tensorflow requirements.txt)
pip install tensorflow-gpu==1.11.0 pip install tensorflow-gpu==1.12.0rc2
python3 util/taskcluster.py --arch gpu --target ../tmp/native_client python3 util/taskcluster.py --arch gpu --target ../tmp/native_client

View File

@ -18,10 +18,12 @@ import time
import traceback import traceback
import inspect import inspect
import progressbar import progressbar
import tempfile
from functools import partial from functools import partial
from six.moves import zip, range, filter, urllib, BaseHTTPServer from six.moves import zip, range, filter, urllib, BaseHTTPServer
from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import freeze_graph
from tensorflow.contrib.lite.python import tflite_convert
from threading import Thread, Lock from threading import Thread, Lock
from util.audio import audiofile_to_input_vector from util.audio import audiofile_to_input_vector
from util.feeding import DataSet, ModelFeeder from util.feeding import DataSet, ModelFeeder
@ -1831,9 +1833,8 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tfli
return ( return (
{ {
'input': input_tensor, 'input': input_tensor,
'input_lengths': seq_length, 'previous_state_c': previous_state_c,
'new_state_c': new_state_c, 'previous_state_h': previous_state_h,
'new_state_h': new_state_h,
}, },
{ {
'outputs': logits, 'outputs': logits,
@ -1849,11 +1850,17 @@ def export():
''' '''
log_info('Exporting the model...') log_info('Exporting the model...')
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
from tensorflow.python.framework.ops import Tensor, Operation
tf.reset_default_graph() tf.reset_default_graph()
session = tf.Session(config=session_config) session = tf.Session(config=session_config)
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
@ -1872,11 +1879,7 @@ def export():
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path checkpoint_path = checkpoint.model_checkpoint_path
if not FLAGS.export_tflite: output_filename = 'output_graph.pb'
output_filename = 'output_graph.pb'
else:
output_filename = 'output_graph.fb'
if FLAGS.remove_export: if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir): if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export') log_info('Removing old export')
@ -1887,31 +1890,61 @@ def export():
if not os.path.isdir(FLAGS.export_dir): if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
if not FLAGS.export_tflite: def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
output_node_names = 'logits,initialize_state' freeze_graph.freeze_graph_with_def_protos(
variables_blacklist = 'previous_state_c,previous_state_h' input_graph_def=session.graph_def,
else: input_saver_def=saver.as_saver_def(),
output_node_names = 'logits,new_state_c,new_state_h' input_checkpoint=checkpoint_path,
variables_blacklist = '' output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
# Freeze graph if not FLAGS.export_tflite:
freeze_graph.freeze_graph_with_def_protos( do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
input_graph_def=session.graph_def, else:
input_saver_def=saver.as_saver_def(), temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
input_checkpoint=checkpoint_path, os.close(temp_fd)
output_node_names=output_node_names, do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
restore_op_name=None, output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
filename_tensor_name=None, class TFLiteFlags():
output_graph=output_graph_path, def __init__(self):
clear_devices=False, self.graph_def_file = temp_freeze
variable_names_blacklist=variables_blacklist, self.inference_type = 'FLOAT'
initializer_nodes='') self.input_arrays = input_names
self.input_shapes = input_shapes
self.output_arrays = output_names
self.output_file = output_tflite_path
self.output_format = 'TFLITE'
default_empty = [
'inference_input_type',
'mean_values',
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency',
'reorder_across_fake_quant',
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'post_training_quantize',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None
flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e: except RuntimeError as e:
log_error(str(e)) log_error(str(e))
def do_single_file_inference(input_file_path): def do_single_file_inference(input_file_path):
with tf.Session(config=session_config) as session: with tf.Session(config=session_config) as session:
inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True) inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True)

View File

@ -62,7 +62,7 @@ RUN wget https://bootstrap.pypa.io/get-pip.py && \
# Clone TensoFlow from Mozilla repo # Clone TensoFlow from Mozilla repo
RUN git clone https://github.com/mozilla/tensorflow/ RUN git clone https://github.com/mozilla/tensorflow/
WORKDIR /tensorflow WORKDIR /tensorflow
RUN git checkout r1.11 RUN git checkout r1.12
# GPU Environment Setup # GPU Environment Setup
@ -190,7 +190,7 @@ RUN cp /tensorflow/bazel-bin/native_client/libctc_decoder_with_kenlm.so /DeepSpe
# Install TensorFlow # Install TensorFlow
WORKDIR /DeepSpeech/ WORKDIR /DeepSpeech/
RUN pip install tensorflow-gpu==1.11.0 RUN pip install tensorflow-gpu==1.12.0rc2
# Make DeepSpeech and install Python bindings # Make DeepSpeech and install Python bindings

View File

@ -45,6 +45,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Training a model](#training-a-model) - [Training a model](#training-a-model)
- [Checkpointing](#checkpointing) - [Checkpointing](#checkpointing)
- [Exporting a model for inference](#exporting-a-model-for-inference) - [Exporting a model for inference](#exporting-a-model-for-inference)
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine) - [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
- [Continuing training from a release model](#continuing-training-from-a-release-model) - [Continuing training from a release model](#continuing-training-from-a-release-model)
- [Code documentation](#code-documentation) - [Code documentation](#code-documentation)
@ -226,7 +227,7 @@ If you have a capable (Nvidia, at least 8GB of VRAM) GPU, it is highly recommend
```bash ```bash
pip3 uninstall tensorflow pip3 uninstall tensorflow
pip3 install 'tensorflow-gpu==1.11.0' pip3 install 'tensorflow-gpu==1.12.0rc2'
``` ```
### Common Voice training data ### Common Voice training data
@ -317,6 +318,10 @@ Be aware however that checkpoints are only valid for the same model geometry the
If the `--export_dir` parameter is provided, a model will have been exported to this directory during training. If the `--export_dir` parameter is provided, a model will have been exported to this directory during training.
Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model. Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model.
### Exporting a model for TFLite
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`.
### Making a mmap-able model for inference ### Making a mmap-able model for inference
The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference. The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference.

View File

@ -52,7 +52,7 @@ Check the [main README](../README.md) for more details.
If you'd like to build the binaries yourself, you'll need the following pre-requisites downloaded/installed: If you'd like to build the binaries yourself, you'll need the following pre-requisites downloaded/installed:
* [TensorFlow requirements](https://www.tensorflow.org/install/install_sources) * [TensorFlow requirements](https://www.tensorflow.org/install/install_sources)
* [TensorFlow `r1.11` sources](https://github.com/mozilla/tensorflow/tree/r1.11) * [TensorFlow `r1.12` sources](https://github.com/mozilla/tensorflow/tree/r1.12)
* [libsox](https://sourceforge.net/projects/sox/) * [libsox](https://sourceforge.net/projects/sox/)
It is required to use our fork of TensorFlow since it includes fixes for common problems encountered when building the native client files. It is required to use our fork of TensorFlow since it includes fixes for common problems encountered when building the native client files.

View File

@ -1,7 +1,7 @@
pandas pandas
progressbar2 progressbar2
python-utils python-utils
tensorflow == 1.11.0 tensorflow == 1.12.0rc2
numpy numpy
matplotlib matplotlib
scipy scipy

View File

@ -6,8 +6,7 @@ build:
- "index.project.deepspeech.deepspeech.native_client.osx.${event.head.sha}" - "index.project.deepspeech.deepspeech.native_client.osx.${event.head.sha}"
- "notify.irc-channel.${notifications.irc}.on-exception" - "notify.irc-channel.${notifications.irc}.on-exception"
- "notify.irc-channel.${notifications.irc}.on-failed" - "notify.irc-channel.${notifications.irc}.on-failed"
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.osx/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.osx/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.osx/artifacts/public/summarize_graph"
scripts: scripts:
build: "taskcluster/host-build.sh" build: "taskcluster/host-build.sh"
package: "taskcluster/package.sh" package: "taskcluster/package.sh"

View File

@ -39,7 +39,6 @@ payload:
training: { $eval: as_slugid("test-training_upstream-linux-amd64-py27mu-opt") } training: { $eval: as_slugid("test-training_upstream-linux-amd64-py27mu-opt") }
in: in:
TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow} TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow}
SUMMARIZE_GRAPH_BINARY: ${build.summarize_graph}
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
# There is no VM yet running tasks on OSX # There is no VM yet running tasks on OSX

View File

@ -14,8 +14,7 @@ build:
system_config: system_config:
> >
${swig.patch_nodejs.linux} ${swig.patch_nodejs.linux}
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.cpu/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/summarize_graph"
scripts: scripts:
build: "taskcluster/host-build.sh" build: "taskcluster/host-build.sh"
package: "taskcluster/package.sh" package: "taskcluster/package.sh"

View File

@ -4,8 +4,7 @@ build:
- "pull_request.synchronize" - "pull_request.synchronize"
- "pull_request.reopened" - "pull_request.reopened"
template_file: linux-opt-base.tyml template_file: linux-opt-base.tyml
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.cpu/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/summarize_graph"
scripts: scripts:
build: 'taskcluster/decoder-build.sh' build: 'taskcluster/decoder-build.sh'
package: 'taskcluster/decoder-package.sh' package: 'taskcluster/decoder-package.sh'

View File

@ -12,8 +12,7 @@ build:
system_config: system_config:
> >
${swig.patch_nodejs.linux} ${swig.patch_nodejs.linux}
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.gpu/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.gpu/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.gpu/artifacts/public/summarize_graph"
maxRunTime: 14400 maxRunTime: 14400
scripts: scripts:
build: "taskcluster/cuda-build.sh" build: "taskcluster/cuda-build.sh"

View File

@ -4,8 +4,7 @@ build:
- "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.arm64" - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.arm64"
- "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.${event.head.sha}.arm64" - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.${event.head.sha}.arm64"
- "index.project.deepspeech.deepspeech.native_client.arm64.${event.head.sha}" - "index.project.deepspeech.deepspeech.native_client.arm64.${event.head.sha}"
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.arm64/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.arm64/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/summarize_graph"
## multistrap 2.2.0-ubuntu1 is broken in 14.04: https://bugs.launchpad.net/ubuntu/+source/multistrap/+bug/1313787 ## multistrap 2.2.0-ubuntu1 is broken in 14.04: https://bugs.launchpad.net/ubuntu/+source/multistrap/+bug/1313787
system_setup: system_setup:
> >

View File

@ -36,7 +36,6 @@ then:
training: { $eval: as_slugid("test-training_upstream-linux-amd64-py27mu-opt") } training: { $eval: as_slugid("test-training_upstream-linux-amd64-py27mu-opt") }
in: in:
TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow} TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow}
SUMMARIZE_GRAPH_BINARY: ${build.summarize_graph}
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
command: command:

View File

@ -4,8 +4,7 @@ build:
- "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.arm" - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.arm"
- "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.${event.head.sha}.arm" - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.${event.head.sha}.arm"
- "index.project.deepspeech.deepspeech.native_client.arm.${event.head.sha}" - "index.project.deepspeech.deepspeech.native_client.arm.${event.head.sha}"
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.arm/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.arm/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/summarize_graph"
## multistrap 2.2.0-ubuntu1 is broken in 14.04: https://bugs.launchpad.net/ubuntu/+source/multistrap/+bug/1313787 ## multistrap 2.2.0-ubuntu1 is broken in 14.04: https://bugs.launchpad.net/ubuntu/+source/multistrap/+bug/1313787
system_setup: system_setup:
> >

View File

@ -16,8 +16,7 @@ build:
system_config: system_config:
> >
${swig.patch_nodejs.linux} ${swig.patch_nodejs.linux}
tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/home.tar.xz" tensorflow: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.cpu/artifacts/public/home.tar.xz"
summarize_graph: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/summarize_graph"
scripts: scripts:
build: "taskcluster/node-build.sh" build: "taskcluster/node-build.sh"
package: "taskcluster/node-package.sh" package: "taskcluster/node-package.sh"

View File

@ -35,7 +35,6 @@ then:
linux_arm64_build: { $eval: as_slugid("linux-arm64-cpu-opt") } linux_arm64_build: { $eval: as_slugid("linux-arm64-cpu-opt") }
node_package: { $eval: as_slugid("node-package") } node_package: { $eval: as_slugid("node-package") }
in: in:
CONVERT_GRAPHDEF_MEMMAPPED: ${build.convert_graphdef}
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_arm64_build}/artifacts/public DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_arm64_build}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package}/artifacts/public DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
@ -44,7 +43,7 @@ then:
PIP_DEFAULT_TIMEOUT: "60" PIP_DEFAULT_TIMEOUT: "60"
PIP_EXTRA_INDEX_URL: "https://lissyx.github.io/deepspeech-python-wheels/" PIP_EXTRA_INDEX_URL: "https://lissyx.github.io/deepspeech-python-wheels/"
EXTRA_PYTHON_CONFIGURE_OPTS: "--with-fpectl" # Required by Debian Stretch EXTRA_PYTHON_CONFIGURE_OPTS: "--with-fpectl" # Required by Debian Stretch
EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.11.0-11-gbee8254" EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.12.0-rc2-5-g1c93ca2"
command: command:
- "/bin/bash" - "/bin/bash"

View File

@ -41,7 +41,7 @@ then:
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pb DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pbmm DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pbmm
EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.11.0-11-gbee8254" EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.12.0-rc2-5-g1c93ca2"
command: command:
- - "/bin/bash" - - "/bin/bash"

View File

@ -44,7 +44,7 @@ then:
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pb DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pbmm DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.2.0-prod-ctcdecode/output_graph.pbmm
PIP_DEFAULT_TIMEOUT: "60" PIP_DEFAULT_TIMEOUT: "60"
EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.11.0-11-gbee8254" EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.12.0-rc2-5-g1c93ca2"
command: command:
- "/bin/bash" - "/bin/bash"

View File

@ -35,7 +35,6 @@ then:
linux_rpi3_build: { $eval: as_slugid("linux-rpi3-cpu-opt") } linux_rpi3_build: { $eval: as_slugid("linux-rpi3-cpu-opt") }
node_package: { $eval: as_slugid("node-package") } node_package: { $eval: as_slugid("node-package") }
in: in:
CONVERT_GRAPHDEF_MEMMAPPED: ${build.convert_graphdef}
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_rpi3_build}/artifacts/public DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_rpi3_build}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package}/artifacts/public DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
@ -44,7 +43,7 @@ then:
PIP_DEFAULT_TIMEOUT: "60" PIP_DEFAULT_TIMEOUT: "60"
PIP_EXTRA_INDEX_URL: "https://www.piwheels.org/simple" PIP_EXTRA_INDEX_URL: "https://www.piwheels.org/simple"
EXTRA_PYTHON_CONFIGURE_OPTS: "--with-fpectl" # Required by Raspbian Stretch / PiWheels EXTRA_PYTHON_CONFIGURE_OPTS: "--with-fpectl" # Required by Raspbian Stretch / PiWheels
EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.11.0-11-gbee8254" EXPECTED_TENSORFLOW_VERSION: "TensorFlow: v1.12.0-rc2-5-g1c93ca2"
command: command:
- "/bin/bash" - "/bin/bash"

View File

@ -7,7 +7,7 @@ build:
apt-get -qq -y install ${python.packages_trusty.apt} apt-get -qq -y install ${python.packages_trusty.apt}
args: args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/tc-train-tests.sh 2.7.14:mu" tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/tc-train-tests.sh 2.7.14:mu"
convert_graphdef: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.11.bee825492fcf830bd65a024bf859cbfc218e1473.cpu/artifacts/public/convert_graphdef_memmapped_format" convert_graphdef: "https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.r1.12.1c93ca24c99d7011ad639eea4cd96e4fe45e1a95.cpu/artifacts/public/convert_graphdef_memmapped_format"
metadata: metadata:
name: "DeepSpeech Linux AMD64 CPU upstream training Py2.7 mu" name: "DeepSpeech Linux AMD64 CPU upstream training Py2.7 mu"
description: "Training a DeepSpeech LDC93S1 model for Linux/AMD64 using upstream TensorFlow Python 2.7 mu, CPU only, optimized version" description: "Training a DeepSpeech LDC93S1 model for Linux/AMD64 using upstream TensorFlow Python 2.7 mu, CPU only, optimized version"

View File

@ -66,7 +66,7 @@ pushd ${HOME}/DeepSpeech/ds/
popd popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.tflite ${TASKCLUSTER_ARTIFACTS}
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}") convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")