Add RNNLM for rescoring.

This commit is contained in:
Fangjun Kuang 2021-11-24 14:52:49 +08:00
parent 774f6643cd
commit 8792dae99e
13 changed files with 2564 additions and 9 deletions

View File

@ -38,6 +38,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
@ -94,7 +95,9 @@ def get_parser():
is the decoding result. is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored - (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result. lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
""", """,
@ -106,7 +109,7 @@ def get_parser():
default=100, default=100,
help="""Number of paths for n-best based decoding method. help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""", """,
) )
@ -117,7 +120,7 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring. It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths. A smaller value results in more unique paths.
""", """,
) )
@ -151,14 +154,55 @@ def get_parser():
"--lm-dir", "--lm-dir",
type=str, type=str,
default="data/lm", default="data/lm",
help="""The LM dir. help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt It should contain either G_4_gram.pt or G_4_gram.fst.txt
""", """,
) )
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
return parser return parser
def get_rnn_lm_model(params: AttributeDict):
from rnn_lm.model import RnnLmModel
# TODO: Pass the following options from command-line
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=1024,
hidden_dim=1024,
num_layers=2,
tie_weights=False,
)
return rnn_lm_model
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
@ -185,6 +229,7 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: nn.Module,
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -217,6 +262,8 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -342,6 +389,7 @@ def decode_one_batch(
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"rnn-lm",
] ]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -369,8 +417,6 @@ def decode_one_batch(
G_with_epsilon_loops=G, G_with_epsilon_loops=G,
lm_scale_list=None, lm_scale_list=None,
) )
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder( best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice, lattice=rescored_lattice,
@ -382,6 +428,26 @@ def decode_one_batch(
eos_id=eos_id, eos_id=eos_id,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -400,6 +466,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa], HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
@ -417,6 +484,8 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG: HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding. The decoding graph. Used only when params.method is NOT ctc-decoding.
H: H:
@ -456,6 +525,7 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,
@ -504,7 +574,7 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder": if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
else: else:
@ -580,6 +650,10 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
HLG = None HLG = None
H = k2.ctc_topo( H = k2.ctc_topo(
@ -604,6 +678,7 @@ def main():
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"rnn-lm",
): ):
if not (params.lm_dir / "G_4_gram.pt").is_file(): if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
@ -635,7 +710,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)
@ -683,6 +762,27 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.method == "rnn-lm":
rnn_lm_model = get_rnn_lm_model(params)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
else:
start = params.rnn_lm_epoch - params.rnn_lm_avg + 1
filenames = []
for i in range(start, params.rnn_lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.rnn_lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
rnn_lm_model.to(device)
rnn_lm_model.load_state_dict(
average_checkpoints(filenames, device=device)
)
else:
rnn_lm_model = None
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only. # CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip # If you want to skip test-clean, you have to skip
@ -696,6 +796,7 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG, HLG=HLG,
H=H, H=H,
bpe_model=bpe_model, bpe_model=bpe_model,

View File

View File

@ -0,0 +1,228 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/compute_perplexity.py \
--epoch 4 \
--avg 2 \
--lm-data ./data/bpe_500/sorted_lm_data-test.pt
"""
import argparse
import logging
import math
from pathlib import Path
import torch
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=49,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="The experiment dir",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
parser.add_argument(
"--sos-id",
type=int,
default=1,
help="SOS ID",
)
parser.add_argument(
"--eos-id",
type=int,
default=1,
help="EOS ID",
)
parser.add_argument(
"--blank-id",
type=int,
default=0,
help="Blank ID",
)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = AttributeDict(vars(args))
print(params)
setup_logger(f"{params.exp_dir}/log-ppl/")
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,317 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import k2
import torch
from icefall.utils import AttributeDict
class LmDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: k2.RaggedTensor,
words: k2.RaggedTensor,
sentence_lengths: torch.Tensor,
max_sent_len: int,
batch_size: int,
):
"""
Args:
sentences:
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
words:
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
sentence_lengths:
A 1-D tensor of dtype torch.int32 containing number of tokens
of each sentence.
max_sent_len:
Maximum sentence length. It is used to change the batch size
dynamically. In general, we try to keep the product of
"max_sent_len in a batch" and "num_of_sent in a batch" being
a constant.
batch_size:
The expected batch size. It is changed dynamically according
to the "max_sent_len".
See `../local/prepare_lm_training_data.py` for how `sentences` and
`words` are generated. We assume that `sentences` are sorted by length.
See `../local/sort_lm_training_data.py`.
"""
super().__init__()
self.sentences = sentences
self.words = words
sentence_lengths = sentence_lengths.tolist()
assert batch_size > 0, batch_size
assert max_sent_len > 1, max_sent_len
batch_indexes = []
num_sentences = sentences.dim0
cur = 0
while cur < num_sentences:
sz = sentence_lengths[cur] // max_sent_len + 1
# Assume the current sentence has 3 * max_sent_len tokens,
# in the worst case, the subsequent sentences also have
# this number of tokens, we should reduce the batch size
# so that this batch will not contain too many tokens
actucal_batch_size = batch_size // sz + 1
actucal_batch_size = min(actucal_batch_size, batch_size)
end = cur + actucal_batch_size
end = min(end, num_sentences)
this_batch_indexes = torch.arange(cur, end).tolist()
batch_indexes.append(this_batch_indexes)
cur = end
assert batch_indexes[-1][-1] == num_sentences - 1
self.batch_indexes = k2.RaggedTensor(batch_indexes)
def __len__(self) -> int:
"""Return number of batches in this dataset"""
return self.batch_indexes.dim0
def __getitem__(self, i: int) -> k2.RaggedTensor:
"""Get the i'th batch in this dataset
Return a ragged tensor with 2 axes [sentence][token].
"""
assert 0 <= i < len(self), i
# indexes is a 1-D tensor containing sentence indexes
indexes = self.batch_indexes[i]
# sentence_words is a ragged tensor with 2 axes
# [sentence][word]
sentence_words = self.sentences[indexes]
# in case indexes contains only 1 entry, the returned
# sentence_words is a 1-D tensor, we have to convert
# it to a ragged tensor
if isinstance(sentence_words, torch.Tensor):
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
# sentence_word_tokens is a ragged tensor with 3 axes
# [sentence][word][token]
sentence_word_tokens = self.words.index(sentence_words)
assert sentence_word_tokens.num_axes == 3
sentence_tokens = sentence_word_tokens.remove_axis(1)
return sentence_tokens
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
class LmDatasetCollate:
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
"""
Args:
sos_id:
Token ID of the SOS symbol.
eos_id:
Token ID of the EOS symbol.
blank_id:
Token ID of the blank symbol.
"""
self.sos_id = sos_id
self.eos_id = eos_id
self.blank_id = blank_id
def __call__(
self, batch: List[k2.RaggedTensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return a tuple containing 3 tensors:
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence starting with `self.sos_id`. It is padded to
the max sentence length with `self.blank_id`.
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence ending with `self.eos_id` before padding.
Then it is padded to the max sentence length with
`self.blank_id`.
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
tokens of each sentence before padding.
"""
# The batching stuff has already been done in LmDataset
assert len(batch) == 1
sentence_tokens = batch[0]
row_splits = sentence_tokens.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
x = sentence_tokens_with_sos.pad(
mode="constant", padding_value=self.blank_id
)
y = sentence_tokens_with_eos.pad(
mode="constant", padding_value=self.blank_id
)
sentence_token_lengths += 1 # plus 1 since we added a SOS
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
def get_dataloader(
filename: str,
is_distributed: bool,
params: AttributeDict,
) -> torch.utils.data.DataLoader:
"""Get dataloader for LM training.
Args:
filename:
Path to the file containing LM data. The file is assumed to
be generated by `../local/sort_lm_training_data.py`.
is_distributed:
True if using DDP training. False otherwise.
params:
Set `get_params()` from `rnn_lm/train.py`
Returns:
Return a dataloader containing the LM data.
"""
lm_data = torch.load(filename)
words = lm_data["words"]
sentences = lm_data["sentences"]
sentence_lengths = lm_data["sentence_lengths"]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=params.max_sent_len,
batch_size=params.batch_size,
)
if is_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
else:
sampler = None
collate_fn = LmDatasetCollate(
sos_id=params.sos_id,
eos_id=params.eos_id,
blank_id=params.blank_id,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=sampler is None,
num_workers=2,
)
return dataloader

