From ebbab37776bb6b10e839b36b2f118139eacdb401 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 8 May 2023 20:48:17 +0800 Subject: [PATCH] Fix broken code in download_lm.py (#1046) --- egs/librispeech/ASR/local/download_lm.py | 50 +++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..da1648d06 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -41,9 +41,57 @@ import os import shutil from pathlib import Path -from lhotse.utils import urlretrieve_progress from tqdm.auto import tqdm +# This function is copied from lhotse +def tqdm_urlretrieve_hook(t): + """Wraps tqdm instance. + Don't forget to close() or __exit__() + the tqdm instance once you're done with it (easiest using `with` syntax). + Example + ------- + >>> from urllib.request import urlretrieve + >>> with tqdm(...) as t: + ... reporthook = tqdm_urlretrieve_hook(t) + ... urlretrieve(..., reporthook=reporthook) + + Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + """ + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + """ + b : int, optional + Number of blocks transferred so far [default: 1]. + bsize : int, optional + Size of each block (in tqdm units) [default: 1]. + tsize : int, optional + Total size (in tqdm units). If [default: None] or -1, + remains unchanged. + """ + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + + return update_to + + +# This function is copied from lhotse +def urlretrieve_progress(url, filename=None, data=None, desc=None): + """ + Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to + display a progress bar of the download. + Use "desc" argument to display a user-readable string that informs what is + being downloaded. + """ + from urllib.request import urlretrieve + + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t: + reporthook = tqdm_urlretrieve_hook(t) + return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data) + def get_args(): parser = argparse.ArgumentParser()