RNN decoding refactoring

This commit is contained in:
Erwan 2022-06-16 11:17:51 +02:00
parent 00b1c291a6
commit 71a9c33bca
4 changed files with 118 additions and 45 deletions

View File

@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
@ -46,8 +46,10 @@ from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
load_averaged_model,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -174,19 +176,47 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser return parser
def get_rnn_lm_model(params: AttributeDict): def get_rnn_lm_model(params: AttributeDict):
from rnn_lm.model import RnnLmModel from rnn_lm.model import RnnLmModel
# TODO: Pass the following options from command-line
rnn_lm_model = RnnLmModel( rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes, vocab_size=params.num_classes,
embedding_dim=1024, embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=1024, hidden_dim=params.rnn_lm_hidden_dim,
num_layers=2, num_layers=params.rnn_lm_num_layers,
tie_weights=False, tie_weights=params.rnn_lm_tie_weights,
) )
return rnn_lm_model return rnn_lm_model
@ -727,20 +757,16 @@ def main():
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 model = load_averaged_model(
filenames = [] params.exp_dir, model, params.epoch, params.avg, device
for i in range(start, params.epoch + 1): )
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm": if params.method == "rnn-lm":
rnn_lm_model = get_rnn_lm_model(params) rnn_lm_model = get_rnn_lm_model(params)
if params.rnn_lm_avg == 1: if params.rnn_lm_avg == 1:
@ -748,19 +774,16 @@ def main():
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model, rnn_lm_model,
) )
else:
start = params.rnn_lm_epoch - params.rnn_lm_avg + 1
filenames = []
for i in range(start, params.rnn_lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.rnn_lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
rnn_lm_model.to(device) rnn_lm_model.to(device)
rnn_lm_model.load_state_dict( else:
average_checkpoints(filenames, device=device) rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
) )
else: rnn_lm_model.eval()
rnn_lm_model = None
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -206,7 +206,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 200, "log_interval": 200,
"reset_interval": 2000, "reset_interval": 2000,
"valid_interval": 30000, "valid_interval": 10000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -539,6 +539,7 @@ def run(rank, world_size, args):
embedding_dim=params.embedding_dim, embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim, hidden_dim=params.hidden_dim,
num_layers=params.num_layers, num_layers=params.num_layers,
tie_weights=params.tie_weights,
) )
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)

View File

@ -22,6 +22,33 @@ import torch
from icefall.utils import add_eos, add_sos, get_texts from icefall.utils import add_eos, add_sos, get_texts
DEFAULT_LM_SCALE = [
0.01,
0.05,
0.08,
0.1,
0.3,
0.5,
0.6,
0.7,
0.9,
1.0,
1.1,
1.2,
1.3,
1.5,
1.7,
1.9,
2.0,
2.1,
2.2,
2.3,
2.5,
3.0,
4.0,
5.0,
]
def _intersect_device( def _intersect_device(
a_fsas: k2.Fsa, a_fsas: k2.Fsa,
@ -1082,28 +1109,17 @@ def rescore_with_rnn_lm(
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
if ngram_lm_scale is None: ngram_lm_scale_list = DEFAULT_LM_SCALE
ngram_lm_scale_list = [0.01, 0.05, 0.08] attention_scale_list = DEFAULT_LM_SCALE
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] rnn_lm_scale_list = DEFAULT_LM_SCALE
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] if ngram_lm_scale:
else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None: if attention_scale:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]
if rnn_lm_scale is None: if rnn_lm_scale:
rnn_lm_scale_list = [0.01, 0.05, 0.08]
rnn_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
rnn_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
rnn_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
rnn_lm_scale_list = [rnn_lm_scale] rnn_lm_scale_list = [rnn_lm_scale]
ans = dict() ans = dict()

View File

@ -35,6 +35,8 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -806,4 +808,35 @@ def optim_step_and_measure_param_change(
p_orig = old_parameters[n] p_orig = old_parameters[n]
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item() relative_change[n] = delta.item()
return relative_change return relative_change
def load_averaged_model(
model_dir: str,
model: torch.nn.Module,
epoch: int,
avg: int,
device: torch.device,
):
"""
Load a model which is the average of all checkpoints
:param model_dir: a str of the experiment directory
:param model: a torch.nn.Module instance
:param epoch: the last epoch to load from
:param avg: how many models to average from
:param device: move model to this device
:return: A model averaged
"""
# start cannot be negative
start = max(epoch - avg + 1, 0)
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
return model