View File

@ -0,0 +1,120 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
class RnnLmModel(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
hidden_dim: int,
num_layers: int,
tie_weights: bool = False,
):
"""
Args:
vocab_size:
Vocabulary size of BPE model.
embedding_dim:
Input embedding dimension.
hidden_dim:
Hidden dimension of RNN layers.
num_layers:
Number of RNN layers.
tie_weights:
True to share the weights between the input embedding layer and the
last output linear layer. See https://arxiv.org/abs/1608.05859
and https://arxiv.org/abs/1611.01462
"""
super().__init__()
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.rnn = torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
)
self.output_linear = torch.nn.Linear(
in_features=hidden_dim, out_features=vocab_size
)
self.vocab_size = vocab_size
if tie_weights:
logging.info("Tying weights")
assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor with shape (N, L). Each row
contains token IDs for a sentence and starts with the SOS token.
y:
A shifted version of `x` and with EOS appended.
lengths:
A 1-D tensor of shape (N,). It contains the sentence lengths
before padding.
Returns:
Return a 2-D tensor of shape (N, L) containing negative log-likelihood
loss values. Note: Loss values for padding positions are set to 0.
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert lengths.ndim == 1, lengths.ndim
assert x.shape == y.shape, (x.shape, y.shape)
batch_size = x.size(0)
assert lengths.size(0) == batch_size, (lengths.size(0), batch_size)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, _ = self.rnn(embedding)
logits = self.output_linear(rnn_out)
# Note: No need to use `log_softmax()` here
# since F.cross_entropy() expects unnormalized probabilities
# nll_loss is of shape (N*L,)
# nll -> negative log-likelihood
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
# Set loss values for padding positions to 0
mask = make_pad_mask(lengths).reshape(-1)
nll_loss.masked_fill_(mask, 0)
nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from rnn_lm.dataset import LmDataset, LmDatasetCollate
def main():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
print(dataset.sentences)
print(dataset.words)
print(dataset.batch_indexes)
print(len(dataset))
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=collate_fn
)
for i in dataloader:
print(i)
# I've checked the output manually; the output is as expected.
if __name__ == "__main__":
main()

