mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge cdd539e55c63d21df91bfad9dd9cccdc306842a9 into 04029871b6a54e35d08116917f88eb7d6ead2d02
This commit is contained in:
commit
6faaec22fc
@ -21,8 +21,12 @@ import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from conformer_ctc.transformer import (
|
||||
Supervisions,
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
|
@ -26,8 +26,9 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer_ctc.conformer import Conformer
|
||||
from conformer_lm.conformer import MaskedLmConformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -37,6 +38,7 @@ from icefall.decode import (
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_conformer_lm,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -94,7 +96,10 @@ def get_parser():
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (6) conformer-lm. In addition to attention-decoder rescoring, it
|
||||
also uses conformer lm for rescoring. See the model in the
|
||||
directory ./conformer_lm
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
@ -106,7 +111,8 @@ def get_parser():
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
nbest, nbest-rescoring, attention-decoder, conformer-lm,
|
||||
and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
@ -117,8 +123,8 @@ def get_parser():
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
nbest, nbest-rescoring, attention-decoder, conformer_lm,
|
||||
and nbest-oracle. A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -147,6 +153,35 @@ def get_parser():
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-exp-dir",
|
||||
type=str,
|
||||
default="conformer_lm/exp",
|
||||
help="""The conformer lm exp dir.
|
||||
Used only when method is conformer_lm.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-epoch",
|
||||
type=int,
|
||||
default=19,
|
||||
help="""Used only when method is conformer_lm.
|
||||
It specifies the checkpoint to use for the conformer
|
||||
lm model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-lm-avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Used only when method is conformer_lm.
|
||||
It specifies number of checkpoints to average for
|
||||
the conformer lm model.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
masked_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -334,6 +370,7 @@ def decode_one_batch(
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
@ -354,7 +391,7 @@ def decode_one_batch(
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
elif params.method in ("attention-decoder", "conformer-lm"):
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
@ -364,6 +401,7 @@ def decode_one_batch(
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
if params.method == "attention-decoder":
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
@ -374,6 +412,21 @@ def decode_one_batch(
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
# It uses:
|
||||
# attention_decoder + conformer_lm
|
||||
best_path_dict = rescore_with_conformer_lm(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=0, # TODO(fangjun): pass it as an argument
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
@ -393,6 +446,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
masked_lm_model: Optional[nn.Module],
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
@ -449,6 +503,7 @@ def decode_dataset(
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
@ -584,6 +639,7 @@ def main():
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
@ -607,7 +663,11 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
if params.method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"conformer-lm",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
@ -655,6 +715,38 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.method == "conformer-lm":
|
||||
logging.info("Loading conformer lm model")
|
||||
# Note: If the parameters does not match
|
||||
# the one used to save the checkpoint, it will
|
||||
# throw while calling `load_state_dict`.
|
||||
masked_lm_model = MaskedLmConformer(
|
||||
num_classes=num_classes,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
if params.conformer_lm_avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
|
||||
masked_lm_model,
|
||||
)
|
||||
else:
|
||||
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.conformer_lm_epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(
|
||||
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
masked_lm_model.to(device)
|
||||
masked_lm_model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device)
|
||||
)
|
||||
else:
|
||||
masked_lm_model = None
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
@ -668,6 +760,7 @@ def main():
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
masked_lm_model=masked_lm_model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
|
@ -24,7 +24,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
|
@ -27,7 +27,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
from conformer_ctc.conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
|
@ -17,8 +17,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformer import (
|
||||
from conformer_ctc.transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
add_sos,
|
||||
@ -26,6 +25,7 @@ from transformer import (
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
|
@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
|
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
@ -0,0 +1,823 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import k2
|
||||
import _k2
|
||||
import logging
|
||||
import sentencepiece as spm
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
|
||||
|
||||
class LmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Torch dataset for language modeling data. This is a map-style dataset.
|
||||
The indices are integers.
|
||||
"""
|
||||
def __init__(self, sentences: k2.RaggedTensor,
|
||||
words: k2.RaggedTensor):
|
||||
super(LmDataset, self).__init__()
|
||||
self.sentences = sentences
|
||||
self.words = words
|
||||
|
||||
|
||||
def __len__(self):
|
||||
# Total size on axis 0, == num sentences
|
||||
return self.sentences.tot_size(0)
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
"""
|
||||
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||
bos or eos symbols).
|
||||
"""
|
||||
# in future will just do:
|
||||
#return self.words[self.sentences[i]].tolist()
|
||||
|
||||
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||
# but for now that operator on k2.RaggedInt does not support when the
|
||||
# ragged has only 2 axes.
|
||||
row_splits = self.sentences.shape.row_splits(1)
|
||||
(begin, end) = row_splits[i:i+2].tolist()
|
||||
sentence = self.sentences.data[begin:end]
|
||||
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||
return sentence.data.tolist()
|
||||
|
||||
|
||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
|
||||
"""
|
||||
returns (train_lm_dataset, test_lm_dataset)
|
||||
"""
|
||||
|
||||
d = torch.load(archive_fn)
|
||||
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||
sentences = d['data'] # a k2.RaggedTensor
|
||||
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
g = torch.manual_seed(0)
|
||||
num_sentences = sentences.tot_size(0)
|
||||
# probably the generator (g) argument to torch.randperm below is not necessary.
|
||||
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
||||
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
|
||||
|
||||
num_test_sentences = int(num_sentences * test_proportion)
|
||||
|
||||
axis=0
|
||||
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
|
||||
test_sents = sentences.arange(axis, 0, num_test_sentences)
|
||||
|
||||
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
||||
|
||||
|
||||
def mask_and_pad(sentence: List[int],
|
||||
seq_len: int,
|
||||
bos_sym: int,
|
||||
eos_sym: int,
|
||||
blank_sym: int,
|
||||
mask_proportion: float,
|
||||
padding_proportion: float,
|
||||
inv_mask_length: float,
|
||||
unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]:
|
||||
"""
|
||||
This function contains part of the logic of collate_fn, broken out. It is responsible
|
||||
for inserting masking and padding into the sequence `sentence`. Most of the arguments
|
||||
are documented for `collate_fn` below.
|
||||
Other args:
|
||||
sentence: The original sentence to be masked and padded.
|
||||
seq_len: The desired length of the lists to be returned
|
||||
bos_sym, eos_sym, blank_sym, mask_proportion,
|
||||
padding_proportion, inv_mask_length, unmasked_weight: see their documentation
|
||||
as args to `collate_fn` below.
|
||||
|
||||
|
||||
Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`,
|
||||
where:
|
||||
`src` is: [bos] + [the sentence after inserting blanks in place of padding
|
||||
after regions to be masked] + [eos] + [blank padding to seq_len].
|
||||
`src_masked` is as `src` but the masked regions have their values replaced with blank,
|
||||
i.e. they are actually masked.
|
||||
`tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len]
|
||||
`weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked
|
||||
positions, 1.0 for masked and padded positions, and 0.0 for positions that
|
||||
correspond to blank-padding after the final [eos].
|
||||
`randomizable` is a bool that is True for positions where the symbol in
|
||||
in `src_masked` is not bos or eos or blank.
|
||||
`attn_mask` is a bool that is False for positions in `src` and `src_masked` that
|
||||
are between the initial [bos] and final [eos] inclusive; and True for
|
||||
positions after the final [eos].
|
||||
"""
|
||||
sent_len = len(sentence)
|
||||
assert sent_len + 3 <= seq_len
|
||||
|
||||
for w in sentence:
|
||||
assert w not in [bos_sym, eos_sym, blank_sym]
|
||||
|
||||
num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]),
|
||||
prob=torch.tensor([mask_proportion])).item())
|
||||
num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item())
|
||||
# Ensure the total length after bos, padding of masked sequences, and eos, is
|
||||
# no greater than seq_len
|
||||
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
||||
|
||||
if num_mask + num_pad == 0:
|
||||
num_pad += 1
|
||||
|
||||
# num_split_points is the number of times we split the (masked+padded)
|
||||
# region, so the total number of (masking+padding) subsequences will be
|
||||
# num_split_points + 1. If num_mask positions are masked, then the
|
||||
# remaining number of words is `sent_len - num_mask`, and any two
|
||||
# masked regions must have at least one non-masked word between them,
|
||||
# so num_split_points == number of masked regions - 1, must be
|
||||
# no greater than `sent_len - num_mask`. The formula about
|
||||
# mask_proportion * inv_mask_length / (1.0 - mask_proportion)
|
||||
# is what's required (I think) so that inv_mask_length is the expected
|
||||
# length of masked regions.
|
||||
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||
# Somehow this assertion failed, debugging it below.
|
||||
assert num_split_points <= sent_len - num_mask
|
||||
|
||||
assert isinstance(num_split_points, int)
|
||||
|
||||
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
|
||||
"""Splits a sequence of `length` items into `num_subseqs` possibly-empty
|
||||
subsequences. The length distributions are geometric, not Poisson, i.e.
|
||||
we choose the split locations with uniform probability rather than
|
||||
randomly assigning each word to one subsequences. This gives us more
|
||||
shorter/longer subsequences.
|
||||
Require num_subseqs > 0
|
||||
"""
|
||||
boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length]
|
||||
return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ]
|
||||
|
||||
mask_lengths = split_into_subseqs(num_mask, num_split_points + 1)
|
||||
pad_lengths = split_into_subseqs(num_pad, num_split_points + 1)
|
||||
# mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0.
|
||||
# From this point we only refer to the mask_pad_lengths.
|
||||
mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ]
|
||||
num_subseqs = len(mask_pad_lengths)
|
||||
assert num_subseqs > 0
|
||||
|
||||
# Now figure out how to distribute these subsequences throughout the actual
|
||||
# sentence. The subsequences, if there are more than one, must not touch,
|
||||
# i.e. there must be an actual word in between each subsequence, where the
|
||||
# number of such "mandatory" words equals num_subseqs - 1. We also have to
|
||||
# subtract `num_mask` words, since obviously the masked words cannot separate
|
||||
# the masked regions.
|
||||
reduced_len = sent_len - num_mask - (num_subseqs - 1)
|
||||
assert reduced_len >= 0
|
||||
# unmasked_lengths will be the lengths of the un-masked regions between the masked
|
||||
# regions.
|
||||
unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1)
|
||||
for i in range(1, num_subseqs):
|
||||
# Unmasked regions between masked regions must have length at least 1,
|
||||
# we add 1 to unmasked regions that are not initial/final.
|
||||
unmasked_lengths[i] = unmasked_lengths[i] + 1
|
||||
assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len
|
||||
|
||||
|
||||
# src_positions will be: for each position in the masked+padded sentence,
|
||||
# the corresponding position in the source sentence `sentence`; or -1
|
||||
# if this was padding.
|
||||
src_positions = []
|
||||
# `masked` will be: for each position in the masked+padded sentence, True if
|
||||
# it was masked and False otherwise. (Note: it is False for padding
|
||||
# locations, although this will not matter in the end).
|
||||
masked = []
|
||||
|
||||
cur_pos = 0 # current position in source sentence
|
||||
for i in range(num_subseqs + 1):
|
||||
for j in range(unmasked_lengths[i]):
|
||||
src_positions.append(cur_pos)
|
||||
masked.append(False)
|
||||
cur_pos += 1
|
||||
if i < num_subseqs:
|
||||
(mask_len, pad_len) = mask_pad_lengths[i]
|
||||
for j in range(mask_len):
|
||||
src_positions.append(cur_pos)
|
||||
masked.append(True)
|
||||
cur_pos += 1
|
||||
for j in range(pad_len):
|
||||
src_positions.append(-1)
|
||||
masked.append(False)
|
||||
assert cur_pos == len(sentence)
|
||||
|
||||
|
||||
src = []
|
||||
src_masked = []
|
||||
tgt = []
|
||||
weight = []
|
||||
randomizable = []
|
||||
|
||||
src.append(bos_sym)
|
||||
src_masked.append(bos_sym)
|
||||
randomizable.append(False)
|
||||
for i, src_pos in enumerate(src_positions):
|
||||
is_masked = masked[i]
|
||||
if src_pos >= 0:
|
||||
src_word = sentence[src_pos]
|
||||
src_masked.append(blank_sym if masked[i] else src_word)
|
||||
src.append(src_word)
|
||||
tgt.append(src_word)
|
||||
weight.append(1.0 if masked[i] else unmasked_weight)
|
||||
randomizable.append(not masked[i])
|
||||
else:
|
||||
# Padding inside a masked region
|
||||
src_masked.append(blank_sym)
|
||||
src.append(blank_sym)
|
||||
tgt.append(blank_sym)
|
||||
weight.append(1.0)
|
||||
randomizable.append(False)
|
||||
src.append(eos_sym)
|
||||
src_masked.append(eos_sym)
|
||||
tgt.append(eos_sym)
|
||||
weight.append(unmasked_weight)
|
||||
tgt.append(blank_sym)
|
||||
weight.append(0.0)
|
||||
randomizable.append(False)
|
||||
|
||||
attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src)))
|
||||
|
||||
for i in range(seq_len - len(src)):
|
||||
src.append(blank_sym)
|
||||
src_masked.append(blank_sym)
|
||||
tgt.append(blank_sym)
|
||||
weight.append(0.0)
|
||||
randomizable.append(False)
|
||||
|
||||
return (src, src_masked, tgt, weight, randomizable, attn_mask)
|
||||
|
||||
|
||||
# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||
|
||||
# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||
|
||||
def collate_fn(sentences: List[List[int]],
|
||||
bos_sym: int,
|
||||
eos_sym: int,
|
||||
blank_sym: int,
|
||||
mask_proportion: float = 0.15,
|
||||
padding_proportion: float = 0.15,
|
||||
randomize_proportion: float = 0.05,
|
||||
inv_mask_length: float = 0.25,
|
||||
unmasked_weight: float = 0.25,
|
||||
debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
Caution, this is not the collate_fn we give directly to the dataloader,
|
||||
we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args]))
|
||||
This formats a list-of-lists-of-int into 5 Tensors, explained below.
|
||||
The key thing is that we mask out subsequences of random length within
|
||||
these sentences, and force the network to predict the masked-out
|
||||
subsequences (which have blanks appended to them to prevent the model
|
||||
from knowing the exact length of the sequences it has to predict).
|
||||
So it's like BERT but at the level of sequences rather than individual
|
||||
words.
|
||||
|
||||
Args:
|
||||
bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2.
|
||||
Is allowed be the same as eos_sym (we are not necessarily
|
||||
saying it will work best that way).
|
||||
eos_sym: the integer id of the end-of-sentence symbol, e.g. 2.
|
||||
blank_sym: the integer id of the blank symbol, e.g. 0 or 1.
|
||||
mask_proportion: The proportion of words in each sentence that
|
||||
are masked, interpreted as (roughly) the probability of any given
|
||||
word being masked, although the masked locations will
|
||||
tend to be in contiguous sequences (they are not independent).
|
||||
padding_proportion: Like mask_proportion, but determines the
|
||||
number of extra, blank symbols that are inserted as padding
|
||||
at the end of masked regions (this ensures that the model
|
||||
cannot know exactly how many words need to be inserted in
|
||||
any given masked region.
|
||||
randomize_proportion: The probability with which we replace
|
||||
words that were not masked with randomly chosen words.
|
||||
Like BERT, this is intended to force the model to predict
|
||||
something reasonable at non-masked positions, and to make
|
||||
this task harder than simply repeating the input.
|
||||
inv_mask_length: This number determines how many separate
|
||||
sub-sequences the (masked + padded) proportion of a sentence is split up
|
||||
into, interpreted as the inverse of the expected length of
|
||||
each *masked* region.
|
||||
unmasked_weight: The weight to be applied to the log-likelihoods of
|
||||
un-masked positions in sentences (predicting un-masked
|
||||
positions is not completely trivial if randomize_proportion > 0).
|
||||
Will be reflected in the returned tgt_weights tensor.
|
||||
|
||||
Returns a tuple (masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights),
|
||||
all with 2 axes and the same shape: (num_sent, seq_len).
|
||||
Their dtypes will be, respectively,
|
||||
(torch.int64, torch.int64,
|
||||
torch.int64, torch.bool,
|
||||
torch.float)
|
||||
masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol
|
||||
appended, masked regions (including padding) replaced with blank,
|
||||
and `randomize_proportion` non-masked symbols replaced with
|
||||
symbols randomly taken from elsewhere in the sentences of this
|
||||
minibatch. Then padded to a fixed length with blank.
|
||||
src_symbols: Like masked_src_symbols, except with the masked symbols replaced
|
||||
with the original symbols (but the padding that follows each
|
||||
masked sub-sequence will still be blank)
|
||||
tgt_symbols: The original sentences, with eos_symbol appended, and then
|
||||
padded with blank to the same length as masked_symbols and
|
||||
src_symbols.
|
||||
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
|
||||
account for all the sentence lengths not being identical
|
||||
(makes each sentence's processing independent of seq_len).
|
||||
Tensor of Bool of shape (num_sent, seq_len), with True
|
||||
for masked positions (these are the blanks that follow the
|
||||
eos_symbol in masked_src_symbols), False for un-masked positions.
|
||||
tgt_weights: Weights that will be applied to the log-probabilities at
|
||||
the output of the network. Will have 1.0 in positions
|
||||
in `tgt_symbols` that were masked (including blank
|
||||
padding at the end of masked regions), `unmasked_weight`
|
||||
in other positions in the original sentences (including
|
||||
terminating eos_symbol); and 0.0 in the remaining positions
|
||||
corresponding to blank padding after the ends of
|
||||
sentences.
|
||||
"""
|
||||
assert blank_sym not in [bos_sym, eos_sym]
|
||||
max_sent_len = max([ len(s) for s in sentences])
|
||||
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
|
||||
|
||||
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
||||
|
||||
# The following formula gives roughly 1 standard deviation above where we'd
|
||||
# expect the maximum sentence length to be with masking and padding.. we use
|
||||
# this as a hard upper limit, to prevent outliers from affecting the batch
|
||||
# size too much. We use this as the size `seq_len`.
|
||||
# The "+ 4" is to ensure there is always room for the BOS, EOS and at least
|
||||
# two padding symbols.
|
||||
seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5)
|
||||
|
||||
|
||||
# srcs, srcs_masked, tgts and weights will be lists of the lists returned
|
||||
# from `mask_and_pad`, one per sentence.
|
||||
srcs = []
|
||||
srcs_masked = []
|
||||
tgts = []
|
||||
weights = []
|
||||
randomizables = []
|
||||
attn_masks = []
|
||||
for s in sentences:
|
||||
(src, src_masked, tgt,
|
||||
weight, randomizable,
|
||||
attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym,
|
||||
blank_sym, mask_proportion, padding_proportion,
|
||||
inv_mask_length, unmasked_weight)
|
||||
srcs.append(src)
|
||||
srcs_masked.append(src_masked)
|
||||
tgts.append(tgt)
|
||||
weights.append(weight)
|
||||
randomizables.append(randomizable)
|
||||
attn_masks.append(attn_mask)
|
||||
|
||||
src_symbols = torch.tensor(srcs, dtype=torch.int64)
|
||||
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
|
||||
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
|
||||
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
|
||||
tgt_weights = torch.tensor(weights, dtype=torch.float)
|
||||
|
||||
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
|
||||
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
|
||||
attn_mask_sum.pop()
|
||||
if len(attn_mask_sum) < seq_len:
|
||||
seq_len = len(attn_mask_sum)
|
||||
(src_symbols, masked_src_symbols,
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
|
||||
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
|
||||
tgt_weights[:,:seq_len])
|
||||
|
||||
if randomize_proportion > 0.0:
|
||||
randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool)
|
||||
randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2)
|
||||
num_randomizable = randomizable_indexes.shape[0]
|
||||
|
||||
to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0]
|
||||
num_to_randomize = to_randomize_indexes.numel()
|
||||
|
||||
# older versions of torch don't have tensor_split, so fake a simplified version of it.
|
||||
# we'd be calling it as xxx.tensor_split(dim=1) if really in torc.
|
||||
def tensor_split(t):
|
||||
return (t[:,0], t[:,1])
|
||||
|
||||
random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize]
|
||||
|
||||
random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])]
|
||||
random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes])
|
||||
src_symbols[random_indexes_tuple] = random_symbols
|
||||
masked_src_symbols[random_indexes_tuple] = random_symbols
|
||||
|
||||
|
||||
# I set this to true and tested with:
|
||||
# python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)'
|
||||
#.. and ran a few times to check the values printed looked about right, and that no assertions failed.
|
||||
if debug:
|
||||
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
|
||||
unmasked_weight,
|
||||
masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||
return (masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||
|
||||
|
||||
|
||||
def check_collated_tensors(sentences: List[List[int]],
|
||||
bos_sym: int,
|
||||
eos_sym: int,
|
||||
blank_sym: int,
|
||||
unmasked_weight: float,
|
||||
masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights):
|
||||
"""
|
||||
This function checks the output of collate_fn, consider it test code. Please see
|
||||
the documentation of collate_fn to understand the args.
|
||||
"""
|
||||
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
|
||||
assert t.shape == masked_src_symbols.shape
|
||||
|
||||
tot_positions = src_symbols.numel()
|
||||
|
||||
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
|
||||
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
|
||||
src_key_padding_mask.tolist(), tgt_weights.tolist())
|
||||
assert len(sentences) == len(masked_src_symbols)
|
||||
|
||||
tot_masked_positions = 0
|
||||
tot_padded_positions = 0
|
||||
tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos
|
||||
tot_randomized_positions = 0
|
||||
num_masked_subseqs = 0
|
||||
tot_symbols = 0 # original symbols in sentences, no bos/eos
|
||||
|
||||
assert unmasked_weight > 0.001 # or this test code won't work..
|
||||
|
||||
for i in range(len(sentences)):
|
||||
reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i]))
|
||||
if sentences[i] != reconstructed_sent:
|
||||
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
|
||||
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
|
||||
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
|
||||
|
||||
assert src[0] == masked_src[0] == bos_sym
|
||||
for j in range(len(masked_src)):
|
||||
assert masked_src[j] == blank_sym or masked_src[j] == src[j]
|
||||
|
||||
if src[j] not in [bos_sym, eos_sym, blank_sym]:
|
||||
tot_symbols += 1
|
||||
|
||||
if j > 0:
|
||||
assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym)
|
||||
if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding..
|
||||
assert src[j] == tgt[j - 1] # masked symbols are not randomized.
|
||||
assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding
|
||||
if weights[j - 1] == 1.0: # Not final blank padding...
|
||||
if tgt[j - 1] == blank_sym:
|
||||
tot_padded_positions += 1
|
||||
else:
|
||||
tot_masked_positions += 1
|
||||
if masked_src[j + 1] != blank_sym:
|
||||
num_masked_subseqs += 1
|
||||
else:
|
||||
assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001
|
||||
if abs(weights[j - 1]-unmasked_weight) < 0.001:
|
||||
tot_unmasked_positions += 1
|
||||
if tgt[j - 1] != src[j]:
|
||||
tot_randomized_positions += 1
|
||||
|
||||
if src_mask[j]: # if masked..
|
||||
assert src[j] == blank_sym
|
||||
|
||||
assert tot_symbols == sum(len(x) for x in sentences)
|
||||
|
||||
assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences)
|
||||
|
||||
print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}")
|
||||
print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)")
|
||||
|
||||
print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}")
|
||||
print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}")
|
||||
print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}")
|
||||
|
||||
|
||||
|
||||
# This shows some useful code about the BPE encoding.
|
||||
# import sentencepiece as spm
|
||||
# sp = spm.SentencePieceProcessor()
|
||||
# sp.load(bpe_model_fn) # bpe.model
|
||||
# sp.GetPieceSize(..)
|
||||
# sp.Decode(...)
|
||||
# sp.Encode(...)
|
||||
|
||||
|
||||
# import dataset
|
||||
# import torch
|
||||
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||
|
||||
|
||||
# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x)))
|
||||
# x = iter(train_dl)
|
||||
# str(next(x))
|
||||
# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]'
|
||||
#
|
||||
# or:
|
||||
# import k2
|
||||
# k2.ragged.to_list(next(x))
|
||||
# [shows something similar].
|
||||
#
|
||||
# You'd really do something like:
|
||||
# for epoch in range(max_epochs):
|
||||
# for minibatch in train_dl:
|
||||
|
||||
|
||||
# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44].
|
||||
#
|
||||
# First: we randomly choose one or more starting positins for a masked segment.
|
||||
# Each sentence must have at least one masked segment (or there is no contribution to the loss function).
|
||||
# We choose to have:
|
||||
# num_masked_segments = max(1, len(sent) // 15)
|
||||
#
|
||||
# The length of the masked segment (this is the target for prediction), we set to the geometric
|
||||
# distribution with the probability of success set to 3:
|
||||
#
|
||||
# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333
|
||||
# Example of sampling:
|
||||
# g.sample(sample_shape=torch.Size([10]))
|
||||
#
|
||||
# We now we randomly compute the location of the masked segments (length computed above) as follows:
|
||||
# First, the masked segments must be separated by at least one non-masked word (else they would be
|
||||
# a single segment). So for n masked segments, there are n-1 words required for minimal separation.
|
||||
# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire
|
||||
# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1
|
||||
# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments,
|
||||
# and at the end). This is the multinomial distribution, but we can more easily compute this
|
||||
# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial().
|
||||
#
|
||||
|
||||
# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions;
|
||||
# this is done so the model never knows the exact length of the masked region. We can just use the
|
||||
# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3
|
||||
# (expected padding length is 3).
|
||||
#
|
||||
# At this point we know where the masked regions are and how much padding they have. We can format
|
||||
# the result as three lists, of the same length:
|
||||
#
|
||||
# sent: contains the words in the sentence with, in masked
|
||||
# positions, the original (target) words, then with
|
||||
# blank in the blank-padding after masked positions.
|
||||
#
|
||||
# sent_augmented: `sent` with, at a small defined percentage of positions
|
||||
# that were *not* masked, the real token replaced with a
|
||||
# token randomly chosen from the tokens in the minibatch.
|
||||
# (like BERT, we use this type of augmentation, so the model
|
||||
# has to predict the original token).
|
||||
#
|
||||
# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except
|
||||
# with masked positions and the blank padding after the masked regions
|
||||
# both replaced with blank.
|
||||
#
|
||||
#
|
||||
#
|
||||
# The way these will be processed is as follows:
|
||||
#
|
||||
# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones.
|
||||
# sent_in = [bos] + sent_augmented + [eos]
|
||||
#
|
||||
# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although
|
||||
# # we only really care about this in masked regions.
|
||||
# # The extra eos is so that the length is the same as
|
||||
# # masked_sent_in and sent_in.
|
||||
#
|
||||
# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine,
|
||||
# # this is a choice; we can perhaps
|
||||
# # report these 2 parts of the loss
|
||||
# # separately though.
|
||||
# # <-- can also set the last element
|
||||
# # of out_scale to a smaller number, since
|
||||
# # it's a repeated eos.
|
||||
#
|
||||
#
|
||||
# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum
|
||||
# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly
|
||||
# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out
|
||||
# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols.
|
||||
#
|
||||
# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale)
|
||||
# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale)
|
||||
# with: (<blk>,<blk>,<eos>, 0) up to the maximum length of any sentence in the minibatch <- or could use
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# # i.e. ones where masked_sent is blank and zeros elsewhere;
|
||||
# # this pertains to positions in `sent_out`.
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5)
|
||||
|
||||
|
||||
|
||||
|
||||
class LmBatchSampler(torch.utils.data.Sampler):
|
||||
"""
|
||||
A sampler that returns a batch of integer indexes as a list, intended for use
|
||||
with class LmDataset. The sentences returned in each batch will all be about
|
||||
the same size, and the batch size is specified as a number of words (we also
|
||||
provide an option that allows you to limit the max memory consumed by transformers)
|
||||
|
||||
Has support for distributed operation.
|
||||
"""
|
||||
def __init__(self, dataset: LmDataset,
|
||||
symbols_per_batch: int,
|
||||
length_ceil: float = 200.0,
|
||||
length_floor: float = 4.0,
|
||||
world_size: Optional[int] = None,
|
||||
rank: int = None,
|
||||
seed: int = 0,
|
||||
delay_init: bool = False):
|
||||
"""
|
||||
Constructor documentation:
|
||||
dataset: the LmDataset object that we are sampling from. This
|
||||
class does not retain a reference to the LmDataset.
|
||||
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
||||
length_floor: When the sentence length gets less than about this much,
|
||||
the batch size stops increasing inversely with sentence
|
||||
length. Prevent OOM on batches with short sentences.
|
||||
length_ceil: After the sentence length gets more than about
|
||||
this much, the batch size will start decreasing
|
||||
as 1/(sentence-length^2). This is a mechanism to
|
||||
avoid excessive memory consumption in transformers, when
|
||||
sentence length gets long.
|
||||
world_size: The world size for distributed operation; if None,
|
||||
will be worked out from torch.distributed.
|
||||
rank: The rank of this sampler/process for distributed operation; if None,
|
||||
will be worked out from torch.distributed.
|
||||
seed: The random seed
|
||||
delay_init: If true, will omit calling self.set_epoch(0) at the
|
||||
end of the __init__ function. In this case the caller
|
||||
must call set_epoch(0). [Setting this option is necessary
|
||||
to work with data-loader worker processes plus DDP, since
|
||||
set_epoch() will use ddp, which I believe is a no-no prior
|
||||
to initializing data-loaders.]
|
||||
"""
|
||||
self.seed = seed
|
||||
self.symbols_per_batch = symbols_per_batch
|
||||
self.length_floor = length_floor
|
||||
self.quadratic_constant = 1.0 / length_ceil
|
||||
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
||||
|
||||
# a configuration constant we don't expose.
|
||||
self.multiplicative_random_length = 0.05
|
||||
|
||||
# "indexes" is the subset of indexes into LmDataset that this
|
||||
# sampler is reponsible for (all of them, in the non-distributed case).
|
||||
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
||||
|
||||
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
|
||||
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
||||
|
||||
# the sentences this sampler is responsible for, as sequences of words.
|
||||
# It's a ragged tensor of int32
|
||||
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
|
||||
|
||||
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
|
||||
# with their respective lengths, in BPE pieces.
|
||||
sentence_lengths = k2.ragged.index(word_lengths, sentences)
|
||||
del sentences # save memory
|
||||
assert isinstance(sentence_lengths, k2.RaggedTensor)
|
||||
|
||||
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
||||
# support int32.)
|
||||
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
|
||||
|
||||
# Convert into a simple tensor of float by adding lengths of words.
|
||||
sentence_lengths = sentence_lengths.sum()
|
||||
|
||||
assert isinstance(sentence_lengths, torch.Tensor)
|
||||
assert sentence_lengths.dtype == torch.float32
|
||||
|
||||
# self.sentence_lengths is a Tensor with dtype=torch.float32. It
|
||||
# contains the lengths, in BPE tokens, of the sentences that this
|
||||
# sampler is responsible for, whose real indexes are in
|
||||
# `data_indexes` above (this is not stored, as we know the formula).
|
||||
self.sentence_lengths = sentence_lengths
|
||||
|
||||
if not delay_init:
|
||||
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||
|
||||
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
|
||||
# Calling this on all copies of a DDP setup will sync the sizes so that
|
||||
# all copies have the exact same number of batches. I think
|
||||
# this needs to be called with the GPU device, not sure if it would
|
||||
# work otherwise.
|
||||
if self.world_size > 1:
|
||||
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
|
||||
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
|
||||
min_size = min_size.to('cpu').item()
|
||||
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
|
||||
self.batch_indices = self.batch_indices[0:min_size]
|
||||
|
||||
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
||||
if world_size is not None:
|
||||
assert world_size >= 1
|
||||
if rank is not None:
|
||||
assert rank >= 0
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
self.world_size = 1 if world_size is None else world_size
|
||||
self.rank = 0 if rank is None else rank
|
||||
return
|
||||
self.world_size = dist.get_world_size() if world_size is None else world_size
|
||||
self.rank = dist.get_rank() if rank is None else rank
|
||||
assert self.rank < self.world_size
|
||||
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
"""
|
||||
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
||||
to re-shuffle the data. If this is not done, this sampler will give you the same batches
|
||||
each time it is called.
|
||||
"""
|
||||
g = torch.manual_seed(self.rank + self.seed + epoch)
|
||||
|
||||
sentence_lengths = (self.sentence_lengths *
|
||||
(1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length))
|
||||
|
||||
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
||||
# when the sentences are long.
|
||||
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
|
||||
|
||||
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
||||
|
||||
# map to the original indexes into the dataset (the original sentence
|
||||
# indexes), see torch.arange expression in the constructor. save as
|
||||
# int32 just to save a little memory. self.indices are indexes into the
|
||||
# LmDataset, just including the subset of indices that this sampler is
|
||||
# responsible for (in terms of rank and world_size), and sorted by
|
||||
# length with a small amount of randomization specific to the epoch.
|
||||
self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32)
|
||||
|
||||
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
||||
# saying which batch each element of values/indices belongs to.
|
||||
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||
|
||||
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
||||
batch_boundaries.add_(1)
|
||||
self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0)
|
||||
|
||||
num_batches = self.batch_boundaries.numel() - 1
|
||||
|
||||
# self.batch_indices is a permutation of [0, 1, ... num_batches -
|
||||
# 1]; it determines the order in which we access the batches. It's
|
||||
# necessary to randomize the order of these, to avoid returning batches
|
||||
# from shortest to longest sentences.
|
||||
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
||||
self._sync_sizes()
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_indices)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterator that yields lists of indices (i.e., integer indices into the LmDataset)
|
||||
"""
|
||||
for batch_idx in self.batch_indices:
|
||||
batch_start = self.batch_boundaries[batch_idx].item()
|
||||
batch_end = self.batch_boundaries[batch_idx + 1].item()
|
||||
yield self.indices[batch_start:batch_end].tolist()
|
||||
|
||||
|
||||
class CollateFn:
|
||||
def __init__(self, **kwargs):
|
||||
self.extra_args = kwargs
|
||||
|
||||
def __call__(self, sentences: List[List[int]]):
|
||||
return collate_fn(sentences, **self.extra_args)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||
# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
||||
# a = iter(sampler)
|
||||
# print(str(next(a)))
|
||||
|
||||
# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
||||
# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
||||
# x = iter(train_dl)
|
||||
# print(str(next(x)))
|
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
File diff suppressed because it is too large
Load Diff
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
# run with:
|
||||
# python3 -m pytest test_conformer.py
|
||||
|
||||
import torch
|
||||
import dataset # from .
|
||||
from conformer import (
|
||||
RelPosTransformerDecoder,
|
||||
RelPosTransformerDecoderLayer,
|
||||
MaskedLmConformer,
|
||||
MaskedLmConformerEncoder,
|
||||
MaskedLmConformerEncoderLayer,
|
||||
RelPositionMultiheadAttention,
|
||||
RelPositionalEncoding,
|
||||
generate_square_subsequent_mask,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_rel_position_multihead_attention():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder_layer():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_masked_lm_conformer_encoder():
|
||||
# Also tests RelPositionalEncoding
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||
norm = torch.nn.LayerNorm(embed_dim)
|
||||
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
|
||||
norm=norm)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_transformer_decoder_layer_rel_pos():
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
attn_mask = generate_square_subsequent_mask(T)
|
||||
memory = torch.randn(T, N, C)
|
||||
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
|
||||
def test_transformer_decoder_rel_pos():
|
||||
embed_dim = 256
|
||||
num_heads = 4
|
||||
T = 25
|
||||
N = 4
|
||||
C = 256
|
||||
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||
decoder_norm = torch.nn.LayerNorm(embed_dim)
|
||||
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||
|
||||
x = torch.randn(N, T, C)
|
||||
x, pos_emb = pos_emb_module(x)
|
||||
x = x.transpose(0, 1) # (T, N, C)
|
||||
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||
attn_mask = generate_square_subsequent_mask(T)
|
||||
memory = torch.randn(T, N, C)
|
||||
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||
|
||||
|
||||
def test_masked_lm_conformer():
|
||||
|
||||
num_classes = 87
|
||||
d_model = 256
|
||||
|
||||
model = MaskedLmConformer(num_classes,d_model)
|
||||
|
||||
|
||||
N = 31
|
||||
|
||||
|
||||
(masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask,
|
||||
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
|
||||
blank_sym=0)
|
||||
|
||||
# test forward() of MaskedLmConformer
|
||||
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
|
||||
src_key_padding_mask)
|
||||
print("nll = ", nll)
|
||||
loss = (nll * tgt_weights).sum()
|
||||
print("loss = ", loss)
|
||||
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
s = 5
|
||||
mask = generate_square_subsequent_mask(s, torch.device('cpu'))
|
||||
inf = float("inf")
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
@ -0,0 +1,32 @@
|
||||
import k2
|
||||
import torch
|
||||
import _k2
|
||||
import dataset
|
||||
import os
|
||||
from torch import multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
#mp.set_start_method('spawn')
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12344"
|
||||
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||
print("len(sampler) = ", len(sampler))
|
||||
|
||||
a = iter(sampler)
|
||||
print(str(next(a)))
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
||||
collate_fn=local_collate_fn,
|
||||
num_workers=2)
|
||||
x = iter(train_dl)
|
||||
print(str(next(x)))
|
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
@ -0,0 +1,39 @@
|
||||
import k2
|
||||
import torch
|
||||
import _k2
|
||||
import dataset
|
||||
from dataset import LmDataset
|
||||
import os
|
||||
from torch import multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||
|
||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
mp.set_start_method('spawn')
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12344"
|
||||
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
words = k2.RaggedInt('[[0][1 2]]')
|
||||
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||
|
||||
train = LmDataset(sentences, words)
|
||||
|
||||
|
||||
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
|
||||
|
||||
a = iter(sampler)
|
||||
print(str(next(a)))
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
|
||||
collate_fn=local_collate_fn,
|
||||
num_workers=0)
|
||||
x = iter(train_dl)
|
||||
print(str(next(x)))
|
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
@ -0,0 +1,623 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey)
|
||||
#
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import dataset # from .
|
||||
import madam # from .
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from conformer import MaskedLmConformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from madam import Gloam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
is saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lr: It specifies the initial learning rate
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||
and continue training from that checkpoint.
|
||||
|
||||
- num_epochs: Number of epochs to train.
|
||||
|
||||
- num_valid_batches: Number of batches of validation data to use each
|
||||
time we compute validation loss
|
||||
|
||||
- symbols_per_batch: Number of symbols in each batch (sampler will
|
||||
choose the number of sentences to satisfy this contraint).
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
||||
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||
# as well as the exponential part.
|
||||
# exp_6, we change the decay from 0.85 to 0.9.
|
||||
"exp_dir": Path("conformer_lm/exp_6"),
|
||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||
"num_tokens": 5000,
|
||||
"blank_sym": 0,
|
||||
"bos_sym": 1,
|
||||
"eos_sym": 1,
|
||||
"start_epoch": 2,
|
||||
"num_epochs": 20,
|
||||
"num_valid_batches": 200,
|
||||
"symbols_per_batch": 5000,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
"beam_size": 10,
|
||||
"accum_grad": 1,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
"max_lrate": 5.0e-04
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
model: nn.Module,
|
||||
batch: Tuple,
|
||||
is_training: bool,
|
||||
):
|
||||
|
||||
"""
|
||||
Compute training or validation loss given the model and its inputs
|
||||
(this corresponds to log-prob of the targets, with weighting
|
||||
of 1.0 for masked subsequences
|
||||
(including padding blanks), and something smaller, e.g. 0.25,
|
||||
for non-masked positions (this is not totally trivial due to
|
||||
a small amount of randomization of symbols).
|
||||
|
||||
This loss is not normalized; you can divide by batch[4].sum()
|
||||
to get a normalized loss (i.e. divide by soft-count).
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of MaskedLmConformer in our case.
|
||||
batch:
|
||||
A batch of data, actually a tuple of 5 tensors (on the device), as returned
|
||||
by collate_fn in ./dataset.py.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
|
||||
Returns:
|
||||
Returns the loss as a scalar tensor.
|
||||
"""
|
||||
(masked_src_symbols, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask, tgt_weights) = batch
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
|
||||
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask)
|
||||
loss = (tgt_nll * tgt_weights).sum()
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
device: torch.device,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
if batch_idx == params.num_valid_batches:
|
||||
break
|
||||
batch = tuple(x.to(device) for x in batch)
|
||||
|
||||
|
||||
loss = compute_loss(model, batch, is_training=False)
|
||||
num_frames = batch[4].sum()
|
||||
|
||||
assert loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
num_frames_cpu = num_frames.cpu().item()
|
||||
|
||||
tot_loss += loss_cpu
|
||||
tot_frames += num_frames_cpu
|
||||
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
(tot_loss, tot_frames) = s.cpu().tolist()
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
device: torch.device,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
device:
|
||||
The device to use for training (model must be on this device)
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train() # training mode
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch = tuple(x.to(device) for x in batch)
|
||||
|
||||
try:
|
||||
loss = compute_loss(
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||
# gradient scale so this should not matter.
|
||||
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
except RuntimeError as e:
|
||||
print(f"Error on batch of shape (N,T) = {batch[0].shape}")
|
||||
raise e
|
||||
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
num_frames_cpu = batch[4].sum().cpu().item()
|
||||
|
||||
tot_loss += loss_cpu
|
||||
tot_frames += num_frames_cpu
|
||||
|
||||
params.tot_frames += num_frames_cpu
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch shape: {tuple(batch[0].shape)}")
|
||||
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / num_frames_cpu,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
device=device,
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
num_tokens = params.num_tokens
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
|
||||
logging.info("About to create model")
|
||||
model = MaskedLmConformer(
|
||||
num_classes=params.num_tokens,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
# Caution: don't forget to do optimizer.set_epoch() with Gloam!
|
||||
# Don't remove this warning!
|
||||
optimizer = Gloam(
|
||||
model.parameters(),
|
||||
max_lrate=params.max_lrate,
|
||||
first_decrease_epoch=1,
|
||||
decay_per_epoch=0.9
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
|
||||
|
||||
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
|
||||
eos_sym=params.eos_sym,
|
||||
blank_sym=params.blank_sym,
|
||||
mask_proportion=0.15,
|
||||
padding_proportion=0.15,
|
||||
randomize_proportion=0.05,
|
||||
inv_mask_length=0.25,
|
||||
unmasked_weight=0.25)
|
||||
|
||||
train_sampler = dataset.LmBatchSampler(train,
|
||||
symbols_per_batch=params.symbols_per_batch,
|
||||
world_size=world_size, rank=rank)
|
||||
test_sampler = dataset.LmBatchSampler(test,
|
||||
symbols_per_batch=params.symbols_per_batch,
|
||||
world_size=world_size, rank=rank)
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(train,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
valid_dl = torch.utils.data.DataLoader(test,
|
||||
batch_sampler=test_sampler,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
|
||||
# optimizer.
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
device=device,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lm.rescore import (
|
||||
compute_alignment,
|
||||
make_hyp_to_ref_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
@ -904,3 +909,198 @@ def rescore_with_attention_decoder(
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_conformer_lm(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
model: torch.nn.Module,
|
||||
masked_lm_model: torch.nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: Optional[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
nbest_scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
masked_lm_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(T, N, C)`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape `(N, T)`.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
masked_lm_scale:
|
||||
Optional. It specifies the scale for conformer_lm scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||
|
||||
if memory_key_padding_mask is not None:
|
||||
# The shape of memory_key_padding_mask is (N, T), so we
|
||||
# use axis=0 here.
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_utt_map
|
||||
)
|
||||
else:
|
||||
expanded_memory_key_padding_mask = None
|
||||
|
||||
# remove axis corresponding to states.
|
||||
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
|
||||
alignment = compute_alignment(tokens, nbest.shape)
|
||||
(
|
||||
masked_src_symbols,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
tgt_weights,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
unmasked_weight=0.0,
|
||||
)
|
||||
|
||||
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||
src_symbols = src_symbols.to(torch.int64)
|
||||
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||
|
||||
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||
masked_src_symbols, src_key_padding_mask
|
||||
)
|
||||
|
||||
tgt_nll = masked_lm_model.decoder_nll(
|
||||
masked_lm_memory,
|
||||
masked_lm_pos_emb,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
)
|
||||
|
||||
# nll means negative log-likelihood
|
||||
# ll means log-likelihood
|
||||
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||
|
||||
# Note: log-likelihood for those pairs that have identical src/tgt are 0
|
||||
# since their tgt_weights is 0
|
||||
|
||||
# TODO(fangjun): Add documentation about why we do the following
|
||||
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
|
||||
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=None,
|
||||
row_ids=tgt_ll_shape_row_ids,
|
||||
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||
)
|
||||
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||
|
||||
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||
masked_lm_scores = ragged_tgt_ll.max()
|
||||
|
||||
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||
token_ids = tokens.tolist()
|
||||
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
if masked_lm_scale is None:
|
||||
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
masked_lm_scale_list = [masked_lm_scale]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
for m_scale in masked_lm_scale_list:
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
+ m_scale * masked_lm_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
363
icefall/lm/rescore.py
Normal file
363
icefall/lm/rescore.py
Normal file
@ -0,0 +1,363 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file contains rescoring code for NN LMs, e.g., conformer LM.
|
||||
|
||||
Here are the ideas about preparing the inputs for the conformer LM model
|
||||
from an Nbest object.
|
||||
|
||||
Given an Nbest object `nbest`, we have:
|
||||
- nbest.fsa
|
||||
- nbest.shape, whose axes are [utt][path]
|
||||
|
||||
We can get `tokens` from nbest.fsa. The resulting `tokens` will have
|
||||
2 axes [path][token]. Note, we should remove 0s from `tokens`.
|
||||
|
||||
We can generate the following inputs for the conformer LM model from `tokens`:
|
||||
- masked_src
|
||||
- src
|
||||
- tgt
|
||||
by using `k2.levenshtein_alignment`.
|
||||
|
||||
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
|
||||
def make_key_padding_mask(lengths: torch.Tensor):
|
||||
"""
|
||||
TODO: add documentation
|
||||
|
||||
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
|
||||
tensor([[False, False, False, True],
|
||||
[False, True, True, True],
|
||||
[False, False, False, False]])
|
||||
"""
|
||||
assert lengths.dim() == 1
|
||||
|
||||
bs = lengths.numel()
|
||||
max_len = lengths.max().item()
|
||||
device = lengths.device
|
||||
seq_range = torch.arange(0, max_len, device=device)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
|
||||
|
||||
def concat(
|
||||
ragged: k2.RaggedTensor, value: int, direction: str
|
||||
) -> k2.RaggedTensor:
|
||||
"""Prepend a value to the beginning of each sublist or append a value.
|
||||
to the end of each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
value:
|
||||
The value to prepend or append.
|
||||
direction:
|
||||
It can be either "left" or "right". If it is "left", we
|
||||
prepend the value to the beginning of each sublist;
|
||||
if it is "right", we append the value to the end of each
|
||||
sublist.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, whose sublists either start with
|
||||
or end with the given value.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> concat(a, value=0, direction="left")
|
||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||
>>> concat(a, value=0, direction="right")
|
||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||
|
||||
"""
|
||||
dtype = ragged.dtype
|
||||
device = ragged.device
|
||||
|
||||
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
||||
pad_values = torch.full(
|
||||
size=(ragged.tot_size(0), 1),
|
||||
fill_value=value,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
pad = k2.RaggedTensor(pad_values)
|
||||
|
||||
if direction == "left":
|
||||
ans = k2.ragged.cat([pad, ragged], axis=1)
|
||||
elif direction == "right":
|
||||
ans = k2.ragged.cat([ragged, pad], axis=1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported direction: {direction}. " \
|
||||
"Expect either "left" or "right"'
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def add_bos(ragged: k2.RaggedTensor, bos_id: int) -> k2.RaggedTensor:
|
||||
"""Add BOS to each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
bos_id:
|
||||
The ID of the BOS symbol.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, where each sublist starts with BOS.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> add_bos(a, bos_id=0)
|
||||
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||
|
||||
"""
|
||||
return concat(ragged, bos_id, direction="left")
|
||||
|
||||
|
||||
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
"""Add EOS to each sublist.
|
||||
|
||||
Args:
|
||||
ragged:
|
||||
A ragged tensor with two axes.
|
||||
bos_id:
|
||||
The ID of the EOS symbol.
|
||||
|
||||
Returns:
|
||||
Return a new ragged tensor, where each sublist ends with EOS.
|
||||
|
||||
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||
>>> a
|
||||
[ [ 1 3 ] [ 5 ] ]
|
||||
>>> add_eos(a, eos_id=0)
|
||||
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||
|
||||
"""
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
|
||||
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
>>> make_hyp_to_ref_map(row_splits)
|
||||
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32)
|
||||
|
||||
"""
|
||||
device = row_splits.device
|
||||
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||
offsets = row_splits[:-1]
|
||||
|
||||
map_tensor_list = []
|
||||
for size, offset in zip(sizes, offsets):
|
||||
# Explanation of the following operations
|
||||
# assume size is 3, offset is 2
|
||||
# torch.arange() + offset is [2, 3, 4]
|
||||
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
||||
# t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]]
|
||||
# reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4]
|
||||
map_tensor = (
|
||||
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||
.expand(size, size)
|
||||
.t()
|
||||
.reshape(-1)
|
||||
)
|
||||
map_tensor_list.append(map_tensor)
|
||||
|
||||
return torch.cat(map_tensor_list)
|
||||
|
||||
|
||||
def make_repeat_map(row_splits: torch.Tensor):
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
|
||||
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
>>> make_repeat_map(row_splits)
|
||||
tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 3, 4], dtype=torch.int32)
|
||||
|
||||
"""
|
||||
device = row_splits.device
|
||||
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||
offsets = row_splits[:-1]
|
||||
|
||||
map_tensor_list = []
|
||||
for size, offset in zip(sizes, offsets):
|
||||
# Explanation of the following operations
|
||||
# assume size is 3, offset is 2
|
||||
# torch.arange() + offset is [2, 3, 4]
|
||||
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
||||
# reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4]
|
||||
map_tensor = (
|
||||
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||
.expand(size, size)
|
||||
.reshape(-1)
|
||||
)
|
||||
map_tensor_list.append(map_tensor)
|
||||
|
||||
return torch.cat(map_tensor_list)
|
||||
|
||||
|
||||
def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
||||
"""Repeat the number of paths of an utterance to the number that
|
||||
equals to the number of paths in the utterance.
|
||||
|
||||
For instance, if an utterance contains 3 paths: [path1 path2 path3],
|
||||
after repeating, this utterance will contain 9 paths:
|
||||
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
|
||||
|
||||
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
|
||||
>>> tokens
|
||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
||||
>>> make_repeat(tokens)
|
||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
|
||||
|
||||
TODO: Add documentation.
|
||||
|
||||
"""
|
||||
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
|
||||
if True:
|
||||
indexes = make_repeat_map(tokens.shape.row_splits(1))
|
||||
return tokens.index(axis=1, indexes=indexes)[0]
|
||||
else:
|
||||
# This branch produces the same result as the above branch.
|
||||
# It's more readable. Will remove it later.
|
||||
repeated = []
|
||||
for p in tokens.tolist():
|
||||
repeated.append(p * len(p))
|
||||
return k2.RaggedTensor(repeated).to(tokens.device)
|
||||
|
||||
|
||||
def compute_alignment(
|
||||
tokens: k2.RaggedTensor,
|
||||
shape: k2.RaggedShape,
|
||||
) -> k2.Fsa:
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
A ragged tensor with two axes: [path][token].
|
||||
shape:
|
||||
A ragged shape with two axes: [utt][path]
|
||||
"""
|
||||
assert tokens.tot_size(0) == shape.tot_size(1)
|
||||
device = tokens.device
|
||||
utt_path_shape = shape.compose(tokens.shape)
|
||||
utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values)
|
||||
utt_path_token_repeated = make_repeat(utt_path_token)
|
||||
path_token_repeated = utt_path_token_repeated.remove_axis(0)
|
||||
|
||||
refs = k2.levenshtein_graph(tokens, device=device)
|
||||
hyps = k2.levenshtein_graph(path_token_repeated, device=device)
|
||||
|
||||
hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1))
|
||||
alignment = k2.levenshtein_alignment(
|
||||
refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map
|
||||
)
|
||||
return alignment
|
||||
|
||||
|
||||
def prepare_conformer_lm_inputs(
|
||||
alignment: k2.Fsa,
|
||||
bos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
unmasked_weight: float = 0.25,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
"""
|
||||
TODO: add documentation.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
It is computed by :func:`compute_alignment`
|
||||
"""
|
||||
device = alignment.device
|
||||
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||
# we remove axis 1, i.e., state, here
|
||||
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||
|
||||
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
|
||||
masked_src = masked_src.remove_values_eq(-1)
|
||||
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
|
||||
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
|
||||
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
|
||||
mode="constant", padding_value=blank_id
|
||||
)
|
||||
|
||||
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
|
||||
src = src.remove_values_eq(-1)
|
||||
bos_src = add_bos(src, bos_id=bos_id)
|
||||
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
||||
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
|
||||
# TODO: Do we need to remove 0s from tgt ?
|
||||
tgt = tgt.remove_values_eq(-1)
|
||||
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
||||
|
||||
# add a blank here since tgt_eos does not start with bos
|
||||
# assume blank id is 0
|
||||
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
|
||||
|
||||
row_splits = tgt_eos.shape.row_splits(1)
|
||||
lengths = row_splits[1:] - row_splits[:-1]
|
||||
src_key_padding_mask = make_key_padding_mask(lengths)
|
||||
|
||||
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
weight = torch.full(
|
||||
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||
fill_value=1,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# find unmasked positions
|
||||
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
|
||||
weight[unmasked_positions] = unmasked_weight
|
||||
|
||||
# set weights for paddings
|
||||
weight[src_key_padding_mask[:, 1:]] = 0
|
||||
zeros = torch.zeros(weight.size(0), 1).to(weight)
|
||||
|
||||
weight = torch.cat((weight, zeros), dim=1)
|
||||
|
||||
# all other positions are assumed to be masked and
|
||||
# have the default weight 1
|
||||
|
||||
return (
|
||||
bos_masked_src_eos_pad,
|
||||
bos_src_eos_pad,
|
||||
tgt_eos_pad,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
)
|
145
test/lm/test_rescore.py
Executable file
145
test/lm/test_rescore.py
Executable file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lm.rescore import (
|
||||
add_bos,
|
||||
add_eos,
|
||||
compute_alignment,
|
||||
make_hyp_to_ref_map,
|
||||
make_repeat,
|
||||
make_repeat_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
|
||||
|
||||
def test_add_bos():
|
||||
bos_id = 100
|
||||
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
|
||||
bos_ragged = add_bos(ragged, bos_id)
|
||||
expected = k2.RaggedTensor([[bos_id, 1, 2], [bos_id, 3], [bos_id, 0]])
|
||||
assert str(bos_ragged) == str(expected)
|
||||
|
||||
|
||||
def test_add_eos():
|
||||
eos_id = 30
|
||||
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
|
||||
ragged_eos = add_eos(ragged, eos_id)
|
||||
expected = k2.RaggedTensor(
|
||||
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||
)
|
||||
assert str(ragged_eos) == str(expected)
|
||||
|
||||
|
||||
def test_pad():
|
||||
bos_id = 10
|
||||
eos_id = 100
|
||||
ragged = k2.RaggedTensor([[1, 2, 3], [5], [9, 8]])
|
||||
bos_ragged = add_bos(ragged, bos_id)
|
||||
bos_ragged_eos = add_eos(bos_ragged, eos_id)
|
||||
blank_id = -1
|
||||
padded = bos_ragged_eos.pad(mode="constant", padding_value=blank_id)
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[bos_id, 1, 2, 3, eos_id],
|
||||
[bos_id, 5, eos_id, blank_id, blank_id],
|
||||
[bos_id, 9, 8, eos_id, blank_id],
|
||||
]
|
||||
).to(padded)
|
||||
assert torch.all(torch.eq(padded, expected))
|
||||
|
||||
|
||||
def test_make_hyp_to_ref_map():
|
||||
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||
row_splits = a.shape.row_splits(1)
|
||||
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
|
||||
def test_make_repeat_map():
|
||||
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||
row_splits = a.shape.row_splits(1)
|
||||
repeat_map = make_repeat_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
||||
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
|
||||
3, 4, 5, 6]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
|
||||
def test_make_repeat():
|
||||
# fmt: off
|
||||
a = k2.RaggedTensor([
|
||||
[[1, 3, 5], [2, 6]],
|
||||
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
||||
])
|
||||
b = make_repeat(a)
|
||||
expected = k2.RaggedTensor([
|
||||
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]],
|
||||
[[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||
[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
||||
])
|
||||
# fmt: on
|
||||
assert str(b) == str(expected)
|
||||
|
||||
|
||||
def test_compute_alignment():
|
||||
# fmt: off
|
||||
tokens = k2.RaggedTensor([
|
||||
# utt 0
|
||||
[1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
|
||||
# utt 1
|
||||
[2, 3], [2],
|
||||
])
|
||||
# fmt: on
|
||||
shape = k2.RaggedShape("[[x x x] [x x]]")
|
||||
alignment = compute_alignment(tokens, shape)
|
||||
(
|
||||
masked_src,
|
||||
src,
|
||||
tgt,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
|
||||
|
||||
# print("masked src", masked_src)
|
||||
# print("src", src)
|
||||
# print("tgt", tgt)
|
||||
# print("src_key_padding_mask", src_key_padding_mask)
|
||||
# print("weight", weight)
|
||||
|
||||
|
||||
def main():
|
||||
test_add_bos()
|
||||
test_add_eos()
|
||||
test_pad()
|
||||
test_make_repeat_map()
|
||||
test_make_hyp_to_ref_map()
|
||||
test_make_repeat()
|
||||
test_compute_alignment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user