Merge of pull requests #49, #50, and #52. Fixes issues #2, #4, #11, #12, #46, #47, and #48

This commit is contained in:
Kelly Davis 2016-10-13 15:15:39 -04:00
parent 9fb60a7ebc
commit a3abc9d92a
12 changed files with 739 additions and 243 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
*.pyc *.pyc
.DS_Store .DS_Store
/logs /logs
/data/ted/TEDLIUM_release2
/data/ted/TEDLIUM_release2.tar.gz

View File

@ -83,11 +83,12 @@
"import tempfile\n", "import tempfile\n",
"import subprocess\n", "import subprocess\n",
"import numpy as np\n", "import numpy as np\n",
"from math import ceil\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from util.log import merge_logs\n", "from util.log import merge_logs\n",
"from util.gpu import get_available_gpus\n", "from util.gpu import get_available_gpus\n",
"from util.importers.ted_lium import read_data_sets\n", "from util.importers.ted_lium import read_data_sets\n",
"from util.text import sparse_tensor_value_to_text, wers\n", "from util.text import sparse_tensor_value_to_texts, wers\n",
"from tensorflow.python.ops import ctc_ops" "from tensorflow.python.ops import ctc_ops"
] ]
}, },
@ -123,11 +124,11 @@
"beta1 = 0.9 # TODO: Determine a reasonable value for this\n", "beta1 = 0.9 # TODO: Determine a reasonable value for this\n",
"beta2 = 0.999 # TODO: Determine a reasonable value for this\n", "beta2 = 0.999 # TODO: Determine a reasonable value for this\n",
"epsilon = 1e-8 # TODO: Determine a reasonable value for this\n", "epsilon = 1e-8 # TODO: Determine a reasonable value for this\n",
"training_iters = 1250 # TODO: Determine a reasonable value for this\n", "training_iters = 15 # TODO: Determine a reasonable value for this\n",
"batch_size = 1 # TODO: Determine a reasonable value for this\n", "batch_size = 5 # TODO: Determine a reasonable value for this\n",
"display_step = 10 # TODO: Determine a reasonable value for this\n", "display_step = 10 # TODO: Determine a reasonable value for this\n",
"validation_step = 50 # TODO: Determine a reasonable value for this\n", "validation_step = 50 # TODO: Determine a reasonable value for this\n",
"checkpoint_step = 1000 # TODO: Determine a reasonable value for this\n", "checkpoint_step = 5 # TODO: Determine a reasonable value for this\n",
"checkpoint_dir = tempfile.gettempdir() # TODO: Determine a reasonable value for this" "checkpoint_dir = tempfile.gettempdir() # TODO: Determine a reasonable value for this"
] ]
}, },
@ -191,14 +192,14 @@
"source": [ "source": [
"Now we will introduce several constants related to the geometry of the network.\n", "Now we will introduce several constants related to the geometry of the network.\n",
"\n", "\n",
"The network views each speech sample as a sequence of time-slices $x^{(i)}_t$ of length $T^{(i)}$. As the speech samples vary in length, we know that $T^{(i)}$ need not equal $T^{(j)}$ for $i \\ne j$. However, BRNN in TensorFlow are unable to deal with sequences with differing lengths. Thus, we must pad speech sample sequences with trailing zeros such that they are all of the same length. This common padded length is captured in the variable `n_steps` which will be set after the data set is loaded. " "The network views each speech sample as a sequence of time-slices $x^{(i)}_t$ of length $T^{(i)}$. As the speech samples vary in length, we know that $T^{(i)}$ need not equal $T^{(j)}$ for $i \\ne j$. For each batch, BRNN in TensorFlow needs to know `n_steps` which is the maximum $T^{(i)}$ for the batch."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Each of the `n_steps` vectors is a vector of MFCC features of a time-slice of the speech sample. We will make the number of MFCC features dependent upon the sample rate of the data set. Generically, if the sample rate is 8kHz we use 13 features. If the sample rate is 16kHz we use 26 features... We capture the dimension of these vectors, equivalently the number of MFCC features, in the variable `n_input`" "Each of the at maximum `n_steps` vectors is a vector of MFCC features of a time-slice of the speech sample. We will make the number of MFCC features dependent upon the sample rate of the data set. Generically, if the sample rate is 8kHz we use 13 features. If the sample rate is 16kHz we use 26 features... We capture the dimension of these vectors, equivalently the number of MFCC features, in the variable `n_input`"
] ]
}, },
{ {
@ -604,10 +605,13 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def calculate_accuracy_and_loss(n_steps, batch_set):\n", "def calculate_accuracy_and_loss(batch_set):\n",
" # Obtain the next batch of data\n", " # Obtain the next batch of data\n",
" batch_x, batch_y, batch_seq_len = batch_set.next_batch(batch_size)\n", " batch_x, batch_y, n_steps = ted_lium.train.next_batch()\n",
"\n", "\n",
" # Set batch_seq_len for the batch\n",
" batch_seq_len = batch_x.shape[0] * [n_steps]\n",
" \n",
" # Calculate the logits of the batch using BiRNN\n", " # Calculate the logits of the batch using BiRNN\n",
" logits = BiRNN(batch_x, n_steps)\n", " logits = BiRNN(batch_x, n_steps)\n",
" \n", " \n",
@ -639,14 +643,21 @@
"source": [ "source": [
"The first lines of `calculate_accuracy_and_loss()`\n", "The first lines of `calculate_accuracy_and_loss()`\n",
"```python\n", "```python\n",
"def calculate_accuracy_and_loss(n_steps, batch_set):\n", "def calculate_accuracy_and_loss(batch_set):\n",
" # Obtain the next batch of data\n", " # Obtain the next batch of data\n",
" batch_x, batch_y, batch_seq_len = batch_set.next_batch(batch_size)\n", " batch_x, batch_y, n_steps = ted_lium.train.next_batch()\n",
"```\n", "```\n",
"simply obtian the next mini-batch of data.\n", "simply obtian the next mini-batch of data.\n",
"\n", "\n",
"The next line\n", "The next line\n",
"```python\n", "```python\n",
" # Set batch_seq_len for the batch\n",
" batch_seq_len = batch_x.shape[0] * [n_steps]\n",
"```\n",
"creates `batch_seq_len` a list of the lengths of the sequences in `batch_x`. (As the sequences are zero padded to the same length, the list contains the value `n_steps` a total of `batch_x.shape[0]` times.)\n",
"\n",
"The next line\n",
"```python\n",
" # Calculate the logits from the BiRNN\n", " # Calculate the logits from the BiRNN\n",
" logits = BiRNN(batch_x, n_steps)\n", " logits = BiRNN(batch_x, n_steps)\n",
"```\n", "```\n",
@ -863,7 +874,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def get_tower_results(n_steps, batch_set, optimizer=None):\n", "def get_tower_results(batch_set, optimizer=None):\n",
" # Tower decodings to return\n", " # Tower decodings to return\n",
" tower_decodings = []\n", " tower_decodings = []\n",
" # Tower labels to return\n", " # Tower labels to return\n",
@ -879,10 +890,7 @@
" with tf.name_scope('tower_%d' % i) as scope:\n", " with tf.name_scope('tower_%d' % i) as scope:\n",
" # Calculate the avg_loss and accuracy and retrieve the decoded \n", " # Calculate the avg_loss and accuracy and retrieve the decoded \n",
" # batch along with the original batch's labels (Y) of this tower\n", " # batch along with the original batch's labels (Y) of this tower\n",
" avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(\\\n", " avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_set)\n",
" n_steps, \\\n",
" batch_set \\\n",
" )\n",
" \n", " \n",
" # Allow for variables to be re-used by the next tower\n", " # Allow for variables to be re-used by the next tower\n",
" tf.get_variable_scope().reuse_variables()\n", " tf.get_variable_scope().reuse_variables()\n",
@ -1090,17 +1098,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def decode_batch(data_set):\n", "def decode_batch(data_set):\n",
" # Set n_steps parameter\n",
" n_steps = data_set.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n",
" total_batch = int(data_set.num_examples/batch_size)\n",
"\n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
" \n",
" # Get gradients for each tower (Runs across all GPU's)\n", " # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, _, _, _ = get_tower_results(n_steps, data_set)\n", " tower_decodings, tower_labels, _, _, _ = get_tower_results(data_set)\n",
" return tower_decodings, tower_labels\n", " return tower_decodings, tower_labels\n",
" " " "
] ]
@ -1130,8 +1129,8 @@
" # Iterating over the towers\n", " # Iterating over the towers\n",
" for i in range(len(tower_decodings)):\n", " for i in range(len(tower_decodings)):\n",
" decoded, labels = session.run([tower_decodings[i], tower_labels[i]], feed_dict)\n", " decoded, labels = session.run([tower_decodings[i], tower_labels[i]], feed_dict)\n",
" originals.extend(sparse_tensor_value_to_text(labels))\n", " originals.extend(sparse_tensor_value_to_texts(labels))\n",
" results.extend(sparse_tensor_value_to_text(decoded))\n", " results.extend(sparse_tensor_value_to_texts(decoded))\n",
" \n", " \n",
" # Pairwise calculation of all rates\n", " # Pairwise calculation of all rates\n",
" rates, mean = wers(originals, results)\n", " rates, mean = wers(originals, results)\n",
@ -1186,24 +1185,18 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def train(session, data_sets):\n", "def train(session, data_sets):\n",
" # Set n_steps parameter\n",
" n_steps = data_sets.train.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n", " # Calculate the total number of batches\n",
" total_batch = int(data_sets.train.num_examples/batch_size)\n", " total_batches = data_sets.train.total_batches\n",
"\n", " \n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
"\n",
" # Create optimizer\n", " # Create optimizer\n",
" optimizer = create_optimizer()\n", " optimizer = create_optimizer()\n",
"\n", "\n",
" # Get gradients for each tower (Runs across all GPU's)\n", " # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, tower_gradients, tower_loss, accuracy = \\\n", " tower_decodings, tower_labels, tower_gradients, tower_loss, accuracy = \\\n",
" get_tower_results(n_steps, data_sets.train, optimizer)\n", " get_tower_results(data_sets.train, optimizer)\n",
" \n", " \n",
" # Validation step preparation\n", " # Validation step preparation\n",
" validation_tower_decodings, validation_tower_labels = decode_batch(data_sets.validation)\n", " validation_tower_decodings, validation_tower_labels = decode_batch(data_sets.dev)\n",
"\n", "\n",
" # Average tower gradients\n", " # Average tower gradients\n",
" avg_tower_gradients = average_gradients(tower_gradients)\n", " avg_tower_gradients = average_gradients(tower_gradients)\n",
@ -1239,7 +1232,7 @@
" print\n", " print\n",
"\n", "\n",
" # Loop over the batches\n", " # Loop over the batches\n",
" for batch in range(total_batch/len(available_devices)):\n", " for batch in range(int(ceil(float(total_batches)/len(available_devices)))):\n",
" # Compute the average loss for the last batch\n", " # Compute the average loss for the last batch\n",
" _, batch_avg_loss = session.run([apply_gradient_op, tower_loss], feed_dict_train)\n", " _, batch_avg_loss = session.run([apply_gradient_op, tower_loss], feed_dict_train)\n",
"\n", "\n",
@ -1247,14 +1240,14 @@
" total_accuracy += session.run(accuracy, feed_dict_train)\n", " total_accuracy += session.run(accuracy, feed_dict_train)\n",
"\n", "\n",
" # Log all variable states in current step\n", " # Log all variable states in current step\n",
" step = epoch * total_batch + batch * len(available_devices)\n", " step = epoch * total_batches + batch * len(available_devices)\n",
" summary_str = session.run(merged, feed_dict_train)\n", " summary_str = session.run(merged, feed_dict_train)\n",
" writer.add_summary(summary_str, step)\n", " writer.add_summary(summary_str, step)\n",
" writer.flush()\n", " writer.flush()\n",
" \n", " \n",
" # Print progress message\n", " # Print progress message\n",
" if epoch % display_step == 0:\n", " if epoch % display_step == 0:\n",
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n", " print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batches))\n",
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n", " _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
" print\n", " print\n",
"\n", "\n",
@ -1285,24 +1278,26 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Create session in which to execute\n", "# Define CPU as device on which the muti-gpu training is orchestrated\n",
"session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))\n", "with tf.device('/cpu:0'):\n",
"\n", " # Obtain ted lium data\n",
"# Obtain ted lium data\n", " ted_lium = read_data_sets(tf.get_default_graph(), './data/ted', batch_size, n_input, n_context)\n",
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n", " \n",
"\n", " # Create session in which to execute\n",
"# Take start time for time measurement\n", " session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))\n",
"time_started = datetime.datetime.utcnow()\n", " \n",
"\n", " # Take start time for time measurement\n",
"# Train the network\n", " time_started = datetime.datetime.utcnow()\n",
"last_train_wer, last_validation_wer = train(session, ted_lium)\n", " \n",
"\n", " # Train the network\n",
"# Take final time for time measurement\n", " last_train_wer, last_validation_wer = train(session, ted_lium)\n",
"time_finished = datetime.datetime.utcnow()\n", " \n",
"\n", " # Take final time for time measurement\n",
"# Calculate duration in seconds\n", " time_finished = datetime.datetime.utcnow()\n",
"duration = time_finished - time_started\n", " \n",
"duration = duration.days * 86400 + duration.seconds" " # Calculate duration in seconds\n",
" duration = time_finished - time_started\n",
" duration = duration.days * 86400 + duration.seconds"
] ]
}, },
{ {
@ -1320,9 +1315,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Test network\n", "# Define CPU as device on which the muti-gpu testing is orchestrated\n",
"test_decodings, test_labels = decode_batch(ted_lium.test)\n", "with tf.device('/cpu:0'):\n",
"_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)" " # Test network\n",
" test_decodings, test_labels = decode_batch(ted_lium.test)\n",
" _, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
] ]
}, },
{ {
@ -1374,9 +1371,9 @@
" 'n_hidden_6': n_hidden_6, \\\n", " 'n_hidden_6': n_hidden_6, \\\n",
" 'n_cell_dim': n_cell_dim, \\\n", " 'n_cell_dim': n_cell_dim, \\\n",
" 'n_character': n_character, \\\n", " 'n_character': n_character, \\\n",
" 'num_examples_train': ted_lium.train.num_examples, \\\n", " 'total_batches_train': ted_lium.train.total_batches, \\\n",
" 'num_examples_validation': ted_lium.validation.num_examples, \\\n", " 'total_batches_validation': ted_lium.validation.total_batches, \\\n",
" 'num_examples_test': ted_lium.test.num_examples \\\n", " 'total_batches_test': ted_lium.test.total_batches \\\n",
" }, \\\n", " }, \\\n",
" 'results': { \\\n", " 'results': { \\\n",
" 'duration': duration, \\\n", " 'duration': duration, \\\n",
@ -1422,7 +1419,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython2", "pygments_lexer": "ipython2",
"version": "2.7.12" "version": "2.7.11"
} }
}, },
"nbformat": 4, "nbformat": 4,