View File

@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
import torch.multiprocessing as mp
from rnn_lm.dataset import LmDataset, LmDatasetCollate
from torch import distributed as dist
def generate_data():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
return sentences, words, sentence_lengths
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12352"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
sentences, words, sentence_lengths = generate_data()
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=False,
)
for i in dataloader:
print(f"rank: {rank}", i)
dist.destroy_process_group()
def main():
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from rnn_lm.model import RnnLmModel
def test_rnn_lm_model():
vocab_size = 4
model = RnnLmModel(
vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2
)
x = torch.tensor(
[
[1, 3, 2, 2],
[1, 2, 2, 0],
[1, 2, 0, 0],
]
)
y = torch.tensor(
[
[3, 2, 2, 1],
[2, 2, 1, 0],
[2, 1, 0, 0],
]
)
lengths = torch.tensor([4, 3, 2])
nll_loss = model(x, y, lengths)
print(nll_loss)
"""
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
[1.4231, 1.2783, 1.7321, 0.0000],
[1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
"""
def test_rnn_lm_model_tie_weights():
model = RnnLmModel(
vocab_size=10,
embedding_dim=10,
hidden_dim=10,
num_layers=2,
tie_weights=True,
)
assert model.input_embedding.weight is model.output_linear.weight
def main():
test_rnn_lm_model()
test_rnn_lm_model_tie_weights()
if __name__ == "__main__":
torch.manual_seed(20211122)
main()

View File

