diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index d3952d3b1..a00664a99 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -84,7 +84,7 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity - def encode( + def run_encoder( self, x: Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[Tensor, Optional[Tensor]]: """ @@ -802,7 +802,8 @@ class RelPositionMultiheadAttention(nn.Module): bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len @@ -872,7 +873,12 @@ class ConvolutionModule(nn.Module): ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, ) self.activation = Swish() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 9ebb76fa1..0611814f6 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( @@ -85,7 +86,7 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # "method": "whole-lattice-rescoring", - "method": "1best", + "method": "attention-decoder", # num_paths is used when method is "nbest", "nbest-rescoring", # and attention-decoder "num_paths": 100, @@ -100,6 +101,8 @@ def decode_one_batch( HLG: k2.Fsa, batch: dict, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -133,6 +136,10 @@ def decode_one_batch( for the format of the `batch`. lexicon: It contains word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -222,8 +229,8 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, - sos_id=lexicon.sos_id, - eos_id=lexicon.eos_id, + sos_id=sos_id, + eos_id=eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -242,6 +249,8 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: """Decode dataset. @@ -257,6 +266,10 @@ def decode_dataset( The decoding graph. lexicon: It contains word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -284,6 +297,8 @@ def decode_dataset( batch=batch, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): @@ -364,6 +379,15 @@ def main(): logging.info(f"device: {device}") + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -456,6 +480,8 @@ def main(): HLG=HLG, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) save_results( diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py index a6569e8d7..08e680607 100644 --- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -1,7 +1,16 @@ #!/usr/bin/env python3 import torch -from transformer import Transformer, encoder_padding_mask +from transformer import ( + Transformer, + encoder_padding_mask, + generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, +) + +from torch.nn.utils.rnn import pad_sequence def test_encoder_padding_mask(): @@ -34,3 +43,47 @@ def test_transformer(): x = torch.rand(N, T, num_features) y, _, _ = model(x) assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index b2123b8fc..a974be4e0 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -10,6 +10,7 @@ import torch.nn as nn from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import get_texts +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -177,14 +178,17 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] - encoder_memory, memory_key_padding_mask = self.encode(x, supervision) - x = self.encoder_output(encoder_memory) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask - def encode( + def run_encoder( self, x: torch.Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ + """Run the transformer encoder. + Args: x: The model input. Its shape is [N, T, C]. @@ -194,8 +198,8 @@ class Transformer(nn.Module): CAUTION: It contains length information, i.e., start and number of frames, before subsampling It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. Returns: Return a tuple with two tensors: - The encoder output, with shape [T, N, C] @@ -212,7 +216,7 @@ class Transformer(nn.Module): return x, mask - def encoder_output(self, x: torch.Tensor) -> torch.Tensor: + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: @@ -232,46 +236,16 @@ class Transformer(nn.Module): self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, - supervision: Optional[Supervisions] = None, - L_inv: Optional[k2.Fsa] = None, - word_table: Optional[k2.SymbolTable] = None, - oov_str: Optional[str] = None, - token_ids: List[List[int]] = None, - sos_id: Optional[int] = None, - eos_id: Optional[int] = None, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, ) -> torch.Tensor: """ - Note: - If phone based lexicon is used, the following arguments are required: - - - supervision - - L_inv - - word_table - - oov_str - - If BPE based lexicon is used, the following arguments are required: - - - token_ids - - sos_id - - eos_id - Args: memory: It's the output of the encoder with shape [T, N, C] memory_key_padding_mask: The padding mask from the encoder. - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - L_inv: - It is an FSA with labels being word IDs and aux_labels being - token IDs (e.g., phone IDs or word piece IDs). - word_table: - Word table providing mapping between words and IDs. - oov_str: - The OOV word, e.g., '' token_ids: A list-of-list IDs. Each sublist contains IDs for an utterance. The IDs can be either phone IDs or word piece IDs. @@ -284,29 +258,13 @@ class Transformer(nn.Module): A scalar, the **sum** of label smoothing loss over utterances in the batch without any normalization. """ - if supervision is not None and word_table is not None: - batch_text = get_normal_transcripts( - supervision, word_table, oov_str - ) - ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, - L_inv, - sos_id, - eos_id, - ) - elif token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) - else: - raise ValueError("Invalid input for decoder self attention") + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -316,6 +274,8 @@ class Transformer(nn.Module): device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) @@ -362,19 +322,14 @@ class Transformer(nn.Module): """ # The common part between this function and decoder_forward could be # extracted as a separate function. - if token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) - else: - raise ValueError("Invalid input for decoder self attention") + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -384,6 +339,8 @@ class Transformer(nn.Module): device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) @@ -948,8 +905,8 @@ def decoder_padding_mask( ) -> torch.Tensor: """Generate a length mask for input. - The masked position are filled with bool(True), - Unmasked positions are filled with bool(False). + The masked position are filled with True, + Unmasked positions are filled with False. Args: ys_pad: @@ -965,45 +922,16 @@ def decoder_padding_mask( return ys_mask -def get_normal_transcripts( - supervision: Supervisions, words: k2.SymbolTable, oov: str = "" -) -> List[List[int]]: - """Get normal transcripts (1 input recording has 1 transcript) - from lhotse cut format. - - Achieved by concatenating the transcripts corresponding to the - same recording. - - Args: - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - words: - The word symbol table. - oov: - Out of vocabulary word. - - Returns: - List[List[int]]: List of concatenated transcripts, length is batch_size - """ - - texts = [ - [token if token in words else oov for token in text.split(" ")] - for text in supervision["text"] - ] - texts_ids = [[words[token] for token in text] for text in texts] - - batch_text = [ - [] for _ in range(int(supervision["sequence_idx"].max().item()) + 1) - ] - for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids): - batch_text[sequence_idx] = batch_text[sequence_idx] + text - return batch_text - - def generate_square_subsequent_mask(sz: int) -> torch.Tensor: """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) Args: sz: mask size @@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: return mask -def add_sos_eos( - ys: List[List[int]], - L_inv: k2.Fsa, - sos_id: int, - eos_id: int, - ignore_id: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Add and labels. +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. Args: - ys: - Batch of unpadded target sequences (i.e., word IDs) - L_inv: - Its labels are words, while its aux_labels are tokens. + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. sos_id: - index of + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. eos_id: - index of - ignore_id: - value for padding + The ID of the EOS token. - Returns: - Return a tuple containing two tensors: - - Input of transformer decoder. - Padded tensor of dimension (batch_size, max_length). - - Output of transformer decoder. - Padded tensor of dimension (batch_size, max_length). + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. """ - - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys = get_hierarchical_targets(ys, L_inv) - ys_in = [torch.cat([_sos, y], dim=0) for y in ys] - ys_out = [torch.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id) - - -def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor: - """Perform padding for the list of tensors. - - Args: - ys: - List of tensors. len(ys) = batch_size. - pad_value: - Value for padding. - - Returns: - Tensor: Padded tensor (batch_size, max_length, `*`). - - Examples: - >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) - - """ - n_batch = len(ys) - max_len = max(x.size(0) for x in ys) - pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value) - - for i in range(n_batch): - pad[i, : ys[i].size(0)] = ys[i] - - return pad - - -def get_hierarchical_targets( - ys: List[List[int]], L_inv: Optional[k2.Fsa] = None -) -> List[torch.Tensor]: - """Get hierarchical transcripts (i.e., phone level transcripts) from - transcripts (i.e., word level transcripts). - - Args: - ys: - Word level transcripts. Each sublist is a transcript of an utterance. - L_inv: - Its labels are words, while its aux_labels are tokens. - - Returns: - List[torch.Tensor]: - Token level transcripts. - """ - - if L_inv is None: - return [torch.tensor(y) for y in ys] - - device = L_inv.device - - transcripts = k2.create_fsa_vec( - [k2.linear_fsa(x, device=device) for x in ys] - ) - transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) - - transcripts_lexicon = k2.intersect( - L_inv, transcripts_with_self_loops, treat_epsilons_specially=False - ) - # Don't call invert_() above because we want to return phone IDs, - # which is the `aux_labels` of transcripts_lexicon - transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) - transcripts_lexicon = k2.top_sort(transcripts_lexicon) - - transcripts_lexicon = k2.shortest_path( - transcripts_lexicon, use_double_scores=True - ) - - ys = get_texts(transcripts_lexicon) - ys = [torch.tensor(y) for y in ys] - - return ys + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans