mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
RNN decoding refactoring
This commit is contained in:
parent
00b1c291a6
commit
71a9c33bca
@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
@ -46,8 +46,10 @@ from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
load_averaged_model,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -174,19 +176,47 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnn-lm-tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_rnn_lm_model(params: AttributeDict):
|
||||
from rnn_lm.model import RnnLmModel
|
||||
|
||||
# TODO: Pass the following options from command-line
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.num_classes,
|
||||
embedding_dim=1024,
|
||||
hidden_dim=1024,
|
||||
num_layers=2,
|
||||
tie_weights=False,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
hidden_dim=params.rnn_lm_hidden_dim,
|
||||
num_layers=params.rnn_lm_num_layers,
|
||||
tie_weights=params.rnn_lm_tie_weights,
|
||||
)
|
||||
return rnn_lm_model
|
||||
|
||||
@ -727,20 +757,16 @@ def main():
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model = load_averaged_model(
|
||||
params.exp_dir, model, params.epoch, params.avg, device
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
rnn_lm_model = None
|
||||
if params.method == "rnn-lm":
|
||||
rnn_lm_model = get_rnn_lm_model(params)
|
||||
if params.rnn_lm_avg == 1:
|
||||
@ -748,19 +774,16 @@ def main():
|
||||
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||
rnn_lm_model,
|
||||
)
|
||||
else:
|
||||
start = params.rnn_lm_epoch - params.rnn_lm_avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.rnn_lm_epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.rnn_lm_exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
rnn_lm_model.to(device)
|
||||
rnn_lm_model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device)
|
||||
else:
|
||||
rnn_lm_model = load_averaged_model(
|
||||
params.rnn_lm_exp_dir,
|
||||
rnn_lm_model,
|
||||
params.rnn_lm_epoch,
|
||||
params.rnn_lm_avg,
|
||||
device,
|
||||
)
|
||||
else:
|
||||
rnn_lm_model = None
|
||||
rnn_lm_model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
|
@ -206,7 +206,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 200,
|
||||
"reset_interval": 2000,
|
||||
"valid_interval": 30000,
|
||||
"valid_interval": 10000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -539,6 +539,7 @@ def run(rank, world_size, args):
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
tie_weights=params.tie_weights,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
@ -22,6 +22,33 @@ import torch
|
||||
|
||||
from icefall.utils import add_eos, add_sos, get_texts
|
||||
|
||||
DEFAULT_LM_SCALE = [
|
||||
0.01,
|
||||
0.05,
|
||||
0.08,
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.5,
|
||||
1.7,
|
||||
1.9,
|
||||
2.0,
|
||||
2.1,
|
||||
2.2,
|
||||
2.3,
|
||||
2.5,
|
||||
3.0,
|
||||
4.0,
|
||||
5.0,
|
||||
]
|
||||
|
||||
|
||||
def _intersect_device(
|
||||
a_fsas: k2.Fsa,
|
||||
@ -1082,28 +1109,17 @@ def rescore_with_rnn_lm(
|
||||
|
||||
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
ngram_lm_scale_list = DEFAULT_LM_SCALE
|
||||
attention_scale_list = DEFAULT_LM_SCALE
|
||||
rnn_lm_scale_list = DEFAULT_LM_SCALE
|
||||
|
||||
if ngram_lm_scale:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
if attention_scale:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
if rnn_lm_scale is None:
|
||||
rnn_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
rnn_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
rnn_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
rnn_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
if rnn_lm_scale:
|
||||
rnn_lm_scale_list = [rnn_lm_scale]
|
||||
|
||||
ans = dict()
|
||||
|
@ -35,6 +35,8 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import average_checkpoints
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
@ -806,4 +808,35 @@ def optim_step_and_measure_param_change(
|
||||
p_orig = old_parameters[n]
|
||||
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
||||
relative_change[n] = delta.item()
|
||||
return relative_change
|
||||
return relative_change
|
||||
|
||||
|
||||
def load_averaged_model(
|
||||
model_dir: str,
|
||||
model: torch.nn.Module,
|
||||
epoch: int,
|
||||
avg: int,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
Load a model which is the average of all checkpoints
|
||||
|
||||
:param model_dir: a str of the experiment directory
|
||||
:param model: a torch.nn.Module instance
|
||||
|
||||
:param epoch: the last epoch to load from
|
||||
:param avg: how many models to average from
|
||||
:param device: move model to this device
|
||||
|
||||
:return: A model averaged
|
||||
"""
|
||||
|
||||
# start cannot be negative
|
||||
start = max(epoch - avg + 1, 0)
|
||||
filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)]
|
||||
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user