Merge pull request #1920 from JRMeyer/transfer-learning-notebook
Fix config checkpoint handling and add notebook
This commit is contained in:
commit
df26eca4d2
|
@ -0,0 +1,4 @@
|
||||||
|
# Python Notebooks for 🐸 STT
|
||||||
|
|
||||||
|
1. Train a new Speech-to-Text model from scratch [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/train-your-first-coqui-STT-model.ipynb)
|
||||||
|
2. Transfer learning (English --> Russian) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/easy-transfer-learning.ipynb)
|
|
@ -0,0 +1,286 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "45ea3ef5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Easy transfer learning with 🐸 STT ⚡\n",
|
||||||
|
"\n",
|
||||||
|
"You want to train a Coqui (🐸) STT model, but you don't have a lot of data. What do you do?\n",
|
||||||
|
"\n",
|
||||||
|
"The answer 💡: Grab a pre-trained model and fine-tune it to your data. This is called `\"Transfer Learning\"` ⚡\n",
|
||||||
|
"\n",
|
||||||
|
"🐸 STT comes with transfer learning support out-of-the box.\n",
|
||||||
|
"\n",
|
||||||
|
"You can even take a pre-trained model and fine-tune it to _any new language_, even if the alphabets are completely different. Likewise, you can fine-tune a model to your own data and improve performance if the language is the same.\n",
|
||||||
|
"\n",
|
||||||
|
"In this notebook, we will:\n",
|
||||||
|
"\n",
|
||||||
|
"1. Download a pre-trained English STT model.\n",
|
||||||
|
"2. Download data for the Russian language.\n",
|
||||||
|
"3. Fine-tune the English model to Russian language.\n",
|
||||||
|
"4. Test the new Russian model and display its performance.\n",
|
||||||
|
"\n",
|
||||||
|
"So, let's jump right in!\n",
|
||||||
|
"\n",
|
||||||
|
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fa2aec77",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## Install Coqui STT if you need to\n",
|
||||||
|
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
|
||||||
|
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8c07a273",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Download pre-trained English model\n",
|
||||||
|
"\n",
|
||||||
|
"We're going to download a very small (but very accurate) pre-trained STT model for English. This model was trained to only transcribe the English words \"yes\" and \"no\", but with transfer learning we can train a new model which could transcribe any words in any language. In this notebook, we will turn this \"constrained vocabulary\" English model into an \"open vocabulary\" Russian model.\n",
|
||||||
|
"\n",
|
||||||
|
"Coqui STT models as typically stored as checkpoints (for training) and protobufs (for deployment). For transfer learning, we want the **model checkpoints**.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "608d203f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"### Download pre-trained model\n",
|
||||||
|
"import os\n",
|
||||||
|
"import tarfile\n",
|
||||||
|
"from coqui_stt_training.util.downloader import maybe_download\n",
|
||||||
|
"\n",
|
||||||
|
"def download_pretrained_model():\n",
|
||||||
|
" model_dir=\"english/\"\n",
|
||||||
|
" if not os.path.exists(\"english/coqui-yesno-checkpoints\"):\n",
|
||||||
|
" maybe_download(\"model.tar.gz\", model_dir, \"https://github.com/coqui-ai/STT-models/releases/download/english%2Fcoqui%2Fyesno-v0.0.1/coqui-yesno-checkpoints.tar.gz\")\n",
|
||||||
|
" print('\\nNo extracted pre-trained model found. Extracting now...')\n",
|
||||||
|
" tar = tarfile.open(\"english/model.tar.gz\")\n",
|
||||||
|
" tar.extractall(\"english/\")\n",
|
||||||
|
" tar.close()\n",
|
||||||
|
" else:\n",
|
||||||
|
" print('Found \"english/coqui-yesno-checkpoints\" - not extracting.')\n",
|
||||||
|
"\n",
|
||||||
|
"# Download + extract pre-trained English model\n",
|
||||||
|
"download_pretrained_model()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ed9dd7ab",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Download data for Russian\n",
|
||||||
|
"\n",
|
||||||
|
"**First things first**: we need some data.\n",
|
||||||
|
"\n",
|
||||||
|
"We're training a Speech-to-Text model, so we need some _speech_ and we need some _text_. Specificially, we want _transcribed speech_. Let's download a Russian audio file and its transcript, pre-formatted for 🐸 STT. \n",
|
||||||
|
"\n",
|
||||||
|
"**Second things second**: we want a Russian alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet. Let's download a Russian alphabet from Coqui and use that.\n",
|
||||||
|
"\n",
|
||||||
|
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b5105ea7",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"### Download sample data\n",
|
||||||
|
"from coqui_stt_training.util.downloader import maybe_download\n",
|
||||||
|
"\n",
|
||||||
|
"def download_sample_data():\n",
|
||||||
|
" data_dir=\"russian/\"\n",
|
||||||
|
" maybe_download(\"ru.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.wav\")\n",
|
||||||
|
" maybe_download(\"ru.csv\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.csv\")\n",
|
||||||
|
" maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/alphabet.ru\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Download sample Russian data\n",
|
||||||
|
"download_sample_data()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b46b7227",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Configure the training run\n",
|
||||||
|
"\n",
|
||||||
|
"Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you can use `initialize_globals_from_args()` to set your own. \n",
|
||||||
|
"\n",
|
||||||
|
"You must **always** configure the paths to your data, and you must **always** configure your alphabet. For transfer learning, it's good practice to define different `load_checkpoint_dir` and `save_checkpoint_dir` paths so that you keep your new model (Russian STT) separate from the old one (English STT). The parameter `drop_source_layers` allows you to remove layers from the original (aka \"source\") model, and re-initialize them from scratch. If you are fine-tuning to a new alphabet you will have to use _at least_ `drop_source_layers=1` to remove the output layer and add a new output layer which matches your new alphabet.\n",
|
||||||
|
"\n",
|
||||||
|
"We are fine-tuning a pre-existing model, so `n_hidden` should be the same as the original English model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cff3c5a0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from coqui_stt_training.util.config import initialize_globals_from_args\n",
|
||||||
|
"\n",
|
||||||
|
"initialize_globals_from_args(\n",
|
||||||
|
" n_hidden=64,\n",
|
||||||
|
" load_checkpoint_dir=\"english/coqui-yesno-checkpoints\",\n",
|
||||||
|
" save_checkpoint_dir=\"russian/checkpoints\",\n",
|
||||||
|
" drop_source_layers=1,\n",
|
||||||
|
" alphabet_config_path=\"russian/alphabet.txt\",\n",
|
||||||
|
" train_files=[\"russian/ru.csv\"],\n",
|
||||||
|
" dev_files=[\"russian/ru.csv\"],\n",
|
||||||
|
" epochs=200,\n",
|
||||||
|
" load_cudnn=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "419828c1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### View all Config settings (*Optional*) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cac6ea3d",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from coqui_stt_training.util.config import Config\n",
|
||||||
|
"\n",
|
||||||
|
"print(Config.to_json())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c8e700d1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Train a new Russian model\n",
|
||||||
|
"\n",
|
||||||
|
"Let's kick off a training run 🚀🚀🚀 (using the configure you set above).\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook should work on either a GPU or a CPU. However, in case you're running this on _multiple_ GPUs we want to only use one, because the sample dataset (one audio file) is too small to split across multiple GPUs."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8aab2195",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||||
|
"import tensorflow.compat.v1 as tfv1\n",
|
||||||
|
"\n",
|
||||||
|
"# use maximum one GPU\n",
|
||||||
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
|
"\n",
|
||||||
|
"early_training_checks()\n",
|
||||||
|
"\n",
|
||||||
|
"tfv1.reset_default_graph()\n",
|
||||||
|
"train()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3c87ba61",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Configure the testing run\n",
|
||||||
|
"\n",
|
||||||
|
"Let's add the path to our testing data and update `load_checkpoint_dir` to our new model checkpoints."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2be7beb5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from coqui_stt_training.util.config import Config\n",
|
||||||
|
"\n",
|
||||||
|
"Config.test_files=[\"russian/ru.csv\"]\n",
|
||||||
|
"Config.load_checkpoint_dir=\"russian/checkpoints\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c6a5c971",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## ✅ Test the new Russian model\n",
|
||||||
|
"\n",
|
||||||
|
"We made it! 🙌\n",
|
||||||
|
"\n",
|
||||||
|
"Let's kick off the testing run, which displays performance metrics.\n",
|
||||||
|
"\n",
|
||||||
|
"We're committing the cardinal sin of ML 😈 (aka - testing on our training data) so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇\n",
|
||||||
|
"\n",
|
||||||
|
"You can see from the test output that our tiny model has overfit to the data, and basically memorized this one sentence.\n",
|
||||||
|
"\n",
|
||||||
|
"When you start training your own models, make sure your testing data doesn't include your training data 😅"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6222dc69",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from coqui_stt_training.train import test\n",
|
||||||
|
"\n",
|
||||||
|
"tfv1.reset_default_graph()\n",
|
||||||
|
"test()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
|
@ -25,6 +25,18 @@
|
||||||
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
|
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fa2aec78",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## Install Coqui STT if you need to\n",
|
||||||
|
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
|
||||||
|
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "be5fe49c",
|
"id": "be5fe49c",
|
||||||
|
@ -46,7 +58,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"**Second things second**: we want an alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet, and you should specify this alphabet before training. Let's download an English alphabet from Coqui and use that.\n",
|
"**Second things second**: we want an alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet, and you should specify this alphabet before training. Let's download an English alphabet from Coqui and use that.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"_*If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
|
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
|
@ -74,32 +74,20 @@ class _SttConfig(Coqpit):
|
||||||
if self.dropout_rate6 < 0:
|
if self.dropout_rate6 < 0:
|
||||||
self.dropout_rate6 = self.dropout_rate
|
self.dropout_rate6 = self.dropout_rate
|
||||||
|
|
||||||
# Set default checkpoint dir
|
# Checkpoint dir logic #
|
||||||
# If separate save and load flags were not specified, default to load and save
|
if self.checkpoint_dir:
|
||||||
# from the same dir.
|
# checkpoint_dir always overrides {save,load}_checkpoint_dir
|
||||||
|
self.save_checkpoint_dir = self.checkpoint_dir
|
||||||
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
self.load_checkpoint_dir = self.checkpoint_dir
|
||||||
if (not self.save_checkpoint_dir) or (
|
else:
|
||||||
self.save_checkpoint_dir is not self.checkpoint_dir
|
if not self.save_checkpoint_dir:
|
||||||
):
|
self.save_checkpoint_dir = xdg.save_data_path(
|
||||||
if not self.checkpoint_dir:
|
|
||||||
self.checkpoint_dir = xdg.save_data_path(
|
|
||||||
os.path.join("stt", "checkpoints")
|
os.path.join("stt", "checkpoints")
|
||||||
)
|
)
|
||||||
self.save_checkpoint_dir = self.checkpoint_dir
|
if not self.load_checkpoint_dir:
|
||||||
else:
|
self.load_checkpoint_dir = xdg.save_data_path(
|
||||||
self.save_checkpoint_dir = self.checkpoint_dir
|
|
||||||
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
|
||||||
if (not self.load_checkpoint_dir) or (
|
|
||||||
self.load_checkpoint_dir is not self.checkpoint_dir
|
|
||||||
):
|
|
||||||
if not self.checkpoint_dir:
|
|
||||||
self.checkpoint_dir = xdg.load_data_path(
|
|
||||||
os.path.join("stt", "checkpoints")
|
os.path.join("stt", "checkpoints")
|
||||||
)
|
)
|
||||||
self.load_checkpoint_dir = self.checkpoint_dir
|
|
||||||
else:
|
|
||||||
self.load_checkpoint_dir = self.checkpoint_dir
|
|
||||||
|
|
||||||
if self.load_train not in ["last", "best", "init", "auto"]:
|
if self.load_train not in ["last", "best", "init", "auto"]:
|
||||||
self.load_train = "auto"
|
self.load_train = "auto"
|
||||||
|
|
Loading…
Reference in New Issue