0
data/ted/.gitkeep Normal file
View File

68
util/audio.py Normal file
View File

@ -0,0 +1,68 @@
import numpy as np
import scipy.io.wavfile as wav
from python_speech_features import mfcc
def audiofile_to_input_vector(audio_filename, numcep, numcontext):
# Load wav files
fs, audio = wav.read(audio_filename)
# Get mfcc coefficients
orig_inputs = mfcc(audio, samplerate=fs, numcep=numcep)
# For each time slice of the training set, we need to copy the context this makes
# the numcep dimensions vector into a numcep + 2*numcep*numcontext dimensions
# because of:
# - numcep dimensions for the current mfcc feature set
# - numcontext*numcep dimensions for each of the past and future (x2) mfcc feature set
# => so numcep + 2*numcontext*numcep
train_inputs = np.array([], np.float32)
train_inputs.resize((orig_inputs.shape[0], numcep + 2*numcep*numcontext))
# Prepare pre-fix post fix context (TODO: Fill empty_mfcc with MCFF of silence)
empty_mfcc = np.array([])
empty_mfcc.resize((numcep))
# Prepare train_inputs with past and future contexts
time_slices = range(train_inputs.shape[0])
context_past_min = time_slices[0] + numcontext
context_future_max = time_slices[-1] - numcontext
for time_slice in time_slices:
### Reminder: array[start:stop:step]
### slices from indice |start| up to |stop| (not included), every |step|
# Pick up to numcontext time slices in the past, and complete with empty
# mfcc features
need_empty_past = max(0, (context_past_min - time_slice))
empty_source_past = list(empty_mfcc for empty_slots in range(need_empty_past))
data_source_past = orig_inputs[max(0, time_slice - numcontext):time_slice]
assert(len(empty_source_past) + len(data_source_past) == numcontext)
# Pick up to numcontext time slices in the future, and complete with empty
# mfcc features
need_empty_future = max(0, (time_slice - context_future_max))
empty_source_future = list(empty_mfcc for empty_slots in range(need_empty_future))
data_source_future = orig_inputs[time_slice + 1:time_slice + numcontext + 1]
assert(len(empty_source_future) + len(data_source_future) == numcontext)
if need_empty_past:
past = np.concatenate((empty_source_past, data_source_past))
else:
past = data_source_past
if need_empty_future:
future = np.concatenate((data_source_future, empty_source_future))
else:
future = data_source_future
past = np.reshape(past, numcontext*numcep)
now = orig_inputs[time_slice]
future = np.reshape(future, numcontext*numcep)
train_inputs[time_slice] = np.concatenate((past, now, future))
assert(len(train_inputs[time_slice]) == numcep + 2*numcep*numcontext)
# Whiten inputs (TODO: Should we whiten)
train_inputs = (train_inputs - np.mean(train_inputs))/np.std(train_inputs)
# Return results
return train_inputs

