diff --git a/bin/import_cv2.py b/bin/import_cv2.py index df098f5c..7c85c8f0 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -31,11 +31,14 @@ MAX_SECS = 10 def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): - for dataset in ["train", "test", "dev", "validated", "other"]: - input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv") - if os.path.isfile(input_tsv): - print("Loading TSV file: ", input_tsv) - _maybe_convert_set(input_tsv, audio_dir, space_after_every_character) + exclude = [] + for dataset in ["test", "dev", "train", "validated", "other"]: + set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character) + if dataset in ["test", "dev"]: + exclude += set_samples + if dataset == "validated": + _maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character, + rows=set_samples, exclude=exclude) def one_sample(sample): @@ -72,47 +75,63 @@ def one_sample(sample): counter["too_long"] += 1 else: # This one is good - keep it for the target CSV - rows.append((os.path.split(wav_filename)[-1], file_size, label)) + rows.append((os.path.split(wav_filename)[-1], file_size, label, sample[2])) counter["all"] += 1 counter["total_time"] += frames return (counter, rows) -def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): - output_csv = os.path.join( - audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv") - ) +def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=None, rows=None, exclude=None): + exclude_transcripts = set() + exclude_speakers = set() + if exclude is not None: + for sample in exclude: + exclude_transcripts.add(sample[2]) + exclude_speakers.add(sample[3]) + + if rows is None: + rows = [] + input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv") + if not os.path.isfile(input_tsv): + return rows + print("Loading TSV file: ", input_tsv) + # Get audiofile path and transcript for each sentence in tsv + samples = [] + with open(input_tsv, encoding="utf-8") as input_tsv_file: + reader = csv.DictReader(input_tsv_file, delimiter="\t") + for row in reader: + samples.append((os.path.join(audio_dir, row["path"]), row["sentence"], row["client_id"])) + + counter = get_counter() + num_samples = len(samples) + + print("Importing mp3 files...") + pool = Pool() + bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] + bar.update(i) + bar.update(num_samples) + pool.close() + pool.join() + + imported_samples = get_imported_samples(counter) + assert counter["all"] == num_samples + assert len(rows) == imported_samples + print_import_report(counter, SAMPLE_RATE, MAX_SECS) + + output_csv = os.path.join(os.path.abspath(audio_dir), dataset + ".csv") print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) - - # Get audiofile path and transcript for each sentence in tsv - samples = [] - with open(input_tsv, encoding="utf-8") as input_tsv_file: - reader = csv.DictReader(input_tsv_file, delimiter="\t") - for row in reader: - samples.append((os.path.join(audio_dir, row["path"]), row["sentence"])) - - counter = get_counter() - num_samples = len(samples) - rows = [] - - print("Importing mp3 files...") - pool = Pool() - bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): - counter += processed[0] - rows += processed[1] - bar.update(i) - bar.update(num_samples) - pool.close() - pool.join() - with open(output_csv, "w", encoding="utf-8") as output_csv_file: print("Writing CSV file for DeepSpeech.py as: ", output_csv) writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES) writer.writeheader() bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR) - for filename, file_size, transcript in bar(rows): + for filename, file_size, transcript, speaker in bar(rows): + if transcript in exclude_transcripts or speaker in exclude_speakers: + continue if space_after_every_character: writer.writerow( { @@ -129,12 +148,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): "transcript": transcript, } ) - - imported_samples = get_imported_samples(counter) - assert counter["all"] == num_samples - assert len(rows) == imported_samples - - print_import_report(counter, SAMPLE_RATE, MAX_SECS) + return rows def _maybe_convert_wav(mp3_filename, wav_filename):