Compare commits

...

45 Commits

Author SHA1 Message Date
003b399253 Fix up the RPi4 build 2021-12-04 16:28:48 +00:00
f008d10c49 Use tensorflow fork with rpi4ub-armv8 build target 2021-12-04 16:08:39 +00:00
0f698133aa Add an rpi4ub-armv8 build variant 2021-12-04 15:45:55 +00:00
8cea2cbfec Use my fork of the fork of tensorflow 2021-12-04 12:34:10 +00:00
Reuben Morais
dbd38c3a89
Merge pull request #2032 from coqui-ai/transcription-scripts-docs
[transcribe] Fix multiprocessing hangs, clean-up target collection, write docs
2021-12-03 16:46:48 +01:00
Reuben Morais
b43e710959 Docs for transcription with training package 2021-12-03 16:22:43 +01:00
Reuben Morais
ff24a8b917 Undo late-imports 2021-12-03 16:22:43 +01:00
Reuben Morais
479d963155 Set training pkg python-requires<3.8 (due to TF 1.15.4 limit) 2021-12-03 16:22:34 +01:00
Reuben Morais
d90bb60506 [transcribe] Fix multiprocessing hangs, clean-up target collection 2021-12-01 15:44:25 +01:00
Reuben Morais
5cefd7069c Use known paths for Scorer and Alphabet copy in export 2021-11-23 14:21:11 +01:00
Reuben Morais
154a67fb2c
Merge pull request #2026 from coqui-ai/save-scorer-alphabet-savedmodel
Save Scorer and alphabet with SavedModel exports
2021-11-19 20:07:32 +01:00
Reuben Morais
d6456ae4aa Save Scorer and alphabet with SavedModel exports 2021-11-19 19:40:00 +01:00
Reuben Morais
3020949075
Merge pull request #2025 from coqui-ai/various-fixes
Docs fixes, SavedModel export, transcribe.py revival
2021-11-19 16:10:20 +01:00
Reuben Morais
efdaa61e2c Revive transcribe.py
Update to use Coqpit based config handling, fix multiprocesing setup, and add CI coverage.
2021-11-19 13:57:44 +01:00
Reuben Morais
419b15b72a Allow exporting as SavedModel 2021-11-18 13:48:52 +01:00
Reuben Morais
6a9bd1e6b6 Add usage instructions for C API 2021-11-17 14:07:22 +01:00
Reuben Morais
922d668155
Merge pull request #2024 from juliandarley/fix-shlex-typo
Fix typo in client.py - shlex in line 17
2021-11-17 13:05:30 +01:00
Julian Darley
8ed0a827de Fix typo in clienty.py - shlex in line 17 2021-11-17 08:33:32 +00:00
Reuben Morais
11c2edb068
Merge pull request #2018 from coqui-ai/node-electron-version-bump
Update NodeJS and ElectronJS build/test versions to supported releases
2021-11-12 22:47:41 +01:00
Reuben Morais
e7c28ca3c9 Remove outdated comment in supported platforms doc [skip ci] 2021-11-12 22:47:15 +01:00
Reuben Morais
2af6f8da89 Explicitly name TF build cache destination file
GitHub's API has stopped sending the artifact name as the file name, so we ended up with a file matching the artifact ID.
Name the full file path explicitly so there's no room for changes.
2021-11-12 22:47:02 +01:00
Reuben Morais
a5c981bb48 Update NodeJS and ElectronJS build/test versions to supported releases 2021-11-12 22:47:02 +01:00
Reuben Morais
23af8bd095 Bump version to v1.1.0-alpha.1 2021-10-31 21:42:00 +01:00
Reuben Morais
2b955fc70f
Merge pull request #2004 from coqui-ai/flashlight-docs
Improve decoder docs and include in RTD
2021-10-31 21:40:04 +01:00
Reuben Morais
90feb63894 Improve decoder package docs and include in RTD 2021-10-31 20:12:37 +01:00
Reuben Morais
91f1307de4 Pin docutils version as 0.18 release breaks build
Build breaks when writing output for AUGMENTATION.rst with error:

AttributeError: 'Values' object has no attribute 'section_self_link'
2021-10-31 16:46:03 +01:00
Reuben Morais
3d1e3ed3ba Don't include RELEASE_NOTES for pre-releases [skip ci] 2021-10-30 17:31:04 +02:00
Reuben Morais
9a2c2028c7 Bump version to v1.1.0-alpha.0 2021-10-30 17:24:05 +02:00
Reuben Morais
6ef733be54
Merge pull request #2001 from coqui-ai/decoder-flashlight
Expose Flashlight LexiconDecoder/LexiconFreeDecoder in decoder package
2021-10-30 17:19:41 +02:00
Reuben Morais
a61180aeae Fix Flashlight multiplatform build 2021-10-30 16:23:44 +02:00
Reuben Morais
391036643c debug 2021-10-30 16:23:44 +02:00
Reuben Morais
04f62ac9f7 Exercise training graph inference/Flashlight decoder in extra training tests 2021-10-30 14:59:32 +02:00
Reuben Morais
755fb81a62 Expose Flashlight LexiconDecoder/LexiconFreeDecoder 2021-10-30 14:59:32 +02:00
Reuben Morais
5f2ff85fe8
Merge pull request #1977 from Legion2/patch-1
fixed duplicate deallocation of stream in Swift STTStream
2021-10-30 10:19:17 +02:00
Reuben Morais
489e49f698
Merge pull request #1990 from JRMeyer/evaluate_tflite
Update evaluate_tflite.py script for Coqpit
2021-10-30 10:18:56 +02:00
Reuben Morais
65e66117e2
Merge pull request #1998 from coqui-ai/aar-pack-deps
Package dynamic deps in AAR
2021-10-29 20:48:43 +02:00
Reuben Morais
a726351341 Bump Windows TF build cache due to worker upgrade 2021-10-29 20:05:22 +02:00
Reuben Morais
d753431d11 Fix build on Windows after internal GitHub Actions MSYS2 changes 2021-10-29 20:05:22 +02:00
Reuben Morais
83b40b2532 Rehost PCRE package to avoid external outages interrupting CI 2021-10-25 11:03:19 +02:00
Reuben Morais
1f7b43f94e Package libkenlm.so, libtensorflowlite.so and libtflitedelegates.so in AAR 2021-10-25 11:03:19 +02:00
Reuben Morais
5ff8d11393 Use export beam width by default in evaluation 2021-10-13 13:36:30 +02:00
Josh Meyer
157ce340b6 Update evaluate_tflite.py script for Coqpit 2021-10-07 14:46:03 -04:00
Reuben Morais
27584037f8 Bump version to v1.0.0 2021-10-04 16:30:39 +02:00
Reuben Morais
29e980473f Docs changes for 1.0.0 2021-10-04 16:30:39 +02:00
Leon Kiefer
fab1bbad73
fixed duplicate deallocation of stream
streamCtx must be unset after STT_FreeStream was called in STT_FinishStreamWithMetadata, else STT_FreeStream is called again on destruction of STTStream resulting in EXC_BAD_ACCESS errors
2021-09-26 12:56:28 +02:00
95 changed files with 6786 additions and 725 deletions

View File

@ -14,6 +14,10 @@ const fs = __nccwpck_require__(5747);
const { throttling } = __nccwpck_require__(9968);
const { GitHub } = __nccwpck_require__(3030);
const Download = __nccwpck_require__(7490);
const Util = __nccwpck_require__(1669);
const Stream = __nccwpck_require__(2413);
const Pipeline = Util.promisify(Stream.pipeline);
async function getGoodArtifacts(client, owner, repo, releaseId, name) {
console.log(`==> GET /repos/${owner}/${repo}/releases/${releaseId}/assets`);
@ -101,22 +105,24 @@ async function main() {
console.log("==> # artifacts:", goodArtifacts.length);
const artifact = goodArtifacts[0];
console.log("==> Artifact:", artifact.id)
const size = filesize(artifact.size, { base: 10 })
console.log(`==> Downloading: ${artifact.name} (${size}) to path: ${path}`)
console.log("==> Downloading:", artifact.name, `(${size})`)
const dir = name ? path : pathname.join(path, artifact.name)
const dir = pathname.dirname(path)
console.log(`==> Creating containing dir if needed: ${dir}`)
fs.mkdirSync(dir, { recursive: true })
await Download(artifact.url, dir, {
headers: {
"Accept": "application/octet-stream",
"Authorization": `token ${token}`,
},
});
await Pipeline(
Download(artifact.url, {
headers: {
"Accept": "application/octet-stream",
"Authorization": `token ${token}`,
},
}),
fs.createWriteStream(path)
)
}
if (artifactStatus === "missing" && download == "true") {
@ -30667,7 +30673,7 @@ module.exports = eval("require")("original-fs");
/***/ ((module) => {
"use strict";
module.exports = JSON.parse("{\"_from\":\"got@^8.3.1\",\"_id\":\"got@8.3.2\",\"_inBundle\":false,\"_integrity\":\"sha512-qjUJ5U/hawxosMryILofZCkm3C84PLJS/0grRIpjAwu+Lkxxj5cxeCU25BG0/3mDSpXKTyZr8oh8wIgLaH0QCw==\",\"_location\":\"/got\",\"_phantomChildren\":{},\"_requested\":{\"type\":\"range\",\"registry\":true,\"raw\":\"got@^8.3.1\",\"name\":\"got\",\"escapedName\":\"got\",\"rawSpec\":\"^8.3.1\",\"saveSpec\":null,\"fetchSpec\":\"^8.3.1\"},\"_requiredBy\":[\"/download\"],\"_resolved\":\"https://registry.npmjs.org/got/-/got-8.3.2.tgz\",\"_shasum\":\"1d23f64390e97f776cac52e5b936e5f514d2e937\",\"_spec\":\"got@^8.3.1\",\"_where\":\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists/node_modules/download\",\"ava\":{\"concurrency\":4},\"browser\":{\"decompress-response\":false,\"electron\":false},\"bugs\":{\"url\":\"https://github.com/sindresorhus/got/issues\"},\"bundleDependencies\":false,\"dependencies\":{\"@sindresorhus/is\":\"^0.7.0\",\"cacheable-request\":\"^2.1.1\",\"decompress-response\":\"^3.3.0\",\"duplexer3\":\"^0.1.4\",\"get-stream\":\"^3.0.0\",\"into-stream\":\"^3.1.0\",\"is-retry-allowed\":\"^1.1.0\",\"isurl\":\"^1.0.0-alpha5\",\"lowercase-keys\":\"^1.0.0\",\"mimic-response\":\"^1.0.0\",\"p-cancelable\":\"^0.4.0\",\"p-timeout\":\"^2.0.1\",\"pify\":\"^3.0.0\",\"safe-buffer\":\"^5.1.1\",\"timed-out\":\"^4.0.1\",\"url-parse-lax\":\"^3.0.0\",\"url-to-options\":\"^1.0.1\"},\"deprecated\":false,\"description\":\"Simplified HTTP requests\",\"devDependencies\":{\"ava\":\"^0.25.0\",\"coveralls\":\"^3.0.0\",\"form-data\":\"^2.1.1\",\"get-port\":\"^3.0.0\",\"nyc\":\"^11.0.2\",\"p-event\":\"^1.3.0\",\"pem\":\"^1.4.4\",\"proxyquire\":\"^1.8.0\",\"sinon\":\"^4.0.0\",\"slow-stream\":\"0.0.4\",\"tempfile\":\"^2.0.0\",\"tempy\":\"^0.2.1\",\"universal-url\":\"1.0.0-alpha\",\"xo\":\"^0.20.0\"},\"engines\":{\"node\":\">=4\"},\"files\":[\"index.js\",\"errors.js\"],\"homepage\":\"https://github.com/sindresorhus/got#readme\",\"keywords\":[\"http\",\"https\",\"get\",\"got\",\"url\",\"uri\",\"request\",\"util\",\"utility\",\"simple\",\"curl\",\"wget\",\"fetch\",\"net\",\"network\",\"electron\"],\"license\":\"MIT\",\"maintainers\":[{\"name\":\"Sindre Sorhus\",\"email\":\"sindresorhus@gmail.com\",\"url\":\"sindresorhus.com\"},{\"name\":\"Vsevolod Strukchinsky\",\"email\":\"floatdrop@gmail.com\",\"url\":\"github.com/floatdrop\"},{\"name\":\"Alexander Tesfamichael\",\"email\":\"alex.tesfamichael@gmail.com\",\"url\":\"alextes.me\"}],\"name\":\"got\",\"repository\":{\"type\":\"git\",\"url\":\"git+https://github.com/sindresorhus/got.git\"},\"scripts\":{\"coveralls\":\"nyc report --reporter=text-lcov | coveralls\",\"test\":\"xo && nyc ava\"},\"version\":\"8.3.2\"}");
module.exports = JSON.parse("{\"_args\":[[\"got@8.3.2\",\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists\"]],\"_development\":true,\"_from\":\"got@8.3.2\",\"_id\":\"got@8.3.2\",\"_inBundle\":false,\"_integrity\":\"sha512-qjUJ5U/hawxosMryILofZCkm3C84PLJS/0grRIpjAwu+Lkxxj5cxeCU25BG0/3mDSpXKTyZr8oh8wIgLaH0QCw==\",\"_location\":\"/got\",\"_phantomChildren\":{},\"_requested\":{\"type\":\"version\",\"registry\":true,\"raw\":\"got@8.3.2\",\"name\":\"got\",\"escapedName\":\"got\",\"rawSpec\":\"8.3.2\",\"saveSpec\":null,\"fetchSpec\":\"8.3.2\"},\"_requiredBy\":[\"/download\"],\"_resolved\":\"https://registry.npmjs.org/got/-/got-8.3.2.tgz\",\"_spec\":\"8.3.2\",\"_where\":\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists\",\"ava\":{\"concurrency\":4},\"browser\":{\"decompress-response\":false,\"electron\":false},\"bugs\":{\"url\":\"https://github.com/sindresorhus/got/issues\"},\"dependencies\":{\"@sindresorhus/is\":\"^0.7.0\",\"cacheable-request\":\"^2.1.1\",\"decompress-response\":\"^3.3.0\",\"duplexer3\":\"^0.1.4\",\"get-stream\":\"^3.0.0\",\"into-stream\":\"^3.1.0\",\"is-retry-allowed\":\"^1.1.0\",\"isurl\":\"^1.0.0-alpha5\",\"lowercase-keys\":\"^1.0.0\",\"mimic-response\":\"^1.0.0\",\"p-cancelable\":\"^0.4.0\",\"p-timeout\":\"^2.0.1\",\"pify\":\"^3.0.0\",\"safe-buffer\":\"^5.1.1\",\"timed-out\":\"^4.0.1\",\"url-parse-lax\":\"^3.0.0\",\"url-to-options\":\"^1.0.1\"},\"description\":\"Simplified HTTP requests\",\"devDependencies\":{\"ava\":\"^0.25.0\",\"coveralls\":\"^3.0.0\",\"form-data\":\"^2.1.1\",\"get-port\":\"^3.0.0\",\"nyc\":\"^11.0.2\",\"p-event\":\"^1.3.0\",\"pem\":\"^1.4.4\",\"proxyquire\":\"^1.8.0\",\"sinon\":\"^4.0.0\",\"slow-stream\":\"0.0.4\",\"tempfile\":\"^2.0.0\",\"tempy\":\"^0.2.1\",\"universal-url\":\"1.0.0-alpha\",\"xo\":\"^0.20.0\"},\"engines\":{\"node\":\">=4\"},\"files\":[\"index.js\",\"errors.js\"],\"homepage\":\"https://github.com/sindresorhus/got#readme\",\"keywords\":[\"http\",\"https\",\"get\",\"got\",\"url\",\"uri\",\"request\",\"util\",\"utility\",\"simple\",\"curl\",\"wget\",\"fetch\",\"net\",\"network\",\"electron\"],\"license\":\"MIT\",\"maintainers\":[{\"name\":\"Sindre Sorhus\",\"email\":\"sindresorhus@gmail.com\",\"url\":\"sindresorhus.com\"},{\"name\":\"Vsevolod Strukchinsky\",\"email\":\"floatdrop@gmail.com\",\"url\":\"github.com/floatdrop\"},{\"name\":\"Alexander Tesfamichael\",\"email\":\"alex.tesfamichael@gmail.com\",\"url\":\"alextes.me\"}],\"name\":\"got\",\"repository\":{\"type\":\"git\",\"url\":\"git+https://github.com/sindresorhus/got.git\"},\"scripts\":{\"coveralls\":\"nyc report --reporter=text-lcov | coveralls\",\"test\":\"xo && nyc ava\"},\"version\":\"8.3.2\"}");
/***/ }),
@ -30683,7 +30689,7 @@ module.exports = JSON.parse("{\"application/1d-interleaved-parityfec\":{\"source
/***/ ((module) => {
"use strict";
module.exports = JSON.parse("{\"_from\":\"seek-bzip@^1.0.5\",\"_id\":\"seek-bzip@1.0.6\",\"_inBundle\":false,\"_integrity\":\"sha512-e1QtP3YL5tWww8uKaOCQ18UxIT2laNBXHjV/S2WYCiK4udiv8lkG89KRIoCjUagnAmCBurjF4zEVX2ByBbnCjQ==\",\"_location\":\"/seek-bzip\",\"_phantomChildren\":{},\"_requested\":{\"type\":\"range\",\"registry\":true,\"raw\":\"seek-bzip@^1.0.5\",\"name\":\"seek-bzip\",\"escapedName\":\"seek-bzip\",\"rawSpec\":\"^1.0.5\",\"saveSpec\":null,\"fetchSpec\":\"^1.0.5\"},\"_requiredBy\":[\"/decompress-tarbz2\"],\"_resolved\":\"https://registry.npmjs.org/seek-bzip/-/seek-bzip-1.0.6.tgz\",\"_shasum\":\"35c4171f55a680916b52a07859ecf3b5857f21c4\",\"_spec\":\"seek-bzip@^1.0.5\",\"_where\":\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists/node_modules/decompress-tarbz2\",\"bin\":{\"seek-bunzip\":\"bin/seek-bunzip\",\"seek-table\":\"bin/seek-bzip-table\"},\"bugs\":{\"url\":\"https://github.com/cscott/seek-bzip/issues\"},\"bundleDependencies\":false,\"contributors\":[{\"name\":\"C. Scott Ananian\",\"url\":\"http://cscott.net\"},{\"name\":\"Eli Skeggs\"},{\"name\":\"Kevin Kwok\"},{\"name\":\"Rob Landley\",\"url\":\"http://landley.net\"}],\"dependencies\":{\"commander\":\"^2.8.1\"},\"deprecated\":false,\"description\":\"a pure-JavaScript Node.JS module for random-access decoding bzip2 data\",\"devDependencies\":{\"fibers\":\"~1.0.6\",\"mocha\":\"~2.2.5\"},\"directories\":{\"test\":\"test\"},\"homepage\":\"https://github.com/cscott/seek-bzip#readme\",\"license\":\"MIT\",\"main\":\"./lib/index.js\",\"name\":\"seek-bzip\",\"repository\":{\"type\":\"git\",\"url\":\"git+https://github.com/cscott/seek-bzip.git\"},\"scripts\":{\"test\":\"mocha\"},\"version\":\"1.0.6\"}");
module.exports = JSON.parse("{\"_args\":[[\"seek-bzip@1.0.6\",\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists\"]],\"_development\":true,\"_from\":\"seek-bzip@1.0.6\",\"_id\":\"seek-bzip@1.0.6\",\"_inBundle\":false,\"_integrity\":\"sha512-e1QtP3YL5tWww8uKaOCQ18UxIT2laNBXHjV/S2WYCiK4udiv8lkG89KRIoCjUagnAmCBurjF4zEVX2ByBbnCjQ==\",\"_location\":\"/seek-bzip\",\"_phantomChildren\":{},\"_requested\":{\"type\":\"version\",\"registry\":true,\"raw\":\"seek-bzip@1.0.6\",\"name\":\"seek-bzip\",\"escapedName\":\"seek-bzip\",\"rawSpec\":\"1.0.6\",\"saveSpec\":null,\"fetchSpec\":\"1.0.6\"},\"_requiredBy\":[\"/decompress-tarbz2\"],\"_resolved\":\"https://registry.npmjs.org/seek-bzip/-/seek-bzip-1.0.6.tgz\",\"_spec\":\"1.0.6\",\"_where\":\"/Users/reubenmorais/Development/STT/.github/actions/check_artifact_exists\",\"bin\":{\"seek-bunzip\":\"bin/seek-bunzip\",\"seek-table\":\"bin/seek-bzip-table\"},\"bugs\":{\"url\":\"https://github.com/cscott/seek-bzip/issues\"},\"contributors\":[{\"name\":\"C. Scott Ananian\",\"url\":\"http://cscott.net\"},{\"name\":\"Eli Skeggs\"},{\"name\":\"Kevin Kwok\"},{\"name\":\"Rob Landley\",\"url\":\"http://landley.net\"}],\"dependencies\":{\"commander\":\"^2.8.1\"},\"description\":\"a pure-JavaScript Node.JS module for random-access decoding bzip2 data\",\"devDependencies\":{\"fibers\":\"~1.0.6\",\"mocha\":\"~2.2.5\"},\"directories\":{\"test\":\"test\"},\"homepage\":\"https://github.com/cscott/seek-bzip#readme\",\"license\":\"MIT\",\"main\":\"./lib/index.js\",\"name\":\"seek-bzip\",\"repository\":{\"type\":\"git\",\"url\":\"git+https://github.com/cscott/seek-bzip.git\"},\"scripts\":{\"test\":\"mocha\"},\"version\":\"1.0.6\"}");
/***/ }),

View File

@ -7,6 +7,10 @@ const fs = require('fs');
const { throttling } = require('@octokit/plugin-throttling');
const { GitHub } = require('@actions/github/lib/utils');
const Download = require('download');
const Util = require('util');
const Stream = require('stream');
const Pipeline = Util.promisify(Stream.pipeline);
async function getGoodArtifacts(client, owner, repo, releaseId, name) {
console.log(`==> GET /repos/${owner}/${repo}/releases/${releaseId}/assets`);
@ -94,22 +98,24 @@ async function main() {
console.log("==> # artifacts:", goodArtifacts.length);
const artifact = goodArtifacts[0];
console.log("==> Artifact:", artifact.id)
const size = filesize(artifact.size, { base: 10 })
console.log(`==> Downloading: ${artifact.name} (${size}) to path: ${path}`)
console.log("==> Downloading:", artifact.name, `(${size})`)
const dir = name ? path : pathname.join(path, artifact.name)
const dir = pathname.dirname(path)
console.log(`==> Creating containing dir if needed: ${dir}`)
fs.mkdirSync(dir, { recursive: true })
await Download(artifact.url, dir, {
headers: {
"Accept": "application/octet-stream",
"Authorization": `token ${token}`,
},
});
await Pipeline(
Download(artifact.url, {
headers: {
"Accept": "application/octet-stream",
"Authorization": `token ${token}`,
},
}),
fs.createWriteStream(path)
)
}
if (artifactStatus === "missing" && download == "true") {

View File

@ -0,0 +1,77 @@
name: "NodeJS binding"
description: "Binding a nodejs binding"
inputs:
nodejs_versions:
description: "NodeJS versions supported"
required: true
electronjs_versions:
description: "ElectronJS versions supported"
required: true
local_cflags:
description: "CFLAGS for NodeJS package"
required: false
default: ""
local_ldflags:
description: "LDFLAGS for NodeJS package"
required: false
default: ""
local_libs:
description: "LIBS for NodeJS package"
required: false
default: ""
target:
description: "TARGET value"
required: false
default: "host"
chroot:
description: "RASPBIAN value"
required: false
default: ""
runs:
using: "composite"
steps:
- run: |
node --version
npm --version
shell: msys2 {0}
- run: |
npm update
shell: msys2 {0}
- run: |
mkdir -p tmp/headers/nodejs tmp/headers/electronjs
shell: msys2 {0}
- run: |
for node in ${{ inputs.nodejs_versions }}; do
EXTRA_CFLAGS=${{ inputs.local_cflags }} \
EXTRA_LDFLAGS=${{ inputs.local_ldflags }} \
EXTRA_LIBS=${{ inputs.local_libs }} \
make -C native_client/javascript \
TARGET=${{ inputs.target }} \
RASPBIAN=${{ inputs.chroot }} \
NODE_ABI_TARGET=--target=${node} \
NODE_DEVDIR=--devdir=headers/nodejs \
clean node-wrapper
done;
shell: msys2 {0}
- run: |
for electron in ${{ inputs.electronjs_versions }}; do
EXTRA_CFLAGS=${{ inputs.local_cflags }} \
EXTRA_LDFLAGS=${{ inputs.local_ldflags }} \
EXTRA_LIBS=${{ inputs.local_libs }} \
make -C native_client/javascript \
TARGET=${{ inputs.target }} \
RASPBIAN=${{ inputs.chroot }} \
NODE_ABI_TARGET=--target=${electron} \
NODE_DIST_URL=--disturl=https://electronjs.org/headers \
NODE_RUNTIME=--runtime=electron \
NODE_DEVDIR=--devdir=headers/electronjs \
clean node-wrapper
done;
shell: msys2 {0}
- run: |
make -C native_client/javascript clean npm-pack
shell: msys2 {0}
- run: |
tar -czf native_client/javascript/wrapper.tar.gz \
-C native_client/javascript/ lib/
shell: msys2 {0}

View File

@ -0,0 +1,14 @@
GitHub Action to set NumPy versions
===================================
This actions aims at computing correct values for NumPy dependencies:
- `NUMPY_BUILD_VERSION`: range of accepted versions at Python binding build time
- `NUMPY_DEP_VERSION`: range of accepted versions for execution time
Versions are set considering several factors:
- API and ABI compatibility ; otherwise we can have the binding wrapper
throwing errors like "Illegal instruction", or computing wrong values
because of changed memory layout
- Wheels availability: for CI and end users, we want to avoid having to
rebuild numpy so we stick to versions where there is an existing upstream
`wheel` file

View File

@ -0,0 +1,93 @@
name: "get numpy versions"
description: "Get proper NumPy build and runtime versions dependencies range"
inputs:
pyver:
description: "Python version"
required: true
outputs:
build_version:
description: "NumPy build dependency"
value: ${{ steps.numpy.outputs.build }}
dep_version:
description: "NumPy runtime dependency"
value: ${{ steps.numpy.outputs.dep }}
runs:
using: "composite"
steps:
- id: numpy
run: |
set -ex
NUMPY_BUILD_VERSION="==1.7.0"
NUMPY_DEP_VERSION=">=1.7.0"
OS=$(uname -s)
ARCH=$(uname -m)
case "${OS}:${ARCH}" in
Linux:x86_64)
case "${{ inputs.pyver }}" in
3.7*)
NUMPY_BUILD_VERSION="==1.14.5"
NUMPY_DEP_VERSION=">=1.14.5,<=1.19.4"
;;
3.8*)
NUMPY_BUILD_VERSION="==1.17.3"
NUMPY_DEP_VERSION=">=1.17.3,<=1.19.4"
;;
3.9*)
NUMPY_BUILD_VERSION="==1.19.4"
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
;;
esac
;;
Darwin:*)
case "${{ inputs.pyver }}" in
3.6*)
NUMPY_BUILD_VERSION="==1.9.0"
NUMPY_DEP_VERSION=">=1.9.0"
;;
3.7*)
NUMPY_BUILD_VERSION="==1.14.5"
NUMPY_DEP_VERSION=">=1.14.5,<=1.17.0"
;;
3.8*)
NUMPY_BUILD_VERSION="==1.17.3"
NUMPY_DEP_VERSION=">=1.17.3,<=1.17.3"
;;
3.9*)
NUMPY_BUILD_VERSION="==1.19.4"
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
;;
esac
;;
${CI_MSYS_VERSION}:x86_64)
case "${{ inputs.pyver }}" in
3.5*)
NUMPY_BUILD_VERSION="==1.11.0"
NUMPY_DEP_VERSION=">=1.11.0,<1.12.0"
;;
3.6*)
NUMPY_BUILD_VERSION="==1.12.0"
NUMPY_DEP_VERSION=">=1.12.0,<1.14.5"
;;
3.7*)
NUMPY_BUILD_VERSION="==1.14.5"
NUMPY_DEP_VERSION=">=1.14.5,<=1.17.0"
;;
3.8*)
NUMPY_BUILD_VERSION="==1.17.3"
NUMPY_DEP_VERSION=">=1.17.3,<=1.17.3"
;;
3.9*)
NUMPY_BUILD_VERSION="==1.19.4"
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
;;
esac
;;
esac
echo "::set-output name=build::${NUMPY_BUILD_VERSION}"
echo "::set-output name=dep::${NUMPY_DEP_VERSION}"
shell: msys2 {0}

