diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py index a00664a99..08287d686 100644 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/conformer.py @@ -1,7 +1,20 @@ #!/usr/bin/env python3 - # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings @@ -396,7 +409,7 @@ class RelPositionalEncoding(torch.nn.Module): :, self.pe.size(1) // 2 - x.size(1) - + 1 : self.pe.size(1) // 2 + + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py index c3354c0a3..676e4bf6a 100755 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/decode.py @@ -1,8 +1,20 @@ #!/usr/bin/env python3 - # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# (still working in progress) import argparse import logging @@ -45,28 +57,63 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=9, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=1, + default=20, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + parser.add_argument( "--lattice-score-scale", type=float, default=1.0, - help="The scale to be applied to `lattice.scores`." - "It's needed if you use any kinds of n-best based rescoring. " - "Currently, it is used when the decoding method is: nbest, " - "nbest-rescoring, attention-decoder, and nbest-oracle. " - "A smaller value results in more unique paths.", + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, ) return parser @@ -92,21 +139,6 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - # - attention-decoder - # - nbest-oracle - # "method": "nbest", - # "method": "nbest-rescoring", - # "method": "whole-lattice-rescoring", - "method": "attention-decoder", - # "method": "nbest-oracle", - # num_paths is used when method is "nbest", "nbest-rescoring", - # attention-decoder, and nbest-oracle - "num_paths": 100, } ) return params @@ -117,7 +149,7 @@ def decode_one_batch( model: nn.Module, HLG: k2.Fsa, batch: dict, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -151,8 +183,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - lexicon: - It contains word symbol table. + word_table: + The word symbol table. sos_id: The token ID of the SOS. eos_id: @@ -205,7 +237,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, ref_texts=supervisions["text"], - lexicon=lexicon, + word_table=word_table, scale=params.lattice_score_scale, ) @@ -225,7 +257,7 @@ def decode_one_batch( key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + hyps = [[word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method in [ @@ -271,7 +303,7 @@ def decode_one_batch( ans = dict() for lm_scale_str, best_path in best_path_dict.items(): hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + hyps = [[word_table[i] for i in ids] for ids in hyps] ans[lm_scale_str] = hyps return ans @@ -281,7 +313,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, HLG: k2.Fsa, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -297,8 +329,8 @@ def decode_dataset( The neural model. HLG: The decoding graph. - lexicon: - It contains word symbol table. + word_table: + It is the word symbol table. sos_id: The token ID for SOS. eos_id: @@ -332,7 +364,7 @@ def decode_dataset( model=model, HLG=HLG, batch=batch, - lexicon=lexicon, + word_table=word_table, G=G, sos_id=sos_id, eos_id=eos_id, @@ -528,7 +560,7 @@ def main(): params=params, model=model, HLG=HLG, - lexicon=lexicon, + word_table=lexicon.word_table, G=G, sos_id=sos_id, eos_id=eos_id, diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py deleted file mode 100755 index c63616d28..000000000 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py +++ /dev/null @@ -1,350 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from conformer import Conformer -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_whole_lattice, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - (3) attention-decoder - Extract n paths from he rescored - lattice and use the transformer attention decoder for - rescoring. - We call it HLG decoding + n-gram LM rescoring + attention - decoder rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or attention-decoder. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and attention-decoder. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--attention-decoder-scale", - type=float, - default=1.2, - help=""" - Used only when method is attention-decoder. - It specifies the scale for attention decoder scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--lattice-score-scale", - type=float, - default=0.5, - help=""" - Used only when method is attention-decoder. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sos-id", - type=float, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the SOS token. - """, - ) - - parser.add_argument( - "--eos-id", - type=float, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the EOS token. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "nhead": 8, - "num_classes": 5000, - "sample_rate": 16000, - "attention_dim": 512, - "subsampling_factor": 4, - "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, - "use_feat_batchnorm": True, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info(f"Decoding started") - features = fbank(waves) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output, memory, memory_key_padding_mask = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - scale=params.lattice_score_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info(f"Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py new file mode 120000 index 000000000..cd27e4304 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/pretrained.py @@ -0,0 +1 @@ +../conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py index 5c3e1222e..720ed6c22 100644 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/subsampling.py @@ -1,3 +1,20 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch import torch.nn as nn diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py deleted file mode 100755 index 937845d77..000000000 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_subsampling.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 - -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling -import torch - - -def test_conv2d_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = Conv2dSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim - - -def test_vgg_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = VggSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py deleted file mode 100644 index 08e680607..000000000 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/test_transformer.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from transformer import ( - Transformer, - encoder_padding_mask, - generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, -) - -from torch.nn.utils.rnn import pad_sequence - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py index 795a2ab57..b0dbe72ad 100755 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/train.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# This is just at the very beginning ... import argparse import logging @@ -60,6 +74,23 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-epochs", + type=int, + default=35, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + return parser @@ -89,11 +120,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -124,13 +150,11 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc_embedding_scale/exp"), + "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, - "start_epoch": 0, - "num_epochs": 20, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py index f237ff8e3..74e61b645 100644 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/transformer.py @@ -1,5 +1,19 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import Dict, List, Optional, Tuple @@ -641,7 +655,7 @@ class PositionalEncoding(nn.Module): """ super().__init__() self.d_model = d_model - self.pos_scale = 1. / math.sqrt(self.d_model) + self.pos_scale = 1.0 / math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) self.pe = None @@ -780,7 +794,8 @@ class Noam(object): class LabelSmoothingLoss(nn.Module): """ - Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + Label-smoothing loss. KL-divergence between + q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa @@ -865,7 +880,8 @@ def encoder_padding_mask( frames, before subsampling) Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. """ if supervisions is None: return None