Merge cdd539e55c63d21df91bfad9dd9cccdc306842a9 into 04029871b6a54e35d08116917f88eb7d6ead2d02

This commit is contained in:
Fangjun Kuang 2021-11-09 15:09:11 +08:00 committed by GitHub
commit 6faaec22fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 5243 additions and 24 deletions

View File

@ -21,8 +21,12 @@ import warnings
from typing import Optional, Tuple
import torch
from conformer_ctc.transformer import (
Supervisions,
Transformer,
encoder_padding_mask,
)
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -26,8 +26,9 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
from conformer_ctc.conformer import Conformer
from conformer_lm.conformer import MaskedLmConformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -37,6 +38,7 @@ from icefall.decode import (
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_conformer_lm,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
@ -94,7 +96,10 @@ def get_parser():
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (6) conformer-lm. In addition to attention-decoder rescoring, it
also uses conformer lm for rescoring. See the model in the
directory ./conformer_lm
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
@ -106,7 +111,8 @@ def get_parser():
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, conformer-lm,
and nbest-oracle
""",
)
@ -117,8 +123,8 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
nbest, nbest-rescoring, attention-decoder, conformer_lm,
and nbest-oracle. A smaller value results in more unique paths.
""",
)
@ -147,6 +153,35 @@ def get_parser():
help="The lang dir",
)
parser.add_argument(
"--conformer-lm-exp-dir",
type=str,
default="conformer_lm/exp",
help="""The conformer lm exp dir.
Used only when method is conformer_lm.
""",
)
parser.add_argument(
"--conformer-lm-epoch",
type=int,
default=19,
help="""Used only when method is conformer_lm.
It specifies the checkpoint to use for the conformer
lm model.
""",
)
parser.add_argument(
"--conformer-lm-avg",
type=int,
default=1,
help="""Used only when method is conformer_lm.
It specifies number of checkpoints to average for
the conformer lm model.
""",
)
return parser
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
masked_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -334,6 +370,7 @@ def decode_one_batch(
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"conformer-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -354,7 +391,7 @@ def decode_one_batch(
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
elif params.method in ("attention-decoder", "conformer-lm"):
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
@ -364,16 +401,32 @@ def decode_one_batch(
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
if params.method == "attention-decoder":
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
else:
# It uses:
# attention_decoder + conformer_lm
best_path_dict = rescore_with_conformer_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
masked_lm_model=masked_lm_model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0, # TODO(fangjun): pass it as an argument
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -393,6 +446,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
masked_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -449,6 +503,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
masked_lm_model=masked_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
@ -584,6 +639,7 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"conformer-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
@ -607,7 +663,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"conformer-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
@ -655,6 +715,38 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.method == "conformer-lm":
logging.info("Loading conformer lm model")
# Note: If the parameters does not match
# the one used to save the checkpoint, it will
# throw while calling `load_state_dict`.
masked_lm_model = MaskedLmConformer(
num_classes=num_classes,
d_model=params.attention_dim,
nhead=params.nhead,
num_decoder_layers=params.num_decoder_layers,
)
if params.conformer_lm_avg == 1:
load_checkpoint(
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
masked_lm_model,
)
else:
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
filenames = []
for i in range(start, params.conformer_lm_epoch + 1):
if start >= 0:
filenames.append(
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
)
logging.info(f"averaging {filenames}")
masked_lm_model.to(device)
masked_lm_model.load_state_dict(
average_checkpoints(filenames, device=device)
)
else:
masked_lm_model = None
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
@ -668,6 +760,7 @@ def main():
dl=test_dl,
params=params,
model=model,
masked_lm_model=masked_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,

View File

@ -24,7 +24,7 @@ import logging
from pathlib import Path
import torch
from conformer import Conformer
from conformer_ctc.conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon

View File

@ -27,7 +27,7 @@ import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
from conformer_ctc.conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (

View File

@ -17,8 +17,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import (
from conformer_ctc.transformer import (
Transformer,
add_eos,
add_sos,
@ -26,6 +25,7 @@ from transformer import (
encoder_padding_mask,
generate_square_subsequent_mask,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():

View File

@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,823 @@
import torch
import torch.distributed as dist
import k2
import _k2
import logging
import sentencepiece as spm
from pathlib import Path
from typing import Optional, List, Tuple, Union
class LmDataset(torch.utils.data.Dataset):
"""
Torch dataset for language modeling data. This is a map-style dataset.
The indices are integers.
"""
def __init__(self, sentences: k2.RaggedTensor,
words: k2.RaggedTensor):
super(LmDataset, self).__init__()
self.sentences = sentences
self.words = words
def __len__(self):
# Total size on axis 0, == num sentences
return self.sentences.tot_size(0)
def __getitem__(self, i: int):
"""
Return the i'th sentence, as a list of ints (representing BPE pieces, without
bos or eos symbols).
"""
# in future will just do:
#return self.words[self.sentences[i]].tolist()
# It would be nicer if we could just return self.sentences[i].tolist(),
# but for now that operator on k2.RaggedInt does not support when the
# ragged has only 2 axes.
row_splits = self.sentences.shape.row_splits(1)
(begin, end) = row_splits[i:i+2].tolist()
sentence = self.sentences.data[begin:end]
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
return sentence.data.tolist()
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
"""
returns (train_lm_dataset, test_lm_dataset)
"""
d = torch.load(archive_fn)
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
sentences = d['data'] # a k2.RaggedTensor
with torch.random.fork_rng(devices=[]):
g = torch.manual_seed(0)
num_sentences = sentences.tot_size(0)
# probably the generator (g) argument to torch.randperm below is not necessary.
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
num_test_sentences = int(num_sentences * test_proportion)
axis=0
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
test_sents = sentences.arange(axis, 0, num_test_sentences)
return LmDataset(train_sents, words), LmDataset(test_sents, words)
def mask_and_pad(sentence: List[int],
seq_len: int,
bos_sym: int,
eos_sym: int,
blank_sym: int,
mask_proportion: float,
padding_proportion: float,
inv_mask_length: float,
unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]:
"""
This function contains part of the logic of collate_fn, broken out. It is responsible
for inserting masking and padding into the sequence `sentence`. Most of the arguments
are documented for `collate_fn` below.
Other args:
sentence: The original sentence to be masked and padded.
seq_len: The desired length of the lists to be returned
bos_sym, eos_sym, blank_sym, mask_proportion,
padding_proportion, inv_mask_length, unmasked_weight: see their documentation
as args to `collate_fn` below.
Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`,
where:
`src` is: [bos] + [the sentence after inserting blanks in place of padding
after regions to be masked] + [eos] + [blank padding to seq_len].
`src_masked` is as `src` but the masked regions have their values replaced with blank,
i.e. they are actually masked.
`tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len]
`weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked
positions, 1.0 for masked and padded positions, and 0.0 for positions that
correspond to blank-padding after the final [eos].
`randomizable` is a bool that is True for positions where the symbol in
in `src_masked` is not bos or eos or blank.
`attn_mask` is a bool that is False for positions in `src` and `src_masked` that
are between the initial [bos] and final [eos] inclusive; and True for
positions after the final [eos].
"""
sent_len = len(sentence)
assert sent_len + 3 <= seq_len
for w in sentence:
assert w not in [bos_sym, eos_sym, blank_sym]
num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]),
prob=torch.tensor([mask_proportion])).item())
num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item())
# Ensure the total length after bos, padding of masked sequences, and eos, is
# no greater than seq_len
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
if num_mask + num_pad == 0:
num_pad += 1
# num_split_points is the number of times we split the (masked+padded)
# region, so the total number of (masking+padding) subsequences will be
# num_split_points + 1. If num_mask positions are masked, then the
# remaining number of words is `sent_len - num_mask`, and any two
# masked regions must have at least one non-masked word between them,
# so num_split_points == number of masked regions - 1, must be
# no greater than `sent_len - num_mask`. The formula about
# mask_proportion * inv_mask_length / (1.0 - mask_proportion)
# is what's required (I think) so that inv_mask_length is the expected
# length of masked regions.
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
# Somehow this assertion failed, debugging it below.
assert num_split_points <= sent_len - num_mask
assert isinstance(num_split_points, int)
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
"""Splits a sequence of `length` items into `num_subseqs` possibly-empty
subsequences. The length distributions are geometric, not Poisson, i.e.
we choose the split locations with uniform probability rather than
randomly assigning each word to one subsequences. This gives us more
shorter/longer subsequences.
Require num_subseqs > 0
"""
boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length]
return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ]
mask_lengths = split_into_subseqs(num_mask, num_split_points + 1)
pad_lengths = split_into_subseqs(num_pad, num_split_points + 1)
# mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0.
# From this point we only refer to the mask_pad_lengths.
mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ]
num_subseqs = len(mask_pad_lengths)
assert num_subseqs > 0
# Now figure out how to distribute these subsequences throughout the actual
# sentence. The subsequences, if there are more than one, must not touch,
# i.e. there must be an actual word in between each subsequence, where the
# number of such "mandatory" words equals num_subseqs - 1. We also have to
# subtract `num_mask` words, since obviously the masked words cannot separate
# the masked regions.
reduced_len = sent_len - num_mask - (num_subseqs - 1)
assert reduced_len >= 0
# unmasked_lengths will be the lengths of the un-masked regions between the masked
# regions.
unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1)
for i in range(1, num_subseqs):
# Unmasked regions between masked regions must have length at least 1,
# we add 1 to unmasked regions that are not initial/final.
unmasked_lengths[i] = unmasked_lengths[i] + 1
assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len
# src_positions will be: for each position in the masked+padded sentence,
# the corresponding position in the source sentence `sentence`; or -1
# if this was padding.
src_positions = []
# `masked` will be: for each position in the masked+padded sentence, True if
# it was masked and False otherwise. (Note: it is False for padding
# locations, although this will not matter in the end).
masked = []
cur_pos = 0 # current position in source sentence
for i in range(num_subseqs + 1):
for j in range(unmasked_lengths[i]):
src_positions.append(cur_pos)
masked.append(False)
cur_pos += 1
if i < num_subseqs:
(mask_len, pad_len) = mask_pad_lengths[i]
for j in range(mask_len):
src_positions.append(cur_pos)
masked.append(True)
cur_pos += 1
for j in range(pad_len):
src_positions.append(-1)
masked.append(False)
assert cur_pos == len(sentence)
src = []
src_masked = []
tgt = []
weight = []
randomizable = []
src.append(bos_sym)
src_masked.append(bos_sym)
randomizable.append(False)
for i, src_pos in enumerate(src_positions):
is_masked = masked[i]
if src_pos >= 0:
src_word = sentence[src_pos]
src_masked.append(blank_sym if masked[i] else src_word)
src.append(src_word)
tgt.append(src_word)
weight.append(1.0 if masked[i] else unmasked_weight)
randomizable.append(not masked[i])
else:
# Padding inside a masked region
src_masked.append(blank_sym)
src.append(blank_sym)
tgt.append(blank_sym)
weight.append(1.0)
randomizable.append(False)
src.append(eos_sym)
src_masked.append(eos_sym)
tgt.append(eos_sym)
weight.append(unmasked_weight)
tgt.append(blank_sym)
weight.append(0.0)
randomizable.append(False)
attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src)))
for i in range(seq_len - len(src)):
src.append(blank_sym)
src_masked.append(blank_sym)
tgt.append(blank_sym)
weight.append(0.0)
randomizable.append(False)
return (src, src_masked, tgt, weight, randomizable, attn_mask)
# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444)
# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)
def collate_fn(sentences: List[List[int]],
bos_sym: int,
eos_sym: int,
blank_sym: int,
mask_proportion: float = 0.15,
padding_proportion: float = 0.15,
randomize_proportion: float = 0.05,
inv_mask_length: float = 0.25,
unmasked_weight: float = 0.25,
debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
Caution, this is not the collate_fn we give directly to the dataloader,
we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args]))
This formats a list-of-lists-of-int into 5 Tensors, explained below.
The key thing is that we mask out subsequences of random length within
these sentences, and force the network to predict the masked-out
subsequences (which have blanks appended to them to prevent the model
from knowing the exact length of the sequences it has to predict).
So it's like BERT but at the level of sequences rather than individual
words.
Args:
bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2.
Is allowed be the same as eos_sym (we are not necessarily
saying it will work best that way).
eos_sym: the integer id of the end-of-sentence symbol, e.g. 2.
blank_sym: the integer id of the blank symbol, e.g. 0 or 1.
mask_proportion: The proportion of words in each sentence that
are masked, interpreted as (roughly) the probability of any given
word being masked, although the masked locations will
tend to be in contiguous sequences (they are not independent).
padding_proportion: Like mask_proportion, but determines the
number of extra, blank symbols that are inserted as padding
at the end of masked regions (this ensures that the model
cannot know exactly how many words need to be inserted in
any given masked region.
randomize_proportion: The probability with which we replace
words that were not masked with randomly chosen words.
Like BERT, this is intended to force the model to predict
something reasonable at non-masked positions, and to make
this task harder than simply repeating the input.
inv_mask_length: This number determines how many separate
sub-sequences the (masked + padded) proportion of a sentence is split up
into, interpreted as the inverse of the expected length of
each *masked* region.
unmasked_weight: The weight to be applied to the log-likelihoods of
un-masked positions in sentences (predicting un-masked
positions is not completely trivial if randomize_proportion > 0).
Will be reflected in the returned tgt_weights tensor.
Returns a tuple (masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask,
tgt_weights),
all with 2 axes and the same shape: (num_sent, seq_len).
Their dtypes will be, respectively,
(torch.int64, torch.int64,
torch.int64, torch.bool,
torch.float)
masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol
appended, masked regions (including padding) replaced with blank,
and `randomize_proportion` non-masked symbols replaced with
symbols randomly taken from elsewhere in the sentences of this
minibatch. Then padded to a fixed length with blank.
src_symbols: Like masked_src_symbols, except with the masked symbols replaced
with the original symbols (but the padding that follows each
masked sub-sequence will still be blank)
tgt_symbols: The original sentences, with eos_symbol appended, and then
padded with blank to the same length as masked_symbols and
src_symbols.
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
account for all the sentence lengths not being identical
(makes each sentence's processing independent of seq_len).
Tensor of Bool of shape (num_sent, seq_len), with True
for masked positions (these are the blanks that follow the
eos_symbol in masked_src_symbols), False for un-masked positions.
tgt_weights: Weights that will be applied to the log-probabilities at
the output of the network. Will have 1.0 in positions
in `tgt_symbols` that were masked (including blank
padding at the end of masked regions), `unmasked_weight`
in other positions in the original sentences (including
terminating eos_symbol); and 0.0 in the remaining positions
corresponding to blank padding after the ends of
sentences.
"""
assert blank_sym not in [bos_sym, eos_sym]
max_sent_len = max([ len(s) for s in sentences])
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
# The following formula gives roughly 1 standard deviation above where we'd
# expect the maximum sentence length to be with masking and padding.. we use
# this as a hard upper limit, to prevent outliers from affecting the batch
# size too much. We use this as the size `seq_len`.
# The "+ 4" is to ensure there is always room for the BOS, EOS and at least
# two padding symbols.
seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5)
# srcs, srcs_masked, tgts and weights will be lists of the lists returned
# from `mask_and_pad`, one per sentence.
srcs = []
srcs_masked = []
tgts = []
weights = []
randomizables = []
attn_masks = []
for s in sentences:
(src, src_masked, tgt,
weight, randomizable,
attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym,
blank_sym, mask_proportion, padding_proportion,
inv_mask_length, unmasked_weight)
srcs.append(src)
srcs_masked.append(src_masked)
tgts.append(tgt)
weights.append(weight)
randomizables.append(randomizable)
attn_masks.append(attn_mask)
src_symbols = torch.tensor(srcs, dtype=torch.int64)
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
tgt_weights = torch.tensor(weights, dtype=torch.float)
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
attn_mask_sum.pop()
if len(attn_mask_sum) < seq_len:
seq_len = len(attn_mask_sum)
(src_symbols, masked_src_symbols,
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
tgt_weights[:,:seq_len])
if randomize_proportion > 0.0:
randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool)
randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2)
num_randomizable = randomizable_indexes.shape[0]
to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0]
num_to_randomize = to_randomize_indexes.numel()
# older versions of torch don't have tensor_split, so fake a simplified version of it.
# we'd be calling it as xxx.tensor_split(dim=1) if really in torc.
def tensor_split(t):
return (t[:,0], t[:,1])
random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize]
random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])]
random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes])
src_symbols[random_indexes_tuple] = random_symbols
masked_src_symbols[random_indexes_tuple] = random_symbols
# I set this to true and tested with:
# python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)'
#.. and ran a few times to check the values printed looked about right, and that no assertions failed.
if debug:
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
unmasked_weight,
masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask, tgt_weights)
return (masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask, tgt_weights)
def check_collated_tensors(sentences: List[List[int]],
bos_sym: int,
eos_sym: int,
blank_sym: int,
unmasked_weight: float,
masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask,
tgt_weights):
"""
This function checks the output of collate_fn, consider it test code. Please see
the documentation of collate_fn to understand the args.
"""
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
assert t.shape == masked_src_symbols.shape
tot_positions = src_symbols.numel()
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
src_key_padding_mask.tolist(), tgt_weights.tolist())
assert len(sentences) == len(masked_src_symbols)
tot_masked_positions = 0
tot_padded_positions = 0
tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos
tot_randomized_positions = 0
num_masked_subseqs = 0
tot_symbols = 0 # original symbols in sentences, no bos/eos
assert unmasked_weight > 0.001 # or this test code won't work..
for i in range(len(sentences)):
reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i]))
if sentences[i] != reconstructed_sent:
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
assert src[0] == masked_src[0] == bos_sym
for j in range(len(masked_src)):
assert masked_src[j] == blank_sym or masked_src[j] == src[j]
if src[j] not in [bos_sym, eos_sym, blank_sym]:
tot_symbols += 1
if j > 0:
assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym)
if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding..
assert src[j] == tgt[j - 1] # masked symbols are not randomized.
assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding
if weights[j - 1] == 1.0: # Not final blank padding...
if tgt[j - 1] == blank_sym:
tot_padded_positions += 1
else:
tot_masked_positions += 1
if masked_src[j + 1] != blank_sym:
num_masked_subseqs += 1
else:
assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001
if abs(weights[j - 1]-unmasked_weight) < 0.001:
tot_unmasked_positions += 1
if tgt[j - 1] != src[j]:
tot_randomized_positions += 1
if src_mask[j]: # if masked..
assert src[j] == blank_sym
assert tot_symbols == sum(len(x) for x in sentences)
assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences)
print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}")
print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)")
print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}")
print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}")
print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}")
# This shows some useful code about the BPE encoding.
# import sentencepiece as spm
# sp = spm.SentencePieceProcessor()
# sp.load(bpe_model_fn) # bpe.model
# sp.GetPieceSize(..)
# sp.Decode(...)
# sp.Encode(...)
# import dataset
# import torch
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x)))
# x = iter(train_dl)
# str(next(x))
# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]'
#
# or:
# import k2
# k2.ragged.to_list(next(x))
# [shows something similar].
#
# You'd really do something like:
# for epoch in range(max_epochs):
# for minibatch in train_dl:
# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44].
#
# First: we randomly choose one or more starting positins for a masked segment.
# Each sentence must have at least one masked segment (or there is no contribution to the loss function).
# We choose to have:
# num_masked_segments = max(1, len(sent) // 15)
#
# The length of the masked segment (this is the target for prediction), we set to the geometric
# distribution with the probability of success set to 3:
#
# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333
# Example of sampling:
# g.sample(sample_shape=torch.Size([10]))
#
# We now we randomly compute the location of the masked segments (length computed above) as follows:
# First, the masked segments must be separated by at least one non-masked word (else they would be
# a single segment). So for n masked segments, there are n-1 words required for minimal separation.
# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire
# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1
# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments,
# and at the end). This is the multinomial distribution, but we can more easily compute this
# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial().
#
# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions;
# this is done so the model never knows the exact length of the masked region. We can just use the
# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3
# (expected padding length is 3).
#
# At this point we know where the masked regions are and how much padding they have. We can format
# the result as three lists, of the same length:
#
# sent: contains the words in the sentence with, in masked
# positions, the original (target) words, then with
# blank in the blank-padding after masked positions.
#
# sent_augmented: `sent` with, at a small defined percentage of positions
# that were *not* masked, the real token replaced with a
# token randomly chosen from the tokens in the minibatch.
# (like BERT, we use this type of augmentation, so the model
# has to predict the original token).
#
# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except
# with masked positions and the blank padding after the masked regions
# both replaced with blank.
#
#
#
# The way these will be processed is as follows:
#
# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones.
# sent_in = [bos] + sent_augmented + [eos]
#
# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although
# # we only really care about this in masked regions.
# # The extra eos is so that the length is the same as
# # masked_sent_in and sent_in.
#
# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine,
# # this is a choice; we can perhaps
# # report these 2 parts of the loss
# # separately though.
# # <-- can also set the last element
# # of out_scale to a smaller number, since
# # it's a repeated eos.
#
#
# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum
# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly
# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out
# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols.
#
# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale)
# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale)
# with: (<blk>,<blk>,<eos>, 0) up to the maximum length of any sentence in the minibatch <- or could use
#
#
#
#
#
#
#
# # i.e. ones where masked_sent is blank and zeros elsewhere;
# # this pertains to positions in `sent_out`.
#
#
#
#
#
#
#
#
#
#
# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5)
class LmBatchSampler(torch.utils.data.Sampler):
"""
A sampler that returns a batch of integer indexes as a list, intended for use
with class LmDataset. The sentences returned in each batch will all be about
the same size, and the batch size is specified as a number of words (we also
provide an option that allows you to limit the max memory consumed by transformers)
Has support for distributed operation.
"""
def __init__(self, dataset: LmDataset,
symbols_per_batch: int,
length_ceil: float = 200.0,
length_floor: float = 4.0,
world_size: Optional[int] = None,
rank: int = None,
seed: int = 0,
delay_init: bool = False):
"""
Constructor documentation:
dataset: the LmDataset object that we are sampling from. This
class does not retain a reference to the LmDataset.
symbols_per_batch: The number of BPE symbols desired in each minibatch
length_floor: When the sentence length gets less than about this much,
the batch size stops increasing inversely with sentence
length. Prevent OOM on batches with short sentences.
length_ceil: After the sentence length gets more than about
this much, the batch size will start decreasing
as 1/(sentence-length^2). This is a mechanism to
avoid excessive memory consumption in transformers, when
sentence length gets long.
world_size: The world size for distributed operation; if None,
will be worked out from torch.distributed.
rank: The rank of this sampler/process for distributed operation; if None,
will be worked out from torch.distributed.
seed: The random seed
delay_init: If true, will omit calling self.set_epoch(0) at the
end of the __init__ function. In this case the caller
must call set_epoch(0). [Setting this option is necessary
to work with data-loader worker processes plus DDP, since
set_epoch() will use ddp, which I believe is a no-no prior
to initializing data-loaders.]
"""
self.seed = seed
self.symbols_per_batch = symbols_per_batch
self.length_floor = length_floor
self.quadratic_constant = 1.0 / length_ceil
self._maybe_init_distributed(world_size=world_size, rank=rank)
# a configuration constant we don't expose.
self.multiplicative_random_length = 0.05
# "indexes" is the subset of indexes into LmDataset that this
# sampler is reponsible for (all of them, in the non-distributed case).
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
# the sentences this sampler is responsible for, as sequences of words.
# It's a ragged tensor of int32
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
# with their respective lengths, in BPE pieces.
sentence_lengths = k2.ragged.index(word_lengths, sentences)
del sentences # save memory
assert isinstance(sentence_lengths, k2.RaggedTensor)
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
# support int32.)
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
# Convert into a simple tensor of float by adding lengths of words.
sentence_lengths = sentence_lengths.sum()
assert isinstance(sentence_lengths, torch.Tensor)
assert sentence_lengths.dtype == torch.float32
# self.sentence_lengths is a Tensor with dtype=torch.float32. It
# contains the lengths, in BPE tokens, of the sentences that this
# sampler is responsible for, whose real indexes are in
# `data_indexes` above (this is not stored, as we know the formula).
self.sentence_lengths = sentence_lengths
if not delay_init:
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
# Calling this on all copies of a DDP setup will sync the sizes so that
# all copies have the exact same number of batches. I think
# this needs to be called with the GPU device, not sure if it would
# work otherwise.
if self.world_size > 1:
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
min_size = min_size.to('cpu').item()
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
self.batch_indices = self.batch_indices[0:min_size]
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
if world_size is not None:
assert world_size >= 1
if rank is not None:
assert rank >= 0
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
return
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
assert self.rank < self.world_size
def set_epoch(self, epoch: int):
"""
Must be called at the beginning of each epoch, before initializing the DataLoader,
to re-shuffle the data. If this is not done, this sampler will give you the same batches
each time it is called.
"""
g = torch.manual_seed(self.rank + self.seed + epoch)
sentence_lengths = (self.sentence_lengths *
(1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length))
# This mechanism regulates the batch size so that we don't get OOM in transformers
# when the sentences are long.
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
# map to the original indexes into the dataset (the original sentence
# indexes), see torch.arange expression in the constructor. save as
# int32 just to save a little memory. self.indices are indexes into the
# LmDataset, just including the subset of indices that this sampler is
# responsible for (in terms of rank and world_size), and sorted by
# length with a small amount of randomization specific to the epoch.
self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32)
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
# saying which batch each element of values/indices belongs to.
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
batch_boundaries.add_(1)
self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0)
num_batches = self.batch_boundaries.numel() - 1
# self.batch_indices is a permutation of [0, 1, ... num_batches -
# 1]; it determines the order in which we access the batches. It's
# necessary to randomize the order of these, to avoid returning batches
# from shortest to longest sentences.
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
self._sync_sizes()
def __len__(self):
return len(self.batch_indices)
def __iter__(self):
"""
Iterator that yields lists of indices (i.e., integer indices into the LmDataset)
"""
for batch_idx in self.batch_indices:
batch_start = self.batch_boundaries[batch_idx].item()
batch_end = self.batch_boundaries[batch_idx + 1].item()
yield self.indices[batch_start:batch_end].tolist()
class CollateFn:
def __init__(self, **kwargs):
self.extra_args = kwargs
def __call__(self, sentences: List[List[int]]):
return collate_fn(sentences, **self.extra_args)
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
# a = iter(sampler)
# print(str(next(a)))
# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
# x = iter(train_dl)
# print(str(next(x)))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,156 @@
#!/usr/bin/env python3
# run with:
# python3 -m pytest test_conformer.py
import torch
import dataset # from .
from conformer import (
RelPosTransformerDecoder,
RelPosTransformerDecoderLayer,
MaskedLmConformer,
MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer,
RelPositionMultiheadAttention,
RelPositionalEncoding,
generate_square_subsequent_mask,
)
from torch.nn.utils.rnn import pad_sequence
def test_rel_position_multihead_attention():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
def test_masked_lm_conformer_encoder_layer():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
def test_masked_lm_conformer_encoder():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
norm = torch.nn.LayerNorm(embed_dim)
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
norm=norm)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
def test_transformer_decoder_layer_rel_pos():
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
def test_transformer_decoder_rel_pos():
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
decoder_norm = torch.nn.LayerNorm(embed_dim)
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
def test_masked_lm_conformer():
num_classes = 87
d_model = 256
model = MaskedLmConformer(num_classes,d_model)
N = 31
(masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask,
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
blank_sym=0)
# test forward() of MaskedLmConformer
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
src_key_padding_mask)
print("nll = ", nll)
loss = (nll * tgt_weights).sum()
print("loss = ", loss)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s, torch.device('cpu'))
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))

View File

@ -0,0 +1,32 @@
import k2
import torch
import _k2
import dataset
import os
from torch import multiprocessing as mp
import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
if __name__ == '__main__':
#mp.set_start_method('spawn')
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12344"
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
print("len(sampler) = ", len(sampler))
a = iter(sampler)
print(str(next(a)))
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
collate_fn=local_collate_fn,
num_workers=2)
x = iter(train_dl)
print(str(next(x)))

View File

@ -0,0 +1,39 @@
import k2
import torch
import _k2
import dataset
from dataset import LmDataset
import os
from torch import multiprocessing as mp
import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
if __name__ == '__main__':
mp.set_start_method('spawn')
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12344"
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
words = k2.RaggedInt('[[0][1 2]]')
sentences = k2.RaggedInt('[[1][][][][][]]')
train = LmDataset(sentences, words)
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
a = iter(sampler)
print(str(next(a)))
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
collate_fn=local_collate_fn,
num_workers=0)
x = iter(train_dl)
print(str(next(x)))

View File

@ -0,0 +1,623 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey)
#
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import dataset # from .
import madam # from .
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from conformer import MaskedLmConformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from madam import Gloam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.utils import (
AttributeDict,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- num_valid_batches: Number of batches of validation data to use each
time we compute validation loss
- symbols_per_batch: Number of symbols in each batch (sampler will
choose the number of sentences to satisfy this contraint).
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
"""
params = AttributeDict(
{
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
# exp_4, vs. exp_3, is using the Gloam optimizer with
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
# as well as the exponential part.
# exp_6, we change the decay from 0.85 to 0.9.
"exp_dir": Path("conformer_lm/exp_6"),
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
"num_tokens": 5000,
"blank_sym": 0,
"bos_sym": 1,
"eos_sym": 1,
"start_epoch": 2,
"num_epochs": 20,
"num_valid_batches": 200,
"symbols_per_batch": 5000,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 3000,
"beam_size": 10,
"accum_grad": 1,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"max_lrate": 5.0e-04
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
batch: Tuple,
is_training: bool,
):
"""
Compute training or validation loss given the model and its inputs
(this corresponds to log-prob of the targets, with weighting
of 1.0 for masked subsequences
(including padding blanks), and something smaller, e.g. 0.25,
for non-masked positions (this is not totally trivial due to
a small amount of randomization of symbols).
This loss is not normalized; you can divide by batch[4].sum()
to get a normalized loss (i.e. divide by soft-count).
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of MaskedLmConformer in our case.
batch:
A batch of data, actually a tuple of 5 tensors (on the device), as returned
by collate_fn in ./dataset.py.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
Returns:
Returns the loss as a scalar tensor.
"""
(masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask, tgt_weights) = batch
with torch.set_grad_enabled(is_training):
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
tgt_symbols, src_key_padding_mask)
loss = (tgt_nll * tgt_weights).sum()
assert loss.requires_grad == is_training
return loss
def compute_validation_loss(
device: torch.device,
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
if batch_idx == params.num_valid_batches:
break
batch = tuple(x.to(device) for x in batch)
loss = compute_loss(model, batch, is_training=False)
num_frames = batch[4].sum()
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
num_frames_cpu = num_frames.cpu().item()
tot_loss += loss_cpu
tot_frames += num_frames_cpu
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
(tot_loss, tot_frames) = s.cpu().tolist()
params.valid_loss = tot_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
device: torch.device,
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
device:
The device to use for training (model must be on this device)
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train() # training mode
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch = tuple(x.to(device) for x in batch)
try:
loss = compute_loss(
model=model,
batch=batch,
is_training=True,
)
optimizer.zero_grad()
loss.backward()
# We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
# gradient scale so this should not matter.
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
print(f"Error on batch of shape (N,T) = {batch[0].shape}")
raise e
loss_cpu = loss.detach().cpu().item()
num_frames_cpu = batch[4].sum().cpu().item()
tot_loss += loss_cpu
tot_frames += num_frames_cpu
params.tot_frames += num_frames_cpu
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch shape: {tuple(batch[0].shape)}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / num_frames_cpu,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
device=device,
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
num_tokens = params.num_tokens
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info("About to create model")
model = MaskedLmConformer(
num_classes=params.num_tokens,
d_model=params.attention_dim,
nhead=params.nhead,
num_decoder_layers=params.num_decoder_layers,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
# Caution: don't forget to do optimizer.set_epoch() with Gloam!
# Don't remove this warning!
optimizer = Gloam(
model.parameters(),
max_lrate=params.max_lrate,
first_decrease_epoch=1,
decay_per_epoch=0.9
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
eos_sym=params.eos_sym,
blank_sym=params.blank_sym,
mask_proportion=0.15,
padding_proportion=0.15,
randomize_proportion=0.05,
inv_mask_length=0.25,
unmasked_weight=0.25)
train_sampler = dataset.LmBatchSampler(train,
symbols_per_batch=params.symbols_per_batch,
world_size=world_size, rank=rank)
test_sampler = dataset.LmBatchSampler(test,
symbols_per_batch=params.symbols_per_batch,
world_size=world_size, rank=rank)
train_dl = torch.utils.data.DataLoader(train,
batch_sampler=train_sampler,
collate_fn=collate_fn)
valid_dl = torch.utils.data.DataLoader(test,
batch_sampler=test_sampler,
collate_fn=collate_fn)
for epoch in range(params.start_epoch, params.num_epochs):
train_sampler.set_epoch(epoch)
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
# optimizer.
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
device=device,
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
import k2
import torch
from icefall.lm.rescore import (
compute_alignment,
make_hyp_to_ref_map,
prepare_conformer_lm_inputs,
)
from icefall.utils import get_texts
@ -904,3 +909,198 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_conformer_lm(
lattice: k2.Fsa,
num_paths: int,
model: torch.nn.Module,
masked_lm_model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
blank_id: int,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
masked_lm_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
masked_lm_scale:
Optional. It specifies the scale for conformer_lm scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
alignment = compute_alignment(tokens, nbest.shape)
(
masked_src_symbols,
src_symbols,
tgt_symbols,
src_key_padding_mask,
tgt_weights,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=sos_id,
eos_id=eos_id,
blank_id=blank_id,
unmasked_weight=0.0,
)
masked_src_symbols = masked_src_symbols.to(torch.int64)
src_symbols = src_symbols.to(torch.int64)
tgt_symbols = tgt_symbols.to(torch.int64)
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
masked_src_symbols, src_key_padding_mask
)
tgt_nll = masked_lm_model.decoder_nll(
masked_lm_memory,
masked_lm_pos_emb,
src_symbols,
tgt_symbols,
src_key_padding_mask,
)
# nll means negative log-likelihood
# ll means log-likelihood
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
# Note: log-likelihood for those pairs that have identical src/tgt are 0
# since their tgt_weights is 0
# TODO(fangjun): Add documentation about why we do the following
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
tgt_ll_shape = k2.ragged.create_ragged_shape2(
row_splits=None,
row_ids=tgt_ll_shape_row_ids,
cached_tot_size=tgt_ll_shape_row_ids.numel(),
)
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
masked_lm_scores = ragged_tgt_ll.max()
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
token_ids = tokens.tolist()
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
if masked_lm_scale is None:
masked_lm_scale_list = [0.01, 0.05, 0.08]
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
masked_lm_scale_list = [masked_lm_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
for m_scale in masked_lm_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
+ m_scale * masked_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
ans[key] = best_path
return ans

363
icefall/lm/rescore.py Normal file
View File

@ -0,0 +1,363 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains rescoring code for NN LMs, e.g., conformer LM.
Here are the ideas about preparing the inputs for the conformer LM model
from an Nbest object.
Given an Nbest object `nbest`, we have:
- nbest.fsa
- nbest.shape, whose axes are [utt][path]
We can get `tokens` from nbest.fsa. The resulting `tokens` will have
2 axes [path][token]. Note, we should remove 0s from `tokens`.
We can generate the following inputs for the conformer LM model from `tokens`:
- masked_src
- src
- tgt
by using `k2.levenshtein_alignment`.
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
"""
from typing import Tuple
import k2
import torch
def make_key_padding_mask(lengths: torch.Tensor):
"""
TODO: add documentation
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
tensor([[False, False, False, True],
[False, True, True, True],
[False, False, False, False]])
"""
assert lengths.dim() == 1
bs = lengths.numel()
max_len = lengths.max().item()
device = lengths.device
seq_range = torch.arange(0, max_len, device=device)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_bos(ragged: k2.RaggedTensor, bos_id: int) -> k2.RaggedTensor:
"""Add BOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
bos_id:
The ID of the BOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with BOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_bos(a, bos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, bos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
bos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
def make_hyp_to_ref_map(row_splits: torch.Tensor):
"""
TODO: Add documentation.
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
>>> make_hyp_to_ref_map(row_splits)
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32)
"""
device = row_splits.device
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
offsets = row_splits[:-1]
map_tensor_list = []
for size, offset in zip(sizes, offsets):
# Explanation of the following operations
# assume size is 3, offset is 2
# torch.arange() + offset is [2, 3, 4]
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
# t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]]
# reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4]
map_tensor = (
(torch.arange(size, dtype=torch.int32, device=device) + offset)
.expand(size, size)
.t()
.reshape(-1)
)
map_tensor_list.append(map_tensor)
return torch.cat(map_tensor_list)
def make_repeat_map(row_splits: torch.Tensor):
"""
TODO: Add documentation.
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
>>> make_repeat_map(row_splits)
tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 3, 4], dtype=torch.int32)
"""
device = row_splits.device
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
offsets = row_splits[:-1]
map_tensor_list = []
for size, offset in zip(sizes, offsets):
# Explanation of the following operations
# assume size is 3, offset is 2
# torch.arange() + offset is [2, 3, 4]
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
# reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4]
map_tensor = (
(torch.arange(size, dtype=torch.int32, device=device) + offset)
.expand(size, size)
.reshape(-1)
)
map_tensor_list.append(map_tensor)
return torch.cat(map_tensor_list)
def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
"""Repeat the number of paths of an utterance to the number that
equals to the number of paths in the utterance.
For instance, if an utterance contains 3 paths: [path1 path2 path3],
after repeating, this utterance will contain 9 paths:
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
>>> tokens
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
>>> make_repeat(tokens)
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
TODO: Add documentation.
"""
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
if True:
indexes = make_repeat_map(tokens.shape.row_splits(1))
return tokens.index(axis=1, indexes=indexes)[0]
else:
# This branch produces the same result as the above branch.
# It's more readable. Will remove it later.
repeated = []
for p in tokens.tolist():
repeated.append(p * len(p))
return k2.RaggedTensor(repeated).to(tokens.device)
def compute_alignment(
tokens: k2.RaggedTensor,
shape: k2.RaggedShape,
) -> k2.Fsa:
"""
TODO: Add documentation.
Args:
tokens:
A ragged tensor with two axes: [path][token].
shape:
A ragged shape with two axes: [utt][path]
"""
assert tokens.tot_size(0) == shape.tot_size(1)
device = tokens.device
utt_path_shape = shape.compose(tokens.shape)
utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values)
utt_path_token_repeated = make_repeat(utt_path_token)
path_token_repeated = utt_path_token_repeated.remove_axis(0)
refs = k2.levenshtein_graph(tokens, device=device)
hyps = k2.levenshtein_graph(path_token_repeated, device=device)
hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1))
alignment = k2.levenshtein_alignment(
refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map
)
return alignment
def prepare_conformer_lm_inputs(
alignment: k2.Fsa,
bos_id: int,
eos_id: int,
blank_id: int,
unmasked_weight: float = 0.25,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
"""
TODO: add documentation.
Args:
alignments:
It is computed by :func:`compute_alignment`
"""
device = alignment.device
# alignment.arcs.shape has axes [fsa][state][arc]
# we remove axis 1, i.e., state, here
labels_shape = alignment.arcs.shape().remove_axis(1)
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
masked_src = masked_src.remove_values_eq(-1)
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
mode="constant", padding_value=blank_id
)
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
src = src.remove_values_eq(-1)
bos_src = add_bos(src, bos_id=bos_id)
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
# TODO: Do we need to remove 0s from tgt ?
tgt = tgt.remove_values_eq(-1)
tgt_eos = add_eos(tgt, eos_id=eos_id)
# add a blank here since tgt_eos does not start with bos
# assume blank id is 0
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
row_splits = tgt_eos.shape.row_splits(1)
lengths = row_splits[1:] - row_splits[:-1]
src_key_padding_mask = make_key_padding_mask(lengths)
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
weight = torch.full(
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
fill_value=1,
dtype=torch.float32,
device=device,
)
# find unmasked positions
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
weight[unmasked_positions] = unmasked_weight
# set weights for paddings
weight[src_key_padding_mask[:, 1:]] = 0
zeros = torch.zeros(weight.size(0), 1).to(weight)
weight = torch.cat((weight, zeros), dim=1)
# all other positions are assumed to be masked and
# have the default weight 1
return (
bos_masked_src_eos_pad,
bos_src_eos_pad,
tgt_eos_pad,
src_key_padding_mask,
weight,
)

145
test/lm/test_rescore.py Executable file
View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from icefall.lm.rescore import (
add_bos,
add_eos,
compute_alignment,
make_hyp_to_ref_map,
make_repeat,
make_repeat_map,
prepare_conformer_lm_inputs,
)
def test_add_bos():
bos_id = 100
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
bos_ragged = add_bos(ragged, bos_id)
expected = k2.RaggedTensor([[bos_id, 1, 2], [bos_id, 3], [bos_id, 0]])
assert str(bos_ragged) == str(expected)
def test_add_eos():
eos_id = 30
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
ragged_eos = add_eos(ragged, eos_id)
expected = k2.RaggedTensor(
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
)
assert str(ragged_eos) == str(expected)
def test_pad():
bos_id = 10
eos_id = 100
ragged = k2.RaggedTensor([[1, 2, 3], [5], [9, 8]])
bos_ragged = add_bos(ragged, bos_id)
bos_ragged_eos = add_eos(bos_ragged, eos_id)
blank_id = -1
padded = bos_ragged_eos.pad(mode="constant", padding_value=blank_id)
expected = torch.tensor(
[
[bos_id, 1, 2, 3, eos_id],
[bos_id, 5, eos_id, blank_id, blank_id],
[bos_id, 9, 8, eos_id, blank_id],
]
).to(padded)
assert torch.all(torch.eq(padded, expected))
def test_make_hyp_to_ref_map():
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
row_splits = a.shape.row_splits(1)
repeat_map = make_hyp_to_ref_map(row_splits)
# fmt: off
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
# fmt: on
assert torch.all(torch.eq(repeat_map, expected))
def test_make_repeat_map():
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
row_splits = a.shape.row_splits(1)
repeat_map = make_repeat_map(row_splits)
# fmt: off
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
3, 4, 5, 6]).to(repeat_map) # noqa
# fmt: on
assert torch.all(torch.eq(repeat_map, expected))
def test_make_repeat():
# fmt: off
a = k2.RaggedTensor([
[[1, 3, 5], [2, 6]],
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
])
b = make_repeat(a)
expected = k2.RaggedTensor([
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]],
[[1, 2, 3, 4], [2], [], [9, 10, 11],
[1, 2, 3, 4], [2], [], [9, 10, 11],
[1, 2, 3, 4], [2], [], [9, 10, 11],
[1, 2, 3, 4], [2], [], [9, 10, 11]],
])
# fmt: on
assert str(b) == str(expected)
def test_compute_alignment():
# fmt: off
tokens = k2.RaggedTensor([
# utt 0
[1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
# utt 1
[2, 3], [2],
])
# fmt: on
shape = k2.RaggedShape("[[x x x] [x x]]")
alignment = compute_alignment(tokens, shape)
(
masked_src,
src,
tgt,
src_key_padding_mask,
weight,
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
# print("masked src", masked_src)
# print("src", src)
# print("tgt", tgt)
# print("src_key_padding_mask", src_key_padding_mask)
# print("weight", weight)
def main():
test_add_bos()
test_add_eos()
test_pad()
test_make_repeat_map()
test_make_hyp_to_ref_map()
test_make_repeat()
test_compute_alignment()
if __name__ == "__main__":
main()