View File

@ -0,0 +1,31 @@
name: "Python binding"
description: "Binding a python binding"
inputs:
numpy_build:
description: "NumPy build dependecy"
required: true
numpy_dep:
description: "NumPy runtime dependecy"
required: true
runs:
using: "composite"
steps:
- run: |
set -xe
python3 --version
pip3 --version
PROJECT_NAME="stt"
NUMPY_BUILD_VERSION="${{ inputs.numpy_build }}" \
NUMPY_DEP_VERSION="${{ inputs.numpy_dep }}" \
EXTRA_CFLAGS=${{ inputs.local_cflags }} \
EXTRA_LDFLAGS=${{ inputs.local_ldflags }} \
EXTRA_LIBS=${{ inputs.local_libs }} \
make -C native_client/python/ \
TARGET=${{ inputs.target }} \
RASPBIAN=${{ inputs.chroot }} \
SETUP_FLAGS="--project_name ${PROJECT_NAME}" \
bindings-clean bindings
shell: msys2 {0}

View File

@ -0,0 +1,35 @@
name: "Tests execution"
description: "Running tests"
inputs:
runtime:
description: "Runtime to use for running test"
required: true
model-kind:
description: "Running against CI baked or production model"
required: true
bitrate:
description: "Bitrate for testing"
required: true
chroot:
description: "Run using a chroot"
required: false
runs:
using: "composite"
steps:
- run: |
set -xe
build="_tflite"
model_kind=""
if [ "${{ inputs.model-kind }}" = "prod" ]; then
model_kind="-prod"
fi
prefix="."
if [ ! -z "${{ inputs.chroot }}" ]; then
prefix="${{ inputs.chroot }}"
fi
${prefix}/ci_scripts/${{ inputs.runtime }}${build}-tests${model_kind}.sh ${{ inputs.bitrate }}
shell: msys2 {0}

View File

@ -18,7 +18,6 @@ env:
# Windows specific
CI_MSYS_VERSION: MSYS_NT-10.0-17763
MSYS2_SHELL_PATH: D:\a\_temp\msys\msys64\usr\bin
defaults:
run:
shell: bash
@ -70,10 +69,11 @@ jobs:
assert ref.startswith(prefix)
parsed = semver.parse_version_info(ref[len(prefix):])
print("::set-output name=is-prerelease::{}".format("true" if parsed.prerelease else "false"))
print("::set-output name=release-notes-file::{}".format("" if parsed.prerelease else "RELEASE_NOTES.md"))
EOF
- uses: softprops/action-gh-release@v1
with:
body_path: RELEASE_NOTES.md
body_path: ${{ steps.check-version.outputs.release-notes-file }}
prerelease: ${{ steps.check-version.outputs.is-prerelease }}
name: ${{ format('Coqui STT {0}', steps.check-version.outputs.version) }}
# Linux jobs
@ -98,7 +98,7 @@ jobs:
sudo apt-get install -y --no-install-recommends autoconf automake bison build-essential mingw-w64
if: steps.swig-build-cache.outputs.cache-hit != 'true'
- run: |
curl -sSL https://ftp.pcre.org/pub/pcre/pcre-8.43.tar.gz > pcre-8.43.tar.gz
curl -sSL https://github.com/coqui-ai/STT/releases/download/v0.10.0-alpha.7/pcre-8.43.tar.gz > pcre-8.43.tar.gz
if: steps.swig-build-cache.outputs.cache-hit != 'true'
- run: |
./Tools/pcre-build.sh --host=x86_64-w64-mingw32
@ -144,7 +144,7 @@ jobs:
path: build-static/
key: swig-4-${{ runner.os }}-${{ env.swig_hash }}
- run: |
curl -sSL https://ftp.pcre.org/pub/pcre/pcre-8.43.tar.gz > pcre-8.43.tar.gz
curl -sSL https://github.com/coqui-ai/STT/releases/download/v0.10.0-alpha.7/pcre-8.43.tar.gz > pcre-8.43.tar.gz
if: steps.swig-build-cache.outputs.cache-hit != 'true'
- run: |
./Tools/pcre-build.sh
@ -394,15 +394,15 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-Linux.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- name: Install dependencies
run: |
apt-get update
apt-get install -y --no-install-recommends xz-utils zip
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linux.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-Linux.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- name: Setup venv
run: |
/opt/python/cp37-cp37m/bin/python -m venv /tmp/venv
@ -512,16 +512,16 @@ jobs:
id: node-headers-cache
with:
path: native_client/javascript/headers/nodejs/
key: node-headers-10.0.0_16.0.0
key: node-headers-12.7.0_17.0.1
- uses: actions/cache@v2
id: electron-headers-cache
with:
path: native_client/javascript/headers/electronjs/
key: electron-headers-5.0.13_12.0.0
key: electron-headers-12.0.0_15.0.0
- uses: ./.github/actions/node-build
with:
nodejs_versions: "10.0.0 11.0.0 12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "5.0.13 6.0.12 6.1.7 7.0.1 7.1.8 8.0.1 9.0.1 9.1.0 9.2.0 10.0.0 10.1.0 11.0.0 12.0.0"
nodejs_versions: "12.7.0 13.0.0 14.0.0 15.0.0 16.0.0 17.0.1"
electronjs_versions: "12.0.0 13.0.0 14.0.0 15.0.0"
- uses: actions/upload-artifact@v2
with:
name: "nodewrapper-tflite-Linux_amd64.tar.gz"
@ -621,7 +621,7 @@ jobs:
strategy:
matrix:
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 12, 14, 16]
nodejs-version: [12, 14, 16, 17]
models: ["test"]
bitrate: ["16k"]
fail-fast: false
@ -674,7 +674,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 6.1.7, 7.1.8, 8.0.1, 9.2.0, 10.1.0, 11.0.0, 12.0.0]
electronjs-version: [12.0.0, 13.0.0, 14.0.0, 15.0.0]
models: ["test"]
bitrate: ["16k"]
fail-fast: false
@ -808,7 +808,7 @@ jobs:
- run: |
mkdir -p ${CI_ARTIFACTS_DIR} || true
- run: |
sudo apt-get install -y --no-install-recommends libopus0
sudo apt-get install -y --no-install-recommends libopus0 sox
- name: Run extra training tests
run: |
python -m pip install coqui_stt_ctcdecoder-*.whl
@ -1161,7 +1161,7 @@ jobs:
brew install automake
if: steps.swig-build-cache.outputs.cache-hit != 'true'
- run: |
curl -sSL https://ftp.pcre.org/pub/pcre/pcre-8.43.tar.gz > pcre-8.43.tar.gz
curl -sSL https://github.com/coqui-ai/STT/releases/download/v0.10.0-alpha.7/pcre-8.43.tar.gz > pcre-8.43.tar.gz
if: steps.swig-build-cache.outputs.cache-hit != 'true'
- run: |
./Tools/pcre-build.sh
@ -1353,11 +1353,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-macOS.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar xkf ${{ needs.tensorflow_opt-macOS.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-macOS.outputs.cache_key }}.tar.xz
tar xkf tf-cache.tar.xz
rm tf-cache.tar.xz
- run: |
git status
- uses: ./.github/actions/select-xcode
@ -1462,16 +1462,16 @@ jobs:
id: node-headers-cache
with:
path: native_client/javascript/headers/nodejs/
key: node-headers-10.0.0_16.0.0
key: node-headers-12.7.0_17.0.1
- uses: actions/cache@v2
id: electron-headers-cache
with:
path: native_client/javascript/headers/electronjs/
key: electron-headers-5.0.13_12.0.0
key: electron-headers-12.0.0_15.0.0
- uses: ./.github/actions/node-build
with:
nodejs_versions: "10.0.0 11.0.0 12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "5.0.13 6.0.12 6.1.7 7.0.1 7.1.8 8.0.1 9.0.1 9.1.0 9.2.0 10.0.0 10.1.0 11.0.0 12.0.0"
nodejs_versions: "12.7.0 13.0.0 14.0.0 15.0.0 16.0.0 17.0.1"
electronjs_versions: "12.0.0 13.0.0 14.0.0 15.0.0"
- uses: actions/upload-artifact@v2
with:
name: "nodewrapper-tflite-macOS_amd64.tar.gz"
@ -1568,7 +1568,7 @@ jobs:
strategy:
matrix:
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 12, 14, 16]
nodejs-version: [12, 14, 16, 17]
models: ["test"]
bitrate: ["16k"]
fail-fast: false
@ -1619,7 +1619,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 6.1.7, 7.1.8, 8.0.1, 9.2.0, 10.1.0, 11.0.0, 12.0.0]
electronjs-version: [12.0.0, 13.0.0, 14.0.0, 15.0.0]
models: ["test"]
bitrate: ["16k"]
env:
@ -1668,9 +1668,10 @@ jobs:
name: "Win|Build CTC decoder Python package"
needs: [swig_Windows_crosscompiled]
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -1693,19 +1694,29 @@ jobs:
with:
name: "swig_Windows_crosscompiled"
path: ${{ github.workspace }}/native_client/ds-swig/
- name: Remove /usr/bin/link conflicting with MSVC link.exe
run: |
set -ex
which link
rm /usr/bin/link
- name: Remove mingw32-make conflicting with MSYS make
run: |
set -ex
which mingw32-make
rm /c/ProgramData/Chocolatey/bin/mingw32-make
- name: Link ds-swig into swig
run: |
set -ex
ls -hal native_client/ds-swig/bin
ln -s ds-swig.exe native_client/ds-swig/bin/swig.exe
chmod +x native_client/ds-swig/bin/ds-swig.exe native_client/ds-swig/bin/swig.exe
- name: Remove /usr/bin/link conflicting with MSVC link.exe
run: |
rm /usr/bin/link
- run: |
make -C native_client/ctcdecode/ \
NUM_PROCESSES=$(nproc) \
bindings
- name: Setup tmate session
uses: mxschmitt/action-tmate@v3
if: failure()
- uses: actions/upload-artifact@v2
with:
name: "coqui_stt_ctcdecoder-windows-test.whl"
@ -1725,7 +1736,7 @@ jobs:
- id: get_cache_key
uses: ./.github/actions/get_cache_key
with:
extras: "9"
extras: "10"
- id: check_artifact_exists
uses: ./.github/actions/check_artifact_exists
with:
@ -1734,10 +1745,17 @@ jobs:
name: "Win|Build TensorFlow (opt)"
needs: tensorflow_opt-Windows
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
steps:
- run: true
shell: bash
if: needs.tensorflow_opt-Windows.outputs.status == 'found'
- uses: ilammy/msvc-dev-cmd@v1
- uses: actions/checkout@v2
with:
fetch-depth: 0
submodules: 'recursive'
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
- uses: msys2/setup-msys2@v2
with:
@ -1751,20 +1769,16 @@ jobs:
unzip
zip
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
- uses: ilammy/msvc-dev-cmd@v1
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
- uses: actions/setup-python@v2
with:
python-version: 3.7.9
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
- uses: actions/checkout@v2
with:
fetch-depth: 0
submodules: 'recursive'
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
# It's important that this PATH change only happens *after* the checkout
# above, because otherwise the checkout fails when persisisting the
# credentials for submodules due to using MSYS2 Git
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- name: Workaround bazel bug when LLVM is installed https://github.com/bazelbuild/bazel/issues/12144
run: |
rm -f /c/msys64/mingw64/clang-cl*
rm -rf "/c/Program Files/LLVM"
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
- run: ./ci_scripts/tf-setup.sh
if: needs.tensorflow_opt-Windows.outputs.status == 'missing'
@ -1780,14 +1794,14 @@ jobs:
build-lib_Windows:
name: "Win|Build libstt+client"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [build-tensorflow-Windows, tensorflow_opt-Windows]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- uses: ilammy/msvc-dev-cmd@v1
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -1800,16 +1814,19 @@ jobs:
tar
unzip
zip
- uses: ilammy/msvc-dev-cmd@v1
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-Windows.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
"C:/Program Files/7-Zip/7z.exe" x ${{ needs.tensorflow_opt-Windows.outputs.cache_key }}.tar.xz -so | "C:/Program Files/7-Zip/7z.exe" x -aos -si -ttar -o`pwd`
rm ${{ needs.tensorflow_opt-Windows.outputs.cache_key }}.tar.xz
- run: |
git status
"C:/Program Files/7-Zip/7z.exe" x tf-cache.tar.xz -so | "C:/Program Files/7-Zip/7z.exe" x -aos -si -ttar -o`pwd`
rm tf-cache.tar.xz
- name: Workaround bazel bug when LLVM is installed https://github.com/bazelbuild/bazel/issues/12144
run: |
rm -f /c/msys64/mingw64/clang-cl*
rm -rf "/c/Program Files/LLVM"
- run: ./ci_scripts/host-build.sh
- run: ./ci_scripts/package.sh
- uses: actions/upload-artifact@v2
@ -1823,6 +1840,9 @@ jobs:
build-python-Windows:
name: "Win|Build Python bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [build-lib_Windows, swig_Windows_crosscompiled]
strategy:
matrix:
@ -1830,8 +1850,6 @@ jobs:
# https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md
python-version: [3.6.8, 3.7.9, 3.8.8, 3.9.4]
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -1871,13 +1889,16 @@ jobs:
run: |
rm /usr/bin/link
- id: get_numpy
uses: ./.github/actions/numpy_vers
uses: ./.github/actions/win-numpy-vers
with:
pyver: ${{ matrix.python-version }}
- uses: ./.github/actions/python-build
- uses: ./.github/actions/win-python-build
with:
numpy_build: "${{ steps.get_numpy.outputs.build_version }}"
numpy_dep: "${{ steps.get_numpy.outputs.dep_version }}"
- name: Setup tmate session
uses: mxschmitt/action-tmate@v3
if: failure()
- uses: actions/upload-artifact@v2
with:
name: "stt-tflite-${{ matrix.python-version }}-Windows.whl"
@ -1885,10 +1906,11 @@ jobs:
build-nodejs-Windows:
name: "Win|Build NodeJS/ElectronJS"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [build-lib_Windows, swig_Windows_crosscompiled]
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -1929,16 +1951,16 @@ jobs:
id: node-headers-cache
with:
path: native_client/javascript/headers/nodejs/
key: node-headers-win-10.0.0_16.0.0
key: node-headers-win-12.7.0_16.0.0
- uses: actions/cache@v2
id: electron-headers-cache
with:
path: native_client/javascript/headers/electronjs/
key: electron-headers-win-5.0.13_12.0.0
- uses: ./.github/actions/node-build
key: electron-headers-win-12.0.0_15.0.0
- uses: ./.github/actions/win-node-build
with:
nodejs_versions: "10.0.0 11.0.0 12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "5.0.13 6.0.12 6.1.7 7.0.1 7.1.8 8.0.1 9.0.1 9.1.0 9.2.0 10.0.0 10.1.0 11.0.0 12.0.0"
nodejs_versions: "12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "12.0.0 13.0.0 14.0.0 15.0.0"
- uses: actions/upload-artifact@v2
with:
name: "nodewrapper-tflite-Windows_amd64.tar.gz"
@ -1950,17 +1972,19 @@ jobs:
test-cpp-Windows:
name: "Win|Test C++ binary"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [build-lib_Windows, train-test-model-Linux]
if: ${{ github.event_name == 'pull_request' }}
env:
CI_TMP_DIR: tmp/
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
path-type: inherit
update: true
install: >-
vim
@ -1986,7 +2010,7 @@ jobs:
path: ${{ env.CI_TMP_DIR }}
- run: |
ls -hal ${{ env.CI_TMP_DIR }}/
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "cppwin"
bitrate: "16k"
@ -1994,6 +2018,9 @@ jobs:
test-py-Windows:
name: "Win|Test Python bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [ build-python-Windows, train-test-model-Linux ]
if: ${{ github.event_name == 'pull_request' }}
strategy:
@ -2009,8 +2036,6 @@ jobs:
STT_PROD_MODEL_MMAP: https://github.com/reuben/STT/releases/download/v0.7.0-alpha.3/output_graph.pbmm
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -2040,7 +2065,7 @@ jobs:
- run: |
ls -hal ${{ env.CI_TMP_DIR }}/
python -m pip install --only-binary :all: --upgrade ${{ env.CI_TMP_DIR }}/stt*.whl
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "python"
bitrate: ${{ matrix.bitrate }}
@ -2048,11 +2073,14 @@ jobs:
test-nodejs-Windows:
name: "Win|Test NodeJS bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [ build-nodejs-Windows, train-test-model-Linux ]
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
nodejs-version: [10, 12, 14, 16]
nodejs-version: [12, 14, 16]
models: ["test"]
bitrate: ["16k"]
fail-fast: false
@ -2062,8 +2090,6 @@ jobs:
STT_PROD_MODEL_MMAP: https://github.com/reuben/STT/releases/download/v0.7.0-alpha.3/output_graph.pbmm
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -2103,7 +2129,7 @@ jobs:
run: |
ls -hal ${{ env.CI_TMP_DIR }}/
npm install ${{ env.CI_TMP_DIR }}/stt*.tgz
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "node"
bitrate: ${{ matrix.bitrate }}
@ -2111,11 +2137,14 @@ jobs:
test-electronjs-Windows:
name: "Win|Test ElectronJS bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [ build-nodejs-Windows, train-test-model-Linux ]
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 6.1.7, 7.1.8, 8.0.1, 9.2.0, 10.1.0, 11.0.0, 12.0.0]
electronjs-version: [12.0.0, 13.0.0, 14.0.0, 15.0.0]
models: ["test"]
bitrate: ["16k"]
env:
@ -2124,8 +2153,6 @@ jobs:
STT_PROD_MODEL_MMAP: https://github.com/reuben/STT/releases/download/v0.7.0-alpha.3/output_graph.pbmm
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -2167,7 +2194,7 @@ jobs:
npm install ${{ env.CI_TMP_DIR }}/stt*.tgz
- run: |
npm install electron@${{ matrix.electronjs-version }}
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "electronjs"
bitrate: ${{ matrix.bitrate }}
@ -2226,7 +2253,7 @@ jobs:
strategy:
matrix:
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 16]
nodejs-version: [12, 17]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2279,7 +2306,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 12.0.0]
electronjs-version: [12.0.0, 15.0.0]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2334,7 +2361,7 @@ jobs:
strategy:
matrix:
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 16]
nodejs-version: [12, 17]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2385,7 +2412,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 12.0.0]
electronjs-version: [12.0.0, 15.0.0]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2433,12 +2460,15 @@ jobs:
test-nodejs_all-Windows:
name: "Win|Test MultiArchPlatform NodeJS bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [repackage-nodejs-allplatforms, train-test-model-Linux]
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 16]
nodejs-version: [12, 16]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2448,8 +2478,6 @@ jobs:
STT_PROD_MODEL_MMAP: https://github.com/reuben/STT/releases/download/v0.7.0-alpha.3/output_graph.pbmm
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -2489,7 +2517,7 @@ jobs:
run: |
ls -hal ${{ env.CI_TMP_DIR }}/
npm install --verbose ${{ env.CI_TMP_DIR }}/stt*.tgz
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "node"
bitrate: ${{ matrix.bitrate }}
@ -2497,11 +2525,14 @@ jobs:
test-electronjs_all-Windows:
name: "Win|Test MultiArchPlatform ElectronJS bindings"
runs-on: windows-2019
defaults:
run:
shell: msys2 {0}
needs: [repackage-nodejs-allplatforms, train-test-model-Linux]
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
electronjs-version: [5.0.13, 12.0.0]
electronjs-version: [12.0.0, 15.0.0]
models: ["test", "prod"]
bitrate: ["8k", "16k"]
fail-fast: false
@ -2511,8 +2542,6 @@ jobs:
STT_PROD_MODEL_MMAP: https://github.com/reuben/STT/releases/download/v0.7.0-alpha.3/output_graph.pbmm
STT_TEST_MODEL: tmp/output_graph.pb
steps:
- name: Switch git-bash shell to MSYS2 shell by adding MSYS2 path to PATH front
run: echo "$MSYS2_SHELL_PATH" >> $GITHUB_PATH
- uses: msys2/setup-msys2@v2
with:
msystem: MSYS
@ -2554,7 +2583,7 @@ jobs:
npm install --verbose ${{ env.CI_TMP_DIR }}/stt*.tgz
- run: |
npm install electron@${{ matrix.electronjs-version }}
- uses: ./.github/actions/run-tests
- uses: ./.github/actions/win-run-tests
with:
runtime: "electronjs"
bitrate: ${{ matrix.bitrate }}
@ -2667,11 +2696,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxarmv7.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- run: |
git status
- name: "Install chroot"
@ -2704,11 +2733,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxaarch64.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- run: |
git status
- name: "Install chroot"
@ -2765,11 +2794,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxarmv7.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
@ -2831,11 +2860,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxarmv7.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxArmv7.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: ./.github/actions/install-xldd
with:
target: ${{ env.SYSTEM_TARGET }}
@ -2850,16 +2879,16 @@ jobs:
id: node-headers-cache
with:
path: native_client/javascript/headers/nodejs/
key: node-headers-10.0.0_16.0.0
key: node-headers-12.7.0_17.0.1
- uses: actions/cache@v2
id: electron-headers-cache
with:
path: native_client/javascript/headers/electronjs/
key: electron-headers-5.0.13_12.0.0
key: electron-headers-12.0.0_15.0.0
- uses: ./.github/actions/node-build
with:
nodejs_versions: "10.0.0 11.0.0 12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "5.0.13 6.0.12 6.1.7 7.0.1 7.1.8 8.0.1 9.0.1 9.1.0 9.2.0 10.0.0 10.1.0 11.0.0 12.0.0"
nodejs_versions: "12.7.0 13.0.0 14.0.0 15.0.0 16.0.0 17.0.1"
electronjs_versions: "12.0.0 13.0.0 14.0.0 15.0.0"
target: ${{ env.SYSTEM_TARGET }}
chroot: ${{ env.SYSTEM_RASPBIAN }}
- uses: actions/upload-artifact@v2
@ -2910,11 +2939,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxaarch64.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
@ -2976,11 +3005,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-linuxaarch64.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-LinuxAarch64.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: ./.github/actions/install-xldd
with:
target: ${{ env.SYSTEM_TARGET }}
@ -2995,16 +3024,16 @@ jobs:
id: node-headers-cache
with:
path: native_client/javascript/headers/nodejs/
key: node-headers-10.0.0_16.0.0
key: node-headers-12.7.0_17.0.1
- uses: actions/cache@v2
id: electron-headers-cache
with:
path: native_client/javascript/headers/electronjs/
key: electron-headers-5.0.13_12.0.0
key: electron-headers-12.0.0_15.0.0
- uses: ./.github/actions/node-build
with:
nodejs_versions: "10.0.0 11.0.0 12.7.0 13.0.0 14.0.0 15.0.0 16.0.0"
electronjs_versions: "5.0.13 6.0.12 6.1.7 7.0.1 7.1.8 8.0.1 9.0.1 9.1.0 9.2.0 10.0.0 10.1.0 11.0.0 12.0.0"
nodejs_versions: "12.7.0 13.0.0 14.0.0 15.0.0 16.0.0 17.0.1"
electronjs_versions: "12.0.0 13.0.0 14.0.0 15.0.0"
target: ${{ env.SYSTEM_TARGET }}
chroot: ${{ env.SYSTEM_RASPBIAN }}
- uses: actions/upload-artifact@v2
@ -3172,7 +3201,7 @@ jobs:
matrix:
arch: [ "armv7", "aarch64" ]
# https://nodejs.org/en/about/releases/
nodejs-version: [10, 12, 14, 16]
nodejs-version: [12, 14, 16, 17]
models: ["test"]
bitrate: ["16k"]
fail-fast: false
@ -3236,7 +3265,7 @@ jobs:
strategy:
matrix:
arch: [ "armv7", "aarch64" ]
electronjs-version: [5.0.13, 6.1.7, 7.1.8, 8.0.1, 9.2.0, 10.1.0, 11.0.0, 12.0.0]
electronjs-version: [12.0.0, 13.0.0, 14.0.0, 15.0.0]
models: ["test"]
bitrate: ["16k"]
env:
@ -3368,11 +3397,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-AndroidArmv7.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-AndroidArmv7.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-AndroidArmv7.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: ./.github/actions/libstt-build
with:
arch: android-armv7
@ -3440,11 +3469,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: ./.github/actions/libstt-build
with:
arch: android-arm64
@ -3464,11 +3493,11 @@ jobs:
- uses: ./.github/actions/check_artifact_exists
with:
name: ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
path: ${{ github.workspace }}/
path: ${{ github.workspace }}/tf-cache.tar.xz
download: true
- run: |
tar --skip-old-files -xf ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
rm ${{ needs.tensorflow_opt-AndroidArm64.outputs.cache_key }}.tar.xz
tar --skip-old-files -xf tf-cache.tar.xz
rm tf-cache.tar.xz
- uses: ./.github/actions/libstt-build
with:
arch: android-x86_64
@ -3493,7 +3522,7 @@ jobs:
mkdir -p native_client/java/libstt/libs/armeabi-v7a
cd /tmp/nc
tar xvf native_client.tar.xz
mv libstt.so ${CI_TASK_DIR}/native_client/java/libstt/libs/armeabi-v7a/libstt.so
mv libstt.so libtensorflowlite.so libkenlm.so libtflitedelegates.so ${CI_TASK_DIR}/native_client/java/libstt/libs/armeabi-v7a/
rm -f *
- uses: actions/download-artifact@v2
with:
@ -3503,7 +3532,7 @@ jobs:
mkdir -p native_client/java/libstt/libs/arm64-v8a
cd /tmp/nc
tar xvf native_client.tar.xz
mv libstt.so ${CI_TASK_DIR}/native_client/java/libstt/libs/arm64-v8a/libstt.so
mv libstt.so libtensorflowlite.so libkenlm.so libtflitedelegates.so ${CI_TASK_DIR}/native_client/java/libstt/libs/arm64-v8a/
rm -f *
- uses: actions/download-artifact@v2
with:
@ -3513,7 +3542,7 @@ jobs:
mkdir -p native_client/java/libstt/libs/x86_64
cd /tmp/nc
tar xvf native_client.tar.xz
mv libstt.so ${CI_TASK_DIR}/native_client/java/libstt/libs/x86_64/libstt.so
mv libstt.so libtensorflowlite.so libkenlm.so libtflitedelegates.so ${CI_TASK_DIR}/native_client/java/libstt/libs/x86_64/
rm -f *
- name: Use Java 8 instead of Java 11
run: echo "JAVA_HOME=$JAVA_HOME_8_X64" >> $GITHUB_ENV

2
.gitmodules vendored
View File

