Merge pull request #1920 from JRMeyer/transfer-learning-notebook

Fix config checkpoint handling and add notebook
This commit is contained in:
Josh Meyer 2021-07-30 10:15:12 -04:00 committed by GitHub
commit df26eca4d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 313 additions and 23 deletions

4
notebooks/README.md Normal file
View File

@ -0,0 +1,4 @@
# Python Notebooks for 🐸 STT
1. Train a new Speech-to-Text model from scratch [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/train-your-first-coqui-STT-model.ipynb)
2. Transfer learning (English --> Russian) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/easy-transfer-learning.ipynb)

View File

@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "45ea3ef5",
"metadata": {},
"source": [
"# Easy transfer learning with 🐸 STT ⚡\n",
"\n",
"You want to train a Coqui (🐸) STT model, but you don't have a lot of data. What do you do?\n",
"\n",
"The answer 💡: Grab a pre-trained model and fine-tune it to your data. This is called `\"Transfer Learning\"` ⚡\n",
"\n",
"🐸 STT comes with transfer learning support out-of-the box.\n",
"\n",
"You can even take a pre-trained model and fine-tune it to _any new language_, even if the alphabets are completely different. Likewise, you can fine-tune a model to your own data and improve performance if the language is the same.\n",
"\n",
"In this notebook, we will:\n",
"\n",
"1. Download a pre-trained English STT model.\n",
"2. Download data for the Russian language.\n",
"3. Fine-tune the English model to Russian language.\n",
"4. Test the new Russian model and display its performance.\n",
"\n",
"So, let's jump right in!\n",
"\n",
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa2aec77",
"metadata": {},
"outputs": [],
"source": [
"## Install Coqui STT if you need to\n",
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
]
},
{
"cell_type": "markdown",
"id": "8c07a273",
"metadata": {},
"source": [
"## ✅ Download pre-trained English model\n",
"\n",
"We're going to download a very small (but very accurate) pre-trained STT model for English. This model was trained to only transcribe the English words \"yes\" and \"no\", but with transfer learning we can train a new model which could transcribe any words in any language. In this notebook, we will turn this \"constrained vocabulary\" English model into an \"open vocabulary\" Russian model.\n",
"\n",
"Coqui STT models as typically stored as checkpoints (for training) and protobufs (for deployment). For transfer learning, we want the **model checkpoints**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "608d203f",
"metadata": {},
"outputs": [],
"source": [
"### Download pre-trained model\n",
"import os\n",
"import tarfile\n",
"from coqui_stt_training.util.downloader import maybe_download\n",
"\n",
"def download_pretrained_model():\n",
" model_dir=\"english/\"\n",
" if not os.path.exists(\"english/coqui-yesno-checkpoints\"):\n",
" maybe_download(\"model.tar.gz\", model_dir, \"https://github.com/coqui-ai/STT-models/releases/download/english%2Fcoqui%2Fyesno-v0.0.1/coqui-yesno-checkpoints.tar.gz\")\n",
" print('\\nNo extracted pre-trained model found. Extracting now...')\n",
" tar = tarfile.open(\"english/model.tar.gz\")\n",
" tar.extractall(\"english/\")\n",
" tar.close()\n",
" else:\n",
" print('Found \"english/coqui-yesno-checkpoints\" - not extracting.')\n",
"\n",
"# Download + extract pre-trained English model\n",
"download_pretrained_model()"
]
},
{
"cell_type": "markdown",
"id": "ed9dd7ab",
"metadata": {},
"source": [
"## ✅ Download data for Russian\n",
"\n",
"**First things first**: we need some data.\n",
"\n",
"We're training a Speech-to-Text model, so we need some _speech_ and we need some _text_. Specificially, we want _transcribed speech_. Let's download a Russian audio file and its transcript, pre-formatted for 🐸 STT. \n",
"\n",
"**Second things second**: we want a Russian alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet. Let's download a Russian alphabet from Coqui and use that.\n",
"\n",
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5105ea7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"### Download sample data\n",
"from coqui_stt_training.util.downloader import maybe_download\n",
"\n",
"def download_sample_data():\n",
" data_dir=\"russian/\"\n",
" maybe_download(\"ru.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.wav\")\n",
" maybe_download(\"ru.csv\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.csv\")\n",
" maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/alphabet.ru\")\n",
"\n",
"# Download sample Russian data\n",
"download_sample_data()"
]
},
{
"cell_type": "markdown",
"id": "b46b7227",
"metadata": {},
"source": [
"## ✅ Configure the training run\n",
"\n",
"Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you can use `initialize_globals_from_args()` to set your own. \n",
"\n",
"You must **always** configure the paths to your data, and you must **always** configure your alphabet. For transfer learning, it's good practice to define different `load_checkpoint_dir` and `save_checkpoint_dir` paths so that you keep your new model (Russian STT) separate from the old one (English STT). The parameter `drop_source_layers` allows you to remove layers from the original (aka \"source\") model, and re-initialize them from scratch. If you are fine-tuning to a new alphabet you will have to use _at least_ `drop_source_layers=1` to remove the output layer and add a new output layer which matches your new alphabet.\n",
"\n",
"We are fine-tuning a pre-existing model, so `n_hidden` should be the same as the original English model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cff3c5a0",
"metadata": {},
"outputs": [],
"source": [
"from coqui_stt_training.util.config import initialize_globals_from_args\n",
"\n",
"initialize_globals_from_args(\n",
" n_hidden=64,\n",
" load_checkpoint_dir=\"english/coqui-yesno-checkpoints\",\n",
" save_checkpoint_dir=\"russian/checkpoints\",\n",
" drop_source_layers=1,\n",
" alphabet_config_path=\"russian/alphabet.txt\",\n",
" train_files=[\"russian/ru.csv\"],\n",
" dev_files=[\"russian/ru.csv\"],\n",
" epochs=200,\n",
" load_cudnn=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "419828c1",
"metadata": {},
"source": [
"### View all Config settings (*Optional*) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cac6ea3d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from coqui_stt_training.util.config import Config\n",
"\n",
"print(Config.to_json())"
]
},
{
"cell_type": "markdown",
"id": "c8e700d1",
"metadata": {},
"source": [
"## ✅ Train a new Russian model\n",
"\n",
"Let's kick off a training run 🚀🚀🚀 (using the configure you set above).\n",
"\n",
"This notebook should work on either a GPU or a CPU. However, in case you're running this on _multiple_ GPUs we want to only use one, because the sample dataset (one audio file) is too small to split across multiple GPUs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aab2195",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from coqui_stt_training.train import train, early_training_checks\n",
"import tensorflow.compat.v1 as tfv1\n",
"\n",
"# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"early_training_checks()\n",
"\n",
"tfv1.reset_default_graph()\n",
"train()"
]
},
{
"cell_type": "markdown",
"id": "3c87ba61",
"metadata": {},
"source": [
"## ✅ Configure the testing run\n",
"\n",
"Let's add the path to our testing data and update `load_checkpoint_dir` to our new model checkpoints."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2be7beb5",
"metadata": {},
"outputs": [],
"source": [
"from coqui_stt_training.util.config import Config\n",
"\n",
"Config.test_files=[\"russian/ru.csv\"]\n",
"Config.load_checkpoint_dir=\"russian/checkpoints\""
]
},
{
"cell_type": "markdown",
"id": "c6a5c971",
"metadata": {},
"source": [
"## ✅ Test the new Russian model\n",
"\n",
"We made it! 🙌\n",
"\n",
"Let's kick off the testing run, which displays performance metrics.\n",
"\n",
"We're committing the cardinal sin of ML 😈 (aka - testing on our training data) so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇\n",
"\n",
"You can see from the test output that our tiny model has overfit to the data, and basically memorized this one sentence.\n",
"\n",
"When you start training your own models, make sure your testing data doesn't include your training data 😅"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6222dc69",
"metadata": {},
"outputs": [],
"source": [
"from coqui_stt_training.train import test\n",
"\n",
"tfv1.reset_default_graph()\n",
"test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -25,6 +25,18 @@
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa2aec78",
"metadata": {},
"outputs": [],
"source": [
"## Install Coqui STT if you need to\n",
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
]
},
{
"cell_type": "markdown",
"id": "be5fe49c",
@ -46,7 +58,7 @@
"\n",
"**Second things second**: we want an alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet, and you should specify this alphabet before training. Let's download an English alphabet from Coqui and use that.\n",
"\n",
"_*If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
]
},
{

View File

@ -74,32 +74,20 @@ class _SttConfig(Coqpit):
if self.dropout_rate6 < 0:
self.dropout_rate6 = self.dropout_rate
# Set default checkpoint dir
# If separate save and load flags were not specified, default to load and save
# from the same dir.
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.save_checkpoint_dir) or (
self.save_checkpoint_dir is not self.checkpoint_dir
):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.save_data_path(
os.path.join("stt", "checkpoints")
)
# Checkpoint dir logic #
if self.checkpoint_dir:
# checkpoint_dir always overrides {save,load}_checkpoint_dir
self.save_checkpoint_dir = self.checkpoint_dir
else:
self.save_checkpoint_dir = self.checkpoint_dir
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.load_checkpoint_dir) or (
self.load_checkpoint_dir is not self.checkpoint_dir
):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.load_data_path(
os.path.join("stt", "checkpoints")
)
self.load_checkpoint_dir = self.checkpoint_dir
else:
self.load_checkpoint_dir = self.checkpoint_dir
if not self.save_checkpoint_dir:
self.save_checkpoint_dir = xdg.save_data_path(
os.path.join("stt", "checkpoints")
)
if not self.load_checkpoint_dir:
self.load_checkpoint_dir = xdg.save_data_path(
os.path.join("stt", "checkpoints")
)
if self.load_train not in ["last", "best", "init", "auto"]:
self.load_train = "auto"