diff --git a/data/lm/generate_lm.py b/data/lm/generate_lm.py index 755d6921..2099c2ae 100644 --- a/data/lm/generate_lm.py +++ b/data/lm/generate_lm.py @@ -18,11 +18,22 @@ def convert_and_filter_topk(args): with io.TextIOWrapper( io.BufferedWriter(gzip.open(data_lower, "w+")), encoding="utf-8" ) as file_out: - with open(args.input_txt, encoding="utf-8") as file_in: - for line in progressbar.progressbar(file_in): - line_lower = line.lower() - counter.update(line_lower.split()) - file_out.write(line_lower) + + # Open the input file either from input.txt or input.txt.gz + _, file_extension = os.path.splitext(args.input_txt) + if file_extension == ".gz": + file_in = io.TextIOWrapper( + io.BufferedWriter(gzip.open(args.input_txt)), encoding="utf-8" + ) + else: + file_in = open(args.input_txt, encoding="utf-8") + + for line in progressbar.progressbar(file_in): + line_lower = line.lower() + counter.update(line_lower.split()) + file_out.write(line_lower) + + file_in.close() # Save top-k words print("\nSaving top {} words ...".format(args.top_k)) @@ -122,7 +133,7 @@ def main(): ) parser.add_argument( "--input_txt", - help="File path to a .txt with sample sentences", + help="Path to a file.txt or file.txt.gz with sample sentences", type=str, required=True, )