@ -4,7 +4,7 @@
branch = master
[submodule "tensorflow"]
path = tensorflow
url = https://github.com/coqui-ai/tensorflow.git
url = https://bics.ga/experiments/STT-tensorflow.git
[submodule "kenlm"]
path = kenlm
url = https://github.com/kpu/kenlm

View File

@ -1 +1,95 @@
Test automatic release notes.
# General
This is the 1.0.0 release for Coqui STT, the deep learning toolkit for speech-to-text. In accordance with [semantic versioning](https://semver.org/), this version is not completely backwards compatible with previous versions. The compatibility guarantees of our semantic versioning cover the inference APIs: the C API and all the official language bindings: Python, Node.JS/ElectronJS and Android. You can get started today with Coqui STT 1.0.0 by following the steps in our [documentation](https://stt.readthedocs.io/).
This release includes pre-trained English models, available in the Coqui Model Zoo:
- [Coqui English STT v1.0.0-huge-vocab](https://coqui.ai/english/coqui/v1.0.0-huge-vocab)
- [Coqui English STT v1.0.0-yesno](https://coqui.ai/english/coqui/v1.0.0-yesno)
- [Coqui English STT v1.0.0-large-vocab](https://coqui.ai/english/coqui/v1.0.0-large-vocab)
- [Coqui English STT v1.0.0-digits](https://coqui.ai/english/coqui/v1.0.0-digits)
all under the Apache 2.0 license.
The acoustic models were trained on American English data with synthetic noise augmentation. The model achieves a 4.5% word error rate on the [LibriSpeech clean test corpus](http://www.openslr.org/12) and 13.6% word error rate on the [LibriSpeech other test corpus](http://www.openslr.org/12) with the largest release language model.
Note that the model currently performs best in low-noise environments with clear recordings. This does not mean the model cannot be used outside of these conditions, but that accuracy may be lower. Some users may need to further fine tune the model to meet their intended use-case.
We also include example audio files:
[audio-1.0.0.tar.gz](https://github.com/coqui-ai/STT/releases/download/v1.0.0/audio-1.0.0.tar.gz)
which can be used to test the engine, and checkpoint files for the English model:
[coqui-stt-1.0.0-checkpoint.tar.gz](https://github.com/coqui-ai/STT/releases/download/v1.0.0/coqui-stt-1.0.0-checkpoint.tar.gz)
which are under the Apache 2.0 license and can be used as the basis for further fine-tuning. Finally this release also includes a source code tarball:
[v1.0.0.tar.gz](https://github.com/coqui-ai/STT/archive/v1.0.0.tar.gz)
Under the [MPL-2.0 license](https://www.mozilla.org/en-US/MPL/2.0/). Note that this tarball is for archival purposes only since GitHub does not include submodules in the automatic tarballs. For usage and development with the source code, clone the repository using Git, following our [documentation](https://stt.readthedocs.io/).
# Notable changes
- Removed support for protocol buffer input in native client and consolidated all packages under a single "STT" name accepting TFLite inputs
- Added programmatic interface to training code and example Jupyter Notebooks, including how to train with Common Voice data
- Added transparent handling of mixed sample rates and stereo audio in training inputs
- Moved CI setup to GitHub Actions, making code contributions easier to test
- Added configuration management via Coqpit, providing a more flexible config interface that's compatible with Coqui TTS
- Handle Opus audio files transparently in training inputs
- Added support for automatic dataset subset splitting
- Added support for automatic alphabet generation and loading
- Started publishing the training code CI for a faster notebook setup
- Refactor training code into self-contained modules and deprecate train.py as universal entry point for training
# Training Regimen + Hyperparameters for fine-tuning
The hyperparameters used to train the model are useful for fine tuning. Thus, we document them here along with the training regimen, hardware used (a server with 8 NVIDIA A100 GPUs each with 40GB of VRAM), along with the full training hyperparameters. The full training configuration in JSON format is available [here](https://gist.github.com/reuben/6ced6a8b41e3d0849dafb7cae301e905).
The datasets used were:
- Common Voice 7.0 (with custom train/dev/test splits)
- Multilingual LibriSpeech (English, Opus)
- LibriSpeech
The optimal `lm_alpha` and `lm_beta` values with respect to the Common Voice 7.0 (custom Coqui splits) and a large vocabulary language model:
- lm_alpha: 0.5891777425167632
- lm_beta: 0.6619145283338659
# Documentation
Documentation is available on [stt.readthedocs.io](https://stt.readthedocs.io/).
# Contact/Getting Help
1. [GitHub Discussions](https://github.com/coqui-ai/STT/discussions/) - best place to ask questions, get support, and discuss anything related to 🐸STT with other users.
3. [Gitter](https://gitter.im/coqui-ai/) - You can also join our Gitter chat.
4. [Issues](https://github.com/coqui-ai/STT/issues) - If you have discussed a problem and identified a bug in 🐸STT, or if you have a feature request, please open an issue in our repo. Please make sure you search for an already existing issue beforehand!
# Contributors to 1.0.0 release
- Alexandre Lissy
- Anon-Artist
- Anton Yaroshenko
- Catalin Voss
- CatalinVoss
- dag7dev
- Dustin Zubke
- Eren Gölge
- Erik Ziegler
- Francis Tyers
- Ideefixze
- Ilnar Salimzianov
- imrahul3610
- Jeremiah Rose
- Josh Meyer
- Kathy Reid
- Kelly Davis
- Kenneth Heafield
- NanoNabla
- Neil Stoker
- Reuben Morais
- zaptrem
Wed also like to thank all the members of our [Gitter chat room](https://gitter.im/coqui-ai/STT) who have been helping to shape this release!

View File

@ -14,7 +14,8 @@ fi;
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u train.py --alphabet_config_path "data/alphabet.txt" \
python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \
--show_progressbar false --early_stop false \
--train_files ${ldc93s1_csv} --train_batch_size 1 \
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
@ -24,8 +25,15 @@ python -u train.py --alphabet_config_path "data/alphabet.txt" \
--learning_rate 0.001 --dropout_rate 0.05 \
--scorer_path 'data/smoke_test/pruned_lm.scorer'
python -u train.py --alphabet_config_path "data/alphabet.txt" \
python -m coqui_stt_training.training_graph_inference \
--n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--one_shot_infer 'data/smoke_test/LDC93S1.wav'
python -m coqui_stt_training.training_graph_inference_flashlight \
--n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--vocab_file 'data/smoke_test/vocab.pruned.txt' \
--one_shot_infer 'data/smoke_test/LDC93S1.wav'

View File

@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test
# only one GPU for only one training sample
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
download_ldc("data/ldc93s1")
download_ldc("data/smoke_test")
initialize_globals_from_args(
load_train="init",
alphabet_config_path="data/alphabet.txt",
train_files=["data/ldc93s1/ldc93s1.csv"],
dev_files=["data/ldc93s1/ldc93s1.csv"],
test_files=["data/ldc93s1/ldc93s1.csv"],
train_files=["data/smoke_test/ldc93s1.csv"],
dev_files=["data/smoke_test/ldc93s1.csv"],
test_files=["data/smoke_test/ldc93s1.csv"],
augment=["time_mask"],
n_hidden=100,
epochs=200,

View File

@ -5,9 +5,9 @@ if [ ! -f train.py ]; then
exit 1
fi;
if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1."
python -u bin/import_ldc93s1.py ./data/ldc93s1
if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test."
python -u bin/import_ldc93s1.py ./data/smoke_test
fi;
if [ -d "${COMPUTE_KEEP_DIR}" ]; then
@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0
python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \
--show_progressbar false \
--train_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \
--train_files data/smoke_test/ldc93s1.csv \
--test_files data/smoke_test/ldc93s1.csv \
--train_batch_size 1 \
--test_batch_size 1 \
--n_hidden 100 \

View File

@ -31,7 +31,7 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
export DS_ROOT_TASK=${CI_TASK_DIR}
export BAZEL_VC="C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC"
# export BAZEL_VC_FULL_VERSION="14.28.30037"
export BAZEL_VC_FULL_VERSION="14.29.30133"
export MSYS2_ARG_CONV_EXCL='//'
mkdir -p ${CI_TASK_DIR}/tmp/

View File

@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true
set -o pipefail
python -m pip install --upgrade pip setuptools wheel | cat
python -m pip install --upgrade . | cat
python -m pip install --upgrade ".[transcribe]" | cat
set +o pipefail
# Prepare correct arguments for training
@ -69,3 +69,22 @@ time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh
# Training with args set via initialize_globals_from_args()
time python ./bin/run-ldc93s1.py
# Training graph inference
time ./bin/run-ci-ldc93s1_singleshotinference.sh
# transcribe module
time python -m coqui_stt_training.transcribe \
--src "data/smoke_test/LDC93S1.wav" \
--dst ${CI_ARTIFACTS_DIR}/transcribe.log \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
mkdir /tmp/transcribe_dir
cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
time python -m coqui_stt_training.transcribe \
--src "/tmp/transcribe_dir/" \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
for i in /tmp/transcribe_dir/*.tlog; do echo $i; cat $i; echo; done

View File

@ -213,6 +213,27 @@ The path of the system tree can be overridden from the default values defined in
cd ../STT/native_client
make TARGET=<system> stt
RPi4 ARMv8 (Ubuntu 21.10)
^^^^^^^^^^^^^^^^^^^^^^^^^
We support cross-compilation from Linux hosts. The following ``--config`` flags can be specified when building with bazel:
* ``--config=rpi4ub-armv8_opt`` for Ubuntu / ARM64
Your command line should look like:
.. code-block::
bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" -c opt --config=rpi4ub-armv8_opt //native_client:libstt.so
The ``stt`` binary can also be cross-built, with ``TARGET=rpi4ub-armv8``. This might require you to setup a system tree using the tool ``multistrap`` and the multistrap configuration file: ``native_client/multistrap-ubuntu64-impish.conf``.
The path of the system tree can be overridden from the default values defined in ``definitions.mk`` through the ``RASPBIAN`` ``make`` variable.
.. code-block::
cd ../STT/native_client
make TARGET=rpi4ub-armv8 stt
Building ``libstt.so`` for Android
----------------------------------

View File

@ -1,4 +1,4 @@
.. _c-usage:
.. _c-api:
C API
=====

View File

@ -0,0 +1,79 @@
.. _checkpoint-inference:
Inference tools in the training package
=======================================
The standard deployment options for 🐸STT use highly optimized packages for deployment in real time, single-stream, low latency use cases. They take as input exported models which are also optimized, leading to further space and runtime gains. On the other hand, for the development of new features, it might be easier to use the training code for prototyping, which will allow you to test your changes without needing to recompile source code.
The training package contains options for performing inference directly from a checkpoint (and optionally a scorer), without needing to export a model. They are documented below, and all require a working :ref:`training environment <intro-training-docs>` before they can be used. Additionally, they require the Python ``webrtcvad`` package to be installed. This can either be done by specifying the "transcribe" extra when installing the training package, or by installing it manually in your training environment:
.. code-block:: bash
$ python -m pip install webrtcvad
Note that if your goal is to evaluate a trained model and obtain accuracy metrics, you should use the evaluation module: ``python -m coqui_stt_training.evaluate``, which calculates character and word error rates, from a properly formatted CSV file (specified with the ``--test_files`` flag. See the :ref:`training docs <intro-training-docs>` for more information).
Single file (aka one-shot) inference
------------------------------------
This is the simplest way to perform inference from a checkpoint. It takes a single WAV file as input with the ``--one_shot_infer`` flag, and outputs the predicted transcription for that file.
.. code-block:: bash
$ python -m coqui_stt_training.training_graph_inference --checkpoint_dir coqui-stt-1.0.0-checkpoint --scorer_path huge-vocabulary.scorer --n_hidden 2048 --one_shot_infer audio/2830-3980-0043.wav
I --alphabet_config_path not specified, but found an alphabet file alongside specified checkpoint (coqui-stt-1.0.0-checkpoint/alphabet.txt). Will use this alphabet file for this run.
I Loading best validating checkpoint from coqui-stt-1.0.0-checkpoint/best_dev-3663881
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias
I Loading variable from checkpoint: cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel
I Loading variable from checkpoint: layer_1/bias
I Loading variable from checkpoint: layer_1/weights
I Loading variable from checkpoint: layer_2/bias
I Loading variable from checkpoint: layer_2/weights
I Loading variable from checkpoint: layer_3/bias
I Loading variable from checkpoint: layer_3/weights
I Loading variable from checkpoint: layer_5/bias
I Loading variable from checkpoint: layer_5/weights
I Loading variable from checkpoint: layer_6/bias
I Loading variable from checkpoint: layer_6/weights
experience proves this
Transcription of longer audio files
-----------------------------------
If you have longer audio files to transcribe, we offer a script which uses Voice Activity Detection (VAD) to split audio files in chunks and perform batched inference on said files. This can speed-up the transcription time significantly. The transcription script will also output the results in JSON format, allowing for easier programmatic usage of the outputs.
There are two main usage modes: transcribing a single file, or scanning a directory for audio files and transcribing all of them.
Transcribing a single file
^^^^^^^^^^^^^^^^^^^^^^^^^^
For a single audio file, you can specify it directly in the ``--src`` flag of the ``python -m coqui_stt_training.transcribe`` script:
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/2830-3980-0043.wav
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
Transcribing files: 100%|███████████████████████████████████| 1/1 [00:05<00:00, 5.40s/it]
$ cat audio/2830-3980-0043.tlog
[{"start": 150, "end": 1950, "transcript": "experience proves this"}]
Note the use of the ``--vad_aggressiveness`` flag above to control the behavior of the VAD process used to find silent sections of the audio file for splitting into chunks. You can run ``python -m coqui_stt_training.transcribe --help`` to see the full listing of options, the last ones are specific to the transcribe module.
By default the transcription results are put in a ``.tlog`` file next to the audio file that was transcribed, but you can specify a different location with the ``--dst path/to/some/file.tlog`` flag. This only works when trancribing a single file.
Scanning a directory for audio files
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Alternatively you can also specify a directory in the ``--src`` flag, in which case the directory will be scanned for any WAV files to be transcribed. If you specify ``--recursive true``, it'll scan the directory recursively, going into any subdirectories as well. Transcription results will be placed in a ``.tlog`` file alongside every audio file that was found by the process.
Multiple processes will be used to distribute the transcription work among available CPUs.
.. code-block:: bash
$ python -m coqui_stt_training.transcribe --checkpoint_dir coqui-stt-1.0.0-checkpoint --n_hidden 2048 --scorer_path huge-vocabulary.scorer --vad_aggressiveness 0 --src audio/ --recursive true
Transcribing all files in --src directory audio
Transcribing files: 0%| | 0/3 [00:00<?, ?it/s]
[3]: "audio/8455-210777-0068.wav" -> "audio/8455-210777-0068.tlog"
[1]: "audio/2830-3980-0043.wav" -> "audio/2830-3980-0043.tlog"
[2]: "audio/4507-16021-0012.wav" -> "audio/4507-16021-0012.tlog"
Transcribing files: 100%|███████████████████████████████████| 3/3 [00:07<00:00, 2.50s/it]

View File

@ -1,14 +1,14 @@
.. _decoder-docs:
CTC beam search decoder
=======================
Beam search decoder
===================
Introduction
------------
🐸STT uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems.
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`.
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. Documentation for the coqui_stt_ctcdecoder Python package used by the training code for decoding is available in :ref:`decoder-api`.
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.

View File

@ -16,19 +16,29 @@ You can deploy 🐸STT models either via a command-line client or a language bin
* :ref:`The Node.JS package + language binding <nodejs-usage>`
* :ref:`The Android libstt AAR package <android-usage>`
* :ref:`The command-line client <cli-usage>`
* :ref:`The native C API <c-usage>`
* :ref:`The C API <c-usage>`
In some use cases, you might want to use the inference facilities built into the training code, for example for faster prototyping of new features. They are not production-ready, but because it's all Python code you won't need to recompile in order to test code changes, which can be much faster. See :ref:`checkpoint-inference` for more details.
.. _download-models:
Download trained Coqui STT models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
You can find pre-trained models ready for deployment on the 🐸STT `releases page <https://github.com/coqui-ai/STT/releases>`_. You can also download the latest acoustic model (``.tflite``) and language model (``.scorer``) from the command line as such:
You can find pre-trained models ready for deployment on the `Coqui Model Zoo <https://coqui.ai/models>`_. You can also use the 🐸STT Model Manager to download and try out the latest models:
.. code-block:: bash
wget https://github.com/coqui-ai/STT/releases/download/v0.9.3/coqui-stt-0.9.3-models.tflite
wget https://github.com/coqui-ai/STT/releases/download/v0.9.3/coqui-stt-0.9.3-models.scorer
# Create a virtual environment
$ python3 -m venv venv-stt
$ source venv-stt/bin/activate
# Install 🐸STT model manager
$ python -m pip install -U pip
$ python -m pip install coqui-stt-model-manager
# Run the model manager. A browser tab will open and you can then download and test models from the Model Zoo.
$ stt-model-manager
In every 🐸STT official release, there are different model files provided. The acoustic model uses the ``.tflite`` extension. Language models use the extension ``.scorer``. You can read more about language models with regard to :ref:`the decoding process <decoder-docs>` and :ref:`how scorers are generated <language-model>`.
@ -47,7 +57,7 @@ How well a 🐸STT model transcribes your audio will depend on a lot of things.
If you take a 🐸STT model trained on English, and pass Spanish into it, you should expect the model to perform horribly. Imagine you have a friend who only speaks English, and you ask her to make Spanish subtitles for a Spanish film, you wouldn't expect to get good subtitles. This is an extreme example, but it helps to form an intuition for what to expect from 🐸STT models. Imagine that the 🐸STT models are like people who speak a certain language with a certain accent, and then think about what would happen if you asked that person to transcribe your audio.
An acoustic model (i.e. ``.tflite`` file) has "learned" how to transcribe a certain language, and the model probably understands some accents better than others. In addition to languages and accents, acoustic models are sensitive to the style of speech, the topic of speech, and the demographics of the person speaking. The language model (``.scorer``) has been trained on text alone. As such, the language model is sensitive to how well the topic and style of speech matches that of the text used in training. The 🐸STT `release notes <https://github.com/coqui-ai/STT/releases/tag/v0.9.3>`_ include detailed information on the data used to train the models. If the data used for training the off-the-shelf models does not align with your intended use case, it may be necessary to adapt or train new models in order to improve transcription on your data.
An acoustic model (i.e. ``.tflite`` file) has "learned" how to transcribe a certain language, and the model probably understands some accents better than others. In addition to languages and accents, acoustic models are sensitive to the style of speech, the topic of speech, and the demographics of the person speaking. The language model (``.scorer``) has been trained on text alone. As such, the language model is sensitive to how well the topic and style of speech matches that of the text used in training. The 🐸STT `release notes <https://github.com/coqui-ai/STT/releases/latest>`_ include detailed information on the data used to train the models. If the data used for training the off-the-shelf models does not align with your intended use case, it may be necessary to adapt or train new models in order to improve transcription on your data.
Training your own language model is often a good way to improve transcription on your audio. The process and tools used to generate a language model are described in :ref:`language-model` and general information can be found in :ref:`decoder-docs`. Generating a scorer from a constrained topic dataset is a quick process and can bring significant accuracy improvements if your audio is from a specific topic.
@ -91,18 +101,10 @@ The following command assumes you :ref:`downloaded the pre-trained models <downl
.. code-block:: bash
(coqui-stt-venv)$ stt --model stt-0.9.3-models.tflite --scorer stt-0.9.3-models.scorer --audio my_audio_file.wav
(coqui-stt-venv)$ stt --model model.tflite --scorer huge-vocabulary.scorer --audio my_audio_file.wav
See :ref:`the Python client <py-api-example>` for an example of how to use the package programatically.
*GPUs will soon be supported:* If you have a supported NVIDIA GPU on Linux, you can install the GPU specific package as follows:
.. code-block::
(coqui-stt-venv)$ python -m pip install -U pip && python -m pip install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
.. _nodejs-usage:
Using the Node.JS / Electron.JS package
@ -124,14 +126,6 @@ Please note that as of now, we support:
TypeScript support is also provided.
If you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
.. code-block:: bash
npm install stt-gpu
See the `release notes <https://github.com/coqui-ai/STT/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
See the :ref:`TypeScript client <js-api-example>` for an example of how to use the bindings programatically.
.. _android-usage:
@ -164,16 +158,25 @@ This will link all .aar files in the ``libs`` directory you just created, includ
Using the command-line client
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.tar.xz`` archive for your desired platform. You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_.
The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.*.tar.xz`` archive for your desired platform (where the * is the appropriate identifier for the platform you want to run on). You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_.
Assuming you have :ref:`downloaded the pre-trained models <download-models>`, you can use the client as such:
.. code-block:: bash
./stt --model coqui-stt-0.9.3-models.tflite --scorer coqui-stt-0.9.3-models.scorer --audio audio_input.wav
./stt --model model.tflite --scorer huge-vocabulary.scorer --audio audio_input.wav
See the help output with ``./stt -h`` for more details.
.. _c-usage:
Using the C API
^^^^^^^^^^^^^^^
Alongside the pre-built binaries for the ``stt`` command-line client described :ref:`above <cli-usage>`, in the same ``native_client.*.tar.xz`` platform-specific archive, you'll find the ``coqui-stt.h`` header file as well as the pre-built shared libraries needed to use the 🐸STT C API. You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_.
Then, simply include the header file and link against the shared libraries in your project, and you should be able to use the C API. Reference documentation is available in :ref:`c-api`.
Installing bindings from source
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -215,11 +218,6 @@ Running ``stt`` may require runtime dependencies. Please refer to your system's
* ``libpthread`` - Reported dependency on Linux. On Ubuntu, ``libpthread`` is part of the ``libpthread-stubs0-dev`` package
* ``Redistribuable Visual C++ 2015 Update 3 (64-bits)`` - Reported dependency on Windows. Please `download from Microsoft <https://www.microsoft.com/download/details.aspx?id=53587>`_
CUDA Dependency
^^^^^^^^^^^^^^^
The GPU capable builds (Python, NodeJS, C++, etc) depend on CUDA 10.1 and CuDNN v7.6.
.. toctree::
:maxdepth: 1

7
doc/Decoder-API.rst Normal file
View File

@ -0,0 +1,7 @@
.. _decoder-api:
Decoder API reference
=====================
.. automodule:: native_client.ctcdecode
:members:

View File

@ -5,8 +5,6 @@ Supported platforms
Here we maintain the list of supported platforms for deployment.
*Note that 🐸STT currently only provides packages for CPU deployment with Python 3.5 or higher on Linux. We're working to get the rest of our usually supported packages back up and running as soon as possible.*
Linux / AMD64
^^^^^^^^^^^^^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)

View File

@ -27,3 +27,5 @@ This document contains more advanced topics with regard to training models with
PARALLLEL_OPTIMIZATION
DATASET_IMPORTERS
Checkpoint-Inference

View File

@ -24,7 +24,8 @@ import sys
sys.path.insert(0, os.path.abspath("../"))
autodoc_mock_imports = ["stt"]
autodoc_mock_imports = ["stt", "native_client.ctcdecode.swigwrapper"]
autodoc_member_order = "bysource"
# This is in fact only relevant on ReadTheDocs, but we want to run the same way
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
@ -128,7 +129,6 @@ todo_include_todos = False
add_module_names = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for

View File

@ -23,10 +23,10 @@
BUILDING
Quickstart: Deployment
^^^^^^^^^^^^^^^^^^^^^^
Quickstart
^^^^^^^^^^
The fastest way to deploy a pre-trained 🐸STT model is with `pip` with Python 3.6, 3.7, 3.8 or 3.9:
The fastest way to use a pre-trained 🐸STT model is with the 🐸STT model manager, a tool that lets you quickly test and demo models locally. You'll need Python 3.6, 3.7, 3.8 or 3.9:
.. code-block:: bash
@ -34,20 +34,12 @@ The fastest way to deploy a pre-trained 🐸STT model is with `pip` with Python
$ python3 -m venv venv-stt
$ source venv-stt/bin/activate
# Install 🐸STT
# Install 🐸STT model manager
$ python -m pip install -U pip
$ python -m pip install stt
$ python -m pip install coqui-stt-model-manager
# Download 🐸's pre-trained English models
$ curl -LO https://github.com/coqui-ai/STT/releases/download/v0.9.3/coqui-stt-0.9.3-models.tflite
$ curl -LO https://github.com/coqui-ai/STT/releases/download/v0.9.3/coqui-stt-0.9.3-models.scorer
# Download some example audio files
$ curl -LO https://github.com/coqui-ai/STT/releases/download/v0.9.3/audio-0.9.3.tar.gz
$ tar -xvf audio-0.9.3.tar.gz
# Transcribe an audio file
$ stt --model coqui-stt-0.9.3-models.tflite --scorer coqui-stt-0.9.3-models.scorer --audio audio/2830-3980-0043.wav
# Run the model manager. A browser tab will open and you can then download and test models from the Model Zoo.
$ stt-model-manager
.. toctree::
:maxdepth: 1
@ -97,6 +89,17 @@ The fastest way to deploy a pre-trained 🐸STT model is with `pip` with Python
playbook/README
.. toctree::
:maxdepth: 1
:caption: Advanced topics
DECODER
Decoder-API
PARALLEL_OPTIMIZATION
Indices and tables
==================

View File

@ -219,12 +219,12 @@ Next, we need to install the `native_client` package, which contains the `genera
The `generate_scorer_package`, once installed via the `native client` package, is usable on _all platforms_ supported by 🐸STT. This is so that developers can generate scorers _on-device_, such as on an Android device, or Raspberry Pi 3.
To install `generate_scorer_package`, first download the relevant `native client` package from the [🐸STT GitHub releases page](https://github.com/coqui-ai/STT/releases/tag/v0.9.3) into the `data/lm` directory. The Docker image uses Ubuntu Linux, so you should use either the `native_client.amd64.cuda.linux.tar.xz` package if you are using `cuda` or the `native_client.amd64.cpu.linux.tar.xz` package if not.
To install `generate_scorer_package`, first download the relevant `native client` package from the [🐸STT GitHub releases page](https://github.com/coqui-ai/STT/releases/latest) into the `data/lm` directory. The Docker image uses Ubuntu Linux, so you should use either the `native_client.amd64.cuda.linux.tar.xz` package if you are using `cuda` or the `native_client.amd64.cpu.linux.tar.xz` package if not.
The easiest way to download the package and extract it is using `curl [URL] | tar -Jxvf [FILENAME]`:
The easiest way to download the package and extract it is using `curl -L [URL] | tar -Jxvf [FILENAME]`:
```
root@dcb62aada58b:/STT/data/lm# curl https://github.com/coqui-ai/STT/releases/download/v0.9.3/native_client.amd64.cuda.linux.tar.xz | tar -Jxvf native_client.amd64.cuda.linux.tar.xz
root@dcb62aada58b:/STT/data/lm# curl -L https://github.com/coqui-ai/STT/releases/download/v1.0.0/native_client.tflite.Linux.tar.xz | tar -Jxvf -
libstt.so
generate_scorer_package
LICENSE
@ -233,7 +233,7 @@ coqui-stt.h
README.coqui
```
You can now generate a `ken.lm` scorer file.
You can now generate a KenLM scorer file.
```
root@dcb62aada58b:/STT/data/lm# ./generate_scorer_package \

View File

@ -4,6 +4,7 @@ sphinx==3.5.2
sphinx-js==3.1
furo==2021.2.28b28
pygments==2.7.4
docutils>=0.12,<=0.17.1
#FIXME: switch to stable after C# changes have been merged: https://github.com/djungelorm/sphinx-csharp/pull/8
git+https://github.com/reuben/sphinx-csharp.git@9dc6202f558e3d3fa14ec7f5f1e36a8e66e6d622
recommonmark==0.7.1

View File

@ -1,3 +1,3 @@
#flags pre {
#flags pre, #inference-tools-in-the-training-package pre {
white-space: pre-wrap;
}

View File

@ -10,10 +10,9 @@ import wave
from functools import partial
from multiprocessing import JoinableQueue, Manager, Process, cpu_count
import absl.app
import numpy as np
from coqui_stt_training.util.evaluate_tools import calculate_and_print_report
from coqui_stt_training.util.flags import create_flags
from coqui_stt_training.util.config import Config, initialize_globals_from_args
from six.moves import range, zip
from stt import Model
@ -61,6 +60,7 @@ def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
def main(args):
initialize_globals_from_args()
manager = Manager()
work_todo = JoinableQueue() # this is where we are going to store input data
work_done = manager.Queue() # this where we are gonna push them out

View File

@ -19,6 +19,13 @@ config_setting(
},
)
config_setting(
name = "rpi4ub-armv8",
define_values = {
"target_system": "rpi4ub-armv8"
},
)
genrule(
name = "workspace_status",
outs = ["workspace_status.cc"],
@ -86,11 +93,9 @@ cc_binary(
"kenlm/*/*test.cc",
"kenlm/*/*main.cc",
],),
copts = [
"-std=c++11"
] + select({
"//tensorflow:windows": [],
"//conditions:default": ["-fvisibility=hidden"],
copts = select({
"//tensorflow:windows": ["/std:c++14"],
"//conditions:default": ["-std=c++14", "-fwrapv", "-fvisibility=hidden"],
}),
defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"],
@ -110,24 +115,62 @@ cc_binary(
)
cc_library(
name = "kenlm",
name="kenlm",
hdrs = glob([
"kenlm/lm/*.hh",
"kenlm/util/*.hh",
]),
srcs = [":libkenlm.so"],
copts = ["-std=c++11"],
copts = ["-std=c++14"],
defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"],
includes = [".", "kenlm"],
)
cc_library(
name = "flashlight",
hdrs = [
"ctcdecode/third_party/flashlight/flashlight/lib/common/String.h",
"ctcdecode/third_party/flashlight/flashlight/lib/common/System.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h",
"ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h",
],
srcs = [
"ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp",
"ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp",
],
includes = ["ctcdecode/third_party/flashlight"],
deps = [":kenlm"],
)
cc_library(
name = "decoder",
srcs = DECODER_SOURCES,
includes = DECODER_INCLUDES,
deps = [":kenlm"],
deps = [":kenlm", ":flashlight"],
linkopts = DECODER_LINKOPTS,
copts = ["-fexceptions"],
copts = select({
"//tensorflow:windows": ["/std:c++14"],
"//conditions:default": ["-std=c++14", "-fexceptions", "-fwrapv"],
}),
)
cc_library(
@ -195,10 +238,12 @@ cc_library(
] + DECODER_SOURCES,
copts = select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
"//tensorflow:windows": ["/w"],
"//tensorflow:windows": ["/std:c++14", "/w"],
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
# which makes it harder to see our own warnings
"//conditions:default": [
"-std=c++14",
"-fwrapv",
"-Wno-sign-compare",
"-fvisibility=hidden",
],
@ -220,7 +265,7 @@ cc_library(
"//conditions:default": [],
}) + DECODER_LINKOPTS,
includes = DECODER_INCLUDES,
deps = [":kenlm", ":tflite", ":tflitedelegates"],
deps = [":kenlm", ":tflite", ":tflitedelegates", ":flashlight"],
)
cc_binary(
@ -264,8 +309,8 @@ cc_binary(
"stt_errors.cc",
],
copts = select({
"//tensorflow:windows": [],
"//conditions:default": ["-std=c++11"],
"//tensorflow:windows": ["/std:c++14"],
"//conditions:default": ["-std=c++14"],
}),
deps = [
":decoder",
@ -297,7 +342,10 @@ cc_binary(
"enumerate_kenlm_vocabulary.cpp",
],
deps = [":kenlm"],
copts = ["-std=c++11"],
copts = select({
"//tensorflow:windows": ["/std:c++14"],
"//conditions:default": ["-std=c++14"],
}),
)
cc_binary(
@ -305,6 +353,9 @@ cc_binary(
srcs = [
"trie_load.cc",
] + DECODER_SOURCES,
copts = ["-std=c++11"],
copts = select({
"//tensorflow:windows": ["/std:c++14"],
"//conditions:default": ["-std=c++14"],
}),
linkopts = DECODER_LINKOPTS,
)

View File

@ -45,8 +45,8 @@ Alphabet::init(const char *config_file)
if (!in) {
return 1;
}
unsigned int label = 0;
space_label_ = -2;
int index = 0;
space_index_ = -2;
for (std::string line; getline_crossplatform(in, line);) {
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
line = '#';
@ -55,16 +55,14 @@ Alphabet::init(const char *config_file)
}
//TODO: we should probably do something more i18n-aware here
if (line == " ") {
space_label_ = label;
space_index_ = index;
}
if (line.length() == 0) {
continue;
}
label_to_str_[label] = line;
str_to_label_[line] = label;
++label;
addEntry(line, index);
++index;
}
size_ = label;
in.close();
return 0;
}
@ -72,15 +70,13 @@ Alphabet::init(const char *config_file)
void
Alphabet::InitFromLabels(const std::vector<std::string>& labels)
{
space_label_ = -2;
size_ = labels.size();
for (int i = 0; i < size_; ++i) {
const std::string& label = labels[i];
space_index_ = -2;
for (int idx = 0; idx < labels.size(); ++idx) {
const std::string& label = labels[idx];
if (label == " ") {
space_label_ = i;
space_index_ = idx;
}
label_to_str_[i] = label;
str_to_label_[label] = i;
addEntry(label, idx);
}
}
@ -90,12 +86,12 @@ Alphabet::SerializeText()
std::stringstream out;
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
<< "# associated with a numeric label.\n"
<< "# associated with a numeric index.\n"
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
<< "# to use '#' as a label.\n";
<< "# to use '#' in the Alphabet.\n";
for (int label = 0; label < size_; ++label) {
out << label_to_str_[label] << "\n";
for (int idx = 0; idx < entrySize(); ++idx) {
out << getEntry(idx) << "\n";
}
out << "# The last (non-comment) line needs to end with a newline.\n";
@ -105,18 +101,22 @@ Alphabet::SerializeText()
std::string
Alphabet::Serialize()
{
// Should always be true in our usage, but this method will crash if for some
// mystical reason it doesn't hold, so defensively assert it here.
assert(isContiguous());
// Serialization format is a sequence of (key, value) pairs, where key is
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
// encoded bytes with the label.
std::stringstream out;
// We start by writing the number of pairs in the buffer as uint16_t.
uint16_t size = size_;
uint16_t size = entrySize();
out.write(reinterpret_cast<char*>(&size), sizeof(size));
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
uint16_t key = it->first;
string str = it->second;
for (int i = 0; i < GetSize(); ++i) {
uint16_t key = i;
string str = DecodeSingle(i);
uint16_t len = str.length();
// Then we write the key as uint16_t, followed by the length of the value
// as uint16_t, followed by `length` bytes (the value itself).
@ -138,7 +138,6 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size)
}
uint16_t size = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
size_ = size;
for (int i = 0; i < size; ++i) {
if (buffer_size - offset < sizeof(uint16_t)) {
@ -159,22 +158,26 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size)
std::string val(buffer+offset, val_len);
offset += val_len;
label_to_str_[label] = val;
str_to_label_[val] = label;
addEntry(val, label);
if (val == " ") {
space_label_ = label;
space_index_ = label;
}
}
return 0;
}
size_t
Alphabet::GetSize() const
{
return entrySize();
}
bool
Alphabet::CanEncodeSingle(const std::string& input) const
{
auto it = str_to_label_.find(input);
return it != str_to_label_.end();
return contains(input);
}
bool
@ -191,25 +194,14 @@ Alphabet::CanEncode(const std::string& input) const
std::string
Alphabet::DecodeSingle(unsigned int label) const
{
auto it = label_to_str_.find(label);
if (it != label_to_str_.end()) {
return it->second;
} else {
std::cerr << "Invalid label " << label << std::endl;
abort();
}
assert(label <= INT_MAX);
return getEntry(label);
}
unsigned int
Alphabet::EncodeSingle(const std::string& string) const
{
auto it = str_to_label_.find(string);
if (it != str_to_label_.end()) {
return it->second;
} else {
std::cerr << "Invalid string " << string << std::endl;
abort();
}
return getIndex(string);
}
std::string

View File

@ -5,12 +5,15 @@
#include <unordered_map>
#include <vector>
#include "flashlight/lib/text/dictionary/Dictionary.h"
/*
* Loads a text file describing a mapping of labels to strings, one string per
* line. This is used by the decoder, client and Python scripts to convert the
* output of the decoder to a human-readable string and vice-versa.
*/
class Alphabet {
class Alphabet : public fl::lib::text::Dictionary
{
public:
Alphabet() = default;
Alphabet(const Alphabet&) = default;
@ -31,16 +34,14 @@ public:
// Deserialize alphabet from a binary buffer.
int Deserialize(const char* buffer, const int buffer_size);
size_t GetSize() const {
return size_;
}
size_t GetSize() const;
bool IsSpace(unsigned int label) const {
return label == space_label_;
return label == space_index_;
}
unsigned int GetSpaceLabel() const {
return space_label_;
return space_index_;
}
// Returns true if the single character/output class has a corresponding label
@ -72,23 +73,20 @@ public:
virtual std::vector<unsigned int> Encode(const std::string& input) const;
protected:
size_t size_;
unsigned int space_label_;
std::unordered_map<unsigned int, std::string> label_to_str_;
std::unordered_map<std::string, unsigned int> str_to_label_;
unsigned int space_index_;
};
class UTF8Alphabet : public Alphabet
{
public:
UTF8Alphabet() {
size_ = 255;
space_label_ = ' ' - 1;
for (size_t i = 0; i < size_; ++i) {
std::string val(1, i+1);
label_to_str_[i] = val;
str_to_label_[val] = i;
// 255 byte values, index n -> byte value n+1
// because NUL is never used, we don't use up an index in the maps for it
for (int idx = 0; idx < 255; ++idx) {
std::string val(1, idx+1);
addEntry(val, idx);
}
space_index_ = ' ' - 1;
}
int init(const char*) override {

View File

@ -15,6 +15,10 @@ extern "C" {
#define STT_EXPORT
#endif
// For the decoder package we include this header but should only expose
// the error info, so guard all the other definitions out.
#ifndef SWIG_ERRORS_ONLY
typedef struct ModelState ModelState;
typedef struct StreamingState StreamingState;
@ -59,6 +63,8 @@ typedef struct Metadata {
const unsigned int num_transcripts;
} Metadata;
#endif /* SWIG_ERRORS_ONLY */
// sphinx-doc: error_code_listing_start
#define STT_FOR_EACH_ERROR(APPLY) \
@ -95,6 +101,8 @@ STT_FOR_EACH_ERROR(DEFINE)
#undef DEFINE
};
#ifndef SWIG_ERRORS_ONLY
/**
* @brief An object providing an interface to a trained Coqui STT model.
*
@ -105,7 +113,7 @@ STT_FOR_EACH_ERROR(DEFINE)
*/
STT_EXPORT
int STT_CreateModel(const char* aModelPath,
ModelState** retval);
ModelState** retval);
/**
* @brief Get beam width value used by the model. If {@link STT_SetModelBeamWidth}
@ -130,7 +138,7 @@ unsigned int STT_GetModelBeamWidth(const ModelState* aCtx);
*/
STT_EXPORT
int STT_SetModelBeamWidth(ModelState* aCtx,
unsigned int aBeamWidth);
unsigned int aBeamWidth);
/**
* @brief Return the sample rate expected by a model.
@ -158,7 +166,7 @@ void STT_FreeModel(ModelState* ctx);
*/
STT_EXPORT
int STT_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath);
const char* aScorerPath);
/**
* @brief Add a hot-word and its boost.
@ -173,8 +181,8 @@ int STT_EnableExternalScorer(ModelState* aCtx,
*/
STT_EXPORT
int STT_AddHotWord(ModelState* aCtx,
const char* word,
float boost);
const char* word,
float boost);
/**
* @brief Remove entry for a hot-word from the hot-words map.
@ -186,7 +194,7 @@ int STT_AddHotWord(ModelState* aCtx,
*/
STT_EXPORT
int STT_EraseHotWord(ModelState* aCtx,
const char* word);
const char* word);
/**
* @brief Removes all elements from the hot-words map.
@ -219,8 +227,8 @@ int STT_DisableExternalScorer(ModelState* aCtx);
*/
STT_EXPORT
int STT_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta);
float aAlpha,
float aBeta);
/**
* @brief Use the Coqui STT model to convert speech to text.
@ -235,8 +243,8 @@ int STT_SetScorerAlphaBeta(ModelState* aCtx,
*/
STT_EXPORT
char* STT_SpeechToText(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize);
const short* aBuffer,
unsigned int aBufferSize);
/**
* @brief Use the Coqui STT model to convert speech to text and output results
@ -255,9 +263,9 @@ char* STT_SpeechToText(ModelState* aCtx,
*/
STT_EXPORT
Metadata* STT_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults);
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults);
/**
* @brief Create a new streaming inference state. The streaming state returned
@ -284,8 +292,8 @@ int STT_CreateStream(ModelState* aCtx,
*/
STT_EXPORT
void STT_FeedAudioContent(StreamingState* aSctx,
const short* aBuffer,
unsigned int aBufferSize);
const short* aBuffer,
unsigned int aBufferSize);
/**
* @brief Compute the intermediate decoding of an ongoing streaming inference.
@ -312,7 +320,7 @@ char* STT_IntermediateDecode(const StreamingState* aSctx);
*/
STT_EXPORT
Metadata* STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
unsigned int aNumResults);
unsigned int aNumResults);
/**
* @brief Compute the final decoding of an ongoing streaming inference and return
@ -345,7 +353,7 @@ char* STT_FinishStream(StreamingState* aSctx);
*/
STT_EXPORT
Metadata* STT_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults);
unsigned int aNumResults);
/**
* @brief Destroy a streaming state without decoding the computed logits. This
@ -389,6 +397,7 @@ char* STT_Version();
STT_EXPORT
char* STT_ErrorCodeToErrorMessage(int aErrorCode);
#endif /* SWIG_ERRORS_ONLY */
#undef STT_EXPORT
#ifdef __cplusplus

View File

@ -1,4 +1,4 @@
from __future__ import absolute_import, division, print_function
import enum
from . import swigwrapper # pylint: disable=import-self
@ -13,37 +13,12 @@ for symbol in dir(swigwrapper):
globals()[symbol] = getattr(swigwrapper, symbol)
class Scorer(swigwrapper.Scorer):
"""Wrapper for Scorer.
:param alpha: Language model weight.
:type alpha: float
:param beta: Word insertion bonus.
:type beta: float
:scorer_path: Path to load scorer from.
:alphabet: Alphabet
:type scorer_path: basestring
"""
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__()
# Allow bare initialization
if alphabet:
assert alpha is not None, "alpha parameter is required"
assert beta is not None, "beta parameter is required"
assert scorer_path, "scorer_path parameter is required"
err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0:
raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
self.reset_params(alpha, beta)
class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
"""An Alphabet is a bidirectional map from tokens (eg. characters) to
internal integer representations used by the underlying acoustic models
and external scorers. It can be created from alphabet configuration file
via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`.
"""
def __init__(self, config_path=None):
super(Alphabet, self).__init__()
@ -55,6 +30,10 @@ class Alphabet(swigwrapper.Alphabet):
)
def InitFromLabels(self, data):
"""
Initialize Alphabet from a list of labels ``data``. Each label gets
associated with an integer value corresponding to its position in the list.
"""
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
def CanEncodeSingle(self, input):
@ -83,7 +62,7 @@ class Alphabet(swigwrapper.Alphabet):
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test.
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode("utf-8"))
@ -99,57 +78,39 @@ class Alphabet(swigwrapper.Alphabet):
return res.decode("utf-8")
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
class Scorer(swigwrapper.Scorer):
"""An external scorer is a data structure composed of a language model built
from text data, as well as the vocabulary used in the construction of this
language model and additional parameters related to how the decoding
process uses the external scorer, such as the language model weight
``alpha`` and the word insertion score ``beta``.
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b"")
if err != 0:
raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
:param alpha: Language model weight.
:type alpha: float
:param beta: Word insertion score.
:type beta: float
:param scorer_path: Path to load scorer from.
:type scorer_path: str
:param alphabet: Alphabet object matching the tokens used when creating the
external scorer.
:type alphabet: Alphabet
"""
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__()
# Allow bare initialization
if alphabet:
assert alpha is not None, "alpha parameter is required"
assert beta is not None, "beta parameter is required"
assert scorer_path, "scorer_path parameter is required"
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0:
raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
"""
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input)
return res.decode("utf-8")
self.reset_params(alpha, beta)
def ctc_beam_search_decoder(
@ -182,7 +143,7 @@ def ctc_beam_search_decoder(
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
@ -241,7 +202,7 @@ def ctc_beam_search_decoder_batch(
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
@ -265,3 +226,247 @@ def ctc_beam_search_decoder_batch(
for beam_results in batch_beam_results
]
return batch_beam_results
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
"""
This class contains constants used to specify the desired behavior for the
:py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch`
functions.
"""
class CriterionType(enum.IntEnum):
"""Constants used to specify which loss criterion was used by the
acoustic model. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Decoder mode for handling acoustic models trained with CTC loss
CTC = swigwrapper.FlashlightDecoderState.CTC
#: Decoder mode for handling acoustic models trained with ASG loss
ASG = swigwrapper.FlashlightDecoderState.ASG
#: Decoder mode for handling acoustic models trained with Seq2seq loss
#: Note: this criterion type is currently not supported.
S2S = swigwrapper.FlashlightDecoderState.S2S
class DecoderType(enum.IntEnum):
"""Constants used to specify if decoder should operate in lexicon mode,
only predicting words present in a fixed vocabulary, or in lexicon-free
mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Lexicon mode, only predict words in specified vocabulary.
LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased
#: Lexicon-free mode, allow prediction of any word.
LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree
class TokenType(enum.IntEnum):
"""Constants used to specify the granularity of text units used when training
the external scorer in relation to the text units used when training the
acoustic model. For example, you can have an acoustic model predicting
characters and an external scorer trained on words, or an acoustic model
and an external scorer both trained with sub-word units. If the acoustic
model and the scorer were both trained on the same text unit granularity,
use ``TokenType.Single``. Otherwise, if the external scorer was trained
on a sequence of acoustic model text units, use ``TokenType.Aggregate``.
This class is a Python :py:class:`enum.IntEnum`.
"""
#: Token type for external scorers trained on the same textual units as
#: the acoustic model.
Single = swigwrapper.FlashlightDecoderState.Single
#: Token type for external scorers trained on a sequence of acoustic model
#: textual units.
Aggregate = swigwrapper.FlashlightDecoderState.Aggregate
def flashlight_beam_search_decoder(
logits_seq,
alphabet,
beam_size,
decoder_type,
token_type,
lm_tokens,
scorer=None,
beam_threshold=25.0,
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode acoustic model emissions for a single sample. Note that unlike
:py:func:`ctc_beam_search_decoder`, this function expects raw outputs
from CTC and ASG acoustic models, without softmaxing them over
timesteps.
:param logits_seq: 2-D list of acoustic model emissions, dimensions are
time steps x number of output units.
:type logits_seq: 2-D list of floats or numpy array
:param alphabet: Alphabet object matching the tokens used when creating the
acoustic model and external scorer if specified.
:type alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param decoder_type: Decoding mode, lexicon-constrained or lexicon-free.
:type decoder_type: FlashlightDecoderState.DecoderType
:param token_type: Type of token in the external scorer.
:type token_type: FlashlightDecoderState.TokenType
:param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained
mode. Must match the token type used in the scorer, ie.
must be a list of characters if scorer is character-based,
or a list of words if scorer is word-based.
:param lm_tokens: list[str]
:param scorer: External scorer.
:type scorer: Scorer
:param beam_threshold: Maximum threshold in beam score from leading beam. Any
newly created candidate beams which lag behind the best
beam so far by more than this value will get pruned.
This is a performance optimization parameter and an
appropriate value should be found empirically using a
validation set.
:type beam_threshold: float
:param cutoff_top_n: Maximum number of tokens to expand per time step during
decoding. Only the highest probability cutoff_top_n
candidates (characters, sub-word units, words) in a given
timestep will be expanded. This is a performance
optimization parameter and an appropriate value should
be found empirically using a validation set.
:type cutoff_top_n: int
:param silence_score: Score to add to beam when encountering a predicted
silence token (eg. the space symbol).
:type silence_score: float
:param merge_with_log_add: Whether to use log-add when merging scores of
new candidate beams equivalent to existing ones
(leading to the same transcription). When disabled,
the maximum score is used.
:type merge_with_log_add: bool
:param criterion_type: Criterion used for training the acoustic model.
:type criterion_type: FlashlightDecoderState.CriterionType
:param transitions: Transition score matrix for ASG acoustic models.
:type transitions: list[float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of FlashlightOutput structures.
:rtype: list[FlashlightOutput]
"""
return swigwrapper.flashlight_beam_search_decoder(
logits_seq,
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
scorer,
token_type,
lm_tokens,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions,
num_results,
)
def flashlight_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
decoder_type,
token_type,
lm_tokens,
num_processes,
scorer=None,
beam_threshold=25.0,
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode batch acoustic model emissions in parallel. ``num_processes``
controls how many samples from the batch will be decoded simultaneously.
All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`.
Returns a list of lists of FlashlightOutput structures.
"""
return swigwrapper.flashlight_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
scorer,
token_type,
lm_tokens,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions,
num_results,
num_processes,
)
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Alphabet class representing 255 possible byte values for Bytes Output Mode.
For internal use only.
"""
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b"")
if err != 0:
raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test.
"""
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input)
return res.decode("utf-8")

View File

@ -17,7 +17,7 @@ else:
ARGS = [
"-fPIC",
"-DKENLM_MAX_ORDER=6",
"-std=c++11",
"-std=c++14",
"-Wno-unused-local-typedefs",
"-Wno-sign-compare",
]
@ -32,6 +32,7 @@ INCLUDES = [
OPENFST_DIR + "/src/include",
"third_party/ThreadPool",
"third_party/object_pool",
"third_party/flashlight",
]
KENLM_FILES = (
@ -40,7 +41,7 @@ KENLM_FILES = (
+ glob.glob("../kenlm/util/double-conversion/*.cc")
)
KENLM_FILES += glob.glob(OPENFST_DIR + "/src/lib/*.cc")
OPENFST_FILES = glob.glob(OPENFST_DIR + "/src/lib/*.cc")
KENLM_FILES = [
fn
@ -50,6 +51,22 @@ KENLM_FILES = [
)
]
FLASHLIGHT_FILES = [
"third_party/flashlight/flashlight/lib/common/String.cpp",
"third_party/flashlight/flashlight/lib/common/System.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp",
"third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp",
"third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp",
"third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp",
]
THIRD_PARTY_FILES = KENLM_FILES + OPENFST_FILES + FLASHLIGHT_FILES
CTC_DECODER_FILES = [
"ctc_beam_search_decoder.cpp",
"scorer.cpp",

View File

@ -12,6 +12,12 @@
#include "fst/fstlib.h"
#include "path_trie.h"
#include "flashlight/lib/text/dictionary/Dictionary.h"
#include "flashlight/lib/text/decoder/Trie.h"
#include "flashlight/lib/text/decoder/LexiconDecoder.h"
#include "flashlight/lib/text/decoder/LexiconFreeDecoder.h"
namespace flt = fl::lib::text;
int
DecoderState::init(const Alphabet& alphabet,
@ -264,6 +270,180 @@ DecoderState::decode(size_t num_results) const
return outputs;
}
int
FlashlightDecoderState::init(
const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
flt::Dictionary lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions)
{
// Lexicon-free decoder must use single-token based LM
if (decoder_type == LexiconFree) {
assert(token_type == Single);
}
// Build lexicon index to LM index map
if (!lm_tokens.contains("<unk>")) {
lm_tokens.addEntry("<unk>");
}
ext_scorer->load_words(lm_tokens);
lm_tokens_ = lm_tokens;
// Convert our criterion type to Flashlight type
flt::CriterionType flt_criterion;
switch (criterion_type) {
case ASG: flt_criterion = flt::CriterionType::ASG; break;
case CTC: flt_criterion = flt::CriterionType::CTC; break;
case S2S: flt_criterion = flt::CriterionType::S2S; break;
default: assert(false);
}
// Build Trie
std::shared_ptr<flt::Trie> trie = nullptr;
auto startState = ext_scorer->start(false);
if (token_type == Aggregate || decoder_type == LexiconBased) {
trie = std::make_shared<flt::Trie>(lm_tokens.indexSize(), alphabet.GetSpaceLabel());
for (int i = 0; i < lm_tokens.entrySize(); ++i) {
const std::string entry = lm_tokens.getEntry(i);
if (entry[0] == '<') { // don't insert <s>, </s> and <unk>
continue;
}
float score = -1;
if (token_type == Aggregate) {
flt::LMStatePtr dummyState;
std::tie(dummyState, score) = ext_scorer->score(startState, i);
}
std::vector<unsigned int> encoded = alphabet.Encode(entry);
std::vector<int> encoded_s(encoded.begin(), encoded.end());
trie->insert(encoded_s, i, score);
}
// Smear trie
trie->smear(flt::SmearingMode::MAX);
}
// Query unknown token score
int unknown_word_index = lm_tokens.getIndex("<unk>");
float unknown_score = -std::numeric_limits<float>::infinity();
if (token_type == Aggregate) {
std::tie(std::ignore, unknown_score) =
ext_scorer->score(startState, unknown_word_index);
}
// Make sure conversions from uint to int below don't trip us
assert(beam_size < INT_MAX);
assert(cutoff_top_n < INT_MAX);
if (decoder_type == LexiconBased) {
flt::LexiconDecoderOptions opts;
opts.beamSize = static_cast<int>(beam_size);
opts.beamSizeToken = static_cast<int>(cutoff_top_n);
opts.beamThreshold = beam_threshold;
opts.lmWeight = ext_scorer->alpha;
opts.wordScore = ext_scorer->beta;
opts.unkScore = unknown_score;
opts.silScore = silence_score;
opts.logAdd = merge_with_log_add;
opts.criterionType = flt_criterion;
decoder_impl_.reset(new flt::LexiconDecoder(
opts,
trie,
ext_scorer,
alphabet.GetSpaceLabel(), // silence index
alphabet.GetSize(), // blank index
unknown_word_index,
transitions,
token_type == Single)
);
} else {
flt::LexiconFreeDecoderOptions opts;
opts.beamSize = static_cast<int>(beam_size);
opts.beamSizeToken = static_cast<int>(cutoff_top_n);
opts.beamThreshold = beam_threshold;
opts.lmWeight = ext_scorer->alpha;
opts.silScore = silence_score;
opts.logAdd = merge_with_log_add;
opts.criterionType = flt_criterion;
decoder_impl_.reset(new flt::LexiconFreeDecoder(
opts,
ext_scorer,
alphabet.GetSpaceLabel(), // silence index
alphabet.GetSize(), // blank index
transitions)
);
}
// Init decoder for stream
decoder_impl_->decodeBegin();
return 0;
}
void
FlashlightDecoderState::next(
const double *probs,
int time_dim,
int class_dim)
{
std::vector<float> probs_f(probs, probs + (time_dim * class_dim) + 1);
decoder_impl_->decodeStep(probs_f.data(), time_dim, class_dim);
}
FlashlightOutput
FlashlightDecoderState::intermediate(bool prune)
{
flt::DecodeResult result = decoder_impl_->getBestHypothesis();
std::vector<int> valid_words;
for (int w : result.words) {
if (w != -1) {
valid_words.push_back(w);
}
}
FlashlightOutput ret;
ret.aggregate_score = result.score;
ret.acoustic_model_score = result.amScore;
ret.language_model_score = result.lmScore;
ret.words = lm_tokens_.mapIndicesToEntries(valid_words); // how does this interact with token-based decoding
ret.tokens = result.tokens;
if (prune) {
decoder_impl_->prune();
}
return ret;
}
std::vector<FlashlightOutput>
FlashlightDecoderState::decode(size_t num_results)
{
decoder_impl_->decodeEnd();
std::vector<flt::DecodeResult> flt_results = decoder_impl_->getAllFinalHypothesis();
std::vector<FlashlightOutput> ret;
for (auto result : flt_results) {
std::vector<int> valid_words;
for (int w : result.words) {
if (w != -1) {
valid_words.push_back(w);
}
}
FlashlightOutput out;
out.aggregate_score = result.score;
out.acoustic_model_score = result.amScore;
out.language_model_score = result.lmScore;
out.words = lm_tokens_.mapIndicesToEntries(valid_words); // how does this interact with token-based decoding
out.tokens = result.tokens;
ret.push_back(out);
}
decoder_impl_.reset(nullptr);
return ret;
}
std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
@ -328,3 +508,104 @@ ctc_beam_search_decoder_batch(
}
return batch_results;
}
std::vector<FlashlightOutput>
flashlight_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
const std::vector<std::string>& lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions,
size_t num_results)
{
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
flt::Dictionary tokens_dict;
for (auto str : lm_tokens) {
tokens_dict.addEntry(str);
}
FlashlightDecoderState state;
state.init(
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
ext_scorer,
token_type,
tokens_dict,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions);
state.next(probs, time_dim, class_dim);
return state.decode(num_results);
}
std::vector<std::vector<FlashlightOutput>>
flashlight_beam_search_decoder_batch(
const double *probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
const std::vector<std::string>& lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions,
size_t num_processes,
size_t num_results)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
ThreadPool pool(num_processes);
// enqueue the tasks of decoding
std::vector<std::future<std::vector<FlashlightOutput>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(flashlight_beam_search_decoder,
&probs[i*time_dim*class_dim],
seq_lengths[i],
class_dim,
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
ext_scorer,
token_type,
lm_tokens,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions,
num_results));
}
// get decoding results
std::vector<std::vector<FlashlightOutput>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}

View File

@ -9,7 +9,10 @@
#include "output.h"
#include "alphabet.h"
class DecoderState {
#include "flashlight/lib/text/decoder/Decoder.h"
class DecoderState
{
int abs_time_step_;
int space_id_;
int blank_id_;
@ -76,6 +79,89 @@ public:
std::vector<Output> decode(size_t num_results=1) const;
};
class FlashlightDecoderState
{
public:
FlashlightDecoderState() = default;
~FlashlightDecoderState() = default;
// Disallow copying
FlashlightDecoderState(const FlashlightDecoderState&) = delete;
FlashlightDecoderState& operator=(FlashlightDecoderState&) = delete;
enum LMTokenType {
Single // LM units == AM units (character/byte LM)
,Aggregate // LM units != AM units (word LM)
};
enum DecoderType {
LexiconBased
,LexiconFree
};
enum CriterionType {
ASG = 0
,CTC = 1
,S2S = 2
};
/* Initialize beam search decoder
*
* Parameters:
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* Zero on success, non-zero on failure.
*/
int init(const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
fl::lib::text::Dictionary lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions);
/* Send data to the decoder
*
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Number of classes (alphabet length + 1 for space character).
*/
void next(const double *probs,
int time_dim,
int class_dim);
/* Return current best hypothesis, optinoally pruning hypothesis space */
FlashlightOutput intermediate(bool prune = true);
/* Get up to num_results transcriptions from current decoder state.
*
* Parameters:
* num_results: Number of hypotheses to return.
*
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<FlashlightOutput> decode(size_t num_results = 1);
private:
fl::lib::text::Dictionary lm_tokens_;
std::unique_ptr<fl::lib::text::Decoder> decoder_impl_;
};
/* CTC Beam Search Decoder
* Parameters:
@ -146,4 +232,86 @@ ctc_beam_search_decoder_batch(
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);
/* Flashlight Beam Search Decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<FlashlightOutput>
flashlight_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
const std::vector<std::string>& lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions,
size_t num_results);
/* Flashlight Beam Search Decoder for batch data
* Parameters:
* probs: 3-D vector where each element is a 2-D vector that can be used
* by flashlight_beam_search_decoder().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<FlashlightOutput>>
flashlight_beam_search_decoder_batch(
const double* probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet& alphabet,
size_t beam_size,
double beam_threshold,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
FlashlightDecoderState::LMTokenType token_type,
const std::vector<std::string>& lm_tokens,
FlashlightDecoderState::DecoderType decoder_type,
double silence_score,
bool merge_with_log_add,
FlashlightDecoderState::CriterionType criterion_type,
std::vector<float> transitions,
size_t num_results,
size_t num_processes);
#endif // CTC_BEAM_SEARCH_DECODER_H_

View File

@ -12,4 +12,12 @@ struct Output {
std::vector<unsigned int> timesteps;
};
struct FlashlightOutput {
double aggregate_score;
double acoustic_model_score;
double language_model_score;
std::vector<std::string> words;
std::vector<int> tokens;
};
#endif // OUTPUT_H_

View File

@ -1,6 +1,7 @@
#ifdef _MSC_VER
#include <stdlib.h>
#include <io.h>
#define NOMINMAX
#include <windows.h>
#define R_OK 4 /* Read permission. */
@ -17,16 +18,27 @@
#include <iostream>
#include <fstream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "kenlm/lm/config.hh"
#include "kenlm/lm/model.hh"
#include "kenlm/lm/state.hh"
#include "kenlm/lm/word_index.hh"
#include "kenlm/util/string_piece.hh"
#include "decoder_utils.h"
using namespace fl::lib::text;
static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 6;
Scorer::Scorer()
{
}
Scorer::~Scorer()
{
}
int
Scorer::init(const std::string& lm_path,
const Alphabet& alphabet)
@ -347,3 +359,54 @@ void Scorer::fill_dictionary(const std::unordered_set<std::string>& vocabulary)
std::unique_ptr<FstType> converted(new FstType(*new_dict));
this->dictionary = std::move(converted);
}
LMStatePtr
Scorer::start(bool startWithNothing)
{
auto outState = std::make_shared<KenLMState>();
if (startWithNothing) {
language_model_->NullContextWrite(outState->ken());
} else {
language_model_->BeginSentenceWrite(outState->ken());
}
return outState;
}
std::pair<LMStatePtr, float>
Scorer::score(const LMStatePtr& state,
const int usrTokenIdx)
{
if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) {
throw std::runtime_error(
"[Scorer] Invalid user token index: " + std::to_string(usrTokenIdx));
}
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(usrTokenIdx);
float score = language_model_->BaseScore(
inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken());
return std::make_pair(std::move(outState), score);
}
std::pair<LMStatePtr, float>
Scorer::finish(const LMStatePtr& state)
{
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(-1);
float score = language_model_->BaseScore(
inState->ken(),
language_model_->BaseVocabulary().EndSentence(),
outState->ken()
);
return std::make_pair(std::move(outState), score);
}
void
Scorer::load_words(const Dictionary& word_dict)
{
const auto& vocab = language_model_->BaseVocabulary();
usrToLmIdxMap_.resize(word_dict.indexSize());
for (int i = 0; i < word_dict.indexSize(); ++i) {
usrToLmIdxMap_[i] = vocab.Index(word_dict.getEntry(i));
}
}

View File

@ -7,9 +7,7 @@
#include <unordered_set>
#include <vector>
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "util/string_piece.hh"
#include "flashlight/lib/text/decoder/lm/KenLM.h"
#include "path_trie.h"
#include "alphabet.h"
@ -27,12 +25,12 @@ const std::string END_TOKEN = "</s>";
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
class Scorer : public fl::lib::text::LM {
public:
using FstType = PathTrie::FstType;
Scorer() = default;
~Scorer() = default;
Scorer();
~Scorer();
// disallow copying
Scorer(const Scorer&) = delete;
@ -94,6 +92,29 @@ public:
// pointer to the dictionary of FST
std::unique_ptr<FstType> dictionary;
// ---------------
// fl::lib::text::LM methods
/* Initialize or reset language model state */
fl::lib::text::LMStatePtr start(bool startWithNothing);
/**
* Query the language model given input state and a specific token, return a
* new language model state and score.
*/
std::pair<fl::lib::text::LMStatePtr, float> score(
const fl::lib::text::LMStatePtr& state,
const int usrTokenIdx);
/* Query the language model and finish decoding. */
std::pair<fl::lib::text::LMStatePtr, float> finish(const fl::lib::text::LMStatePtr& state);
// ---------------
// fl::lib::text helper
// Must be called before use of this Scorer with Flashlight APIs.
void load_words(const fl::lib::text::Dictionary& word_dict);
protected:
// necessary setup after setting alphabet
void setup_char_map();

View File

@ -70,7 +70,7 @@ third_party_build = "third_party.{}".format(archive_ext)
ctc_decoder_build = "first_party.{}".format(archive_ext)
maybe_rebuild(KENLM_FILES, third_party_build, build_dir)
maybe_rebuild(THIRD_PARTY_FILES, third_party_build, build_dir)
maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir)
decoder_module = Extension(
@ -96,7 +96,9 @@ class BuildExtFirst(build):
setup(
name="coqui_stt_ctcdecoder",
version=project_version,
description="""DS CTC decoder""",
description="Coqui STT Python decoder package.",
long_description="Documentation available at `stt.readthedocs.io <https://stt.readthedocs.io/en/latest/Decoder-API.html>`_",
long_description_content_type="text/x-rst; charset=UTF-8",
cmdclass={"build": BuildExtFirst},
ext_modules=[decoder_module],
package_dir={"coqui_stt_ctcdecoder": "."},

View File

@ -20,9 +20,12 @@ import_array();
namespace std {
%template(StringVector) vector<string>;
%template(FloatVector) vector<float>;
%template(UnsignedIntVector) vector<unsigned int>;
%template(OutputVector) vector<Output>;
%template(OutputVectorVector) vector<vector<Output>>;
%template(FlashlightOutputVector) vector<FlashlightOutput>;
%template(FlashlightOutputVectorVector) vector<vector<FlashlightOutput>>;
%template(Map) unordered_map<string, float>;
}
@ -36,6 +39,7 @@ namespace std {
%ignore Scorer::dictionary;
%include "third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h"
%include "../alphabet.h"
%include "output.h"
%include "scorer.h"
@ -45,13 +49,5 @@ namespace std {
%constant const char* __git_version__ = ds_git_version();
// Import only the error code enum definitions from coqui-stt.h
// We can't just do |%ignore "";| here because it affects this file globally (even
// files %include'd above). That causes SWIG to lose destructor information and
// leads to leaks of the wrapper objects.
// Instead we ignore functions and classes (structs), which are the only other
// things in coqui-stt.h. If we add some new construct to coqui-stt.h we need
// to update the ignore rules here to avoid exposing unwanted APIs in the decoder
// package.
%rename("$ignore", %$isfunction) "";
%rename("$ignore", %$isclass) "";
#define SWIG_ERRORS_ONLY
%include "../coqui-stt.h"

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,115 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "flashlight/lib/common/String.h"
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
static constexpr const char* kSpaceChars = "\t\n\v\f\r ";
namespace fl {
namespace lib {
std::string trim(const std::string& str) {
auto i = str.find_first_not_of(kSpaceChars);
if (i == std::string::npos) {
return "";
}
auto j = str.find_last_not_of(kSpaceChars);
if (j == std::string::npos || i > j) {
return "";
}
return str.substr(i, j - i + 1);
}
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl) {
if (from.empty()) {
return;
}
size_t pos = 0;
while ((pos = str.find(from, pos)) != std::string::npos) {
str.replace(pos, from.length(), repl);
pos += repl.length();
}
}
bool startsWith(const std::string& input, const std::string& pattern) {
return (input.find(pattern) == 0);
}
bool endsWith(const std::string& input, const std::string& pattern) {
if (pattern.size() > input.size()) {
return false;
}
return std::equal(pattern.rbegin(), pattern.rend(), input.rbegin());
}
template <bool Any, typename Delim>
static std::vector<std::string> splitImpl(
const Delim& delim,
std::string::size_type delimSize,
const std::string& input,
bool ignoreEmpty = false) {
std::vector<std::string> result;
std::string::size_type i = 0;
while (true) {
auto j = Any ? input.find_first_of(delim, i) : input.find(delim, i);
if (j == std::string::npos) {
break;
}
if (!(ignoreEmpty && i == j)) {
result.emplace_back(input.begin() + i, input.begin() + j);
}
i = j + delimSize;
}
if (!(ignoreEmpty && i == input.size())) {
result.emplace_back(input.begin() + i, input.end());
}
return result;
}
std::vector<std::string>
split(char delim, const std::string& input, bool ignoreEmpty) {
return splitImpl<false>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string>
split(const std::string& delim, const std::string& input, bool ignoreEmpty) {
if (delim.empty()) {
throw std::invalid_argument("delimiter is empty string");
}
return splitImpl<false>(delim, delim.size(), input, ignoreEmpty);
}
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty) {
return splitImpl<true>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty) {
return splitOnAnyOf(kSpaceChars, input, ignoreEmpty);
}
std::string join(
const std::string& delim,
const std::vector<std::string>& vec) {
return join(delim, vec.begin(), vec.end());
}
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,130 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <algorithm>
#include <cerrno>
#include <chrono>
#include <cstring>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace fl {
namespace lib {
// ============================ Types and Templates ============================
template <typename It>
using DecayDereference =
typename std::decay<decltype(*std::declval<It>())>::type;
template <typename S, typename T>
using EnableIfSame = typename std::enable_if<std::is_same<S, T>::value>::type;
// ================================== Functions
// ==================================
std::string trim(const std::string& str);
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl);
bool startsWith(const std::string& input, const std::string& pattern);
bool endsWith(const std::string& input, const std::string& pattern);
std::vector<std::string>
split(char delim, const std::string& input, bool ignoreEmpty = false);
std::vector<std::string> split(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty = false);
/**
* Join a vector of `std::string` inserting `delim` in between.
*/
std::string join(const std::string& delim, const std::vector<std::string>& vec);
/**
* Join a range of `std::string` specified by iterators.
*/
template <
typename FwdIt,
typename = EnableIfSame<DecayDereference<FwdIt>, std::string>>
std::string join(const std::string& delim, FwdIt begin, FwdIt end) {
if (begin == end) {
return "";
}
size_t totalSize = begin->size();
for (auto it = std::next(begin); it != end; ++it) {
totalSize += delim.size() + it->size();
}
std::string result;
result.reserve(totalSize);
result.append(*begin);
for (auto it = std::next(begin); it != end; ++it) {
result.append(delim);
result.append(*it);
}
return result;
}
/**
* Create an output string using a `printf`-style format string and arguments.
* Safer than `sprintf` which is vulnerable to buffer overflow.
*/
template <class... Args>
std::string format(const char* fmt, Args&&... args) {
auto res = std::snprintf(nullptr, 0, fmt, std::forward<Args>(args)...);
if (res < 0) {
throw std::runtime_error(std::strerror(errno));
}
std::string buf(res, '\0');
// the size here is fine -- it's legal to write '\0' to buf[res]
auto res2 = std::snprintf(&buf[0], res + 1, fmt, std::forward<Args>(args)...);
if (res2 < 0) {
throw std::runtime_error(std::strerror(errno));
}
if (res2 != res) {
throw std::runtime_error(
"The size of the formated string is not equal to what it is expected.");
}
return buf;
}
/**
* Dedup the elements in a vector.
*/
template <class T>
void dedup(std::vector<T>& in) {
if (in.empty()) {
return;
}
auto it = std::unique(in.begin(), in.end());
in.resize(std::distance(in.begin(), it));
}
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,177 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "flashlight/lib/common/System.h"
#include <sys/stat.h>
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
#include "flashlight/lib/common/String.h"
namespace fl {
namespace lib {
size_t getProcessId() {
#ifdef _WIN32
return GetCurrentProcessId();
#else
return ::getpid();
#endif
}
size_t getThreadId() {
#ifdef _WIN32
return GetCurrentThreadId();
#else
return std::hash<std::thread::id>()(std::this_thread::get_id());
#endif
}
std::string pathSeperator() {
#ifdef _WIN32
return "\\";
#else
return "/";
#endif
}
std::string pathsConcat(const std::string& p1, const std::string& p2) {
if (!p1.empty() && p1[p1.length() - 1] != pathSeperator()[0]) {
return (
trim(p1) + pathSeperator() + trim(p2)); // Need to add a path separator
} else {
return (trim(p1) + trim(p2));
}
}
namespace {
/**
* @path contains directories separated by path separator.
* Returns a vector with the directores in the original order. Vector with a
* Special cases: a vector with a single entry containing the input is returned
* when path is one of the following special cases: empty, ., .. and /
*/
std::vector<std::string> getDirsOnPath(const std::string& path) {
const std::string trimPath = trim(path);
if (trimPath.empty() || trimPath == pathSeperator() || trimPath == "." ||
trimPath == "..") {
return {trimPath};
}
const std::vector<std::string> tokens = split(pathSeperator(), trimPath);
std::vector<std::string> dirs;
for (const std::string& token : tokens) {
const std::string dir = trim(token);
if (!dir.empty()) {
dirs.push_back(dir);
}
}
return dirs;
}
} // namespace
std::string dirname(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.size() < 2) {
return ".";
} else {
dirsOnPath.pop_back();
const std::string root =
((trim(path))[0] == pathSeperator()[0]) ? pathSeperator() : "";
return root + join(pathSeperator(), dirsOnPath);
}
}
std::string basename(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.empty()) {
return "";
} else {
return dirsOnPath.back();
}
}
bool dirExists(const std::string& path) {
struct stat info;
if (stat(path.c_str(), &info) != 0) {
return false;
} else if (info.st_mode & S_IFDIR) {
return true;
} else {
return false;
}
}
bool fileExists(const std::string& path) {
std::ifstream fs(path, std::ifstream::in);
return fs.good();
}
std::string getEnvVar(
const std::string& key,
const std::string& dflt /*= "" */) {
char* val = getenv(key.c_str());
return val ? std::string(val) : dflt;
}
std::string getTmpPath(const std::string& filename) {
std::string tmpDir = "/tmp";
auto getTmpDir = [&tmpDir](const std::string& env) {
char* dir = std::getenv(env.c_str());
if (dir != nullptr) {
tmpDir = std::string(dir);
}
};
getTmpDir("TMPDIR");
getTmpDir("TEMP");
getTmpDir("TMP");
return tmpDir + "/fl_tmp_" + getEnvVar("USER", "unknown") + "_" + filename;
}
std::vector<std::string> getFileContent(const std::string& file) {
std::vector<std::string> data;
std::ifstream in = createInputStream(file);
std::string str;
while (std::getline(in, str)) {
data.emplace_back(str);
}
in.close();
return data;
}
std::ifstream createInputStream(const std::string& filename) {
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for reading: " + filename);
}
return file;
}
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode) {
std::ofstream file(filename, mode);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for writing: " + filename);
}
return file;
}
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,86 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <chrono>
#include <fstream>
#include <string>
#include <thread>
#include <type_traits>
#include <vector>
namespace fl {
namespace lib {
size_t getProcessId();
size_t getThreadId();
std::string pathsConcat(const std::string& p1, const std::string& p2);
std::string pathSeperator();
std::string dirname(const std::string& path);
std::string basename(const std::string& path);
bool dirExists(const std::string& path);
bool fileExists(const std::string& path);
std::string getEnvVar(const std::string& key, const std::string& dflt = "");
std::string getTmpPath(const std::string& filename);
std::vector<std::string> getFileContent(const std::string& file);
std::ifstream createInputStream(const std::string& filename);
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode = std::ios_base::out);
/**
* Calls `f(args...)` repeatedly, retrying if an exception is thrown.
* Supports sleeps between retries, with duration starting at `initial` and
* multiplying by `factor` each retry. At most `maxIters` calls are made.
*/
template <class Fn, class... Args>
typename std::result_of<Fn(Args...)>::type retryWithBackoff(
std::chrono::duration<double> initial,
double factor,
int64_t maxIters,
Fn&& f,
Args&&... args) {
if (!(initial.count() >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad initial");
} else if (!(factor >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad factor");
} else if (maxIters <= 0) {
throw std::invalid_argument("retryWithBackoff: bad maxIters");
}
auto sleepSecs = initial.count();
for (int64_t i = 0; i < maxIters; ++i) {
try {
return f(std::forward<Args>(args)...);
} catch (...) {
if (i >= maxIters - 1) {
throw;
}
}
if (sleepSecs > 0.0) {
/* sleep override */
std::this_thread::sleep_for(
std::chrono::duration<double>(std::min(1e7, sleepSecs)));
}
sleepSecs *= factor;
}
throw std::logic_error("retryWithBackoff: hit unreachable");
}
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include "flashlight/lib/text/decoder/Utils.h"
namespace fl {
namespace lib {
namespace text {
enum class CriterionType { ASG = 0, CTC = 1, S2S = 2 };
/**
* Decoder support two typical use cases:
* Offline manner:
* decoder.decode(someData) [returns all hypothesis (transcription)]
*
* Online manner:
* decoder.decodeBegin() [called only at the beginning of the stream]
* while (stream)
* decoder.decodeStep(someData) [one or more calls]
* decoder.getBestHypothesis() [returns the best hypothesis (transcription)]
* decoder.prune() [prunes the hypothesis space]
* decoder.decodeEnd() [called only at the end of the stream]
*
* Note: function decoder.prune() deletes hypothesis up until time when called
* to supports online decoding. It will also add a offset to the scores in beam
* to avoid underflow/overflow.
*
*/
class Decoder {
public:
Decoder() = default;
virtual ~Decoder() = default;
/* Initialize decoder before starting consume emissions */
virtual void decodeBegin() {}
/* Consume emissions in T x N chunks and increase the hypothesis space */
virtual void decodeStep(const float* emissions, int T, int N) = 0;
/* Finish up decoding after consuming all emissions */
virtual void decodeEnd() {}
/* Offline decode function, which consume all emissions at once */
virtual std::vector<DecodeResult>
decode(const float* emissions, int T, int N) {
decodeBegin();
decodeStep(emissions, T, N);
decodeEnd();
return getAllFinalHypothesis();
}
/* Prune the hypothesis space */
virtual void prune(int lookBack = 0) = 0;
/* Get the number of decoded frame in buffer */
virtual int nDecodedFramesInBuffer() const = 0;
/*
* Get the best completed hypothesis which is `lookBack` frames ahead the last
* one in buffer. For lexicon requiredd LMs, completed hypothesis means no
* partial word appears at the end.
*/
virtual DecodeResult getBestHypothesis(int lookBack = 0) const = 0;
/* Get all the final hypothesis */
virtual std::vector<DecodeResult> getAllFinalHypothesis() const = 0;
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,328 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <unordered_map>
#include "flashlight/lib/text/decoder/LexiconDecoder.h"
namespace fl {
namespace lib {
namespace text {
void LexiconDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(
0.0, lm_->start(0), lexicon_->getRoot(), nullptr, sil_, -1);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconDecoderState>());
}
}
std::vector<size_t> idx(N);
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) {
const TrieNode* prevLex = prevHyp.lex;
const int prevIdx = prevHyp.token;
const float lexMaxScore =
prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;
/* (1) Try children */
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
auto iter = prevLex->children.find(n);
if (iter == prevLex->children.end()) {
continue;
}
const TrieNodePtr& lex = iter->second;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
LMStatePtr lmState;
double lmScore = 0.;
if (isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second;
}
// We eat-up a new token
if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank ||
n != prevIdx) {
if (!lex->children.empty()) {
if (!isLmToken_) {
lmState = prevHyp.lmState;
lmScore = lex->maxScore - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmState,
lex.get(),
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
// If we got a true word
for (auto label : lex->labels) {
if (prevLex == lexicon_->getRoot() && prevHyp.token == n) {
// This is to avoid an situation that, when there is word with
// single token spelling (e.g. X -> x) in the lexicon and token `x`
// is predicted in several consecutive frames, multiple word `X`
// will be emitted. This violates the property of CTC, where
// there must be an blank token in between to predict 2 identical
// tokens consecutively.
continue;
}
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, label);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.wordScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
label,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
// If we got an unknown word
if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) {
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.unkScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
unk_,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
/* (2) Try same lexicon node */
if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank ||
prevLex == lexicon_->getRoot()) {
int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
/* (3) CTC only, try blank */
if (opt_.criterionType == CriterionType::CTC) {
int n = blank_;
double amScore = emissions[t * N + n];
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
// finish proposing
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
bool hasNiceEnding = false;
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
if (prevHyp.lex == lexicon_->getRoot()) {
hasNiceEnding = true;
break;
}
}
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const TrieNode* prevLex = prevHyp.lex;
const LMStatePtr& prevLmState = prevHyp.lmState;
if (!hasNiceEnding || prevHyp.lex == lexicon_->getRoot()) {
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
prevLex,
&prevHyp,
sil_,
-1,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
if (finalFrame < 1) {
return std::vector<DecodeResult>{};
}
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconDecoder::getBestHypothesis(int lookBack) const {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return DecodeResult();
}
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,187 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <unordered_map>
#include "flashlight/lib/text/decoder/Decoder.h"
#include "flashlight/lib/text/decoder/Trie.h"
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
struct LexiconDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double wordScore; // Word insertion score
double unkScore; // Unknown word insertion score
double silScore; // Silence insertion score
bool logAdd; // If or not use logadd when merging hypothesis
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const TrieNode* lex; // Trie node in the lexicon
const LexiconDecoderState* parent; // Parent hypothesis
int token; // Label of token
int word; // Label of word (-1 if incomplete)
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconDecoderState(
const double score,
const LMStatePtr& lmState,
const TrieNode* lex,
const LexiconDecoderState* parent,
const int token,
const int word,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
lex(lex),
parent(parent),
token(token),
word(word),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconDecoderState()
: score(0.),
lmState(nullptr),
lex(nullptr),
parent(nullptr),
token(-1),
word(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (lex != node->lex) {
return lex > node->lex ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return word;
}
bool isComplete() const {
return !parent || parent->word >= 0;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + wordScore_ * |W_known| + unkScore_ *
* |W_unknown| + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. Note that the lexicon is used to limit the
* search space and all candidate words are generated from it if unkScore is
* -inf, otherwise <UNK> will be generated for OOVs.
*/
class LexiconDecoder : public Decoder {
public:
LexiconDecoder(
LexiconDecoderOptions opt,
const TriePtr& lexicon,
const LMPtr& lm,
const int sil,
const int blank,
const int unk,
const std::vector<float>& transitions,
const bool isLmToken)
: opt_(std::move(opt)),
lexicon_(lexicon),
lm_(lm),
sil_(sil),
blank_(blank),
unk_(unk),
transitions_(transitions),
isLmToken_(isLmToken) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconDecoderOptions opt_;
// Lexicon trie to restrict beam-search decoder
TriePtr lexicon_;
LMPtr lm_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Index of unknown word
int unk_;
// matrix of transitions (for ASG criterion)
std::vector<float> transitions_;
// if LM is token-level (operates on the same level as acoustic model)
// or it is word-level (in case of false)
bool isLmToken_;
// All the hypothesis new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,207 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "flashlight/lib/text/decoder/LexiconFreeDecoder.h"
namespace fl {
namespace lib {
namespace text {
void LexiconFreeDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconFreeDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconFreeDecoderState>());
}
}
std::vector<size_t> idx(N);
// Looping over all the frames
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) {
const int prevIdx = prevHyp.token;
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + emissions[t * N + n];
if (n == sil_) {
score += opt_.silScore;
}
if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) ||
(opt_.criterionType == CriterionType::CTC && n != blank_ &&
(n != prevIdx || prevHyp.prevBlank))) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
} else if (opt_.criterionType == CriterionType::CTC && n == blank_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
} else {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconFreeDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const LMStatePtr& prevLmState = prevHyp.lmState;
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
sil_,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconFreeDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconFreeDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconFreeDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconFreeDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,160 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <unordered_map>
#include "flashlight/lib/text/decoder/Decoder.h"
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
struct LexiconFreeDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double silScore; // Silence insertion score
bool logAdd;
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconFreeDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconFreeDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const LexiconFreeDecoderState* parent; // Parent hypothesis
int token; // Label of token
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconFreeDecoderState(
const double score,
const LMStatePtr& lmState,
const LexiconFreeDecoderState* parent,
const int token,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
parent(parent),
token(token),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconFreeDecoderState()
: score(0),
lmState(nullptr),
parent(nullptr),
token(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconFreeDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return -1;
}
bool isComplete() const {
return true;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. We are allowed to generate words from all the
* possible combination of tokens.
*/
class LexiconFreeDecoder : public Decoder {
public:
LexiconFreeDecoder(
LexiconFreeDecoderOptions opt,
const LMPtr& lm,
const int sil,
const int blank,
const std::vector<float>& transitions)
: opt_(std::move(opt)),
lm_(lm),
transitions_(transitions),
sil_(sil),
blank_(blank) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconFreeDecoderOptions opt_;
LMPtr lm_;
std::vector<float> transitions_;
// All the hypothesis new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconFreeDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconFreeDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconFreeDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,179 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h"
namespace fl {
namespace lib {
namespace text {
void LexiconFreeSeq2SeqDecoder::decodeStep(
const float* emissions,
int T,
int N) {
// Extend hyp_ buffer
if (hyp_.size() < maxOutputLength_ + 2) {
for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) {
hyp_.emplace(i, std::vector<LexiconFreeSeq2SeqDecoderState>());
}
}
// Start from here.
hyp_[0].clear();
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr);
// Decode frame by frame
int t = 0;
for (; t < maxOutputLength_; t++) {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
// Batch forwarding
rawY_.clear();
rawPrevStates_.clear();
for (const LexiconFreeSeq2SeqDecoderState& prevHyp : hyp_[t]) {
const AMStatePtr& prevState = prevHyp.amState;
if (prevHyp.token == eos_) {
continue;
}
rawY_.push_back(prevHyp.token);
rawPrevStates_.push_back(prevState);
}
if (rawY_.size() == 0) {
break;
}
std::vector<std::vector<float>> amScores;
std::vector<AMStatePtr> outStates;
std::tie(amScores, outStates) =
amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t);
std::vector<size_t> idx(amScores.back().size());
// Generate new hypothesis
for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) {
const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo];
// Change nothing for completed hypothesis
if (prevHyp.token == eos_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score,
prevHyp.lmState,
&prevHyp,
eos_,
nullptr,
prevHyp.amScore,
prevHyp.lmScore);
continue;
}
const AMStatePtr& outState = outStates[validHypo];
if (!outState) {
validHypo++;
continue;
}
std::iota(idx.begin(), idx.end(), 0);
if (amScores[validHypo].size() > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&amScores, &validHypo](const size_t& l, const size_t& r) {
return amScores[validHypo][l] > amScores[validHypo][r];
});
}
for (int r = 0;
r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken);
r++) {
int n = idx[r];
double amScore = amScores[validHypo][n];
if (n == eos_) { /* (1) Try eos */
auto lmStateScorePair = lm_->finish(prevHyp.lmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
nullptr,
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
} else { /* (2) Try normal token */
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
outState,
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
validHypo++;
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
updateLMCache(lm_, hyp_[t + 1]);
} // End of decoding
while (t > 0 && hyp_[t].empty()) {
--t;
}
hyp_[maxOutputLength_ + 1].resize(hyp_[t].size());
for (int i = 0; i < hyp_[t].size(); i++) {
hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]);
}
}
std::vector<DecodeResult> LexiconFreeSeq2SeqDecoder::getAllFinalHypothesis()
const {
return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size());
}
DecodeResult LexiconFreeSeq2SeqDecoder::getBestHypothesis(
int /* unused */) const {
return getHypothesis(
hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size());
}
void LexiconFreeSeq2SeqDecoder::prune(int /* unused */) {
return;
}
int LexiconFreeSeq2SeqDecoder::nDecodedFramesInBuffer() const {
/* unused function */
return -1;
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,141 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include <unordered_map>
#include "flashlight/lib/text/decoder/Decoder.h"
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
using AMStatePtr = std::shared_ptr<void>;
using AMUpdateFunc = std::function<
std::pair<std::vector<std::vector<float>>, std::vector<AMStatePtr>>(
const float*,
const int,
const int,
const std::vector<int>&,
const std::vector<AMStatePtr>&,
int&)>;
struct LexiconFreeSeq2SeqDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double eosScore; // Score for inserting an EOS
bool logAdd; // If or not use logadd when merging hypothesis
};
/**
* LexiconFreeSeq2SeqDecoderState stores information for each hypothesis in the
* beam.
*/
struct LexiconFreeSeq2SeqDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const LexiconFreeSeq2SeqDecoderState* parent; // Parent hypothesis
int token; // Label of token
AMStatePtr amState; // Acoustic model state
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconFreeSeq2SeqDecoderState(
const double score,
const LMStatePtr& lmState,
const LexiconFreeSeq2SeqDecoderState* parent,
const int token,
const AMStatePtr& amState = nullptr,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
parent(parent),
token(token),
amState(amState),
amScore(amScore),
lmScore(lmScore) {}
LexiconFreeSeq2SeqDecoderState()
: score(0),
lmState(nullptr),
parent(nullptr),
token(-1),
amState(nullptr),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconFreeSeq2SeqDecoderState* node) const {
return lmState->compare(node->lmState);
}
int getWord() const {
return -1;
}
};
/**
* Decoder implements a beam seach decoder that finds the token transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS|
*
* where P_{lm}(W) is the language model score. The sequence of tokens is not
* constrained by a lexicon, and thus the language model must operate at
* token-level.
*
* TODO: Doesn't support online decoding now.
*
*/
class LexiconFreeSeq2SeqDecoder : public Decoder {
public:
LexiconFreeSeq2SeqDecoder(
LexiconFreeSeq2SeqDecoderOptions opt,
const LMPtr& lm,
const int eos,
AMUpdateFunc amUpdateFunc,
const int maxOutputLength)
: opt_(std::move(opt)),
lm_(lm),
eos_(eos),
amUpdateFunc_(amUpdateFunc),
maxOutputLength_(maxOutputLength) {}
void decodeStep(const float* emissions, int T, int N) override;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconFreeSeq2SeqDecoderOptions opt_;
LMPtr lm_;
int eos_;
AMUpdateFunc amUpdateFunc_;
std::vector<int> rawY_;
std::vector<AMStatePtr> rawPrevStates_;
int maxOutputLength_;
std::vector<LexiconFreeSeq2SeqDecoderState> candidates_;
std::vector<LexiconFreeSeq2SeqDecoderState*> candidatePtrs_;
double candidatesBestScore_;
std::unordered_map<int, std::vector<LexiconFreeSeq2SeqDecoderState>> hyp_;
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,243 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h"
namespace fl {
namespace lib {
namespace text {
void LexiconSeq2SeqDecoder::decodeStep(const float* emissions, int T, int N) {
// Extend hyp_ buffer
if (hyp_.size() < maxOutputLength_ + 2) {
for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) {
hyp_.emplace(i, std::vector<LexiconSeq2SeqDecoderState>());
}
}
// Start from here.
hyp_[0].clear();
hyp_[0].emplace_back(
0.0, lm_->start(0), lexicon_->getRoot(), nullptr, -1, -1, nullptr);
auto compare = [](const LexiconSeq2SeqDecoderState& n1,
const LexiconSeq2SeqDecoderState& n2) {
return n1.score > n2.score;
};
// Decode frame by frame
int t = 0;
for (; t < maxOutputLength_; t++) {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
// Batch forwarding
rawY_.clear();
rawPrevStates_.clear();
for (const LexiconSeq2SeqDecoderState& prevHyp : hyp_[t]) {
const AMStatePtr& prevState = prevHyp.amState;
if (prevHyp.token == eos_) {
continue;
}
rawY_.push_back(prevHyp.token);
rawPrevStates_.push_back(prevState);
}
if (rawY_.size() == 0) {
break;
}
std::vector<std::vector<float>> amScores;
std::vector<AMStatePtr> outStates;
std::tie(amScores, outStates) =
amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t);
std::vector<size_t> idx(amScores.back().size());
// Generate new hypothesis
for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) {
const LexiconSeq2SeqDecoderState& prevHyp = hyp_[t][hypo];
// Change nothing for completed hypothesis
if (prevHyp.token == eos_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score,
prevHyp.lmState,
prevHyp.lex,
&prevHyp,
eos_,
-1,
nullptr,
prevHyp.amScore,
prevHyp.lmScore);
continue;
}
const AMStatePtr& outState = outStates[validHypo];
if (!outState) {
validHypo++;
continue;
}
const TrieNode* prevLex = prevHyp.lex;
const float lexMaxScore =
prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;
std::iota(idx.begin(), idx.end(), 0);
if (amScores[validHypo].size() > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&amScores, &validHypo](const size_t& l, const size_t& r) {
return amScores[validHypo][l] > amScores[validHypo][r];
});
}
for (int r = 0;
r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken);
r++) {
int n = idx[r];
double amScore = amScores[validHypo][n];
/* (1) Try eos */
if (n == eos_ && (prevLex == lexicon_->getRoot())) {
auto lmStateScorePair = lm_->finish(prevHyp.lmState);
LMStatePtr lmState = lmStateScorePair.first;
double lmScore;
if (isLmToken_) {
lmScore = lmStateScorePair.second;
} else {
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
-1,
nullptr,
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
/* (2) Try normal token */
if (n != eos_) {
auto searchLex = prevLex->children.find(n);
if (searchLex != prevLex->children.end()) {
auto lex = searchLex->second;
LMStatePtr lmState;
double lmScore;
if (isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second;
} else {
// smearing
lmState = prevHyp.lmState;
lmScore = lex->maxScore - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore + opt_.lmWeight * lmScore,
lmState,
lex.get(),
&prevHyp,
n,
-1,
outState,
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
// If we got a true word
if (lex->labels.size() > 0) {
for (auto word : lex->labels) {
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, word);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore + opt_.wordScore +
opt_.lmWeight * lmScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
word,
outState,
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
if (isLmToken_) {
break;
}
}
}
}
}
}
validHypo++;
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
updateLMCache(lm_, hyp_[t + 1]);
} // End of decoding
while (t > 0 && hyp_[t].empty()) {
--t;
}
hyp_[maxOutputLength_ + 1].resize(hyp_[t].size());
for (int i = 0; i < hyp_[t].size(); i++) {
hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]);
}
}
std::vector<DecodeResult> LexiconSeq2SeqDecoder::getAllFinalHypothesis() const {
return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size());
}
DecodeResult LexiconSeq2SeqDecoder::getBestHypothesis(int /* unused */) const {
return getHypothesis(
hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size());
}
void LexiconSeq2SeqDecoder::prune(int /* unused */) {
return;
}
int LexiconSeq2SeqDecoder::nDecodedFramesInBuffer() const {
/* unused function */
return -1;
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,165 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include <unordered_map>
#include "flashlight/lib/text/decoder/Decoder.h"
#include "flashlight/lib/text/decoder/Trie.h"
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
using AMStatePtr = std::shared_ptr<void>;
using AMUpdateFunc = std::function<
std::pair<std::vector<std::vector<float>>, std::vector<AMStatePtr>>(
const float*,
const int,
const int,
const std::vector<int>&,
const std::vector<AMStatePtr>&,
int&)>;
struct LexiconSeq2SeqDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double wordScore; // Word insertion score
double eosScore; // Score for inserting an EOS
bool logAdd; // If or not use logadd when merging hypothesis
};
/**
* LexiconSeq2SeqDecoderState stores information for each hypothesis in the
* beam.
*/
struct LexiconSeq2SeqDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const TrieNode* lex;
const LexiconSeq2SeqDecoderState* parent; // Parent hypothesis
int token; // Label of token
int word;
AMStatePtr amState; // Acoustic model state
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconSeq2SeqDecoderState(
const double score,
const LMStatePtr& lmState,
const TrieNode* lex,
const LexiconSeq2SeqDecoderState* parent,
const int token,
const int word,
const AMStatePtr& amState,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
lex(lex),
parent(parent),
token(token),
word(word),
amState(amState),
amScore(amScore),
lmScore(lmScore) {}
LexiconSeq2SeqDecoderState()
: score(0),
lmState(nullptr),
lex(nullptr),
parent(nullptr),
token(-1),
word(-1),
amState(nullptr),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconSeq2SeqDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (lex != node->lex) {
return lex > node->lex ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
}
return 0;
}
int getWord() const {
return word;
}
};
/**
* Decoder implements a beam seach decoder that finds the token transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS|
*
* where P_{lm}(W) is the language model score. The transcription W is
* constrained by a lexicon. The language model may operate at word-level
* (isLmToken=false) or token-level (isLmToken=true).
*
* TODO: Doesn't support online decoding now.
*
*/
class LexiconSeq2SeqDecoder : public Decoder {
public:
LexiconSeq2SeqDecoder(
LexiconSeq2SeqDecoderOptions opt,
const TriePtr& lexicon,
const LMPtr& lm,
const int eos,
AMUpdateFunc amUpdateFunc,
const int maxOutputLength,
const bool isLmToken)
: opt_(std::move(opt)),
lm_(lm),
lexicon_(lexicon),
eos_(eos),
amUpdateFunc_(amUpdateFunc),
maxOutputLength_(maxOutputLength),
isLmToken_(isLmToken) {}
void decodeStep(const float* emissions, int T, int N) override;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconSeq2SeqDecoderOptions opt_;
LMPtr lm_;
TriePtr lexicon_;
int eos_;
AMUpdateFunc amUpdateFunc_;
std::vector<int> rawY_;
std::vector<AMStatePtr> rawPrevStates_;
int maxOutputLength_;
bool isLmToken_;
std::vector<LexiconSeq2SeqDecoderState> candidates_;
std::vector<LexiconSeq2SeqDecoderState*> candidatePtrs_;
double candidatesBestScore_;
std::unordered_map<int, std::vector<LexiconSeq2SeqDecoderState>> hyp_;
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,104 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <math.h>
#include <stdlib.h>
#include <iostream>
#include <limits>
#include <string>
#include "flashlight/lib/text/decoder/Trie.h"
namespace fl {
namespace lib {
namespace text {
const double kMinusLogThreshold = -39.14;
const TrieNode* Trie::getRoot() const {
return root_.get();
}
TrieNodePtr
Trie::insert(const std::vector<int>& indices, int label, float score) {
TrieNodePtr node = root_;
for (int i = 0; i < indices.size(); i++) {
int idx = indices[i];
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
node->children[idx] = std::make_shared<TrieNode>(idx);
}
node = node->children[idx];
}
if (node->labels.size() < kTrieMaxLabel) {
node->labels.push_back(label);
node->scores.push_back(score);
} else {
std::cerr << "[Trie] Trie label number reached limit: " << kTrieMaxLabel
<< "\n";
}
return node;
}
TrieNodePtr Trie::search(const std::vector<int>& indices) {
TrieNodePtr node = root_;
for (auto idx : indices) {
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
return nullptr;
}
node = node->children[idx];
}
return node;
}
/* logadd */
double TrieLogAdd(double log_a, double log_b) {
double minusdif;
if (log_a < log_b) {
std::swap(log_a, log_b);
}
minusdif = log_b - log_a;
if (minusdif < kMinusLogThreshold) {
return log_a;
} else {
return log_a + log1p(exp(minusdif));
}
}
void smearNode(TrieNodePtr node, SmearingMode smearMode) {
node->maxScore = -std::numeric_limits<float>::infinity();
for (auto score : node->scores) {
node->maxScore = TrieLogAdd(node->maxScore, score);
}
for (auto child : node->children) {
auto childNode = child.second;
smearNode(childNode, smearMode);
if (smearMode == SmearingMode::LOGADD) {
node->maxScore = TrieLogAdd(node->maxScore, childNode->maxScore);
} else if (
smearMode == SmearingMode::MAX &&
childNode->maxScore > node->maxScore) {
node->maxScore = childNode->maxScore;
}
}
}
void Trie::smear(SmearingMode smearMode) {
if (smearMode != SmearingMode::NONE) {
smearNode(root_, smearMode);
}
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,95 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include <unordered_map>
#include <vector>
namespace fl {
namespace lib {
namespace text {
constexpr int kTrieMaxLabel = 6;
enum class SmearingMode {
NONE = 0,
MAX = 1,
LOGADD = 2,
};
/**
* TrieNode is the trie node structure in Trie.
*/
struct TrieNode {
explicit TrieNode(int idx)
: children(std::unordered_map<int, std::shared_ptr<TrieNode>>()),
idx(idx),
maxScore(0) {
labels.reserve(kTrieMaxLabel);
scores.reserve(kTrieMaxLabel);
}
// Pointers to the children of a node
std::unordered_map<int, std::shared_ptr<TrieNode>> children;
// Node index
int idx;
// Labels of words that are constructed from the given path. Note that
// `labels` is nonempty only if the current node represents a completed token.
std::vector<int> labels;
// Scores (`scores` should have the same size as `labels`)
std::vector<float> scores;
// Maximum score of all the labels if this node is a leaf,
// otherwise it will be the value after trie smearing.
float maxScore;
};
using TrieNodePtr = std::shared_ptr<TrieNode>;
/**
* Trie is used to store the lexicon in langiage model. We use it to limit
* the search space in deocder and quickly look up scores for a given token
* (completed word) or make prediction for incompleted ones based on smearing.
*/
class Trie {
public:
Trie(int maxChildren, int rootIdx)
: root_(std::make_shared<TrieNode>(rootIdx)), maxChildren_(maxChildren) {}
/* Return the root node pointer */
const TrieNode* getRoot() const;
/* Insert a token into trie with label */
TrieNodePtr insert(const std::vector<int>& indices, int label, float score);
/* Get the labels for a given token */
TrieNodePtr search(const std::vector<int>& indices);
/**
* Smearing the trie using the valid labels inserted in the trie so as to get
* score on each node (incompleted token).
* For example, if smear_mode is MAX, then for node "a" in path "c"->"a", we
* will select the maximum score from all its children like "c"->"a"->"t",
* "c"->"a"->"n", "c"->"a"->"r"->"e" and so on.
* This process will be carry out recusively on all the nodes.
*/
void smear(const SmearingMode smear_mode);
private:
TrieNodePtr root_;
int maxChildren_; // The maximum number of childern for each node. It is
// usually the size of letters or phonmes.
};
using TriePtr = std::shared_ptr<Trie>;
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,15 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
namespace fl {
namespace lib {
namespace text {
// Place holder
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,275 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <algorithm>
#include <cmath>
#include <unordered_map>
#include <vector>
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
/* ===================== Definitions ===================== */
const double kNegativeInfinity = -std::numeric_limits<double>::infinity();
const int kLookBackLimit = 100;
struct DecodeResult {
double score;
double amScore;
double lmScore;
std::vector<int> words;
std::vector<int> tokens;
explicit DecodeResult(int length = 0)
: score(0), words(length, -1), tokens(length, -1) {}
};
/* ===================== Candidate-related operations ===================== */
template <class DecoderState>
void candidatesReset(
double& candidatesBestScore,
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs) {
candidatesBestScore = kNegativeInfinity;
candidates.clear();
candidatePtrs.clear();
}
template <class DecoderState, class... Args>
void candidatesAdd(
std::vector<DecoderState>& candidates,
double& candidatesBestScore,
const double beamThreshold,
const double score,
const Args&... args) {
if (score >= candidatesBestScore) {
candidatesBestScore = score;
}
if (score >= candidatesBestScore - beamThreshold) {
candidates.emplace_back(score, args...);
}
}
template <class DecoderState>
void candidatesStore(
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs,
std::vector<DecoderState>& outputs,
const int beamSize,
const double threshold,
const bool logAdd,
const bool returnSorted) {
outputs.clear();
if (candidates.empty()) {
return;
}
/* 1. Select valid candidates */
for (auto& candidate : candidates) {
if (candidate.score >= threshold) {
candidatePtrs.emplace_back(&candidate);
}
}
/* 2. Merge candidates */
std::sort(
candidatePtrs.begin(),
candidatePtrs.end(),
[](const DecoderState* node1, const DecoderState* node2) {
int cmp = node1->compareNoScoreStates(node2);
return cmp == 0 ? node1->score > node2->score : cmp > 0;
});
int nHypAfterMerging = 1;
for (int i = 1; i < candidatePtrs.size(); i++) {
if (candidatePtrs[i]->compareNoScoreStates(
candidatePtrs[nHypAfterMerging - 1]) != 0) {
// Distinct candidate
candidatePtrs[nHypAfterMerging] = candidatePtrs[i];
nHypAfterMerging++;
} else {
// Same candidate
double maxScore = std::max(
candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score);
if (logAdd) {
double minScore = std::min(
candidatePtrs[nHypAfterMerging - 1]->score,
candidatePtrs[i]->score);
candidatePtrs[nHypAfterMerging - 1]->score =
maxScore + std::log1p(std::exp(minScore - maxScore));
} else {
candidatePtrs[nHypAfterMerging - 1]->score = maxScore;
}
}
}
candidatePtrs.resize(nHypAfterMerging);
/* 3. Sort and prune */
auto compareNodeScore = [](const DecoderState* node1,
const DecoderState* node2) {
return node1->score > node2->score;
};
int nValidHyp = candidatePtrs.size();
int finalSize = std::min(nValidHyp, beamSize);
if (!returnSorted && nValidHyp > beamSize) {
std::nth_element(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
} else if (returnSorted) {
std::partial_sort(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
}
for (int i = 0; i < finalSize; i++) {
outputs.emplace_back(std::move(*candidatePtrs[i]));
}
}
/* ===================== Result-related operations ===================== */
template <class DecoderState>
DecodeResult getHypothesis(const DecoderState* node, const int finalFrame) {
const DecoderState* node_ = node;
if (!node_) {
return DecodeResult();
}
DecodeResult res(finalFrame + 1);
res.score = node_->score;
res.amScore = node_->amScore;
res.lmScore = node_->lmScore;
int i = 0;
while (node_) {
res.words[finalFrame - i] = node_->getWord();
res.tokens[finalFrame - i] = node_->token;
node_ = node_->parent;
i++;
}
return res;
}
template <class DecoderState>
std::vector<DecodeResult> getAllHypothesis(
const std::vector<DecoderState>& finalHyps,
const int finalFrame) {
int nHyp = finalHyps.size();
std::vector<DecodeResult> res(nHyp);
for (int r = 0; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
res[r] = getHypothesis(node, finalFrame);
}
return res;
}
template <class DecoderState>
const DecoderState* findBestAncestor(
const std::vector<DecoderState>& finalHyps,
int& lookBack) {
int nHyp = finalHyps.size();
if (nHyp == 0) {
return nullptr;
}
double bestScore = finalHyps.front().score;
const DecoderState* bestNode = finalHyps.data();
for (int r = 1; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
if (node->score > bestScore) {
bestScore = node->score;
bestNode = node;
}
}
int n = 0;
while (bestNode && n < lookBack) {
n++;
bestNode = bestNode->parent;
}
const int maxLookBack = lookBack + kLookBackLimit;
while (bestNode) {
// Check for first emitted word.
if (bestNode->isComplete()) {
break;
}
n++;
bestNode = bestNode->parent;
if (n == maxLookBack) {
break;
}
}
lookBack = n;
return bestNode;
}
template <class DecoderState>
void pruneAndNormalize(
std::unordered_map<int, std::vector<DecoderState>>& hypothesis,
const int startFrame,
const int lookBack) {
/* 1. Move things from back of hypothesis to front. */
for (int i = 0; i < hypothesis.size(); i++) {
if (i <= lookBack) {
hypothesis[i].swap(hypothesis[i + startFrame]);
} else {
hypothesis[i].clear();
}
}
/* 2. Avoid further back-tracking */
for (DecoderState& hyp : hypothesis[0]) {
hyp.parent = nullptr;
}
/* 3. Avoid score underflow/overflow. */
double largestScore = hypothesis[lookBack].front().score;
for (int i = 1; i < hypothesis[lookBack].size(); i++) {
if (largestScore < hypothesis[lookBack][i].score) {
largestScore = hypothesis[lookBack][i].score;
}
}
for (int i = 0; i < hypothesis[lookBack].size(); i++) {
hypothesis[lookBack][i].score -= largestScore;
}
}
/* ===================== LM-related operations ===================== */
template <class DecoderState>
void updateLMCache(const LMPtr& lm, std::vector<DecoderState>& hypothesis) {
// For ConvLM update cache
std::vector<LMStatePtr> states;
for (const auto& hyp : hypothesis) {
states.emplace_back(hyp.lmState);
}
lm->updateCache(states);
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,239 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include <cmath>
#include <cstring>
#include <iostream>
#include "flashlight/lib/text/decoder/lm/ConvLM.h"
namespace fl {
namespace lib {
namespace text {
ConvLM::ConvLM(
const GetConvLmScoreFunc& getConvLmScoreFunc,
const std::string& tokenVocabPath,
const Dictionary& usrTknDict,
int lmMemory,
int beamSize,
int historySize)
: lmMemory_(lmMemory),
beamSize_(beamSize),
getConvLmScoreFunc_(getConvLmScoreFunc),
maxHistorySize_(historySize) {
if (historySize < 1) {
throw std::invalid_argument("[ConvLM] History size is too small.");
}
/* Load token vocabulary */
// Note: fairseq vocab should start with:
// <fairseq_style> - 0 <pad> - 1, </s> - 2, <unk> - 3
std::cerr << "[ConvLM]: Loading vocabulary from " << tokenVocabPath << "\n";
vocab_ = Dictionary(tokenVocabPath);
vocab_.setDefaultIndex(vocab_.getIndex(kUnkToken));
vocabSize_ = vocab_.indexSize();
std::cerr << "[ConvLM]: vocabulary size of convLM " << vocabSize_ << "\n";
/* Create index map */
usrToLmIdxMap_.resize(usrTknDict.indexSize());
for (int i = 0; i < usrTknDict.indexSize(); i++) {
auto token = usrTknDict.getEntry(i);
int lmIdx = vocab_.getIndex(token.c_str());
usrToLmIdxMap_[i] = lmIdx;
}
/* Refresh cache */
cacheIndices_.reserve(beamSize_);
cache_.resize(beamSize_, std::vector<float>(vocabSize_));
slot_.reserve(beamSize_);
batchedTokens_.resize(beamSize_ * maxHistorySize_);
}
LMStatePtr ConvLM::start(bool startWithNothing) {
cacheIndices_.clear();
auto outState = std::make_shared<ConvLMState>(1);
if (!startWithNothing) {
outState->length = 1;
outState->tokens[0] = vocab_.getIndex(kEosToken);
} else {
throw std::invalid_argument(
"[ConvLM] Only support using EOS to start the sentence");
}
return outState;
}
std::pair<LMStatePtr, float> ConvLM::scoreWithLmIdx(
const LMStatePtr& state,
const int tokenIdx) {
auto rawInState = std::static_pointer_cast<ConvLMState>(state).get();
int inStateLength = rawInState->length;
std::shared_ptr<ConvLMState> outState;
// Prepare output state
if (inStateLength == maxHistorySize_) {
outState = std::make_shared<ConvLMState>(maxHistorySize_);
std::copy(
rawInState->tokens.begin() + 1,
rawInState->tokens.end(),
outState->tokens.begin());
outState->tokens[maxHistorySize_ - 1] = tokenIdx;
} else {
outState = std::make_shared<ConvLMState>(inStateLength + 1);
std::copy(
rawInState->tokens.begin(),
rawInState->tokens.end(),
outState->tokens.begin());
outState->tokens[inStateLength] = tokenIdx;
}
// Prepare score
float score = 0;
if (tokenIdx < 0 || tokenIdx >= vocabSize_) {
throw std::out_of_range(
"[ConvLM] Invalid query word: " + std::to_string(tokenIdx));
}
if (cacheIndices_.find(rawInState) != cacheIndices_.end()) {
// Cache hit
auto cacheInd = cacheIndices_[rawInState];
if (cacheInd < 0 || cacheInd >= beamSize_) {
throw std::logic_error(
"[ConvLM] Invalid cache access: " + std::to_string(cacheInd));
}
score = cache_[cacheInd][tokenIdx];
} else {
// Cache miss
if (cacheIndices_.size() == beamSize_) {
cacheIndices_.clear();
}
int newIdx = cacheIndices_.size();
cacheIndices_[rawInState] = newIdx;
std::vector<int> lastTokenPositions = {rawInState->length - 1};
cache_[newIdx] =
getConvLmScoreFunc_(rawInState->tokens, lastTokenPositions, -1, 1);
score = cache_[newIdx][tokenIdx];
}
if (std::isnan(score) || !std::isfinite(score)) {
throw std::runtime_error(
"[ConvLM] Bad scoring from ConvLM: " + std::to_string(score));
}
return std::make_pair(std::move(outState), score);
}
std::pair<LMStatePtr, float> ConvLM::score(
const LMStatePtr& state,
const int usrTokenIdx) {
if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) {
throw std::out_of_range(
"[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx));
}
return scoreWithLmIdx(state, usrToLmIdxMap_[usrTokenIdx]);
}
std::pair<LMStatePtr, float> ConvLM::finish(const LMStatePtr& state) {
return scoreWithLmIdx(state, vocab_.getIndex(kEosToken));
}
void ConvLM::updateCache(std::vector<LMStatePtr> states) {
int longestHistory = -1, nStates = states.size();
if (nStates > beamSize_) {
throw std::invalid_argument(
"[ConvLM] Cache size too small (consider larger than beam size).");
}
// Refresh cache, store LM states that did not changed
slot_.clear();
slot_.resize(beamSize_, nullptr);
for (const auto& state : states) {
auto rawState = std::static_pointer_cast<ConvLMState>(state).get();
if (cacheIndices_.find(rawState) != cacheIndices_.end()) {
slot_[cacheIndices_[rawState]] = rawState;
} else if (rawState->length > longestHistory) {
// prepare intest history only for those which should be predicted
longestHistory = rawState->length;
}
}
cacheIndices_.clear();
int cacheSize = 0;
for (int i = 0; i < beamSize_; i++) {
if (!slot_[i]) {
continue;
}
cache_[cacheSize] = cache_[i];
cacheIndices_[slot_[i]] = cacheSize;
++cacheSize;
}
// Determine batchsize
if (longestHistory <= 0) {
return;
}
// batchSize * longestHistory = cacheSize;
int maxBatchSize = lmMemory_ / longestHistory;
if (maxBatchSize > nStates) {
maxBatchSize = nStates;
}
// Run batch forward
int batchStart = 0;
while (batchStart < nStates) {
// Select batch
int nBatchStates = 0;
std::vector<int> lastTokenPositions;
for (int i = batchStart; (nBatchStates < maxBatchSize) && (i < nStates);
i++, batchStart++) {
auto rawState = std::static_pointer_cast<ConvLMState>(states[i]).get();
if (cacheIndices_.find(rawState) != cacheIndices_.end()) {
continue;
}
cacheIndices_[rawState] = cacheSize + nBatchStates;
int start = nBatchStates * longestHistory;
for (int j = 0; j < rawState->length; j++) {
batchedTokens_[start + j] = rawState->tokens[j];
}
start += rawState->length;
for (int j = 0; j < longestHistory - rawState->length; j++) {
batchedTokens_[start + j] = vocab_.getIndex(kPadToken);
}
lastTokenPositions.push_back(rawState->length - 1);
++nBatchStates;
}
if (nBatchStates == 0 && batchStart >= nStates) {
// if all states were skipped
break;
}
// Feed forward
if (nBatchStates < 1 || longestHistory < 1) {
throw std::logic_error(
"[ConvLM] Invalid batch: [" + std::to_string(nBatchStates) + " x " +
std::to_string(longestHistory) + "]");
}
auto batchedProb = getConvLmScoreFunc_(
batchedTokens_, lastTokenPositions, longestHistory, nBatchStates);
if (batchedProb.size() != vocabSize_ * nBatchStates) {
throw std::logic_error(
"[ConvLM] Batch X Vocab size " + std::to_string(batchedProb.size()) +
" mismatch with " + std::to_string(vocabSize_ * nBatchStates));
}
// Place probabilities in cache
for (int i = 0; i < nBatchStates; i++, cacheSize++) {
std::memcpy(
cache_[cacheSize].data(),
batchedProb.data() + vocabSize_ * i,
vocabSize_ * sizeof(float));
}
}
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,73 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include <functional>
#include "flashlight/lib/text/decoder/lm/LM.h"
#include "flashlight/lib/text/dictionary/Defines.h"
#include "flashlight/lib/text/dictionary/Dictionary.h"
namespace fl {
namespace lib {
namespace text {
using GetConvLmScoreFunc = std::function<std::vector<
float>(const std::vector<int>&, const std::vector<int>&, int, int)>;
struct ConvLMState : LMState {
std::vector<int> tokens;
int length;
ConvLMState() : length(0) {}
explicit ConvLMState(int size)
: tokens(std::vector<int>(size)), length(size) {}
};
class ConvLM : public LM {
public:
ConvLM(
const GetConvLmScoreFunc& getConvLmScoreFunc,
const std::string& tokenVocabPath,
const Dictionary& usrTknDict,
int lmMemory = 10000,
int beamSize = 2500,
int historySize = 49);
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
void updateCache(std::vector<LMStatePtr> states) override;
private:
// This cache is also not thread-safe!
int lmMemory_;
int beamSize_;
std::unordered_map<ConvLMState*, int> cacheIndices_;
std::vector<std::vector<float>> cache_;
std::vector<ConvLMState*> slot_;
std::vector<int> batchedTokens_;
Dictionary vocab_;
GetConvLmScoreFunc getConvLmScoreFunc_;
int vocabSize_;
int maxHistorySize_;
std::pair<LMStatePtr, float> scoreWithLmIdx(
const LMStatePtr& state,
const int tokenIdx);
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "flashlight/lib/text/decoder/lm/KenLM.h"
#include <stdexcept>
#include <kenlm/lm/model.hh>
namespace fl {
namespace lib {
namespace text {
KenLMState::KenLMState() : ken_(std::make_unique<lm::ngram::State>()) {}
KenLM::KenLM(const std::string& path, const Dictionary& usrTknDict) {
// Load LM
model_.reset(lm::ngram::LoadVirtual(path.c_str()));
if (!model_) {
throw std::runtime_error("[KenLM] LM loading failed.");
}
vocab_ = &model_->BaseVocabulary();
if (!vocab_) {
throw std::runtime_error("[KenLM] LM vocabulary loading failed.");
}
// Create index map
usrToLmIdxMap_.resize(usrTknDict.indexSize());
for (int i = 0; i < usrTknDict.indexSize(); i++) {
auto token = usrTknDict.getEntry(i);
int lmIdx = vocab_->Index(token.c_str());
usrToLmIdxMap_[i] = lmIdx;
}
}
LMStatePtr KenLM::start(bool startWithNothing) {
auto outState = std::make_shared<KenLMState>();
if (startWithNothing) {
model_->NullContextWrite(outState->ken());
} else {
model_->BeginSentenceWrite(outState->ken());
}
return outState;
}
std::pair<LMStatePtr, float> KenLM::score(
const LMStatePtr& state,
const int usrTokenIdx) {
if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) {
throw std::runtime_error(
"[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx));
}
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(usrTokenIdx);
float score = model_->BaseScore(
inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken());
return std::make_pair(std::move(outState), score);
}
std::pair<LMStatePtr, float> KenLM::finish(const LMStatePtr& state) {
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(-1);
float score =
model_->BaseScore(inState->ken(), vocab_->EndSentence(), outState->ken());
return std::make_pair(std::move(outState), score);
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include "flashlight/lib/text/decoder/lm/LM.h"
#include "flashlight/lib/text/dictionary/Dictionary.h"
// Forward declarations to avoid including KenLM headers
namespace lm {
namespace base {
struct Vocabulary;
struct Model;
} // namespace base
namespace ngram {
struct State;
} // namespace ngram
} // namespace lm
namespace fl {
namespace lib {
namespace text {
/**
* KenLMState is a state object from KenLM, which contains context length,
* indicies and compare functions
* https://github.com/kpu/kenlm/blob/master/lm/state.hh.
*/
struct KenLMState : LMState {
KenLMState();
std::unique_ptr<lm::ngram::State> ken_;
lm::ngram::State* ken() {
return ken_.get();
}
};
/**
* KenLM extends LM by using the toolkit https://kheafield.com/code/kenlm/.
*/
class KenLM : public LM {
public:
KenLM(const std::string& path, const Dictionary& usrTknDict);
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
private:
std::shared_ptr<lm::base::Model> model_;
const lm::base::Vocabulary* vocab_;
};
using KenLMPtr = std::shared_ptr<KenLM>;
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,90 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <utility>
#include <vector>
namespace fl {
namespace lib {
namespace text {
struct LMState {
std::unordered_map<int, std::shared_ptr<LMState>> children;
template <typename T>
std::shared_ptr<T> child(int usrIdx) {
auto s = children.find(usrIdx);
if (s == children.end()) {
auto state = std::make_shared<T>();
children[usrIdx] = state;
return state;
} else {
return std::static_pointer_cast<T>(s->second);
}
}
/* Compare two language model states. */
int compare(const std::shared_ptr<LMState>& state) const {
LMState* inState = state.get();
if (!state) {
throw std::runtime_error("a state is null");
}
if (this == inState) {
return 0;
} else if (this < inState) {
return -1;
} else {
return 1;
}
};
};
/**
* LMStatePtr is a shared LMState* tracking LM states generated during decoding.
*/
using LMStatePtr = std::shared_ptr<LMState>;
/**
* LM is a thin wrapper for laguage models. We abstrct several common methods
* here which can be shared for KenLM, ConvLM, RNNLM, etc.
*/
class LM {
public:
/* Initialize or reset language model */
virtual LMStatePtr start(bool startWithNothing) = 0;
/**
* Query the language model given input language model state and a specific
* token, return a new language model state and score.
*/
virtual std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) = 0;
/* Query the language model and finish decoding. */
virtual std::pair<LMStatePtr, float> finish(const LMStatePtr& state) = 0;
/* Update LM caches (optional) given a bunch of new states generated */
virtual void updateCache(std::vector<LMStatePtr> stateIdices) {}
virtual ~LM() = default;
protected:
/* Map indices from acoustic model to LM for each valid token. */
std::vector<int> usrToLmIdxMap_;
};
using LMPtr = std::shared_ptr<LM>;
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,31 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "flashlight/lib/text/decoder/lm/ZeroLM.h"
#include <stdexcept>
namespace fl {
namespace lib {
namespace text {
LMStatePtr ZeroLM::start(bool /* unused */) {
return std::make_shared<LMState>();
}
std::pair<LMStatePtr, float> ZeroLM::score(
const LMStatePtr& state /* unused */,
const int usrTokenIdx) {
return std::make_pair(state->child<LMState>(usrTokenIdx), 0.0);
}
std::pair<LMStatePtr, float> ZeroLM::finish(const LMStatePtr& state) {
return std::make_pair(state, 0.0);
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include "flashlight/lib/text/decoder/lm/LM.h"
namespace fl {
namespace lib {
namespace text {
/**
* ZeroLM is a dummy language model class, which mimics the behavious of a
* uni-gram language model but always returns 0 as score.
*/
class ZeroLM : public LM {
public:
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
};
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,21 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
namespace fl {
namespace lib {
namespace text {
constexpr const char* kUnkToken = "<unk>";
constexpr const char* kEosToken = "</s>";
constexpr const char* kPadToken = "<pad>";
constexpr const char* kMaskToken = "<mask>";
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,152 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include <stdexcept>
#include "flashlight/lib/common/String.h"
#include "flashlight/lib/common/System.h"
#include "flashlight/lib/text/dictionary/Dictionary.h"
#include "flashlight/lib/text/dictionary/Utils.h"
namespace fl {
namespace lib {
namespace text {
Dictionary::Dictionary(std::istream& stream) {
createFromStream(stream);
}
Dictionary::Dictionary(const std::string& filename) {
std::ifstream stream = createInputStream(filename);
createFromStream(stream);
}
void Dictionary::createFromStream(std::istream& stream) {
if (!stream) {
throw std::runtime_error("Unable to open dictionary input stream.");
}
std::string line;
while (std::getline(stream, line)) {
if (line.empty()) {
continue;
}
auto tkns = splitOnWhitespace(line, true);
auto idx = idx2entry_.size();
// All entries on the same line map to the same index
for (const auto& tkn : tkns) {
addEntry(tkn, idx);
}
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::addEntry(const std::string& entry, int idx) {
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry name in dictionary '" + entry + "'");
}
entry2idx_[entry] = idx;
if (idx2entry_.find(idx) == idx2entry_.end()) {
idx2entry_[idx] = entry;
}
}
void Dictionary::addEntry(const std::string& entry) {
// Check if the entry already exists in the dictionary
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry in dictionary '" + entry + "'");
}
int idx = idx2entry_.size();
// Find first available index.
while (idx2entry_.find(idx) != idx2entry_.end()) {
++idx;
}
addEntry(entry, idx);
}
std::string Dictionary::getEntry(int idx) const {
auto iter = idx2entry_.find(idx);
if (iter == idx2entry_.end()) {
throw std::invalid_argument(
"Unknown index in dictionary '" + std::to_string(idx) + "'");
}
return iter->second;
}
void Dictionary::setDefaultIndex(int idx) {
defaultIndex_ = idx;
}
int Dictionary::getIndex(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
if (defaultIndex_ < 0) {
throw std::invalid_argument(
"Unknown entry in dictionary: '" + entry + "'");
} else {
return defaultIndex_;
}
}
return iter->second;
}
bool Dictionary::contains(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
return false;
}
return true;
}
size_t Dictionary::entrySize() const {
return entry2idx_.size();
}
bool Dictionary::isContiguous() const {
for (size_t i = 0; i < indexSize(); ++i) {
if (idx2entry_.find(i) == idx2entry_.end()) {
return false;
}
}
for (const auto& tknidx : entry2idx_) {
if (idx2entry_.find(tknidx.second) == idx2entry_.end()) {
return false;
}
}
return true;
}
std::vector<int> Dictionary::mapEntriesToIndices(
const std::vector<std::string>& entries) const {
std::vector<int> indices;
indices.reserve(entries.size());
for (const auto& tkn : entries) {
indices.emplace_back(getIndex(tkn));
}
return indices;
}
std::vector<std::string> Dictionary::mapIndicesToEntries(
const std::vector<int>& indices) const {
std::vector<std::string> entries;
entries.reserve(indices.size());
for (const auto& idx : indices) {
entries.emplace_back(getEntry(idx));
}
return entries;
}
size_t Dictionary::indexSize() const {
return idx2entry_.size();
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <istream>
#include <string>
#include <unordered_map>
#include <vector>
namespace fl {
namespace lib {
namespace text {
// A simple dictionary class which holds a bidirectional map
// entry (strings) <--> integer indices. Not thread-safe !
class Dictionary {
public:
// Creates an empty dictionary
Dictionary() {}
explicit Dictionary(std::istream& stream);
explicit Dictionary(const std::string& filename);
size_t entrySize() const;
size_t indexSize() const;
void addEntry(const std::string& entry, int idx);
void addEntry(const std::string& entry);
std::string getEntry(int idx) const;
void setDefaultIndex(int idx);
int getIndex(const std::string& entry) const;
bool contains(const std::string& entry) const;
// checks if all the indices are contiguous
bool isContiguous() const;
std::vector<int> mapEntriesToIndices(
const std::vector<std::string>& entries) const;
std::vector<std::string> mapIndicesToEntries(
const std::vector<int>& indices) const;
private:
// Creates a dictionary from an input stream
void createFromStream(std::istream& stream);
std::unordered_map<std::string, int> entry2idx_;
std::unordered_map<int, std::string> idx2entry_;
int defaultIndex_ = -1;
};
typedef std::unordered_map<int, Dictionary> DictionaryMap;
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,147 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "flashlight/lib/text/dictionary/Utils.h"
#include "flashlight/lib/common/String.h"
#include "flashlight/lib/common/System.h"
#include "flashlight/lib/text/dictionary/Defines.h"
namespace fl {
namespace lib {
namespace text {
Dictionary createWordDict(const LexiconMap& lexicon) {
Dictionary dict;
for (const auto& it : lexicon) {
dict.addEntry(it.first);
}
dict.setDefaultIndex(dict.getIndex(kUnkToken));
return dict;
}
LexiconMap loadWords(const std::string& filename, int maxWords) {
LexiconMap lexicon;
std::string line;
std::ifstream infile = createInputStream(filename);
// Add at most `maxWords` words into the lexicon.
// If `maxWords` is negative then no limit is applied.
while (maxWords != lexicon.size() && std::getline(infile, line)) {
// Parse the line into two strings: word and spelling.
auto fields = splitOnWhitespace(line, true);
if (fields.size() < 2) {
throw std::runtime_error("[loadWords] Invalid line: " + line);
}
const std::string& word = fields[0];
std::vector<std::string> spelling(fields.size() - 1);
std::copy(fields.begin() + 1, fields.end(), spelling.begin());
// Add the word into the dictionary.
if (lexicon.find(word) == lexicon.end()) {
lexicon[word] = {};
}
// Add the current spelling of the words to the list of spellings.
lexicon[word].push_back(spelling);
}
// Insert unknown word.
lexicon[kUnkToken] = {};
return lexicon;
}
std::vector<std::string> splitWrd(const std::string& word) {
std::vector<std::string> tokens;
tokens.reserve(word.size());
int len = word.length();
for (int i = 0; i < len;) {
auto c = static_cast<unsigned char>(word[i]);
int curTknBytes = -1;
// UTF-8 checks, works for ASCII automatically
if ((c & 0x80) == 0) {
curTknBytes = 1;
} else if ((c & 0xE0) == 0xC0) {
curTknBytes = 2;
} else if ((c & 0xF0) == 0xE0) {
curTknBytes = 3;
} else if ((c & 0xF8) == 0xF0) {
curTknBytes = 4;
}
if (curTknBytes == -1 || i + curTknBytes > len) {
throw std::runtime_error("splitWrd: invalid UTF-8 : " + word);
}
tokens.emplace_back(word.begin() + i, word.begin() + i + curTknBytes);
i += curTknBytes;
}
return tokens;
}
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::vector<int> replabelValueToIdx(maxReps + 1);
for (int i = 1; i <= maxReps; ++i) {
replabelValueToIdx[i] = dict.getIndex("<" + std::to_string(i) + ">");
}
std::vector<int> result;
int prevToken = -1;
int numReps = 0;
for (int token : tokens) {
if (token == prevToken && numReps < maxReps) {
numReps++;
} else {
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
numReps = 0;
}
result.push_back(token);
prevToken = token;
}
}
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
}
return result;
}
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::unordered_map<int, int> replabelIdxToValue;
for (int i = 1; i <= maxReps; ++i) {
replabelIdxToValue.emplace(dict.getIndex("<" + std::to_string(i) + ">"), i);
}
std::vector<int> result;
int prevToken = -1;
for (int token : tokens) {
auto it = replabelIdxToValue.find(token);
if (it == replabelIdxToValue.end()) {
result.push_back(token);
prevToken = token;
} else if (prevToken != -1) {
result.insert(result.end(), it->second, prevToken);
prevToken = -1;
}
}
return result;
}
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -0,0 +1,52 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "flashlight/lib/text/dictionary/Dictionary.h"
namespace fl {
namespace lib {
namespace text {
using LexiconMap =
std::unordered_map<std::string, std::vector<std::vector<std::string>>>;
Dictionary createWordDict(const LexiconMap& lexicon);
LexiconMap loadWords(const std::string& filename, int maxWords = -1);
// split word into tokens abc -> {"a", "b", "c"}
// Works with ASCII, UTF-8 encodings
std::vector<std::string> splitWrd(const std::string& word);
/**
* Pack a token sequence by replacing consecutive repeats with replabels,
* e.g. "abbccc" -> "ab1c2". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
/**
* Unpack a token sequence by replacing replabels with repeated tokens,
* e.g. "ab1c2" -> "abbccc". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
} // namespace text
} // namespace lib
} // namespace fl

View File

@ -112,6 +112,28 @@ NODE_PLATFORM_TARGET := --target_arch=arm64 --target_platform=linux
TOOLCHAIN_LDD_OPTS := --root $(RASPBIAN)/
endif # ($(TARGET),rpi3-armv8)
# Custom: RPi 4, Ubuntu 21.10, Arm v8 (64-bit)
ifeq ($(TARGET),rpi4ub-armv8)
TOOLCHAIN_DIR ?= ${TFDIR}/bazel-$(shell basename "${TFDIR}")/external/LinaroAarch64Gcc72/bin
TOOLCHAIN ?= $(TOOLCHAIN_DIR)/aarch64-linux-gnu-
RASPBIAN ?= $(abspath $(NC_DIR)/../multistrap-ubuntu64-impish)
CFLAGS := -march=armv8-a -mtune=cortex-a72 -D_GLIBCXX_USE_CXX11_ABI=0 --sysroot $(RASPBIAN)
CXXFLAGS := $(CFLAGS)
LDFLAGS := -Wl,-rpath-link,$(RASPBIAN)/lib/aarch64-linux-gnu/ -Wl,-rpath-link,$(RASPBIAN)/usr/lib/aarch64-linux-gnu/
SOX_CFLAGS := -I$(RASPBIAN)/usr/include
SOX_LDFLAGS := $(RASPBIAN)/usr/lib/aarch64-linux-gnu/libsox.so
PYVER := $(shell python -c "import platform; maj, min, _ = platform.python_version_tuple(); print(maj+'.'+min);")
PYTHON_PACKAGES :=
PYTHON_PATH := PYTHONPATH=$(RASPBIAN)/usr/lib/python$(PYVER)/:$(RASPBIAN)/usr/lib/python3/dist-packages/
PYTHON_SYSCONFIGDATA := _PYTHON_SYSCONFIGDATA_NAME=_sysconfigdata__linux_aarch64-linux-gnu
NUMPY_INCLUDE := NUMPY_INCLUDE=$(RASPBIAN)/usr/include/python3.9/
PYTHON_PLATFORM_NAME := --plat-name linux_aarch64
NODE_PLATFORM_TARGET := --target_arch=arm64 --target_platform=linux
TOOLCHAIN_LDD_OPTS := --root $(RASPBIAN)/
endif # ($(TARGET),rpi4ub-armv8)
ifeq ($(TARGET),ios-simulator)
CFLAGS := -isysroot $(shell xcrun -sdk iphonesimulator13.5 -show-sdk-path)
SOX_CFLAGS :=

View File

@ -19,20 +19,41 @@ add_library( # Sets the name of the library.
# Provides a relative path to your source file(s).
../jni/stt_wrap.cpp )
add_library( stt-lib
SHARED
IMPORTED )
add_library(stt-lib SHARED IMPORTED)
set_target_properties(stt-lib PROPERTIES
IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libstt.so)
set_target_properties( stt-lib
PROPERTIES
IMPORTED_LOCATION
${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libstt.so )
add_library(kenlm-lib SHARED IMPORTED)
set_target_properties(kenlm-lib PROPERTIES
IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libkenlm.so)
add_custom_command( TARGET stt-jni POST_BUILD
add_library(tensorflowlite-lib SHARED IMPORTED)
set_target_properties(tensorflowlite-lib PROPERTIES
IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libtensorflowlite.so)
add_library(tflitedelegates-lib SHARED IMPORTED)
set_target_properties(tflitedelegates-lib PROPERTIES
IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libtflitedelegates.so)
add_custom_command(TARGET stt-jni POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libstt.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libstt.so )
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libstt.so)
add_custom_command(TARGET stt-jni POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libkenlm.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libkenlm.so)
add_custom_command(TARGET stt-jni POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libtensorflowlite.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libtensorflowlite.so)
add_custom_command(TARGET stt-jni POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_SOURCE_DIR}/libs/${ANDROID_ABI}/libtflitedelegates.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libtflitedelegates.so)
# Searches for a specified prebuilt library and stores the path as a
# variable. Because CMake includes system libraries in the search path by

View File

@ -1943,6 +1943,42 @@
"node_abi": 64,
"v8": "6.8"
},
"10.21.0": {
"node_abi": 64,
"v8": "6.8"
},
"10.22.0": {
"node_abi": 64,
"v8": "6.8"
},
"10.22.1": {
"node_abi": 64,
"v8": "6.8"
},
"10.23.0": {
"node_abi": 64,
"v8": "6.8"
},
"10.23.1": {
"node_abi": 64,
"v8": "6.8"
},
"10.23.2": {
"node_abi": 64,
"v8": "6.8"
},
"10.23.3": {
"node_abi": 64,
"v8": "6.8"
},
"10.24.0": {
"node_abi": 64,
"v8": "6.8"
},
"10.24.1": {
"node_abi": 64,
"v8": "6.8"
},
"11.0.0": {
"node_abi": 67,
"v8": "7.0"
@ -2115,6 +2151,86 @@
"node_abi": 72,
"v8": "7.8"
},
"12.17.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.18.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.18.1": {
"node_abi": 72,
"v8": "7.8"
},
"12.18.2": {
"node_abi": 72,
"v8": "7.8"
},
"12.18.3": {
"node_abi": 72,
"v8": "7.8"
},
"12.18.4": {
"node_abi": 72,
"v8": "7.8"
},
"12.19.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.19.1": {
"node_abi": 72,
"v8": "7.8"
},
"12.20.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.20.1": {
"node_abi": 72,
"v8": "7.8"
},
"12.20.2": {
"node_abi": 72,
"v8": "7.8"
},
"12.21.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.0": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.1": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.2": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.3": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.4": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.5": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.6": {
"node_abi": 72,
"v8": "7.8"
},
"12.22.7": {
"node_abi": 72,
"v8": "7.8"
},
"13.0.0": {
"node_abi": 79,
"v8": "7.8"
@ -2199,12 +2315,276 @@
"node_abi": 83,
"v8": "8.1"
},
"14.4.0": {
"node_abi": 83,
"v8": "8.1"
},
"14.5.0": {
"node_abi": 83,
"v8": "8.3"
},
"14.6.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.7.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.8.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.9.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.10.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.10.1": {
"node_abi": 83,
"v8": "8.4"
},
"14.11.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.12.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.13.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.13.1": {
"node_abi": 83,
"v8": "8.4"
},
"14.14.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.1": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.2": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.3": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.4": {
"node_abi": 83,
"v8": "8.4"
},
"14.15.5": {
"node_abi": 83,
"v8": "8.4"
},
"14.16.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.16.1": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.1": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.2": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.3": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.4": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.5": {
"node_abi": 83,
"v8": "8.4"
},
"14.17.6": {
"node_abi": 83,
"v8": "8.4"
},
"14.18.0": {
"node_abi": 83,
"v8": "8.4"
},
"14.18.1": {
"node_abi": 83,
"v8": "8.4"
},
"15.0.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.0.1": {
"node_abi": 88,
"v8": "8.6"
},
"15.1.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.2.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.2.1": {
"node_abi": 88,
"v8": "8.6"
},
"15.3.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.4.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.5.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.5.1": {
"node_abi": 88,
"v8": "8.6"
},
"15.6.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.7.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.8.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.9.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.10.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.11.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.12.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.13.0": {
"node_abi": 88,
"v8": "8.6"
},
"15.14.0": {
"node_abi": 88,
"v8": "8.6"
},
"16.0.0": {
"node_abi": 93,
"v8": "9.0"
},
"16.1.0": {
"node_abi": 93,
"v8": "9.0"
},
"16.2.0": {
"node_abi": 93,
"v8": "9.0"
},
"16.3.0": {
"node_abi": 93,
"v8": "9.0"
},
"16.4.0": {
"node_abi": 93,
"v8": "9.1"
},
"16.4.1": {
"node_abi": 93,
"v8": "9.1"
},
"16.4.2": {
"node_abi": 93,
"v8": "9.1"
},
"16.5.0": {
"node_abi": 93,
"v8": "9.1"
},
"16.6.0": {
"node_abi": 93,
"v8": "9.2"
},
"16.6.1": {
"node_abi": 93,
"v8": "9.2"
},
"16.6.2": {
"node_abi": 93,
"v8": "9.2"
},
"16.7.0": {
"node_abi": 93,
"v8": "9.2"
},
"16.8.0": {
"node_abi": 93,
"v8": "9.2"
},
"16.9.0": {
"node_abi": 93,
"v8": "9.3"
},
"16.9.1": {
"node_abi": 93,
"v8": "9.3"
},
"16.10.0": {
"node_abi": 93,
"v8": "9.3"
},
"16.11.0": {
"node_abi": 93,
"v8": "9.4"
},
"16.11.1": {
"node_abi": 93,
"v8": "9.4"
},
"17.0.0": {
"node_abi": 102,
"v8": "9.5"
},
"17.0.1": {
"node_abi": 102,
"v8": "9.5"
}
}

View File

@ -0,0 +1,14 @@
[General]
arch=arm64
noauth=false
unpack=true
debootstrap=Debian
aptsources=Debian
cleanup=true
[Debian]
packages=apt libc6 libc6-dev libstdc++-8-dev linux-libc-dev libffi-dev libpython3.9-dev libsox-dev python3-numpy python3-setuptools
source=http://ports.ubuntu.com/ubuntu-ports
keyring=ubuntu-keyring
components=main universe
suite=impish

View File

@ -14,7 +14,7 @@ import numpy as np
from stt import Model, version
try:
from shhlex import quote
from shlex import quote
except ImportError:
from pipes import quote

View File

@ -283,7 +283,10 @@ public class STTStream {
precondition(streamCtx != nil, "calling method on invalidated Stream")
let result = STT_FinishStreamWithMetadata(streamCtx, UInt32(numResults))!
defer { STT_FreeMetadata(result) }
defer {
STT_FreeMetadata(result)
streamCtx = nil
}
return STTMetadata(fromInternal: result)
}
}

View File

@ -26,8 +26,8 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele
private var audioData = Data()
override init() {
let modelPath = Bundle.main.path(forResource: "coqui-stt-0.9.3-models", ofType: "tflite")!
let scorerPath = Bundle.main.path(forResource: "coqui-stt-0.9.3-models", ofType: "scorer")!
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")!
let scorerPath = Bundle.main.path(forResource: "huge-vocab", ofType: "scorer")!
model = try! STTModel(modelPath: modelPath)
try! model.enableExternalScorer(scorerPath: scorerPath)

View File

@ -78,8 +78,8 @@
"def download_sample_data():\n",
" data_dir=\"english/\"\n",
" # Download data + alphabet\n",
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n",
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n",
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n",
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n",
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
" # Format data\n",
" with open(transcript_file, \"r\") as fin:\n",

View File

@ -66,9 +66,12 @@ def main():
],
package_dir={"": "training"},
packages=find_packages(where="training"),
python_requires=">=3.5, <4",
python_requires=">=3.5, <3.8",
install_requires=install_requires,
include_package_data=True,
extras_require={
"transcribe": ["webrtcvad"],
},
)

@ -1 +1 @@
Subproject commit 4bdd3955115cc08df61cf94e16a4ea8e0f4847c4
Subproject commit 27a1657c4f574eaafc22bb81d1c77e23794e2eec

View File

@ -1 +1 @@
0.10.0-alpha.29
1.1.0-alpha.1

View File

@ -387,7 +387,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
"input_samples": input_samples,
}
if not Config.export_tflite:
if not tflite:
inputs["input_lengths"] = seq_length
outputs = {

View File

@ -134,7 +134,7 @@ def evaluate(test_csvs, create_model):
batch_logits,
batch_lengths,
Config.alphabet,
Config.beam_width,
Config.export_beam_width,
num_processes=num_processes,
scorer=scorer,
cutoff_prob=Config.cutoff_prob,

View File

@ -0,0 +1,201 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import progressbar
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import (
Scorer,
flashlight_beam_search_decoder_batch,
FlashlightDecoderState,
)
from six.moves import zip
import tensorflow as tf
from .deepspeech_model import create_model
from .util.augmentations import NormalizeSampleRate
from .util.checkpoints import load_graph_for_evaluation
from .util.config import (
Config,
create_progressbar,
initialize_globals_from_cli,
log_error,
log_progress,
)
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset
from .util.helpers import check_ctcdecoder_version
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts(
(value.indices, value.values, value.dense_shape), alphabet
)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.Decode(res) for res in results]
def evaluate(test_csvs, create_model):
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
else:
scorer = None
test_sets = [
create_dataset(
[csv],
batch_size=Config.test_batch_size,
train_phase=False,
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
reverse=Config.reverse_test,
limit=Config.limit_test,
)
for csv in test_csvs
]
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]),
)
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with open(Config.vocab_file) as fin:
vocab = [l.strip().encode("utf-8") for l in fin]
with tfv1.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(
prefix="Test epoch | ",
widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()],
).start()
log_progress("Test epoch...")
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
(
batch_wav_filenames,
batch_logits,
batch_loss,
batch_lengths,
batch_transcripts,
) = session.run(
[batch_wav_filename, transposed, loss, batch_x_len, batch_y]
)
except tf.errors.OutOfRangeError:
break
decoded = flashlight_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
beam_size=Config.export_beam_width,
decoder_type=FlashlightDecoderState.DecoderType.LexiconBased,
token_type=FlashlightDecoderState.TokenType.Aggregate,
lm_tokens=vocab,
num_processes=num_processes,
scorer=scorer,
cutoff_top_n=Config.cutoff_top_n,
)
predictions.extend(" ".join(d[0].words) for d in decoded)
ground_truths.extend(
sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)
)
wav_filenames.extend(
wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames
)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(
wav_filenames, ground_truths, predictions, losses, dataset
)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print("Testing model on {}".format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def test():
tfv1.reset_default_graph()
samples = evaluate(Config.test_files, create_model)
if Config.test_output_file:
save_samples_json(samples, Config.test_output_file)
def main():
initialize_globals_from_cli()
check_ctcdecoder_version()
if not Config.test_files:
raise RuntimeError(
"You need to specify what files to use for evaluation via "
"the --test_files flag."
)
test()
if __name__ == "__main__":
main()

View File

@ -9,13 +9,15 @@ DESIRED_LOG_LEVEL = (
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import shutil
from .deepspeech_model import create_inference_graph
from .deepspeech_model import create_inference_graph, create_model
from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals_from_cli, log_error, log_info
from .util.feeding import wavfile_bytes_to_features
from .util.io import (
open_remote,
rmtree_remote,
@ -35,6 +37,9 @@ def export():
"""
log_info("Exporting the model...")
if Config.export_savedmodel:
return export_savedmodel()
tfv1.reset_default_graph()
inputs, outputs, _ = create_inference_graph(
@ -172,6 +177,83 @@ def export():
)
def export_savedmodel():
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session:
input_wavfile_contents = tf.placeholder(tf.string)
features, features_len = wavfile_bytes_to_features(input_wavfile_contents)
previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
previous_state_c, previous_state_h
)
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# One rate per layer
no_dropout = [None] * 6
logits, layers = create_model(
batch_x=features,
batch_size=1,
seq_length=features_len,
dropout=no_dropout,
previous_state=previous_state,
)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
probs = tf.nn.softmax(logits)
# Remove batch dimension
squeezed = tf.squeeze(probs)
builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir)
input_file_tinfo = tfv1.saved_model.utils.build_tensor_info(
input_wavfile_contents
)
output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed)
forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs={
"input_wavfile": input_file_tinfo,
},
outputs={
"probs": output_probs_tinfo,
},
method_name="forward",
)
builder.add_meta_graph_and_variables(
session,
[tfv1.saved_model.tag_constants.SERVING],
signature_def_map={
tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig
},
)
builder.save()
# Copy scorer and alphabet alongside SavedModel
if Config.scorer_path:
shutil.copy(
Config.scorer_path, os.path.join(Config.export_dir, "exported.scorer")
)
shutil.copy(
Config.effective_alphabet_path,
os.path.join(Config.export_dir, "alphabet.txt"),
)
log_info(f"Exported SavedModel to {Config.export_dir}")
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(

View File

@ -11,8 +11,6 @@ DESIRED_LOG_LEVEL = (
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import json
import shutil
import time
from datetime import datetime
from pathlib import Path
@ -59,11 +57,7 @@ from .util.config import (
)
from .util.feeding import create_dataset
from .util.helpers import check_ctcdecoder_version
from .util.io import (
is_remote_path,
open_remote,
remove_remote,
)
from .util.io import remove_remote
# Accuracy and Loss
@ -416,18 +410,6 @@ def train():
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
# Save flags next to checkpoints
if not is_remote_path(Config.save_checkpoint_dir):
os.makedirs(Config.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout:
json.dump(Config.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
with open_remote(preserved_alphabet_file, "wb") as fout:
fout.write(Config.alphabet.SerializeText())
with tfv1.Session(config=Config.session_config) as session:
log_debug("Session opened.")

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import (
flashlight_beam_search_decoder,
Scorer,
FlashlightDecoderState,
)
from .deepspeech_model import create_inference_graph, create_overlapping_windows
from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals_from_cli, log_error
from .util.feeding import audiofile_to_features
def do_single_file_inference(input_file_path):
tfv1.reset_default_graph()
with open(Config.vocab_file) as fin:
vocab = [w.encode("utf-8") for w in [l.strip() for l in fin]]
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, layers = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
probs = layers["raw_logits"].eval(
feed_dict={
inputs["input"]: features,
inputs["input_lengths"]: features_len,
inputs["previous_state_c"]: previous_state_c,
inputs["previous_state_h"]: previous_state_h,
},
session=session,
)
probs = np.squeeze(probs)
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
else:
scorer = None
decoded = flashlight_beam_search_decoder(
probs,
Config.alphabet,
beam_size=Config.export_beam_width,
decoder_type=FlashlightDecoderState.LexiconBased,
token_type=FlashlightDecoderState.Aggregate,
lm_tokens=vocab,
scorer=scorer,
cutoff_top_n=Config.cutoff_top_n,
)
# Print highest probability result
print(" ".join(d.decode("utf-8") for d in decoded[0].words))
def main():
initialize_globals_from_cli()
if Config.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(Config.one_shot_infer)
else:
raise RuntimeError(
"Calling training_graph_inference script directly but no --one_shot_infer input audio file specified"
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,333 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import glob
import itertools
import json
import multiprocessing
import os
import sys
from dataclasses import dataclass, field
from multiprocessing import Pool, Lock, cpu_count
from pathlib import Path
from typing import Optional, List, Tuple
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
DESIRED_LOG_LEVEL = (
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
# Hide GPUs to prevent issues with child processes trying to use the same GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import (
BaseSttConfig,
Config,
initialize_globals_from_instance,
)
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.helpers import check_ctcdecoder_version
from tqdm import tqdm
def transcribe_file(audio_path: Path, tlog_path: Path):
initialize_transcribe_config()
scorer = None
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(str(audio_path), as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=Config.batch_size,
aggressiveness=Config.vad_aggressiveness,
outlier_duration_ms=Config.outlier_duration_ms,
outlier_batch_size=Config.outlier_batch_size,
)
iterator = tfv1.data.make_one_shot_iterator(data_set)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
# Load checkpoint in a mutex way to avoid hangs in TensorFlow code
with lock:
load_graph_for_evaluation(session, silent=True)
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
Config.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def init_fn(l):
global lock
lock = l
def step_function(job):
"""Wrap transcribe_file to unpack arguments from a single tuple"""
idx, src, dst = job
transcribe_file(src, dst)
return idx, src, dst
def transcribe_many(src_paths, dst_paths):
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths)
lock = Lock()
with Pool(
processes=min(cpu_count(), len(src_paths)),
initializer=init_fn,
initargs=(lock,),
) as pool:
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
cwd = Path.cwd()
for result in process_iterable:
idx, src, dst = result
# Revert to relative if possible to make logs more concise
# if path is not relative to cwd, use the absolute path
# (Path.is_relative_to is only available in Python >=3.9)
try:
src = src.relative_to(cwd)
except ValueError:
pass
try:
dst = dst.relative_to(cwd)
except ValueError:
pass
tqdm.write(f'[{idx+1}]: "{src}" -> "{dst}"')
def get_tasks_from_catalog(catalog_file_path: Path) -> Tuple[List[Path], List[Path]]:
"""Given a `catalog_file_path` pointing to a .catalog file (from DSAlign),
extract transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results. For .catalog file inputs, these are taken from the
"audio" and "tlog" properties of the entries in the catalog, with any relative
paths being absolutized relative to the directory containing the .catalog file.
"""
assert catalog_file_path.suffix == ".catalog"
catalog_dir = catalog_file_path.parent
with open(catalog_file_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
def resolve(spec_path: Optional[Path]):
if spec_path is None:
return None
if not spec_path.is_absolute():
spec_path = catalog_dir / spec_path
return spec_path
catalog_entries = [
(resolve(Path(e["audio"])), resolve(Path(e["tlog"]))) for e in catalog_entries
]
for src, dst in catalog_entries:
if not Config.force and dst.is_file():
raise RuntimeError(
f"Destination file already exists: {dst}. Use --force for overwriting."
)
if not dst.parent.is_dir():
dst.parent.mkdir(parents=True)
src_paths, dst_paths = zip(*catalog_entries)
return src_paths, dst_paths
def get_tasks_from_dir(src_dir: Path, recursive: bool) -> Tuple[List[Path], List[Path]]:
"""Given a directory `src_dir` containing audio files, scan it for audio files
and return transcription tasks, ie. (src_path, dest_path) pairs corresponding to
a path to an audio file to be transcribed, and a path to a JSON file to place
transcription results.
"""
glob_method = src_dir.rglob if recursive else src_dir.glob
src_paths = list(glob_method("*.wav"))
dst_paths = [path.with_suffix(".tlog") for path in src_paths]
return src_paths, dst_paths
def transcribe():
initialize_transcribe_config()
src_path = Path(Config.src).resolve()
if not Config.src or not src_path.exists():
# path not given or non-existant
raise RuntimeError(
"You have to specify which audio file, catalog file or directory to "
"transcribe with the --src flag."
)
else:
# path given and exists
if src_path.is_file():
if src_path.suffix != ".catalog":
# Transcribe one file
dst_path = (
Path(Config.dst).resolve()
if Config.dst
else src_path.with_suffix(".tlog")
)
if dst_path.is_file() and not Config.force:
raise RuntimeError(
f'Destination file "{dst_path}" already exists - use '
"--force for overwriting."
)
if not dst_path.parent.is_dir():
raise RuntimeError("Missing destination directory")
transcribe_many([src_path], [dst_path])
else:
# Transcribe from .catalog input
src_paths, dst_paths = get_tasks_from_catalog(src_path)
transcribe_many(src_paths, dst_paths)
elif src_path.is_dir():
# Transcribe from dir input
print(f"Transcribing all files in --src directory {src_path}")
src_paths, dst_paths = get_tasks_from_dir(src_path, Config.recursive)
transcribe_many(src_paths, dst_paths)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field(
default="",
metadata=dict(
help="Source path to an audio file or directory or catalog file. "
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of the original.'
),
)
dst: str = field(
default="",
metadata=dict(
help="path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated."
),
)
recursive: bool = field(
default=False,
metadata=dict(help="scan source directory recursively for audio"),
)
force: bool = field(
default=False,
metadata=dict(
help="Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)"
),
)
vad_aggressiveness: int = field(
default=3,
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
)
batch_size: int = field(
default=40,
metadata=dict(help="Default batch size"),
)
outlier_duration_ms: int = field(
default=10000,
metadata=dict(
help="Duration in ms after which samples are considered outliers"
),
)
outlier_batch_size: int = field(
default=1,
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
)
def __post_init__(self):
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
raise RuntimeError(
"Parameter --dst not supported if --src points to a catalog"
)
if os.path.isdir(self.src):
if self.dst:
raise RuntimeError(
"Destination path not supported for batch decoding jobs."
)
super().__post_init__()
def initialize_transcribe_config():
config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config)
def main():
assert not tf.test.is_gpu_available()
# Set start method to spawn on all platforms to avoid issues with TensorFlow
multiprocessing.set_start_method("spawn")
try:
import webrtcvad
except ImportError:
print(
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
)
sys.exit(1)
check_ctcdecoder_version()
transcribe()
if __name__ == "__main__":
main()

View File

@ -7,7 +7,13 @@ import tensorflow as tf
from .config import Config, log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
def _load_checkpoint(
session,
checkpoint_path,
allow_drop_layers,
allow_lr_init=True,
silent: bool = False,
):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -75,12 +81,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v)
load_vars -= init_vars
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for v in sorted(load_vars, key=lambda v: v.op.name):
log_info("Loading variable from checkpoint: %s" % (v.op.name))
maybe_log_info(f"Loading variable from checkpoint: {v.op.name}")
v.load(ckpt.get_tensor(v.op.name), session=session)
for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name))
maybe_log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer)
@ -99,31 +109,49 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
def _load_or_init_impl(
session, method_order, allow_drop_layers, allow_lr_init=True, silent: bool = False
):
def maybe_log_info(*args, **kwargs):
if not silent:
log_info(*args, **kwargs)
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == "best":
ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path:
log_info("Loading best validating checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading best validating checkpoint from {}".format(ckpt_path)
)
log_info("Could not find best validating checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == "last":
ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path:
log_info("Loading most recent checkpoint from {}".format(ckpt_path))
return _load_checkpoint(
session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
maybe_log_info(
"Loading most recent checkpoint from {}".format(ckpt_path)
)
log_info("Could not find most recent checkpoint.")
return _load_checkpoint(
session,
ckpt_path,
allow_drop_layers,
allow_lr_init=allow_lr_init,
silent=silent,
)
maybe_log_info("Could not find most recent checkpoint.")
# Initialize all variables
elif method == "init":
log_info("Initializing all variables.")
maybe_log_info("Initializing all variables.")
return _initialize_all_variables(session)
else:
@ -138,7 +166,7 @@ def reload_best_checkpoint(session):
_load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session):
def load_or_init_graph_for_training(session, silent: bool = False):
"""
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
@ -149,10 +177,10 @@ def load_or_init_graph_for_training(session):
methods = ["best", "last", "init"]
else:
methods = [Config.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)
_load_or_init_impl(session, methods, allow_drop_layers=True, silent=silent)
def load_graph_for_evaluation(session):
def load_graph_for_evaluation(session, silent: bool = False):
"""
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
@ -163,4 +191,4 @@ def load_graph_for_evaluation(session):
methods = ["best", "last"]
else:
methods = [Config.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)
_load_or_init_impl(session, methods, allow_drop_layers=False, silent=silent)

View File

@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function
import json
import os
import sys
from dataclasses import asdict, dataclass, field
@ -17,7 +18,7 @@ from .augmentations import NormalizeSampleRate, parse_augmentations
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
from .gpu import get_available_gpus
from .helpers import parse_file_size
from .io import path_exists_remote
from .io import is_remote_path, open_remote, path_exists_remote
class _ConfigSingleton:
@ -37,7 +38,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
@dataclass
class _SttConfig(Coqpit):
class BaseSttConfig(Coqpit):
def __post_init__(self):
# Augmentations
self.augmentations = parse_augmentations(self.augment)
@ -161,6 +162,7 @@ class _SttConfig(Coqpit):
self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path:
self.alphabet = Alphabet(self.alphabet_config_path)
self.effective_alphabet_path = self.alphabet_config_path
elif os.path.exists(loaded_checkpoint_alphabet_file):
print(
"I --alphabet_config_path not specified, but found an alphabet file "
@ -168,6 +170,7 @@ class _SttConfig(Coqpit):
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
self.effective_alphabet_path = loaded_checkpoint_alphabet_file
elif self.train_files and self.dev_files and self.test_files:
# If all subsets are in the same folder and there's an alphabet file
# alongside them, use it.
@ -185,6 +188,7 @@ class _SttConfig(Coqpit):
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(str(possible_alphabet))
self.effective_alphabet_path = possible_alphabet
if not self.alphabet:
# Generate alphabet automatically from input dataset, but only if
@ -199,6 +203,7 @@ class _SttConfig(Coqpit):
characters, alphabet = create_alphabet_from_sources(sources)
print(f"I Generated alphabet characters: {characters}.")
self.alphabet = alphabet
self.effective_alphabet_path = saved_checkpoint_alphabet_file
else:
raise RuntimeError(
"Missing --alphabet_config_path flag. Couldn't find an alphabet file "
@ -208,6 +213,19 @@ class _SttConfig(Coqpit):
"be generated automatically."
)
# Save flags next to checkpoints
if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
if not os.path.exists(flags_file):
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
if not os.path.exists(saved_checkpoint_alphabet_file):
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
# Geometric Constants
# ===================
@ -322,6 +340,13 @@ class _SttConfig(Coqpit):
),
)
vocab_file: str = field(
default="",
metadata=dict(
help="For use with evaluate_flashlight - text file containing vocabulary of scorer, one word per line."
),
)
read_buffer: str = field(
default="1MB",
metadata=dict(
@ -580,6 +605,10 @@ class _SttConfig(Coqpit):
default=True,
metadata=dict(help="export a quantized model (optimized for size)"),
)
export_savedmodel: bool = field(
default=False,
metadata=dict(help="export model in TF SavedModel format"),
)
n_steps: int = field(
default=16,
metadata=dict(
@ -824,16 +853,22 @@ class _SttConfig(Coqpit):
def initialize_globals_from_cli():
c = _SttConfig.init_from_argparse(arg_prefix="")
c = BaseSttConfig.init_from_argparse(arg_prefix="")
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_args(**override_args):
# Update Config with new args
c = _SttConfig(**override_args)
c = BaseSttConfig(**override_args)
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_instance(config):
""" Initialize Config singleton from an existing Config instance (or subclass) """
assert isinstance(config, BaseSttConfig)
_ConfigSingleton._config = config # pylint: disable=protected-access
# Logging functions
# =================

View File

@ -84,6 +84,14 @@ def audiofile_to_features(
wav_filename, clock=0.0, train_phase=False, augmentations=None
):
samples = tf.io.read_file(wav_filename)
return wavfile_bytes_to_features(
samples, clock, train_phase, augmentations, sample_id=wav_filename
)
def wavfile_bytes_to_features(
samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None
):
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features(
decoded.audio,
@ -91,7 +99,7 @@ def audiofile_to_features(
clock=clock,
train_phase=train_phase,
augmentations=augmentations,
sample_id=wav_filename,
sample_id=sample_id,
)

View File

@ -64,7 +64,7 @@ def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--validate_label_locale",
help="Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE's DIRECTORY INTO PYTHONPATH.",
help="Path to a Python file defining a |validate_label| function for your locale.",
)
return parser
@ -81,15 +81,15 @@ def get_validate_label(args):
:return: The user-supplied validate_label function
:type: function
"""
# Python 3.5 does not support passing a pathlib.Path to os.path.* methods
if "validate_label_locale" not in args or (args.validate_label_locale is None):
print(
"WARNING: No --validate_label_locale specified, your might end with inconsistent dataset."
"WARNING: No --validate_label_locale specified, you might end with inconsistent dataset."
)
return validate_label_eng
# Python 3.5 does not support passing a pathlib.Path to os.path.* methods
validate_label_locale = str(args.validate_label_locale)
if not os.path.exists(os.path.abspath(validate_label_locale)):
print("ERROR: Inexistent --validate_label_locale specified. Please check.")
print("ERROR: Path specified in --validate_label_locale is not a file.")
return None
module_dir = os.path.abspath(os.path.dirname(validate_label_locale))
sys.path.insert(1, module_dir)

View File

@ -2,246 +2,15 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import os
import sys
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow.compat.v1.logging as tflogging
import tensorflow as tf
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger("sox").setLevel(logging.ERROR)
import glob
from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import (
create_progressbar,
log_error,
log_info,
log_progress,
)
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
initialize_globals_from_cli()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size,
)
iterator = tf.data.Iterator.from_structure(
data_set.output_types,
data_set.output_shapes,
output_classes=data_set.output_classes,
)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def transcribe_many(src_paths, dst_paths):
pbar = create_progressbar(
prefix="Transcribing files | ", max_value=len(src_paths)
).start()
for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress(
'Transcribed file {} of {} from "{}" to "{}"'.format(
i + 1, len(src_paths), src_paths[i], dst_paths[i]
)
)
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def main(_):
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail("Parameter --dst not supported if --src points to a catalog")
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not FLAGS.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(FLAGS.dst)
if FLAGS.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail("Destination file not supported for batch decoding jobs.")
else:
if not FLAGS.recursive:
print(
"If you wish to recursively scan --src, then you must use --recursive"
)
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
if __name__ == "__main__":
create_flags()
tf.app.flags.DEFINE_string(
"src",
"",
"Source path to an audio file or directory or catalog file."
"Catalog files should be formatted from DSAlign. A directory will"
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
"written in-place using the source filenames with "
'suffix ".tlog" instead of ".wav".',
print(
"Using the top level transcribe.py script is deprecated and will be removed "
"in a future release. Instead use: python -m coqui_stt_training.transcribe"
)
tf.app.flags.DEFINE_string(
"dst",
"",
"path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated.",
)
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
tf.app.flags.DEFINE_boolean(
"force",
False,
"Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)",
)
tf.app.flags.DEFINE_integer(
"vad_aggressiveness",
3,
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
)
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
tf.app.flags.DEFINE_float(
"outlier_duration_ms",
10000,
"Duration in ms after which samples are considered outliers",
)
tf.app.flags.DEFINE_integer(
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
)
tf.app.run(main)
try:
from coqui_stt_training import transcribe as stt_transcribe
except ImportError:
print("Training package is not installed. See training documentation.")
raise
stt_transcribe.main()