mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Merge cdd539e55c63d21df91bfad9dd9cccdc306842a9 into 04029871b6a54e35d08116917f88eb7d6ead2d02
This commit is contained in:
commit
6faaec22fc
@ -21,8 +21,12 @@ import warnings
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from conformer_ctc.transformer import (
|
||||||
|
Supervisions,
|
||||||
|
Transformer,
|
||||||
|
encoder_padding_mask,
|
||||||
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
|
@ -26,8 +26,9 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
|
from conformer_lm.conformer import MaskedLmConformer
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -37,6 +38,7 @@ from icefall.decode import (
|
|||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_conformer_lm,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
@ -94,7 +96,10 @@ def get_parser():
|
|||||||
is the decoding result.
|
is the decoding result.
|
||||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||||
lattice, the path with the highest score is the decoding result.
|
lattice, the path with the highest score is the decoding result.
|
||||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
- (6) conformer-lm. In addition to attention-decoder rescoring, it
|
||||||
|
also uses conformer lm for rescoring. See the model in the
|
||||||
|
directory ./conformer_lm
|
||||||
|
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
""",
|
""",
|
||||||
@ -106,7 +111,8 @@ def get_parser():
|
|||||||
default=100,
|
default=100,
|
||||||
help="""Number of paths for n-best based decoding method.
|
help="""Number of paths for n-best based decoding method.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, conformer-lm,
|
||||||
|
and nbest-oracle
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,8 +123,8 @@ def get_parser():
|
|||||||
help="""The scale to be applied to `lattice.scores`.
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
It's needed if you use any kinds of n-best based rescoring.
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
Used only when "method" is one of the following values:
|
Used only when "method" is one of the following values:
|
||||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
nbest, nbest-rescoring, attention-decoder, conformer_lm,
|
||||||
A smaller value results in more unique paths.
|
and nbest-oracle. A smaller value results in more unique paths.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,6 +153,35 @@ def get_parser():
|
|||||||
help="The lang dir",
|
help="The lang dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_lm/exp",
|
||||||
|
help="""The conformer lm exp dir.
|
||||||
|
Used only when method is conformer_lm.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-epoch",
|
||||||
|
type=int,
|
||||||
|
default=19,
|
||||||
|
help="""Used only when method is conformer_lm.
|
||||||
|
It specifies the checkpoint to use for the conformer
|
||||||
|
lm model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-lm-avg",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Used only when method is conformer_lm.
|
||||||
|
It specifies number of checkpoints to average for
|
||||||
|
the conformer lm model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -177,6 +212,7 @@ def get_params() -> AttributeDict:
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
masked_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -334,6 +370,7 @@ def decode_one_batch(
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
@ -354,7 +391,7 @@ def decode_one_batch(
|
|||||||
G_with_epsilon_loops=G,
|
G_with_epsilon_loops=G,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
)
|
)
|
||||||
elif params.method == "attention-decoder":
|
elif params.method in ("attention-decoder", "conformer-lm"):
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
@ -364,16 +401,32 @@ def decode_one_batch(
|
|||||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||||
# `rescore_with_attention_decoder`
|
# `rescore_with_attention_decoder`
|
||||||
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
if params.method == "attention-decoder":
|
||||||
lattice=rescored_lattice,
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
num_paths=params.num_paths,
|
lattice=rescored_lattice,
|
||||||
model=model,
|
num_paths=params.num_paths,
|
||||||
memory=memory,
|
model=model,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory=memory,
|
||||||
sos_id=sos_id,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
eos_id=eos_id,
|
sos_id=sos_id,
|
||||||
nbest_scale=params.nbest_scale,
|
eos_id=eos_id,
|
||||||
)
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# It uses:
|
||||||
|
# attention_decoder + conformer_lm
|
||||||
|
best_path_dict = rescore_with_conformer_lm(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=0, # TODO(fangjun): pass it as an argument
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
@ -393,6 +446,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
masked_lm_model: Optional[nn.Module],
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: Optional[k2.Fsa],
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
@ -449,6 +503,7 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
@ -584,6 +639,7 @@ def main():
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
):
|
):
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
@ -607,7 +663,11 @@ def main():
|
|||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in [
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"conformer-lm",
|
||||||
|
]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
@ -655,6 +715,38 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
if params.method == "conformer-lm":
|
||||||
|
logging.info("Loading conformer lm model")
|
||||||
|
# Note: If the parameters does not match
|
||||||
|
# the one used to save the checkpoint, it will
|
||||||
|
# throw while calling `load_state_dict`.
|
||||||
|
masked_lm_model = MaskedLmConformer(
|
||||||
|
num_classes=num_classes,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
if params.conformer_lm_avg == 1:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.conformer_lm_exp_dir}/epoch-{params.conformer_lm_epoch}.pt", # noqa
|
||||||
|
masked_lm_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
start = params.conformer_lm_epoch - params.conformer_lm_avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.conformer_lm_epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(
|
||||||
|
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
masked_lm_model.to(device)
|
||||||
|
masked_lm_model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masked_lm_model = None
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
# CAUTION: `test_sets` is for displaying only.
|
||||||
# If you want to skip test-clean, you have to skip
|
# If you want to skip test-clean, you have to skip
|
||||||
@ -668,6 +760,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
masked_lm_model=masked_lm_model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
|
@ -24,7 +24,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
@ -27,7 +27,7 @@ import kaldifeat
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer_ctc.conformer import Conformer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
|
@ -17,8 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from conformer_ctc.transformer import (
|
||||||
from transformer import (
|
|
||||||
Transformer,
|
Transformer,
|
||||||
add_eos,
|
add_eos,
|
||||||
add_sos,
|
add_sos,
|
||||||
@ -26,6 +25,7 @@ from transformer import (
|
|||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
generate_square_subsequent_mask,
|
generate_square_subsequent_mask,
|
||||||
)
|
)
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
def test_encoder_padding_mask():
|
def test_encoder_padding_mask():
|
||||||
|
@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from conformer_ctc.subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
|
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../tdnn_lstm_ctc/asr_datamodule.py
|
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
@ -0,0 +1,823 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import k2
|
||||||
|
import _k2
|
||||||
|
import logging
|
||||||
|
import sentencepiece as spm
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LmDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Torch dataset for language modeling data. This is a map-style dataset.
|
||||||
|
The indices are integers.
|
||||||
|
"""
|
||||||
|
def __init__(self, sentences: k2.RaggedTensor,
|
||||||
|
words: k2.RaggedTensor):
|
||||||
|
super(LmDataset, self).__init__()
|
||||||
|
self.sentences = sentences
|
||||||
|
self.words = words
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Total size on axis 0, == num sentences
|
||||||
|
return self.sentences.tot_size(0)
|
||||||
|
|
||||||
|
def __getitem__(self, i: int):
|
||||||
|
"""
|
||||||
|
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||||
|
bos or eos symbols).
|
||||||
|
"""
|
||||||
|
# in future will just do:
|
||||||
|
#return self.words[self.sentences[i]].tolist()
|
||||||
|
|
||||||
|
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||||
|
# but for now that operator on k2.RaggedInt does not support when the
|
||||||
|
# ragged has only 2 axes.
|
||||||
|
row_splits = self.sentences.shape.row_splits(1)
|
||||||
|
(begin, end) = row_splits[i:i+2].tolist()
|
||||||
|
sentence = self.sentences.data[begin:end]
|
||||||
|
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||||
|
return sentence.data.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
|
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
|
||||||
|
"""
|
||||||
|
returns (train_lm_dataset, test_lm_dataset)
|
||||||
|
"""
|
||||||
|
|
||||||
|
d = torch.load(archive_fn)
|
||||||
|
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||||
|
sentences = d['data'] # a k2.RaggedTensor
|
||||||
|
|
||||||
|
with torch.random.fork_rng(devices=[]):
|
||||||
|
g = torch.manual_seed(0)
|
||||||
|
num_sentences = sentences.tot_size(0)
|
||||||
|
# probably the generator (g) argument to torch.randperm below is not necessary.
|
||||||
|
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
||||||
|
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
|
||||||
|
|
||||||
|
num_test_sentences = int(num_sentences * test_proportion)
|
||||||
|
|
||||||
|
axis=0
|
||||||
|
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
|
||||||
|
test_sents = sentences.arange(axis, 0, num_test_sentences)
|
||||||
|
|
||||||
|
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_and_pad(sentence: List[int],
|
||||||
|
seq_len: int,
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
mask_proportion: float,
|
||||||
|
padding_proportion: float,
|
||||||
|
inv_mask_length: float,
|
||||||
|
unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]:
|
||||||
|
"""
|
||||||
|
This function contains part of the logic of collate_fn, broken out. It is responsible
|
||||||
|
for inserting masking and padding into the sequence `sentence`. Most of the arguments
|
||||||
|
are documented for `collate_fn` below.
|
||||||
|
Other args:
|
||||||
|
sentence: The original sentence to be masked and padded.
|
||||||
|
seq_len: The desired length of the lists to be returned
|
||||||
|
bos_sym, eos_sym, blank_sym, mask_proportion,
|
||||||
|
padding_proportion, inv_mask_length, unmasked_weight: see their documentation
|
||||||
|
as args to `collate_fn` below.
|
||||||
|
|
||||||
|
|
||||||
|
Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`,
|
||||||
|
where:
|
||||||
|
`src` is: [bos] + [the sentence after inserting blanks in place of padding
|
||||||
|
after regions to be masked] + [eos] + [blank padding to seq_len].
|
||||||
|
`src_masked` is as `src` but the masked regions have their values replaced with blank,
|
||||||
|
i.e. they are actually masked.
|
||||||
|
`tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len]
|
||||||
|
`weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked
|
||||||
|
positions, 1.0 for masked and padded positions, and 0.0 for positions that
|
||||||
|
correspond to blank-padding after the final [eos].
|
||||||
|
`randomizable` is a bool that is True for positions where the symbol in
|
||||||
|
in `src_masked` is not bos or eos or blank.
|
||||||
|
`attn_mask` is a bool that is False for positions in `src` and `src_masked` that
|
||||||
|
are between the initial [bos] and final [eos] inclusive; and True for
|
||||||
|
positions after the final [eos].
|
||||||
|
"""
|
||||||
|
sent_len = len(sentence)
|
||||||
|
assert sent_len + 3 <= seq_len
|
||||||
|
|
||||||
|
for w in sentence:
|
||||||
|
assert w not in [bos_sym, eos_sym, blank_sym]
|
||||||
|
|
||||||
|
num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]),
|
||||||
|
prob=torch.tensor([mask_proportion])).item())
|
||||||
|
num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item())
|
||||||
|
# Ensure the total length after bos, padding of masked sequences, and eos, is
|
||||||
|
# no greater than seq_len
|
||||||
|
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
||||||
|
|
||||||
|
if num_mask + num_pad == 0:
|
||||||
|
num_pad += 1
|
||||||
|
|
||||||
|
# num_split_points is the number of times we split the (masked+padded)
|
||||||
|
# region, so the total number of (masking+padding) subsequences will be
|
||||||
|
# num_split_points + 1. If num_mask positions are masked, then the
|
||||||
|
# remaining number of words is `sent_len - num_mask`, and any two
|
||||||
|
# masked regions must have at least one non-masked word between them,
|
||||||
|
# so num_split_points == number of masked regions - 1, must be
|
||||||
|
# no greater than `sent_len - num_mask`. The formula about
|
||||||
|
# mask_proportion * inv_mask_length / (1.0 - mask_proportion)
|
||||||
|
# is what's required (I think) so that inv_mask_length is the expected
|
||||||
|
# length of masked regions.
|
||||||
|
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||||
|
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||||
|
# Somehow this assertion failed, debugging it below.
|
||||||
|
assert num_split_points <= sent_len - num_mask
|
||||||
|
|
||||||
|
assert isinstance(num_split_points, int)
|
||||||
|
|
||||||
|
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
|
||||||
|
"""Splits a sequence of `length` items into `num_subseqs` possibly-empty
|
||||||
|
subsequences. The length distributions are geometric, not Poisson, i.e.
|
||||||
|
we choose the split locations with uniform probability rather than
|
||||||
|
randomly assigning each word to one subsequences. This gives us more
|
||||||
|
shorter/longer subsequences.
|
||||||
|
Require num_subseqs > 0
|
||||||
|
"""
|
||||||
|
boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length]
|
||||||
|
return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ]
|
||||||
|
|
||||||
|
mask_lengths = split_into_subseqs(num_mask, num_split_points + 1)
|
||||||
|
pad_lengths = split_into_subseqs(num_pad, num_split_points + 1)
|
||||||
|
# mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0.
|
||||||
|
# From this point we only refer to the mask_pad_lengths.
|
||||||
|
mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ]
|
||||||
|
num_subseqs = len(mask_pad_lengths)
|
||||||
|
assert num_subseqs > 0
|
||||||
|
|
||||||
|
# Now figure out how to distribute these subsequences throughout the actual
|
||||||
|
# sentence. The subsequences, if there are more than one, must not touch,
|
||||||
|
# i.e. there must be an actual word in between each subsequence, where the
|
||||||
|
# number of such "mandatory" words equals num_subseqs - 1. We also have to
|
||||||
|
# subtract `num_mask` words, since obviously the masked words cannot separate
|
||||||
|
# the masked regions.
|
||||||
|
reduced_len = sent_len - num_mask - (num_subseqs - 1)
|
||||||
|
assert reduced_len >= 0
|
||||||
|
# unmasked_lengths will be the lengths of the un-masked regions between the masked
|
||||||
|
# regions.
|
||||||
|
unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1)
|
||||||
|
for i in range(1, num_subseqs):
|
||||||
|
# Unmasked regions between masked regions must have length at least 1,
|
||||||
|
# we add 1 to unmasked regions that are not initial/final.
|
||||||
|
unmasked_lengths[i] = unmasked_lengths[i] + 1
|
||||||
|
assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len
|
||||||
|
|
||||||
|
|
||||||
|
# src_positions will be: for each position in the masked+padded sentence,
|
||||||
|
# the corresponding position in the source sentence `sentence`; or -1
|
||||||
|
# if this was padding.
|
||||||
|
src_positions = []
|
||||||
|
# `masked` will be: for each position in the masked+padded sentence, True if
|
||||||
|
# it was masked and False otherwise. (Note: it is False for padding
|
||||||
|
# locations, although this will not matter in the end).
|
||||||
|
masked = []
|
||||||
|
|
||||||
|
cur_pos = 0 # current position in source sentence
|
||||||
|
for i in range(num_subseqs + 1):
|
||||||
|
for j in range(unmasked_lengths[i]):
|
||||||
|
src_positions.append(cur_pos)
|
||||||
|
masked.append(False)
|
||||||
|
cur_pos += 1
|
||||||
|
if i < num_subseqs:
|
||||||
|
(mask_len, pad_len) = mask_pad_lengths[i]
|
||||||
|
for j in range(mask_len):
|
||||||
|
src_positions.append(cur_pos)
|
||||||
|
masked.append(True)
|
||||||
|
cur_pos += 1
|
||||||
|
for j in range(pad_len):
|
||||||
|
src_positions.append(-1)
|
||||||
|
masked.append(False)
|
||||||
|
assert cur_pos == len(sentence)
|
||||||
|
|
||||||
|
|
||||||
|
src = []
|
||||||
|
src_masked = []
|
||||||
|
tgt = []
|
||||||
|
weight = []
|
||||||
|
randomizable = []
|
||||||
|
|
||||||
|
src.append(bos_sym)
|
||||||
|
src_masked.append(bos_sym)
|
||||||
|
randomizable.append(False)
|
||||||
|
for i, src_pos in enumerate(src_positions):
|
||||||
|
is_masked = masked[i]
|
||||||
|
if src_pos >= 0:
|
||||||
|
src_word = sentence[src_pos]
|
||||||
|
src_masked.append(blank_sym if masked[i] else src_word)
|
||||||
|
src.append(src_word)
|
||||||
|
tgt.append(src_word)
|
||||||
|
weight.append(1.0 if masked[i] else unmasked_weight)
|
||||||
|
randomizable.append(not masked[i])
|
||||||
|
else:
|
||||||
|
# Padding inside a masked region
|
||||||
|
src_masked.append(blank_sym)
|
||||||
|
src.append(blank_sym)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(1.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
src.append(eos_sym)
|
||||||
|
src_masked.append(eos_sym)
|
||||||
|
tgt.append(eos_sym)
|
||||||
|
weight.append(unmasked_weight)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(0.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
|
||||||
|
attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src)))
|
||||||
|
|
||||||
|
for i in range(seq_len - len(src)):
|
||||||
|
src.append(blank_sym)
|
||||||
|
src_masked.append(blank_sym)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(0.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
|
||||||
|
return (src, src_masked, tgt, weight, randomizable, attn_mask)
|
||||||
|
|
||||||
|
|
||||||
|
# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||||
|
|
||||||
|
# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||||
|
|
||||||
|
def collate_fn(sentences: List[List[int]],
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
mask_proportion: float = 0.15,
|
||||||
|
padding_proportion: float = 0.15,
|
||||||
|
randomize_proportion: float = 0.05,
|
||||||
|
inv_mask_length: float = 0.25,
|
||||||
|
unmasked_weight: float = 0.25,
|
||||||
|
debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Caution, this is not the collate_fn we give directly to the dataloader,
|
||||||
|
we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args]))
|
||||||
|
This formats a list-of-lists-of-int into 5 Tensors, explained below.
|
||||||
|
The key thing is that we mask out subsequences of random length within
|
||||||
|
these sentences, and force the network to predict the masked-out
|
||||||
|
subsequences (which have blanks appended to them to prevent the model
|
||||||
|
from knowing the exact length of the sequences it has to predict).
|
||||||
|
So it's like BERT but at the level of sequences rather than individual
|
||||||
|
words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2.
|
||||||
|
Is allowed be the same as eos_sym (we are not necessarily
|
||||||
|
saying it will work best that way).
|
||||||
|
eos_sym: the integer id of the end-of-sentence symbol, e.g. 2.
|
||||||
|
blank_sym: the integer id of the blank symbol, e.g. 0 or 1.
|
||||||
|
mask_proportion: The proportion of words in each sentence that
|
||||||
|
are masked, interpreted as (roughly) the probability of any given
|
||||||
|
word being masked, although the masked locations will
|
||||||
|
tend to be in contiguous sequences (they are not independent).
|
||||||
|
padding_proportion: Like mask_proportion, but determines the
|
||||||
|
number of extra, blank symbols that are inserted as padding
|
||||||
|
at the end of masked regions (this ensures that the model
|
||||||
|
cannot know exactly how many words need to be inserted in
|
||||||
|
any given masked region.
|
||||||
|
randomize_proportion: The probability with which we replace
|
||||||
|
words that were not masked with randomly chosen words.
|
||||||
|
Like BERT, this is intended to force the model to predict
|
||||||
|
something reasonable at non-masked positions, and to make
|
||||||
|
this task harder than simply repeating the input.
|
||||||
|
inv_mask_length: This number determines how many separate
|
||||||
|
sub-sequences the (masked + padded) proportion of a sentence is split up
|
||||||
|
into, interpreted as the inverse of the expected length of
|
||||||
|
each *masked* region.
|
||||||
|
unmasked_weight: The weight to be applied to the log-likelihoods of
|
||||||
|
un-masked positions in sentences (predicting un-masked
|
||||||
|
positions is not completely trivial if randomize_proportion > 0).
|
||||||
|
Will be reflected in the returned tgt_weights tensor.
|
||||||
|
|
||||||
|
Returns a tuple (masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights),
|
||||||
|
all with 2 axes and the same shape: (num_sent, seq_len).
|
||||||
|
Their dtypes will be, respectively,
|
||||||
|
(torch.int64, torch.int64,
|
||||||
|
torch.int64, torch.bool,
|
||||||
|
torch.float)
|
||||||
|
masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol
|
||||||
|
appended, masked regions (including padding) replaced with blank,
|
||||||
|
and `randomize_proportion` non-masked symbols replaced with
|
||||||
|
symbols randomly taken from elsewhere in the sentences of this
|
||||||
|
minibatch. Then padded to a fixed length with blank.
|
||||||
|
src_symbols: Like masked_src_symbols, except with the masked symbols replaced
|
||||||
|
with the original symbols (but the padding that follows each
|
||||||
|
masked sub-sequence will still be blank)
|
||||||
|
tgt_symbols: The original sentences, with eos_symbol appended, and then
|
||||||
|
padded with blank to the same length as masked_symbols and
|
||||||
|
src_symbols.
|
||||||
|
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
|
||||||
|
account for all the sentence lengths not being identical
|
||||||
|
(makes each sentence's processing independent of seq_len).
|
||||||
|
Tensor of Bool of shape (num_sent, seq_len), with True
|
||||||
|
for masked positions (these are the blanks that follow the
|
||||||
|
eos_symbol in masked_src_symbols), False for un-masked positions.
|
||||||
|
tgt_weights: Weights that will be applied to the log-probabilities at
|
||||||
|
the output of the network. Will have 1.0 in positions
|
||||||
|
in `tgt_symbols` that were masked (including blank
|
||||||
|
padding at the end of masked regions), `unmasked_weight`
|
||||||
|
in other positions in the original sentences (including
|
||||||
|
terminating eos_symbol); and 0.0 in the remaining positions
|
||||||
|
corresponding to blank padding after the ends of
|
||||||
|
sentences.
|
||||||
|
"""
|
||||||
|
assert blank_sym not in [bos_sym, eos_sym]
|
||||||
|
max_sent_len = max([ len(s) for s in sentences])
|
||||||
|
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
|
||||||
|
|
||||||
|
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
||||||
|
|
||||||
|
# The following formula gives roughly 1 standard deviation above where we'd
|
||||||
|
# expect the maximum sentence length to be with masking and padding.. we use
|
||||||
|
# this as a hard upper limit, to prevent outliers from affecting the batch
|
||||||
|
# size too much. We use this as the size `seq_len`.
|
||||||
|
# The "+ 4" is to ensure there is always room for the BOS, EOS and at least
|
||||||
|
# two padding symbols.
|
||||||
|
seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# srcs, srcs_masked, tgts and weights will be lists of the lists returned
|
||||||
|
# from `mask_and_pad`, one per sentence.
|
||||||
|
srcs = []
|
||||||
|
srcs_masked = []
|
||||||
|
tgts = []
|
||||||
|
weights = []
|
||||||
|
randomizables = []
|
||||||
|
attn_masks = []
|
||||||
|
for s in sentences:
|
||||||
|
(src, src_masked, tgt,
|
||||||
|
weight, randomizable,
|
||||||
|
attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym,
|
||||||
|
blank_sym, mask_proportion, padding_proportion,
|
||||||
|
inv_mask_length, unmasked_weight)
|
||||||
|
srcs.append(src)
|
||||||
|
srcs_masked.append(src_masked)
|
||||||
|
tgts.append(tgt)
|
||||||
|
weights.append(weight)
|
||||||
|
randomizables.append(randomizable)
|
||||||
|
attn_masks.append(attn_mask)
|
||||||
|
|
||||||
|
src_symbols = torch.tensor(srcs, dtype=torch.int64)
|
||||||
|
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
|
||||||
|
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
|
||||||
|
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
|
||||||
|
tgt_weights = torch.tensor(weights, dtype=torch.float)
|
||||||
|
|
||||||
|
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
|
||||||
|
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
|
||||||
|
attn_mask_sum.pop()
|
||||||
|
if len(attn_mask_sum) < seq_len:
|
||||||
|
seq_len = len(attn_mask_sum)
|
||||||
|
(src_symbols, masked_src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
|
||||||
|
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
|
||||||
|
tgt_weights[:,:seq_len])
|
||||||
|
|
||||||
|
if randomize_proportion > 0.0:
|
||||||
|
randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool)
|
||||||
|
randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2)
|
||||||
|
num_randomizable = randomizable_indexes.shape[0]
|
||||||
|
|
||||||
|
to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0]
|
||||||
|
num_to_randomize = to_randomize_indexes.numel()
|
||||||
|
|
||||||
|
# older versions of torch don't have tensor_split, so fake a simplified version of it.
|
||||||
|
# we'd be calling it as xxx.tensor_split(dim=1) if really in torc.
|
||||||
|
def tensor_split(t):
|
||||||
|
return (t[:,0], t[:,1])
|
||||||
|
|
||||||
|
random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize]
|
||||||
|
|
||||||
|
random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])]
|
||||||
|
random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes])
|
||||||
|
src_symbols[random_indexes_tuple] = random_symbols
|
||||||
|
masked_src_symbols[random_indexes_tuple] = random_symbols
|
||||||
|
|
||||||
|
|
||||||
|
# I set this to true and tested with:
|
||||||
|
# python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)'
|
||||||
|
#.. and ran a few times to check the values printed looked about right, and that no assertions failed.
|
||||||
|
if debug:
|
||||||
|
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
|
||||||
|
unmasked_weight,
|
||||||
|
masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||||
|
return (masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def check_collated_tensors(sentences: List[List[int]],
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
unmasked_weight: float,
|
||||||
|
masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights):
|
||||||
|
"""
|
||||||
|
This function checks the output of collate_fn, consider it test code. Please see
|
||||||
|
the documentation of collate_fn to understand the args.
|
||||||
|
"""
|
||||||
|
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
|
||||||
|
assert t.shape == masked_src_symbols.shape
|
||||||
|
|
||||||
|
tot_positions = src_symbols.numel()
|
||||||
|
|
||||||
|
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
|
||||||
|
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
|
||||||
|
src_key_padding_mask.tolist(), tgt_weights.tolist())
|
||||||
|
assert len(sentences) == len(masked_src_symbols)
|
||||||
|
|
||||||
|
tot_masked_positions = 0
|
||||||
|
tot_padded_positions = 0
|
||||||
|
tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos
|
||||||
|
tot_randomized_positions = 0
|
||||||
|
num_masked_subseqs = 0
|
||||||
|
tot_symbols = 0 # original symbols in sentences, no bos/eos
|
||||||
|
|
||||||
|
assert unmasked_weight > 0.001 # or this test code won't work..
|
||||||
|
|
||||||
|
for i in range(len(sentences)):
|
||||||
|
reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i]))
|
||||||
|
if sentences[i] != reconstructed_sent:
|
||||||
|
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
|
||||||
|
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
|
||||||
|
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
|
||||||
|
|
||||||
|
assert src[0] == masked_src[0] == bos_sym
|
||||||
|
for j in range(len(masked_src)):
|
||||||
|
assert masked_src[j] == blank_sym or masked_src[j] == src[j]
|
||||||
|
|
||||||
|
if src[j] not in [bos_sym, eos_sym, blank_sym]:
|
||||||
|
tot_symbols += 1
|
||||||
|
|
||||||
|
if j > 0:
|
||||||
|
assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym)
|
||||||
|
if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding..
|
||||||
|
assert src[j] == tgt[j - 1] # masked symbols are not randomized.
|
||||||
|
assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding
|
||||||
|
if weights[j - 1] == 1.0: # Not final blank padding...
|
||||||
|
if tgt[j - 1] == blank_sym:
|
||||||
|
tot_padded_positions += 1
|
||||||
|
else:
|
||||||
|
tot_masked_positions += 1
|
||||||
|
if masked_src[j + 1] != blank_sym:
|
||||||
|
num_masked_subseqs += 1
|
||||||
|
else:
|
||||||
|
assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001
|
||||||
|
if abs(weights[j - 1]-unmasked_weight) < 0.001:
|
||||||
|
tot_unmasked_positions += 1
|
||||||
|
if tgt[j - 1] != src[j]:
|
||||||
|
tot_randomized_positions += 1
|
||||||
|
|
||||||
|
if src_mask[j]: # if masked..
|
||||||
|
assert src[j] == blank_sym
|
||||||
|
|
||||||
|
assert tot_symbols == sum(len(x) for x in sentences)
|
||||||
|
|
||||||
|
assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences)
|
||||||
|
|
||||||
|
print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}")
|
||||||
|
print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)")
|
||||||
|
|
||||||
|
print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}")
|
||||||
|
print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}")
|
||||||
|
print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# This shows some useful code about the BPE encoding.
|
||||||
|
# import sentencepiece as spm
|
||||||
|
# sp = spm.SentencePieceProcessor()
|
||||||
|
# sp.load(bpe_model_fn) # bpe.model
|
||||||
|
# sp.GetPieceSize(..)
|
||||||
|
# sp.Decode(...)
|
||||||
|
# sp.Encode(...)
|
||||||
|
|
||||||
|
|
||||||
|
# import dataset
|
||||||
|
# import torch
|
||||||
|
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
|
||||||
|
|
||||||
|
# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x)))
|
||||||
|
# x = iter(train_dl)
|
||||||
|
# str(next(x))
|
||||||
|
# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]'
|
||||||
|
#
|
||||||
|
# or:
|
||||||
|
# import k2
|
||||||
|
# k2.ragged.to_list(next(x))
|
||||||
|
# [shows something similar].
|
||||||
|
#
|
||||||
|
# You'd really do something like:
|
||||||
|
# for epoch in range(max_epochs):
|
||||||
|
# for minibatch in train_dl:
|
||||||
|
|
||||||
|
|
||||||
|
# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44].
|
||||||
|
#
|
||||||
|
# First: we randomly choose one or more starting positins for a masked segment.
|
||||||
|
# Each sentence must have at least one masked segment (or there is no contribution to the loss function).
|
||||||
|
# We choose to have:
|
||||||
|
# num_masked_segments = max(1, len(sent) // 15)
|
||||||
|
#
|
||||||
|
# The length of the masked segment (this is the target for prediction), we set to the geometric
|
||||||
|
# distribution with the probability of success set to 3:
|
||||||
|
#
|
||||||
|
# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333
|
||||||
|
# Example of sampling:
|
||||||
|
# g.sample(sample_shape=torch.Size([10]))
|
||||||
|
#
|
||||||
|
# We now we randomly compute the location of the masked segments (length computed above) as follows:
|
||||||
|
# First, the masked segments must be separated by at least one non-masked word (else they would be
|
||||||
|
# a single segment). So for n masked segments, there are n-1 words required for minimal separation.
|
||||||
|
# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire
|
||||||
|
# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1
|
||||||
|
# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments,
|
||||||
|
# and at the end). This is the multinomial distribution, but we can more easily compute this
|
||||||
|
# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial().
|
||||||
|
#
|
||||||
|
|
||||||
|
# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions;
|
||||||
|
# this is done so the model never knows the exact length of the masked region. We can just use the
|
||||||
|
# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3
|
||||||
|
# (expected padding length is 3).
|
||||||
|
#
|
||||||
|
# At this point we know where the masked regions are and how much padding they have. We can format
|
||||||
|
# the result as three lists, of the same length:
|
||||||
|
#
|
||||||
|
# sent: contains the words in the sentence with, in masked
|
||||||
|
# positions, the original (target) words, then with
|
||||||
|
# blank in the blank-padding after masked positions.
|
||||||
|
#
|
||||||
|
# sent_augmented: `sent` with, at a small defined percentage of positions
|
||||||
|
# that were *not* masked, the real token replaced with a
|
||||||
|
# token randomly chosen from the tokens in the minibatch.
|
||||||
|
# (like BERT, we use this type of augmentation, so the model
|
||||||
|
# has to predict the original token).
|
||||||
|
#
|
||||||
|
# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except
|
||||||
|
# with masked positions and the blank padding after the masked regions
|
||||||
|
# both replaced with blank.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# The way these will be processed is as follows:
|
||||||
|
#
|
||||||
|
# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones.
|
||||||
|
# sent_in = [bos] + sent_augmented + [eos]
|
||||||
|
#
|
||||||
|
# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although
|
||||||
|
# # we only really care about this in masked regions.
|
||||||
|
# # The extra eos is so that the length is the same as
|
||||||
|
# # masked_sent_in and sent_in.
|
||||||
|
#
|
||||||
|
# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine,
|
||||||
|
# # this is a choice; we can perhaps
|
||||||
|
# # report these 2 parts of the loss
|
||||||
|
# # separately though.
|
||||||
|
# # <-- can also set the last element
|
||||||
|
# # of out_scale to a smaller number, since
|
||||||
|
# # it's a repeated eos.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum
|
||||||
|
# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly
|
||||||
|
# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out
|
||||||
|
# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols.
|
||||||
|
#
|
||||||
|
# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale)
|
||||||
|
# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale)
|
||||||
|
# with: (<blk>,<blk>,<eos>, 0) up to the maximum length of any sentence in the minibatch <- or could use
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # i.e. ones where masked_sent is blank and zeros elsewhere;
|
||||||
|
# # this pertains to positions in `sent_out`.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LmBatchSampler(torch.utils.data.Sampler):
|
||||||
|
"""
|
||||||
|
A sampler that returns a batch of integer indexes as a list, intended for use
|
||||||
|
with class LmDataset. The sentences returned in each batch will all be about
|
||||||
|
the same size, and the batch size is specified as a number of words (we also
|
||||||
|
provide an option that allows you to limit the max memory consumed by transformers)
|
||||||
|
|
||||||
|
Has support for distributed operation.
|
||||||
|
"""
|
||||||
|
def __init__(self, dataset: LmDataset,
|
||||||
|
symbols_per_batch: int,
|
||||||
|
length_ceil: float = 200.0,
|
||||||
|
length_floor: float = 4.0,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: int = None,
|
||||||
|
seed: int = 0,
|
||||||
|
delay_init: bool = False):
|
||||||
|
"""
|
||||||
|
Constructor documentation:
|
||||||
|
dataset: the LmDataset object that we are sampling from. This
|
||||||
|
class does not retain a reference to the LmDataset.
|
||||||
|
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
||||||
|
length_floor: When the sentence length gets less than about this much,
|
||||||
|
the batch size stops increasing inversely with sentence
|
||||||
|
length. Prevent OOM on batches with short sentences.
|
||||||
|
length_ceil: After the sentence length gets more than about
|
||||||
|
this much, the batch size will start decreasing
|
||||||
|
as 1/(sentence-length^2). This is a mechanism to
|
||||||
|
avoid excessive memory consumption in transformers, when
|
||||||
|
sentence length gets long.
|
||||||
|
world_size: The world size for distributed operation; if None,
|
||||||
|
will be worked out from torch.distributed.
|
||||||
|
rank: The rank of this sampler/process for distributed operation; if None,
|
||||||
|
will be worked out from torch.distributed.
|
||||||
|
seed: The random seed
|
||||||
|
delay_init: If true, will omit calling self.set_epoch(0) at the
|
||||||
|
end of the __init__ function. In this case the caller
|
||||||
|
must call set_epoch(0). [Setting this option is necessary
|
||||||
|
to work with data-loader worker processes plus DDP, since
|
||||||
|
set_epoch() will use ddp, which I believe is a no-no prior
|
||||||
|
to initializing data-loaders.]
|
||||||
|
"""
|
||||||
|
self.seed = seed
|
||||||
|
self.symbols_per_batch = symbols_per_batch
|
||||||
|
self.length_floor = length_floor
|
||||||
|
self.quadratic_constant = 1.0 / length_ceil
|
||||||
|
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
||||||
|
|
||||||
|
# a configuration constant we don't expose.
|
||||||
|
self.multiplicative_random_length = 0.05
|
||||||
|
|
||||||
|
# "indexes" is the subset of indexes into LmDataset that this
|
||||||
|
# sampler is reponsible for (all of them, in the non-distributed case).
|
||||||
|
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
||||||
|
|
||||||
|
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
|
||||||
|
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
||||||
|
|
||||||
|
# the sentences this sampler is responsible for, as sequences of words.
|
||||||
|
# It's a ragged tensor of int32
|
||||||
|
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
|
||||||
|
|
||||||
|
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
|
||||||
|
# with their respective lengths, in BPE pieces.
|
||||||
|
sentence_lengths = k2.ragged.index(word_lengths, sentences)
|
||||||
|
del sentences # save memory
|
||||||
|
assert isinstance(sentence_lengths, k2.RaggedTensor)
|
||||||
|
|
||||||
|
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
||||||
|
# support int32.)
|
||||||
|
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# Convert into a simple tensor of float by adding lengths of words.
|
||||||
|
sentence_lengths = sentence_lengths.sum()
|
||||||
|
|
||||||
|
assert isinstance(sentence_lengths, torch.Tensor)
|
||||||
|
assert sentence_lengths.dtype == torch.float32
|
||||||
|
|
||||||
|
# self.sentence_lengths is a Tensor with dtype=torch.float32. It
|
||||||
|
# contains the lengths, in BPE tokens, of the sentences that this
|
||||||
|
# sampler is responsible for, whose real indexes are in
|
||||||
|
# `data_indexes` above (this is not stored, as we know the formula).
|
||||||
|
self.sentence_lengths = sentence_lengths
|
||||||
|
|
||||||
|
if not delay_init:
|
||||||
|
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||||
|
|
||||||
|
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
|
||||||
|
# Calling this on all copies of a DDP setup will sync the sizes so that
|
||||||
|
# all copies have the exact same number of batches. I think
|
||||||
|
# this needs to be called with the GPU device, not sure if it would
|
||||||
|
# work otherwise.
|
||||||
|
if self.world_size > 1:
|
||||||
|
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
|
||||||
|
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
|
||||||
|
min_size = min_size.to('cpu').item()
|
||||||
|
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
|
||||||
|
self.batch_indices = self.batch_indices[0:min_size]
|
||||||
|
|
||||||
|
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
||||||
|
if world_size is not None:
|
||||||
|
assert world_size >= 1
|
||||||
|
if rank is not None:
|
||||||
|
assert rank >= 0
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
self.world_size = 1 if world_size is None else world_size
|
||||||
|
self.rank = 0 if rank is None else rank
|
||||||
|
return
|
||||||
|
self.world_size = dist.get_world_size() if world_size is None else world_size
|
||||||
|
self.rank = dist.get_rank() if rank is None else rank
|
||||||
|
assert self.rank < self.world_size
|
||||||
|
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
"""
|
||||||
|
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
||||||
|
to re-shuffle the data. If this is not done, this sampler will give you the same batches
|
||||||
|
each time it is called.
|
||||||
|
"""
|
||||||
|
g = torch.manual_seed(self.rank + self.seed + epoch)
|
||||||
|
|
||||||
|
sentence_lengths = (self.sentence_lengths *
|
||||||
|
(1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length))
|
||||||
|
|
||||||
|
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
||||||
|
# when the sentences are long.
|
||||||
|
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
|
||||||
|
|
||||||
|
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
||||||
|
|
||||||
|
# map to the original indexes into the dataset (the original sentence
|
||||||
|
# indexes), see torch.arange expression in the constructor. save as
|
||||||
|
# int32 just to save a little memory. self.indices are indexes into the
|
||||||
|
# LmDataset, just including the subset of indices that this sampler is
|
||||||
|
# responsible for (in terms of rank and world_size), and sorted by
|
||||||
|
# length with a small amount of randomization specific to the epoch.
|
||||||
|
self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
||||||
|
# saying which batch each element of values/indices belongs to.
|
||||||
|
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
||||||
|
batch_boundaries.add_(1)
|
||||||
|
self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0)
|
||||||
|
|
||||||
|
num_batches = self.batch_boundaries.numel() - 1
|
||||||
|
|
||||||
|
# self.batch_indices is a permutation of [0, 1, ... num_batches -
|
||||||
|
# 1]; it determines the order in which we access the batches. It's
|
||||||
|
# necessary to randomize the order of these, to avoid returning batches
|
||||||
|
# from shortest to longest sentences.
|
||||||
|
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
||||||
|
self._sync_sizes()
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_indices)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Iterator that yields lists of indices (i.e., integer indices into the LmDataset)
|
||||||
|
"""
|
||||||
|
for batch_idx in self.batch_indices:
|
||||||
|
batch_start = self.batch_boundaries[batch_idx].item()
|
||||||
|
batch_end = self.batch_boundaries[batch_idx + 1].item()
|
||||||
|
yield self.indices[batch_start:batch_end].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class CollateFn:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.extra_args = kwargs
|
||||||
|
|
||||||
|
def __call__(self, sentences: List[List[int]]):
|
||||||
|
return collate_fn(sentences, **self.extra_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
||||||
|
# a = iter(sampler)
|
||||||
|
# print(str(next(a)))
|
||||||
|
|
||||||
|
# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
||||||
|
# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
||||||
|
# x = iter(train_dl)
|
||||||
|
# print(str(next(x)))
|
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
File diff suppressed because it is too large
Load Diff
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# run with:
|
||||||
|
# python3 -m pytest test_conformer.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import dataset # from .
|
||||||
|
from conformer import (
|
||||||
|
RelPosTransformerDecoder,
|
||||||
|
RelPosTransformerDecoderLayer,
|
||||||
|
MaskedLmConformer,
|
||||||
|
MaskedLmConformerEncoder,
|
||||||
|
MaskedLmConformerEncoderLayer,
|
||||||
|
RelPositionMultiheadAttention,
|
||||||
|
RelPositionalEncoding,
|
||||||
|
generate_square_subsequent_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def test_rel_position_multihead_attention():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer_encoder_layer():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer_encoder():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||||
|
norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
|
||||||
|
norm=norm)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_decoder_layer_rel_pos():
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
attn_mask = generate_square_subsequent_mask(T)
|
||||||
|
memory = torch.randn(T, N, C)
|
||||||
|
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_decoder_rel_pos():
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||||
|
decoder_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
attn_mask = generate_square_subsequent_mask(T)
|
||||||
|
memory = torch.randn(T, N, C)
|
||||||
|
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer():
|
||||||
|
|
||||||
|
num_classes = 87
|
||||||
|
d_model = 256
|
||||||
|
|
||||||
|
model = MaskedLmConformer(num_classes,d_model)
|
||||||
|
|
||||||
|
|
||||||
|
N = 31
|
||||||
|
|
||||||
|
|
||||||
|
(masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
|
||||||
|
blank_sym=0)
|
||||||
|
|
||||||
|
# test forward() of MaskedLmConformer
|
||||||
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
|
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
|
||||||
|
src_key_padding_mask)
|
||||||
|
print("nll = ", nll)
|
||||||
|
loss = (nll * tgt_weights).sum()
|
||||||
|
print("loss = ", loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_square_subsequent_mask():
|
||||||
|
s = 5
|
||||||
|
mask = generate_square_subsequent_mask(s, torch.device('cpu'))
|
||||||
|
inf = float("inf")
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.0, -inf, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(torch.eq(mask, expected_mask))
|
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
#mp.set_start_method('spawn')
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||||
|
print("len(sampler) = ", len(sampler))
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=2)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
from dataset import LmDataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||||
|
|
||||||
|
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
words = k2.RaggedInt('[[0][1 2]]')
|
||||||
|
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||||
|
|
||||||
|
train = LmDataset(sentences, words)
|
||||||
|
|
||||||
|
|
||||||
|
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=0)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
@ -0,0 +1,623 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import dataset # from .
|
||||||
|
import madam # from .
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import MaskedLmConformer
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from madam import Gloam
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
is saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- exp_dir: It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
|
||||||
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||||
|
and continue training from that checkpoint.
|
||||||
|
|
||||||
|
- num_epochs: Number of epochs to train.
|
||||||
|
|
||||||
|
- num_valid_batches: Number of batches of validation data to use each
|
||||||
|
time we compute validation loss
|
||||||
|
|
||||||
|
- symbols_per_batch: Number of symbols in each batch (sampler will
|
||||||
|
choose the number of sentences to satisfy this contraint).
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
||||||
|
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||||
|
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||||
|
# as well as the exponential part.
|
||||||
|
# exp_6, we change the decay from 0.85 to 0.9.
|
||||||
|
"exp_dir": Path("conformer_lm/exp_6"),
|
||||||
|
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
|
"num_tokens": 5000,
|
||||||
|
"blank_sym": 0,
|
||||||
|
"bos_sym": 1,
|
||||||
|
"eos_sym": 1,
|
||||||
|
"start_epoch": 2,
|
||||||
|
"num_epochs": 20,
|
||||||
|
"num_valid_batches": 200,
|
||||||
|
"symbols_per_batch": 5000,
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": 0,
|
||||||
|
"log_interval": 10,
|
||||||
|
"reset_interval": 200,
|
||||||
|
"valid_interval": 3000,
|
||||||
|
"beam_size": 10,
|
||||||
|
"accum_grad": 1,
|
||||||
|
"attention_dim": 512,
|
||||||
|
"nhead": 8,
|
||||||
|
"num_decoder_layers": 6,
|
||||||
|
"max_lrate": 5.0e-04
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is positive, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
|
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
optimizer:
|
||||||
|
The optimizer that we are using.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler we are using.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if params.start_epoch <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
saved_params = load_checkpoint(
|
||||||
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint_impl(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
model: nn.Module,
|
||||||
|
batch: Tuple,
|
||||||
|
is_training: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute training or validation loss given the model and its inputs
|
||||||
|
(this corresponds to log-prob of the targets, with weighting
|
||||||
|
of 1.0 for masked subsequences
|
||||||
|
(including padding blanks), and something smaller, e.g. 0.25,
|
||||||
|
for non-masked positions (this is not totally trivial due to
|
||||||
|
a small amount of randomization of symbols).
|
||||||
|
|
||||||
|
This loss is not normalized; you can divide by batch[4].sum()
|
||||||
|
to get a normalized loss (i.e. divide by soft-count).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training. It is an instance of MaskedLmConformer in our case.
|
||||||
|
batch:
|
||||||
|
A batch of data, actually a tuple of 5 tensors (on the device), as returned
|
||||||
|
by collate_fn in ./dataset.py.
|
||||||
|
is_training:
|
||||||
|
True for training. False for validation. When it is True, this
|
||||||
|
function enables autograd during computation; when it is False, it
|
||||||
|
disables autograd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns the loss as a scalar tensor.
|
||||||
|
"""
|
||||||
|
(masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights) = batch
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
|
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
|
||||||
|
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask)
|
||||||
|
loss = (tgt_nll * tgt_weights).sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Run the validation process. The validation loss
|
||||||
|
is saved in `params.valid_loss`.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tot_loss = 0.0
|
||||||
|
tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
if batch_idx == params.num_valid_batches:
|
||||||
|
break
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
|
||||||
|
loss = compute_loss(model, batch, is_training=False)
|
||||||
|
num_frames = batch[4].sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = num_frames.cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
s = torch.tensor(
|
||||||
|
[tot_loss, tot_frames],
|
||||||
|
device=loss.device,
|
||||||
|
)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
(tot_loss, tot_frames) = s.cpu().tolist()
|
||||||
|
|
||||||
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if params.valid_loss < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = params.valid_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device:
|
||||||
|
The device to use for training (model must be on this device)
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
optimizer:
|
||||||
|
The optimizer we are using.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
"""
|
||||||
|
model.train() # training mode
|
||||||
|
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
params.tot_loss = 0.0
|
||||||
|
params.tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss = compute_loss(
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
# We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||||
|
# gradient scale so this should not matter.
|
||||||
|
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Error on batch of shape (N,T) = {batch[0].shape}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = batch[4].sum().cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
params.tot_frames += num_frames_cpu
|
||||||
|
params.tot_loss += loss_cpu
|
||||||
|
|
||||||
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if batch_idx % params.log_interval == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
|
||||||
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
|
f"batch shape: {tuple(batch[0].shape)}")
|
||||||
|
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / num_frames_cpu,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
|
compute_validation_loss(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_loss",
|
||||||
|
params.valid_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.train_loss = params.tot_loss / params.tot_frames
|
||||||
|
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def run(rank, world_size, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
num_tokens = params.num_tokens
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = MaskedLmConformer(
|
||||||
|
num_classes=params.num_tokens,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
|
# Caution: don't forget to do optimizer.set_epoch() with Gloam!
|
||||||
|
# Don't remove this warning!
|
||||||
|
optimizer = Gloam(
|
||||||
|
model.parameters(),
|
||||||
|
max_lrate=params.max_lrate,
|
||||||
|
first_decrease_epoch=1,
|
||||||
|
decay_per_epoch=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoints:
|
||||||
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
|
||||||
|
|
||||||
|
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
|
||||||
|
eos_sym=params.eos_sym,
|
||||||
|
blank_sym=params.blank_sym,
|
||||||
|
mask_proportion=0.15,
|
||||||
|
padding_proportion=0.15,
|
||||||
|
randomize_proportion=0.05,
|
||||||
|
inv_mask_length=0.25,
|
||||||
|
unmasked_weight=0.25)
|
||||||
|
|
||||||
|
train_sampler = dataset.LmBatchSampler(train,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
test_sampler = dataset.LmBatchSampler(test,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
valid_dl = torch.utils.data.DataLoader(test,
|
||||||
|
batch_sampler=test_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
|
||||||
|
# optimizer.
|
||||||
|
|
||||||
|
cur_lr = optimizer._rate
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -20,6 +20,11 @@ from typing import Dict, List, Optional, Union
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from icefall.lm.rescore import (
|
||||||
|
compute_alignment,
|
||||||
|
make_hyp_to_ref_map,
|
||||||
|
prepare_conformer_lm_inputs,
|
||||||
|
)
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -904,3 +909,198 @@ def rescore_with_attention_decoder(
|
|||||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
ans[key] = best_path
|
ans[key] = best_path
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_conformer_lm(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
masked_lm_model: torch.nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: Optional[torch.Tensor],
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
blank_id: int,
|
||||||
|
nbest_scale: float = 1.0,
|
||||||
|
ngram_lm_scale: Optional[float] = None,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
|
masked_lm_scale: Optional[float] = None,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||||
|
an attention decoder to rescore them. The path with the highest score is
|
||||||
|
the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec with axes [utt][state][arc].
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `(T, N, C)`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape `(N, T)`.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to `lattice.scores`. A smaller value
|
||||||
|
leads to more unique paths at the risk of missing the correct path.
|
||||||
|
ngram_lm_scale:
|
||||||
|
Optional. It specifies the scale for n-gram LM scores.
|
||||||
|
attention_scale:
|
||||||
|
Optional. It specifies the scale for attention decoder scores.
|
||||||
|
masked_lm_scale:
|
||||||
|
Optional. It specifies the scale for conformer_lm scores.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each utterance in the lattice.
|
||||||
|
"""
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# nbest.fsa.scores are all 0s at this point
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa has its scores set.
|
||||||
|
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||||
|
assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
|
am_scores = nbest.compute_am_scores()
|
||||||
|
ngram_lm_scores = nbest.compute_lm_scores()
|
||||||
|
|
||||||
|
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||||
|
assert hasattr(nbest.fsa, "tokens")
|
||||||
|
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||||
|
|
||||||
|
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||||
|
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||||
|
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||||
|
|
||||||
|
if memory_key_padding_mask is not None:
|
||||||
|
# The shape of memory_key_padding_mask is (N, T), so we
|
||||||
|
# use axis=0 here.
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_utt_map
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||||
|
tokens = tokens.remove_values_leq(0)
|
||||||
|
|
||||||
|
alignment = compute_alignment(tokens, nbest.shape)
|
||||||
|
(
|
||||||
|
masked_src_symbols,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
tgt_weights,
|
||||||
|
) = prepare_conformer_lm_inputs(
|
||||||
|
alignment,
|
||||||
|
bos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=blank_id,
|
||||||
|
unmasked_weight=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||||
|
src_symbols = src_symbols.to(torch.int64)
|
||||||
|
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||||
|
|
||||||
|
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||||
|
masked_src_symbols, src_key_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
tgt_nll = masked_lm_model.decoder_nll(
|
||||||
|
masked_lm_memory,
|
||||||
|
masked_lm_pos_emb,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# nll means negative log-likelihood
|
||||||
|
# ll means log-likelihood
|
||||||
|
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||||
|
|
||||||
|
# Note: log-likelihood for those pairs that have identical src/tgt are 0
|
||||||
|
# since their tgt_weights is 0
|
||||||
|
|
||||||
|
# TODO(fangjun): Add documentation about why we do the following
|
||||||
|
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
|
||||||
|
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=None,
|
||||||
|
row_ids=tgt_ll_shape_row_ids,
|
||||||
|
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||||
|
)
|
||||||
|
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||||
|
|
||||||
|
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||||
|
masked_lm_scores = ragged_tgt_ll.max()
|
||||||
|
|
||||||
|
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||||
|
token_ids = tokens.tolist()
|
||||||
|
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
if masked_lm_scale is None:
|
||||||
|
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
masked_lm_scale_list = [masked_lm_scale]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
for m_scale in masked_lm_scale_list:
|
||||||
|
tot_scores = (
|
||||||
|
am_scores.values
|
||||||
|
+ n_scale * ngram_lm_scores.values
|
||||||
|
+ a_scale * attention_scores
|
||||||
|
+ m_scale * masked_lm_scores
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_masked_lm_scale_{m_scale}" # noqa
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
363
icefall/lm/rescore.py
Normal file
363
icefall/lm/rescore.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file contains rescoring code for NN LMs, e.g., conformer LM.
|
||||||
|
|
||||||
|
Here are the ideas about preparing the inputs for the conformer LM model
|
||||||
|
from an Nbest object.
|
||||||
|
|
||||||
|
Given an Nbest object `nbest`, we have:
|
||||||
|
- nbest.fsa
|
||||||
|
- nbest.shape, whose axes are [utt][path]
|
||||||
|
|
||||||
|
We can get `tokens` from nbest.fsa. The resulting `tokens` will have
|
||||||
|
2 axes [path][token]. Note, we should remove 0s from `tokens`.
|
||||||
|
|
||||||
|
We can generate the following inputs for the conformer LM model from `tokens`:
|
||||||
|
- masked_src
|
||||||
|
- src
|
||||||
|
- tgt
|
||||||
|
by using `k2.levenshtein_alignment`.
|
||||||
|
|
||||||
|
TODO(fangjun): Add more doc about rescoring with masked conformer-lm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_key_padding_mask(lengths: torch.Tensor):
|
||||||
|
"""
|
||||||
|
TODO: add documentation
|
||||||
|
|
||||||
|
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
|
||||||
|
tensor([[False, False, False, True],
|
||||||
|
[False, True, True, True],
|
||||||
|
[False, False, False, False]])
|
||||||
|
"""
|
||||||
|
assert lengths.dim() == 1
|
||||||
|
|
||||||
|
bs = lengths.numel()
|
||||||
|
max_len = lengths.max().item()
|
||||||
|
device = lengths.device
|
||||||
|
seq_range = torch.arange(0, max_len, device=device)
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||||
|
|
||||||
|
seq_length_expand = lengths.unsqueeze(-1)
|
||||||
|
mask = seq_range_expand >= seq_length_expand
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def concat(
|
||||||
|
ragged: k2.RaggedTensor, value: int, direction: str
|
||||||
|
) -> k2.RaggedTensor:
|
||||||
|
"""Prepend a value to the beginning of each sublist or append a value.
|
||||||
|
to the end of each sublist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ragged:
|
||||||
|
A ragged tensor with two axes.
|
||||||
|
value:
|
||||||
|
The value to prepend or append.
|
||||||
|
direction:
|
||||||
|
It can be either "left" or "right". If it is "left", we
|
||||||
|
prepend the value to the beginning of each sublist;
|
||||||
|
if it is "right", we append the value to the end of each
|
||||||
|
sublist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new ragged tensor, whose sublists either start with
|
||||||
|
or end with the given value.
|
||||||
|
|
||||||
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||||
|
>>> a
|
||||||
|
[ [ 1 3 ] [ 5 ] ]
|
||||||
|
>>> concat(a, value=0, direction="left")
|
||||||
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||||
|
>>> concat(a, value=0, direction="right")
|
||||||
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
dtype = ragged.dtype
|
||||||
|
device = ragged.device
|
||||||
|
|
||||||
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
||||||
|
pad_values = torch.full(
|
||||||
|
size=(ragged.tot_size(0), 1),
|
||||||
|
fill_value=value,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
pad = k2.RaggedTensor(pad_values)
|
||||||
|
|
||||||
|
if direction == "left":
|
||||||
|
ans = k2.ragged.cat([pad, ragged], axis=1)
|
||||||
|
elif direction == "right":
|
||||||
|
ans = k2.ragged.cat([ragged, pad], axis=1)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Unsupported direction: {direction}. " \
|
||||||
|
"Expect either "left" or "right"'
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def add_bos(ragged: k2.RaggedTensor, bos_id: int) -> k2.RaggedTensor:
|
||||||
|
"""Add BOS to each sublist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ragged:
|
||||||
|
A ragged tensor with two axes.
|
||||||
|
bos_id:
|
||||||
|
The ID of the BOS symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new ragged tensor, where each sublist starts with BOS.
|
||||||
|
|
||||||
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||||
|
>>> a
|
||||||
|
[ [ 1 3 ] [ 5 ] ]
|
||||||
|
>>> add_bos(a, bos_id=0)
|
||||||
|
[ [ 0 1 3 ] [ 0 5 ] ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
return concat(ragged, bos_id, direction="left")
|
||||||
|
|
||||||
|
|
||||||
|
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||||
|
"""Add EOS to each sublist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ragged:
|
||||||
|
A ragged tensor with two axes.
|
||||||
|
bos_id:
|
||||||
|
The ID of the EOS symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new ragged tensor, where each sublist ends with EOS.
|
||||||
|
|
||||||
|
>>> a = k2.RaggedTensor([[1, 3], [5]])
|
||||||
|
>>> a
|
||||||
|
[ [ 1 3 ] [ 5 ] ]
|
||||||
|
>>> add_eos(a, eos_id=0)
|
||||||
|
[ [ 1 3 0 ] [ 5 0 ] ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
return concat(ragged, eos_id, direction="right")
|
||||||
|
|
||||||
|
|
||||||
|
def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
||||||
|
"""
|
||||||
|
TODO: Add documentation.
|
||||||
|
|
||||||
|
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||||
|
>>> make_hyp_to_ref_map(row_splits)
|
||||||
|
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32)
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = row_splits.device
|
||||||
|
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||||
|
offsets = row_splits[:-1]
|
||||||
|
|
||||||
|
map_tensor_list = []
|
||||||
|
for size, offset in zip(sizes, offsets):
|
||||||
|
# Explanation of the following operations
|
||||||
|
# assume size is 3, offset is 2
|
||||||
|
# torch.arange() + offset is [2, 3, 4]
|
||||||
|
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
||||||
|
# t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]]
|
||||||
|
# reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4]
|
||||||
|
map_tensor = (
|
||||||
|
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||||
|
.expand(size, size)
|
||||||
|
.t()
|
||||||
|
.reshape(-1)
|
||||||
|
)
|
||||||
|
map_tensor_list.append(map_tensor)
|
||||||
|
|
||||||
|
return torch.cat(map_tensor_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_repeat_map(row_splits: torch.Tensor):
|
||||||
|
"""
|
||||||
|
TODO: Add documentation.
|
||||||
|
|
||||||
|
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||||
|
>>> make_repeat_map(row_splits)
|
||||||
|
tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 3, 4], dtype=torch.int32)
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = row_splits.device
|
||||||
|
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||||
|
offsets = row_splits[:-1]
|
||||||
|
|
||||||
|
map_tensor_list = []
|
||||||
|
for size, offset in zip(sizes, offsets):
|
||||||
|
# Explanation of the following operations
|
||||||
|
# assume size is 3, offset is 2
|
||||||
|
# torch.arange() + offset is [2, 3, 4]
|
||||||
|
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
||||||
|
# reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4]
|
||||||
|
map_tensor = (
|
||||||
|
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||||
|
.expand(size, size)
|
||||||
|
.reshape(-1)
|
||||||
|
)
|
||||||
|
map_tensor_list.append(map_tensor)
|
||||||
|
|
||||||
|
return torch.cat(map_tensor_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
||||||
|
"""Repeat the number of paths of an utterance to the number that
|
||||||
|
equals to the number of paths in the utterance.
|
||||||
|
|
||||||
|
For instance, if an utterance contains 3 paths: [path1 path2 path3],
|
||||||
|
after repeating, this utterance will contain 9 paths:
|
||||||
|
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
|
||||||
|
|
||||||
|
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
|
||||||
|
>>> tokens
|
||||||
|
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
||||||
|
>>> make_repeat(tokens)
|
||||||
|
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
|
||||||
|
|
||||||
|
TODO: Add documentation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
|
||||||
|
if True:
|
||||||
|
indexes = make_repeat_map(tokens.shape.row_splits(1))
|
||||||
|
return tokens.index(axis=1, indexes=indexes)[0]
|
||||||
|
else:
|
||||||
|
# This branch produces the same result as the above branch.
|
||||||
|
# It's more readable. Will remove it later.
|
||||||
|
repeated = []
|
||||||
|
for p in tokens.tolist():
|
||||||
|
repeated.append(p * len(p))
|
||||||
|
return k2.RaggedTensor(repeated).to(tokens.device)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_alignment(
|
||||||
|
tokens: k2.RaggedTensor,
|
||||||
|
shape: k2.RaggedShape,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
TODO: Add documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
A ragged tensor with two axes: [path][token].
|
||||||
|
shape:
|
||||||
|
A ragged shape with two axes: [utt][path]
|
||||||
|
"""
|
||||||
|
assert tokens.tot_size(0) == shape.tot_size(1)
|
||||||
|
device = tokens.device
|
||||||
|
utt_path_shape = shape.compose(tokens.shape)
|
||||||
|
utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values)
|
||||||
|
utt_path_token_repeated = make_repeat(utt_path_token)
|
||||||
|
path_token_repeated = utt_path_token_repeated.remove_axis(0)
|
||||||
|
|
||||||
|
refs = k2.levenshtein_graph(tokens, device=device)
|
||||||
|
hyps = k2.levenshtein_graph(path_token_repeated, device=device)
|
||||||
|
|
||||||
|
hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1))
|
||||||
|
alignment = k2.levenshtein_alignment(
|
||||||
|
refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map
|
||||||
|
)
|
||||||
|
return alignment
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_conformer_lm_inputs(
|
||||||
|
alignment: k2.Fsa,
|
||||||
|
bos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
blank_id: int,
|
||||||
|
unmasked_weight: float = 0.25,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
TODO: add documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments:
|
||||||
|
It is computed by :func:`compute_alignment`
|
||||||
|
"""
|
||||||
|
device = alignment.device
|
||||||
|
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||||
|
# we remove axis 1, i.e., state, here
|
||||||
|
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
|
||||||
|
masked_src = masked_src.remove_values_eq(-1)
|
||||||
|
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
|
||||||
|
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
|
||||||
|
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
|
||||||
|
mode="constant", padding_value=blank_id
|
||||||
|
)
|
||||||
|
|
||||||
|
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
|
||||||
|
src = src.remove_values_eq(-1)
|
||||||
|
bos_src = add_bos(src, bos_id=bos_id)
|
||||||
|
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
||||||
|
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
|
||||||
|
# TODO: Do we need to remove 0s from tgt ?
|
||||||
|
tgt = tgt.remove_values_eq(-1)
|
||||||
|
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
||||||
|
|
||||||
|
# add a blank here since tgt_eos does not start with bos
|
||||||
|
# assume blank id is 0
|
||||||
|
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
|
||||||
|
|
||||||
|
row_splits = tgt_eos.shape.row_splits(1)
|
||||||
|
lengths = row_splits[1:] - row_splits[:-1]
|
||||||
|
src_key_padding_mask = make_key_padding_mask(lengths)
|
||||||
|
|
||||||
|
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
weight = torch.full(
|
||||||
|
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||||
|
fill_value=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# find unmasked positions
|
||||||
|
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
|
||||||
|
weight[unmasked_positions] = unmasked_weight
|
||||||
|
|
||||||
|
# set weights for paddings
|
||||||
|
weight[src_key_padding_mask[:, 1:]] = 0
|
||||||
|
zeros = torch.zeros(weight.size(0), 1).to(weight)
|
||||||
|
|
||||||
|
weight = torch.cat((weight, zeros), dim=1)
|
||||||
|
|
||||||
|
# all other positions are assumed to be masked and
|
||||||
|
# have the default weight 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
bos_masked_src_eos_pad,
|
||||||
|
bos_src_eos_pad,
|
||||||
|
tgt_eos_pad,
|
||||||
|
src_key_padding_mask,
|
||||||
|
weight,
|
||||||
|
)
|
145
test/lm/test_rescore.py
Executable file
145
test/lm/test_rescore.py
Executable file
@ -0,0 +1,145 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lm.rescore import (
|
||||||
|
add_bos,
|
||||||
|
add_eos,
|
||||||
|
compute_alignment,
|
||||||
|
make_hyp_to_ref_map,
|
||||||
|
make_repeat,
|
||||||
|
make_repeat_map,
|
||||||
|
prepare_conformer_lm_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_bos():
|
||||||
|
bos_id = 100
|
||||||
|
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
|
||||||
|
bos_ragged = add_bos(ragged, bos_id)
|
||||||
|
expected = k2.RaggedTensor([[bos_id, 1, 2], [bos_id, 3], [bos_id, 0]])
|
||||||
|
assert str(bos_ragged) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_eos():
|
||||||
|
eos_id = 30
|
||||||
|
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
|
||||||
|
ragged_eos = add_eos(ragged, eos_id)
|
||||||
|
expected = k2.RaggedTensor(
|
||||||
|
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
|
||||||
|
)
|
||||||
|
assert str(ragged_eos) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad():
|
||||||
|
bos_id = 10
|
||||||
|
eos_id = 100
|
||||||
|
ragged = k2.RaggedTensor([[1, 2, 3], [5], [9, 8]])
|
||||||
|
bos_ragged = add_bos(ragged, bos_id)
|
||||||
|
bos_ragged_eos = add_eos(bos_ragged, eos_id)
|
||||||
|
blank_id = -1
|
||||||
|
padded = bos_ragged_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
expected = torch.tensor(
|
||||||
|
[
|
||||||
|
[bos_id, 1, 2, 3, eos_id],
|
||||||
|
[bos_id, 5, eos_id, blank_id, blank_id],
|
||||||
|
[bos_id, 9, 8, eos_id, blank_id],
|
||||||
|
]
|
||||||
|
).to(padded)
|
||||||
|
assert torch.all(torch.eq(padded, expected))
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_hyp_to_ref_map():
|
||||||
|
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||||
|
row_splits = a.shape.row_splits(1)
|
||||||
|
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||||
|
# fmt: off
|
||||||
|
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
||||||
|
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
|
||||||
|
# fmt: on
|
||||||
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_repeat_map():
|
||||||
|
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||||
|
row_splits = a.shape.row_splits(1)
|
||||||
|
repeat_map = make_repeat_map(row_splits)
|
||||||
|
# fmt: off
|
||||||
|
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
||||||
|
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
|
||||||
|
3, 4, 5, 6]).to(repeat_map) # noqa
|
||||||
|
# fmt: on
|
||||||
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_repeat():
|
||||||
|
# fmt: off
|
||||||
|
a = k2.RaggedTensor([
|
||||||
|
[[1, 3, 5], [2, 6]],
|
||||||
|
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
||||||
|
])
|
||||||
|
b = make_repeat(a)
|
||||||
|
expected = k2.RaggedTensor([
|
||||||
|
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]],
|
||||||
|
[[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||||
|
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||||
|
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
||||||
|
[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
assert str(b) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_alignment():
|
||||||
|
# fmt: off
|
||||||
|
tokens = k2.RaggedTensor([
|
||||||
|
# utt 0
|
||||||
|
[1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
|
||||||
|
# utt 1
|
||||||
|
[2, 3], [2],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
shape = k2.RaggedShape("[[x x x] [x x]]")
|
||||||
|
alignment = compute_alignment(tokens, shape)
|
||||||
|
(
|
||||||
|
masked_src,
|
||||||
|
src,
|
||||||
|
tgt,
|
||||||
|
src_key_padding_mask,
|
||||||
|
weight,
|
||||||
|
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
|
||||||
|
|
||||||
|
# print("masked src", masked_src)
|
||||||
|
# print("src", src)
|
||||||
|
# print("tgt", tgt)
|
||||||
|
# print("src_key_padding_mask", src_key_padding_mask)
|
||||||
|
# print("weight", weight)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_add_bos()
|
||||||
|
test_add_eos()
|
||||||
|
test_pad()
|
||||||
|
test_make_repeat_map()
|
||||||
|
test_make_hyp_to_ref_map()
|
||||||
|
test_make_repeat()
|
||||||
|
test_compute_alignment()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user