From 2eb21668e08d8c6a7b971c9aaf0fa008c452dbae Mon Sep 17 00:00:00 2001 From: Erwan Date: Mon, 20 Jun 2022 11:42:46 +0200 Subject: [PATCH] Update results and move rnn directory --- egs/librispeech/ASR/RESULTS.md | 40 +++++++++++++++---- egs/librispeech/ASR/conformer_ctc/decode.py | 22 ++++------ egs/librispeech/ASR/prepare.sh | 8 ++-- egs/librispeech/ASR/rnn_lm/__init__.py | 0 .../rnn_lm/compute_perplexity.py | 17 +++----- .../ASR => icefall}/rnn_lm/dataset.py | 0 .../ASR => icefall}/rnn_lm/model.py | 0 .../ASR => icefall}/rnn_lm/test_dataset.py | 0 .../rnn_lm/test_dataset_ddp.py | 0 .../ASR => icefall}/rnn_lm/test_model.py | 2 +- .../ASR => icefall}/rnn_lm/train.py | 11 +---- 11 files changed, 54 insertions(+), 46 deletions(-) delete mode 100644 egs/librispeech/ASR/rnn_lm/__init__.py rename {egs/librispeech/ASR => icefall}/rnn_lm/compute_perplexity.py (91%) rename {egs/librispeech/ASR => icefall}/rnn_lm/dataset.py (100%) rename {egs/librispeech/ASR => icefall}/rnn_lm/model.py (100%) rename {egs/librispeech/ASR => icefall}/rnn_lm/test_dataset.py (100%) rename {egs/librispeech/ASR => icefall}/rnn_lm/test_dataset_ddp.py (100%) rename {egs/librispeech/ASR => icefall}/rnn_lm/test_model.py (98%) rename {egs/librispeech/ASR => icefall}/rnn_lm/train.py (98%) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 66410ef40..cc15b79b4 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1139,17 +1139,18 @@ You can find the tensorboard log at: AttributeDict: params = AttributeDict( { @@ -768,7 +756,13 @@ def main(): rnn_lm_model = None if params.method == "rnn-lm": - rnn_lm_model = get_rnn_lm_model(params) + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) if params.rnn_lm_avg == 1: load_checkpoint( f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f9ce6faec..94e003036 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,9 +41,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - # 5000 - # 2000 - # 1000 + 5000 + 2000 + 1000 500 ) @@ -277,6 +277,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} ./local/compile_lg.py --lang-dir $lang_dir + done +fi if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then log "Stage 11: Generate LM training data" diff --git a/egs/librispeech/ASR/rnn_lm/__init__.py b/egs/librispeech/ASR/rnn_lm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/librispeech/ASR/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py similarity index 91% rename from egs/librispeech/ASR/rnn_lm/compute_perplexity.py rename to icefall/rnn_lm/compute_perplexity.py index e754b9534..8ba1b312f 100755 --- a/egs/librispeech/ASR/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -33,8 +33,8 @@ import torch from rnn_lm.dataset import get_dataloader from rnn_lm.model import RnnLmModel -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, setup_logger +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, load_averaged_model, setup_logger def get_parser(): @@ -165,17 +165,12 @@ def main(): if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) num_param_requires_grad = sum( diff --git a/egs/librispeech/ASR/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py similarity index 100% rename from egs/librispeech/ASR/rnn_lm/dataset.py rename to icefall/rnn_lm/dataset.py diff --git a/egs/librispeech/ASR/rnn_lm/model.py b/icefall/rnn_lm/model.py similarity index 100% rename from egs/librispeech/ASR/rnn_lm/model.py rename to icefall/rnn_lm/model.py diff --git a/egs/librispeech/ASR/rnn_lm/test_dataset.py b/icefall/rnn_lm/test_dataset.py similarity index 100% rename from egs/librispeech/ASR/rnn_lm/test_dataset.py rename to icefall/rnn_lm/test_dataset.py diff --git a/egs/librispeech/ASR/rnn_lm/test_dataset_ddp.py b/icefall/rnn_lm/test_dataset_ddp.py similarity index 100% rename from egs/librispeech/ASR/rnn_lm/test_dataset_ddp.py rename to icefall/rnn_lm/test_dataset_ddp.py diff --git a/egs/librispeech/ASR/rnn_lm/test_model.py b/icefall/rnn_lm/test_model.py similarity index 98% rename from egs/librispeech/ASR/rnn_lm/test_model.py rename to icefall/rnn_lm/test_model.py index e0876727d..5a216a3fb 100755 --- a/egs/librispeech/ASR/rnn_lm/test_model.py +++ b/icefall/rnn_lm/test_model.py @@ -40,7 +40,7 @@ def test_rnn_lm_model(): ) lengths = torch.tensor([4, 3, 2]) nll_loss = model(x, y, lengths) - + print(nll_loss) """ tensor([[1.1180, 1.3059, 1.2426, 1.7773], [1.4231, 1.2783, 1.7321, 0.0000], diff --git a/egs/librispeech/ASR/rnn_lm/train.py b/icefall/rnn_lm/train.py similarity index 98% rename from egs/librispeech/ASR/rnn_lm/train.py rename to icefall/rnn_lm/train.py index 5df9cfdb2..e20cf39c2 100755 --- a/egs/librispeech/ASR/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -120,13 +120,6 @@ def get_parser(): default=50, ) - parser.add_argument( - "--use-ddp-launch", - type=str2bool, - default=False, - help="True if using torch.distributed.launch", - ) - parser.add_argument( "--lm-data", type=str, @@ -165,7 +158,7 @@ def get_parser(): parser.add_argument( "--num-layers", type=int, - default=4, + default=3, help="Number of RNN layers the model", ) @@ -206,7 +199,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 200, "reset_interval": 2000, - "valid_interval": 10000, + "valid_interval": 5000, "env_info": get_env_info(), } )