View File

@ -1,80 +0,0 @@
import numpy as np
import scipy.io.wavfile as wav
from python_speech_features import mfcc
def audiofiles_to_audio_data_sets(audio_filenames, numcep, numcontext):
# Define audio_data_sets to return
inputs = []
input_seq_lens = []
# Loop over audio_filenames
for audio_filename in audio_filenames:
# Load wav files
fs, audio = wav.read(audio_filename)
# Get mfcc coefficients
orig_inputs = mfcc(audio, samplerate=fs, numcep=numcep)
# For each time slice of the training set, we need to copy the context this makes
# the numcep dimensions vector into a numcep + 2*numcep*numcontext dimensions
# because of:
# - numcep dimensions for the current mfcc feature set
# - numcontext*numcep dimensions for each of the past and future (x2) mfcc feature set
# => so numcep + 2*numcontext*numcep
train_inputs = np.array([], np.float32)
train_inputs.resize((orig_inputs.shape[0], numcep + 2*numcep*numcontext))
# Prepare pre-fix post fix context (TODO: Fill empty_mfcc with MCFF of silence)
empty_mfcc = np.array([])
empty_mfcc.resize((numcep))
# Prepare train_inputs with past and future contexts
time_slices = range(train_inputs.shape[0])
context_past_min = time_slices[0] + numcontext
context_future_max = time_slices[-1] - numcontext
for time_slice in time_slices:
### Reminder: array[start:stop:step]
### slices from indice |start| up to |stop| (not included), every |step|
# Pick up to numcontext time slices in the past, and complete with empty
# mfcc features
need_empty_past = max(0, (context_past_min - time_slice))
empty_source_past = list(empty_mfcc for empty_slots in range(need_empty_past))
data_source_past = orig_inputs[max(0, time_slice - numcontext):time_slice]
assert(len(empty_source_past) + len(data_source_past) == numcontext)
# Pick up to numcontext time slices in the future, and complete with empty
# mfcc features
need_empty_future = max(0, (time_slice - context_future_max))
empty_source_future = list(empty_mfcc for empty_slots in range(need_empty_future))
data_source_future = orig_inputs[time_slice + 1:time_slice + numcontext + 1]
assert(len(empty_source_future) + len(data_source_future) == numcontext)
if need_empty_past:
past = np.concatenate((empty_source_past, data_source_past))
else:
past = data_source_past
if need_empty_future:
future = np.concatenate((data_source_future, empty_source_future))
else:
future = data_source_future
past = np.reshape(past, numcontext*numcep)
now = orig_inputs[time_slice]
future = np.reshape(future, numcontext*numcep)
train_inputs[time_slice] = np.concatenate((past, now, future))
assert(len(train_inputs[time_slice]) == numcep + 2*numcep*numcontext)
# Whiten inputs (TODO: Should we whiten)
train_inputs = (train_inputs - np.mean(train_inputs))/np.std(train_inputs)
# Obtain array of sequence lengths
input_seq_lens.append(train_inputs.shape[0])
# Convert train_inputs to proper form
inputs.append(train_inputs)
# Return results
return (np.asarray(inputs), input_seq_lens)

