RNN decoding refactoring

This commit is contained in:
Erwan 2022-06-16 11:17:51 +02:00
parent 00b1c291a6
commit 71a9c33bca
4 changed files with 118 additions and 45 deletions

View File

@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -46,8 +46,10 @@ from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -174,19 +176,47 @@ def get_parser():
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser
def get_rnn_lm_model(params: AttributeDict):
from rnn_lm.model import RnnLmModel
# TODO: Pass the following options from command-line
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=1024,
hidden_dim=1024,
num_layers=2,
tie_weights=False,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
return rnn_lm_model
@ -727,20 +757,16 @@ def main():
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = get_rnn_lm_model(params)
if params.rnn_lm_avg == 1:
@ -748,19 +774,16 @@ def main():
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
else:
start = params.rnn_lm_epoch - params.rnn_lm_avg + 1
filenames = []
for i in range(start, params.rnn_lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.rnn_lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
rnn_lm_model.to(device)
rnn_lm_model.load_state_dict(
average_checkpoints(filenames, device=device)
)
else:
rnn_lm_model = None
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
librispeech = LibriSpeechAsrDataModule(args)

View File

@ -206,7 +206,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 30000,
"valid_interval": 10000,
"env_info": get_env_info(),
}
)
@ -539,6 +539,7 @@ def run(rank, world_size, args):
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)

View File

@ -22,6 +22,33 @@ import torch
from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
0.01,
0.05,
0.08,
0.1,
0.3,
0.5,
0.6,
0.7,
0.9,
1.0,
1.1,
1.2,
1.3,
1.5,
1.7,
1.9,
2.0,
2.1,
2.2,
2.3,
2.5,
3.0,
4.0,
5.0,
]
def _intersect_device(
a_fsas: k2.Fsa,
@ -1082,28 +1109,17 @@ def rescore_with_rnn_lm(
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
ngram_lm_scale_list = DEFAULT_LM_SCALE
attention_scale_list = DEFAULT_LM_SCALE
rnn_lm_scale_list = DEFAULT_LM_SCALE
if ngram_lm_scale:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
if attention_scale:
attention_scale_list = [attention_scale]
if rnn_lm_scale is None:
rnn_lm_scale_list = [0.01, 0.05, 0.08]
rnn_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
rnn_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
rnn_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
if rnn_lm_scale:
rnn_lm_scale_list = [rnn_lm_scale]
ans = dict()

View File

@ -35,6 +35,8 @@ import torch.distributed as dist
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path]
@ -807,3 +809,34 @@ def optim_step_and_measure_param_change(
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change
def load_averaged_model(
model_dir: str,
model: torch.nn.Module,
epoch: int,
avg: int,
device: torch.device,
):
"""
Load a model which is the average of all checkpoints
:param model_dir: a str of the experiment directory
:param model: a torch.nn.Module instance
:param epoch: the last epoch to load from
:param avg: how many models to average from
:param device: move model to this device
:return: A model averaged
"""
# start cannot be negative
start = max(epoch - avg + 1, 0)
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
return model