@ -0,0 +1,607 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--num-epochs 20 \
--batch-size 200 \
If you want to use DDP training, e.g., a single node with 4 GPUs,
use:
python -m torch.distributed.launch \
--use_env \
--nproc_per_node 4 \
./rnn_lm/train.py \
--use-ddp-launch true \
--start-epoch 0 \
--num-epochs 10 \
--batch-size 200
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from lhotse.utils import fix_random_seed
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp_small",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
)
parser.add_argument(
"--use-ddp-launch",
type=str2bool,
default=False,
help="True if using torch.distributed.launch",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
# LM training/validation data
"lm_data": "data/lm_training_bpe_500/sorted_lm_data.pt",
"lm_data_valid": "data/lm_training_bpe_500/sorted_lm_data-valid.pt",
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
# model related
#
# vocab size of the BPE model
"vocab_size": 500,
"embedding_dim": 1024,
"hidden_dim": 1024,
"num_layers": 2,
#
"lr": 1e-3,
"weight_decay": 1e-6,
#
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 30000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model, e.g., RnnLmModel.
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_ppl", tot_ppl, params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.use_ddp_launch:
local_rank = get_local_rank()
else:
local_rank = rank
logging.warning(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
setup_logger(
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
)
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[local_rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=world_size > 1,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=world_size > 1,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if world_size > 1:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
if args.use_ddp_launch:
# for torch.distributed.lanunch
rank = get_rank()
world_size = get_world_size()
print(f"rank: {rank}, world_size: {world_size}")
# This following is a hack as the default log level
# is warning
logging.info = logging.warning
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,607 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--num-epochs 20 \
--batch-size 200 \
If you want to use DDP training, e.g., a single node with 4 GPUs,
use:
python -m torch.distributed.launch \
--use_env \
--nproc_per_node 4 \
./rnn_lm/train.py \
--use-ddp-launch true \
--start-epoch 0 \
--num-epochs 10 \
--batch-size 200
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from lhotse.utils import fix_random_seed
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
)
parser.add_argument(
"--use-ddp-launch",
type=str2bool,
default=False,
help="True if using torch.distributed.launch",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
# LM training/validation data
"lm_data": "data/lm_training_bpe_500/sorted_lm_data.pt",
"lm_data_valid": "data/lm_training_bpe_500/sorted_lm_data-valid.pt",
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
# model related
#
# vocab size of the BPE model
"vocab_size": 500,
"embedding_dim": 2048,
"hidden_dim": 2048,
"num_layers": 4,
#
"lr": 1e-3,
"weight_decay": 1e-6,
#
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 30000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model, e.g., RnnLmModel.
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_ppl", tot_ppl, params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.use_ddp_launch:
local_rank = get_local_rank()
else:
local_rank = rank
logging.warning(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
setup_logger(
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
)
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[local_rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=world_size > 1,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=world_size > 1,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if world_size > 1:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
if args.use_ddp_launch:
# for torch.distributed.lanunch
rank = get_rank()
world_size = get_world_size()
print(f"rank: {rank}, world_size: {world_size}")
# This following is a hack as the default log level
# is warning
logging.info = logging.warning
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Union
import k2 import k2
import torch import torch
from icefall.utils import get_texts from icefall.utils import add_eos, add_sos, get_texts
def _intersect_device( def _intersect_device(
@ -903,3 +903,172 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path ans[key] = best_path
return ans return ans
def rescore_with_rnn_lm(
lattice: k2.Fsa,
num_paths: int,
rnn_lm_model: torch.nn.Module,
model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
blank_id: int,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
rnn_lm_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
rnn_lm_scale:
Optional. It specifies the scale for RNN LM scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
if len(token_ids) == 0:
print("Warning: rescore_with_attention_decoder(): empty token-ids")
return None
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
# Now for RNN LM
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64)
y_tokens = y_tokens.to(torch.int64)
sentence_lengths = sentence_lengths.to(torch.int64)
rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
assert rnn_lm_nll.ndim == 2
assert rnn_lm_nll.shape[0] == len(token_ids)
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale]
if rnn_lm_scale is None:
rnn_lm_scale_list = [0.01, 0.05, 0.08]
rnn_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
rnn_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
rnn_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
rnn_lm_scale_list = [rnn_lm_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
for r_scale in rnn_lm_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
+ r_scale * rnn_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa
ans[key] = best_path
return ans

View File

@ -637,3 +637,128 @@ class MetricsTracker(collections.defaultdict):
""" """
for k, v in self.norm_items(): for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx) tb_writer.add_scalar(prefix + k, v, batch_idx)
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)

View File

@ -22,9 +22,12 @@ import torch
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
add_eos,
add_sos,
encode_supervisions, encode_supervisions,
get_env_info, get_env_info,
get_texts, get_texts,
make_pad_mask,
) )
@ -130,3 +133,35 @@ def test_attribute_dict():
def test_get_env_info(): def test_get_env_info():
s = get_env_info() s = get_env_info()
print(s) print(s)
def test_makd_pad_mask():
lengths = torch.tensor([1, 3, 2])
mask = make_pad_mask(lengths)
expected = torch.tensor(
[
[False, True, True],
[False, False, False],
[False, False, True],
]
)
assert torch.all(torch.eq(mask, expected))
assert (~expected).sum() == lengths.sum()
def test_add_sos():
sos_id = 100
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
sos_ragged = add_sos(ragged, sos_id)
expected = k2.RaggedTensor([[sos_id, 1, 2], [sos_id, 3], [sos_id, 0]])
assert str(sos_ragged) == str(expected)
def test_add_eos():
eos_id = 30
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
ragged_eos = add_eos(ragged, eos_id)
expected = k2.RaggedTensor(
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
)
assert str(ragged_eos) == str(expected)