254
util/importers/librivox.py Normal file
View File

@ -0,0 +1,254 @@
import fnmatch
import numpy as np
import os
import random
import subprocess
import tarfile
from glob import glob
from itertools import cycle
from math import ceil
from sox import Transformer
from Queue import PriorityQueue
from Queue import Queue
from shutil import rmtree
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.platform import gfile
from threading import Thread
from util.audio import audiofile_to_input_vector
from util.gpu import get_available_gpus
from util.text import texts_to_sparse_tensor
class DataSets(object):
def __init__(self, train, dev, test):
self._dev = dev
self._test = test
self._train = train
@property
def train(self):
return self._train
@property
def dev(self):
return self._dev
@property
def test(self):
return self._test
class DataSet(object):
def __init__(self, graph, txt_files, thread_count, batch_size, numcep, numcontext):
self._graph = graph
self._numcep = numcep
self._batch_queue = Queue(2 * self._get_device_count())
self._txt_files = txt_files
self._batch_size = batch_size
self._numcontext = numcontext
self._thread_count = thread_count
self._files_circular_list = self._create_files_circular_list()
self._start_queue_threads()
def _get_device_count(self):
available_gpus = get_available_gpus()
return max(len(available_gpus), 1)
def _start_queue_threads(self):
batch_threads = [Thread(target=self._populate_batch_queue) for i in xrange(self._thread_count)]
for batch_thread in batch_threads:
batch_thread.daemon = True
batch_thread.start()
def _create_files_circular_list(self):
priorityQueue = PriorityQueue()
for txt_file in self._txt_files:
wav_file = os.path.splitext(txt_file)[0] + ".wav"
wav_file_size = os.path.getsize(wav_file)
priorityQueue.put((wav_file_size, (txt_file, wav_file)))
files_list = []
while not priorityQueue.empty():
priority, (txt_file, wav_file) = priorityQueue.get()
files_list.append((txt_file, wav_file))
return cycle(files_list)
def _populate_batch_queue(self):
with self._graph.as_default():
while True:
n_steps = 0
sources = []
targets = []
for index, (txt_file, wav_file) in enumerate(self._files_circular_list):
if index >= self._batch_size:
break
next_source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
if n_steps < next_source.shape[0]:
n_steps = next_source.shape[0]
sources.append(next_source)
with open(txt_file) as open_txt_file:
targets.append(open_txt_file.read())
target = texts_to_sparse_tensor(targets)
for index, next_source in enumerate(sources):
npad = ((0,(n_steps - next_source.shape[0])), (0,0))
sources[index] = np.pad(next_source, pad_width=npad, mode='constant')
source = np.array(sources)
self._batch_queue.put((source, target))
def next_batch(self):
source, target = self._batch_queue.get()
return (source, target, source.shape[1])
@property
def total_batches(self):
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count=8):
# Check if we can convert FLAC with SoX before we start
sox_help_out = subprocess.check_output(["sox", "-h"])
if sox_help_out.find("flac") == -1:
print("Error: SoX doesn't support FLAC. Please install SoX with FLAC support and try again.")
exit(1)
# Conditionally download data to data_dir
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
train_clean_100 = base.maybe_download("train-clean-100.tar.gz", data_dir, TRAIN_CLEAN_100_URL)
train_clean_360 = base.maybe_download("train-clean-360.tar.gz", data_dir, TRAIN_CLEAN_360_URL)
train_other_500 = base.maybe_download("train-other-500.tar.gz", data_dir, TRAIN_OTHER_500_URL)
dev_clean = base.maybe_download("dev-clean.tar.gz", data_dir, DEV_CLEAN_URL)
dev_other = base.maybe_download("dev-other.tar.gz", data_dir, DEV_OTHER_URL)
test_clean = base.maybe_download("test-clean.tar.gz", data_dir, TEST_CLEAN_URL)
test_other = base.maybe_download("test-other.tar.gz", data_dir, TEST_OTHER_URL)
# Conditionally extract LibriSpeech data
# We extract each archive into data_dir, but test for existence in
# data_dir/LibriSpeech because the archives share that root.
LIBRIVOX_DIR = "LibriSpeech"
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-other"), dev_other)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "test-clean"), test_clean)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "test-other"), test_other)
# Conditionally convert FLAC data to wav, from:
# data_dir/LibriSpeech/split/1/2/1-2-3.flac
# to:
# data_dir/LibriSpeech/split-wav/1-2-3.wav
_maybe_convert_wav(work_dir, "train-clean-100", "train-clean-100-wav")
_maybe_convert_wav(work_dir, "train-clean-360", "train-clean-360-wav")
_maybe_convert_wav(work_dir, "train-other-500", "train-other-500-wav")
_maybe_convert_wav(work_dir, "dev-clean", "dev-clean-wav")
_maybe_convert_wav(work_dir, "dev-other", "dev-other-wav")
_maybe_convert_wav(work_dir, "test-clean", "test-clean-wav")
_maybe_convert_wav(work_dir, "test-other", "test-other-wav")
# Conditionally split LibriSpeech transcriptions, from:
# data_dir/LibriSpeech/split/1/2/1-2.trans.txt
# to:
# data_dir/LibriSpeech/split-wav/1-2-0.txt
# data_dir/LibriSpeech/split-wav/1-2-1.txt
# data_dir/LibriSpeech/split-wav/1-2-2.txt
# ...
_maybe_split_transcriptions(work_dir, "train-clean-100", "train-clean-100-wav")
_maybe_split_transcriptions(work_dir, "train-clean-360", "train-clean-360-wav")
_maybe_split_transcriptions(work_dir, "train-other-500", "train-other-500-wav")
_maybe_split_transcriptions(work_dir, "dev-clean", "dev-clean-wav")
_maybe_split_transcriptions(work_dir, "dev-other", "dev-other-wav")
_maybe_split_transcriptions(work_dir, "test-clean", "test-clean-wav")
_maybe_split_transcriptions(work_dir, "test-other", "test-other-wav")
# Create train DataSet from all the train archives
train = _read_data_set(graph, work_dir, "train-*-wav", thread_count, batch_size, numcep, numcontext)
# Create dev DataSet from all the dev archives
dev = _read_data_set(graph, work_dir, "dev-*-wav", thread_count, batch_size, numcep, numcontext)
# Create test DataSet from all the test archives
test = _read_data_set(graph, work_dir, "test-*-wav", thread_count, batch_size, numcep, numcontext)
# Return DataSets
return DataSets(train, dev, test)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
tar = tarfile.open(archive)
tar.extractall(data_dir)
tar.close()
# os.remove(archive)
def _maybe_convert_wav(data_dir, extracted_data, converted_data):
source_dir = os.path.join(data_dir, extracted_data)
target_dir = os.path.join(data_dir, converted_data)
# Conditionally convert FLAC files to wav files
if not gfile.Exists(target_dir):
# Create target_dir
os.makedirs(target_dir)
# Loop over FLAC files in source_dir and convert each to wav
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.flac'):
flac_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(flac_file))[0] + ".wav"
wav_file = os.path.join(target_dir, wav_filename)
transformer = Transformer()
transformer.build(flac_file, wav_file)
os.remove(flac_file)
def _maybe_split_transcriptions(extracted_dir, data_set, dest_dir):
source_dir = os.path.join(extracted_dir, data_set)
target_dir = os.path.join(extracted_dir, dest_dir)
# Loop over transcription files and split each one
#
# The format for each file 1-2.trans.txt is:
# 1-2-0 transcription of 1-2-0.flac
# 1-2-1 transcription of 1-2-1.flac
# ...
#
# Each file is then split into several files:
# 1-2-0.txt (contains transcription of 1-2-0.flac)
# 1-2-1.txt (contains transcription of 1-2-1.flac)
# ...
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.trans.txt'):
trans_filename = os.path.join(root, filename)
with open(trans_filename, "r") as fin:
for line in fin:
first_space = line.find(" ")
txt_file = line[:first_space] + ".txt"
with open(os.path.join(target_dir, txt_file), "w") as fout:
fout.write(line[first_space+1:].lower().strip("\n"))
os.remove(trans_filename)
def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep, numcontext):
# Create data set dir
dataset_dir = os.path.join(work_dir, data_set)
# Obtain list of txt files
txt_files = glob(os.path.join(dataset_dir, "*.txt"))
# Return DataSet
return DataSet(graph, txt_files, thread_count, batch_size, numcep, numcontext)

294
util/importers/ted_lium.py Normal file
View File

@ -0,0 +1,294 @@
import wave
import random
import tarfile
import threading
import numpy as np
from os import path
from os import rmdir
from os import remove
from glob import glob
from math import ceil
from Queue import Queue
from os import makedirs
from sox import Transformer
from itertools import cycle
from os.path import getsize
from threading import Thread
from Queue import PriorityQueue
from util.stm import parse_stm_file
from util.gpu import get_available_gpus
from util.text import texts_to_sparse_tensor
from tensorflow.python.platform import gfile
from util.audio import audiofile_to_input_vector
from tensorflow.contrib.learn.python.learn.datasets import base
class DataSets(object):
def __init__(self, train, dev, test):
self._dev = dev
self._test = test
self._train = train
@property
def train(self):
return self._train
@property
def dev(self):
return self._dev
@property
def test(self):
return self._test
class DataSet(object):
def __init__(self, graph, txt_files, thread_count, batch_size, numcep, numcontext):
self._graph = graph
self._numcep = numcep
self._batch_queue = Queue(2 * self._get_device_count())
self._txt_files = txt_files
self._batch_size = batch_size
self._numcontext = numcontext
self._thread_count = thread_count
self._files_circular_list = self._create_files_circular_list()
self._start_queue_threads()
def _get_device_count(self):
available_gpus = get_available_gpus()
return max(len(available_gpus), 1)
def _start_queue_threads(self):
batch_threads = [Thread(target=self._populate_batch_queue) for i in xrange(self._thread_count)]
for batch_thread in batch_threads:
batch_thread.daemon = True
batch_thread.start()
def _create_files_circular_list(self):
priorityQueue = PriorityQueue()
for txt_file in self._txt_files:
stm_dir = path.sep + "stm" + path.sep
wav_dir = path.sep + "wav" + path.sep
wav_file = path.splitext(txt_file.replace(stm_dir, wav_dir))[0] + ".wav"
wav_file_size = getsize(wav_file)
priorityQueue.put((wav_file_size, (txt_file, wav_file)))
files_list = []
while not priorityQueue.empty():
priority, (txt_file, wav_file) = priorityQueue.get()
files_list.append((txt_file, wav_file))
return cycle(files_list)
def _populate_batch_queue(self):
with self._graph.as_default():
while True:
n_steps = 0
sources = []
targets = []
for index, (txt_file, wav_file) in enumerate(self._files_circular_list):
if index >= self._batch_size:
break
next_source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
if n_steps < next_source.shape[0]:
n_steps = next_source.shape[0]
sources.append(next_source)
with open(txt_file) as open_txt_file:
targets.append(open_txt_file.read())
target = texts_to_sparse_tensor(targets)
for index, next_source in enumerate(sources):
npad = ((0,(n_steps - next_source.shape[0])), (0,0))
sources[index] = np.pad(next_source, pad_width=npad, mode='constant')
source = np.array(sources)
self._batch_queue.put((source, target))
def next_batch(self):
source, target = self._batch_queue.get()
return (source, target, source.shape[1])
@property
def total_batches(self):
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count=8):
# Conditionally download data
TED_DATA = "TEDLIUM_release2.tar.gz"
TED_DATA_URL = "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz"
local_file = base.maybe_download(TED_DATA, data_dir, TED_DATA_URL)
# Conditionally extract TED data
TED_DIR = "TEDLIUM_release2"
_maybe_extract(data_dir, TED_DIR, local_file)
# Conditionally convert TED sph data to wav
_maybe_convert_wav(data_dir, TED_DIR)
# Conditionally split TED wav data
_maybe_split_wav(data_dir, TED_DIR)
# Conditionally split TED stm data
_maybe_split_stm(data_dir, TED_DIR)
# Create dev DataSet
dev = _read_data_set(graph, data_dir, TED_DIR, "dev", thread_count, batch_size, numcep, numcontext)
# Create test DataSet
test = _read_data_set(graph, data_dir, TED_DIR, "test", thread_count, batch_size, numcep, numcontext)
# Create train DataSet
train = _read_data_set(graph, data_dir, TED_DIR, "train", thread_count, batch_size, numcep, numcontext)
# Return DataSets
return DataSets(train, dev, test)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(path.join(data_dir, extracted_data)):
tar = tarfile.open(archive)
tar.extractall(data_dir)
tar.close()
remove(archive)
def _maybe_convert_wav(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
# Conditionally convert dev sph to wav
_maybe_convert_wav_dataset(extracted_dir, "dev")
# Conditionally convert train sph to wav
_maybe_convert_wav_dataset(extracted_dir, "train")
# Conditionally convert test sph to wav
_maybe_convert_wav_dataset(extracted_dir, "test")
def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Create source dir
source_dir = path.join(extracted_dir, data_set, "sph")
# Create target dir
target_dir = path.join(extracted_dir, data_set, "wav")
# Conditionally convert sph files to wav files
if not gfile.Exists(target_dir):
# Create target_dir
makedirs(target_dir)
# Loop over sph files in source_dir and convert each to wav
for sph_file in glob(path.join(source_dir, "*.sph")):
transformer = Transformer()
wav_filename = path.splitext(path.basename(sph_file))[0] + ".wav"
wav_file = path.join(target_dir, wav_filename)
transformer.build(sph_file, wav_file)
remove(sph_file)
# Remove source_dir
rmdir(source_dir)
def _maybe_split_wav(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
# Conditionally split dev wav
_maybe_split_wav_dataset(extracted_dir, "dev")
# Conditionally split train wav
_maybe_split_wav_dataset(extracted_dir, "train")
# Conditionally split test wav
_maybe_split_wav_dataset(extracted_dir, "test")
def _maybe_split_wav_dataset(extracted_dir, data_set):
# Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm")
# Create wav dir
wav_dir = path.join(extracted_dir, data_set, "wav")
# Loop over stm files and split corresponding wav
for stm_file in glob(path.join(stm_dir, "*.stm")):
# Parse stm file
stm_segments = parse_stm_file(stm_file)
# Open wav corresponding to stm_file
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
wav_file = path.join(wav_dir, wav_filename)
origAudio = wave.open(wav_file,'r')
# Loop over stm_segments and split wav_file for each segment
for stm_segment in stm_segments:
# Create wav segment filename
start_time = stm_segment.start_time
stop_time = stm_segment.stop_time
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
new_wav_file = path.join(wav_dir, new_wav_filename)
# If the wav segment filename does not exist create it
if not gfile.Exists(new_wav_file):
_split_wav(origAudio, start_time, stop_time, new_wav_file)
# Close origAudio
origAudio.close()
# Remove wav_file
remove(wav_file)
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time*frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
chunkAudio = wave.open(new_wav_file,'w')
chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData)
chunkAudio.close()
def _maybe_split_stm(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
# Conditionally split dev stm
_maybe_split_stm_dataset(extracted_dir, "dev")
# Conditionally split train stm
_maybe_split_stm_dataset(extracted_dir, "train")
# Conditionally split test stm
_maybe_split_stm_dataset(extracted_dir, "test")
def _maybe_split_stm_dataset(extracted_dir, data_set):
# Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm")
# Obtain stm files
stm_files = glob(path.join(stm_dir, "*.stm"))
# Loop over stm files and split each one
for stm_file in stm_files:
# Parse stm file
stm_segments = parse_stm_file(stm_file)
# Loop over stm_segments and create txt file for each one
for stm_segment in stm_segments:
start_time = stm_segment.start_time
stop_time = stm_segment.stop_time
txt_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".txt"
txt_file = path.join(stm_dir, txt_filename)
# If the txt segment file does not exist create it
if not gfile.Exists(txt_file):
with open(txt_file, "w+") as f:
f.write(stm_segment.transcript)
# Remove stm_file
remove(stm_file)
def _read_data_set(graph, data_dir, extracted_data, data_set, thread_count, batch_size, numcep, numcontext):
# Create stm dir
stm_dir = path.join(data_dir, extracted_data, data_set, "stm")
# Obtain list of txt files
txt_files = glob(path.join(stm_dir, "*.txt"))
# Return DataSet
return DataSet(graph, txt_files, thread_count, batch_size, numcep, numcontext)

View File

@ -1,86 +0,0 @@
import numpy as np
from os import path
from util.text import text_to_sparse_tensor
from util.audio import audiofiles_to_audio_data_sets
class DataSets(object):
def __init__(self, train, validation, test):
self._train = train
self._validation = validation
self._test = test
@property
def train(self):
return self._train
@property
def validation(self):
return self._validation
@property
def test(self):
return self._test
class DataSet(object):
def __init__(self, inputs, outputs, seq_len):
self._offset = 0
self._inputs = inputs
self._outputs = outputs
self._seq_len = seq_len
def next_batch(self, batch_size):
next_batch = (self._inputs, self._outputs, self._seq_len) # TODO: Choose only batch_size elements
self._offset += batch_size
return next_batch
@property
def max_batch_seq_len(self):
return np.amax(self._seq_len)
@property
def num_examples(self):
return self._inputs.shape[0]
def read_data_sets(data_dir, numcep, numcontext):
# Get train data
train_outputs = read_text_data_sets(data_dir, 'train')
train_inputs, train_seq_len = read_audio_data_sets(data_dir, numcep, numcontext, 'train')
# Get validation data
validation_outputs = read_text_data_sets(data_dir, 'validation')
validation_inputs, validation_seq_len = read_audio_data_sets(data_dir, numcep, numcontext, 'validation')
# Get test data
test_outputs = read_text_data_sets(data_dir, 'test')
test_inputs, test_seq_len = read_audio_data_sets(data_dir, numcep, numcontext, 'test')
# Create train, validation, and test DataSet's
train = DataSet(inputs=train_inputs, outputs=train_outputs, seq_len=train_seq_len)
validation = DataSet(inputs=validation_inputs, outputs=validation_outputs, seq_len=validation_seq_len)
test = DataSet(inputs=test_inputs, outputs=test_outputs, seq_len=test_seq_len)
# Return DataSets
return DataSets(train=train, validation=validation, test=test)
def read_text_data_sets(data_dir, data_type):
# TODO: Do not ignore data_type = ['train'|'validation'|'test']
# Create file names
text_filename = path.join(data_dir, 'LDC93S1.txt')
# Read text file and create list of sentence's words w/spaces replaced by ''
with open(text_filename, 'rb') as f:
for line in f.readlines():
original = ' '.join(line.strip().lower().split(' ')[2:]).replace('.', '')
return text_to_sparse_tensor([original])
def read_audio_data_sets(data_dir, numcep, numcontext, data_type):
# TODO: Do not ignore data_type = ['train'|'validation'|'test']
# Create file name
audio_filename = path.join(data_dir, 'LDC93S1.wav')
# Return properly formatted data
return audiofiles_to_audio_data_sets([audio_filename], numcep, numcontext)

50
util/stm.py Normal file
View File

@ -0,0 +1,50 @@
class STMSegment(object):
def __init__(self, stm_line):
tokens = stm_line.split()
self._filename = tokens[0]
self._channel = tokens[1]
self._speaker_id = tokens[2]
self._start_time = float(tokens[3])
self._stop_time = float(tokens[4])
self._labels = tokens[5]
self._transcript = ""
for token in tokens[6:]:
self._transcript += token + " "
self._transcript = self._transcript.strip()
@property
def filename(self):
return self._filename
@property
def channel(self):
return self._channel
@property
def speaker_id(self):
return self._speaker_id
@property
def start_time(self):
return self._start_time
@property
def stop_time(self):
return self._stop_time
@property
def labels(self):
return self._labels
@property
def transcript(self):
return self._transcript
def parse_stm_file(stm_file):
stm_segments = []
with open(stm_file) as stm_lines:
for stm_line in stm_lines:
stmSegment = STMSegment(stm_line)
if not "ignore_time_segment_in_scoring" == stmSegment.transcript:
stm_segments.append(stmSegment)
return stm_segments

View File

@ -6,21 +6,17 @@ SPACE_TOKEN = '<space>'
SPACE_INDEX = 0 SPACE_INDEX = 0
FIRST_INDEX = ord('a') - 1 # 0 is reserved to space FIRST_INDEX = ord('a') - 1 # 0 is reserved to space
def text_to_sparse_tensor(originals):
return tf.SparseTensor.from_value(text_to_sparse_tensor_value(originals))
def text_to_sparse_tensor_value(originals): def texts_to_sparse_tensor(originals):
tuple = text_to_sparse_tuple(originals)
return tf.SparseTensorValue(indices=tuple[0], values=tuple[1], shape=tuple[2])
def text_to_sparse_tuple(originals):
# Define list to hold results # Define list to hold results
results = [] results = []
# Process each original in originals # Process each original in originals
for original in originals: for original in originals:
# Create list of sentence's words w/spaces replaced by '' # Create list of sentence's words w/spaces replaced by ''
result = original.replace(' ', ' ') result = original.replace(" '", "") # TODO: Deal with this properly
result = result.replace("'", "") # TODO: Deal with this properly
result = result.replace(' ', ' ')
result = result.split(' ') result = result.split(' ')
# Tokenize words into letters adding in SPACE_TOKEN where required # Tokenize words into letters adding in SPACE_TOKEN where required
@ -35,6 +31,7 @@ def text_to_sparse_tuple(originals):
# Creating sparse representation to feed the placeholder # Creating sparse representation to feed the placeholder
return sparse_tuple_from(results) return sparse_tuple_from(results)
def sparse_tuple_from(sequences, dtype=np.int32): def sparse_tuple_from(sequences, dtype=np.int32):
"""Create a sparse representention of x. """Create a sparse representention of x.
Args: Args:
@ -53,12 +50,12 @@ def sparse_tuple_from(sequences, dtype=np.int32):
values = np.asarray(values, dtype=dtype) values = np.asarray(values, dtype=dtype)
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1]+1], dtype=np.int64) shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1]+1], dtype=np.int64)
return (indices, values, shape); return tf.SparseTensor(indices=indices, values=values, shape=shape)
def sparse_tensor_value_to_text(value): def sparse_tensor_value_to_texts(value):
return sparse_tuple_to_text((value.indices, value.values, value.shape)) return sparse_tuple_to_texts((value.indices, value.values, value.shape))
def sparse_tuple_to_text(tuple): def sparse_tuple_to_texts(tuple):
indices = tuple[0] indices = tuple[0]
values = tuple[1] values = tuple[1]
results = [''] * tuple[2][0] results = [''] * tuple[2][0]