diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py index 0a23f0d9d..1b6991bcd 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -183,11 +183,13 @@ class LibriSpeechAsrDataModule: "--train-manifest", type=str, default="librispeech_cuts_train-clean-100.jsonl.gz", - help="Train manifest file." + help="Train manifest file.", ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -268,11 +270,14 @@ class LibriSpeechAsrDataModule: logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, return_cuts=self.args.return_cuts, + cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( - cuts_valid, max_duration=self.args.max_duration, shuffle=False, + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") @@ -293,11 +298,16 @@ class LibriSpeechAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False, + cuts, + max_duration=self.args.max_duration, + shuffle=False, ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl @@ -311,9 +321,7 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / self.args.train_manifest - ) + return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest) @lru_cache() def train_all_shuf_cuts(self) -> CutSet: diff --git a/egs/librispeech/WSASR/conformer_ctc2/conformer.py b/egs/librispeech/WSASR/conformer_ctc2/conformer.py index 43a9872f8..db4821d37 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/WSASR/conformer_ctc2/conformer.py @@ -92,7 +92,7 @@ class Conformer(Transformer): if self.subsampling_factor == 4: self.encoder_embed = Conv2dSubsampling(num_features, d_model) elif self.subsampling_factor == 2: - self.encoder_embed = Conv2dSubsampling2(num_features, d_model) + self.encoder_embed = Conv2dSubsampling2(num_features, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py index 2d550a520..3fa045533 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -32,19 +32,16 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer -from icefall.otc_graph_compiler import OtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.decode import ( - get_lattice, - one_best_decoding, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.env import get_env_info from icefall.lexicon import Lexicon +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler from icefall.utils import ( AttributeDict, get_texts, @@ -62,7 +59,10 @@ def get_parser(): ) parser.add_argument( - "--otc-token", type=str, default="", help="OTC token", + "--otc-token", + type=str, + default="", + help="OTC token", ) parser.add_argument( @@ -137,11 +137,17 @@ def get_parser(): ) parser.add_argument( - "--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir", + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="The experiment dir", ) parser.add_argument( - "--lang-dir", type=str, default="data/lang_bpe_200", help="The lang dir", + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="The lang dir", ) parser.add_argument( @@ -345,7 +351,11 @@ def decode_one_batch( return {key: hyps} if params.method == "ctc-greedy-search": - hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,) + hyps, _ = ctc_greedy_search( + nnet_output, + memory, + memory_key_padding_mask, + ) # hyps is a list of str, e.g., ['xxx yyy zzz', ...] hyps = bpe_model.decode(hyps) @@ -557,7 +567,11 @@ def main(): if params.method == "ctc-decoding" or params.method == "ctc-greedy-search": HLG = None - H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,) + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: diff --git a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py index f56b32683..2ba802866 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py +++ b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py @@ -120,6 +120,7 @@ class Conv2dSubsampling(torch.nn.Module): x = self.out_balancer(x) return x + class Conv2dSubsampling2(torch.nn.Module): """Convolutional 2D subsampling (to 1/2 length). diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 347de5e7d..7b85b8b89 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -66,24 +66,24 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics -from icefall.otc_graph_compiler import OtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) +from icefall.decode import one_best_decoding from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions_otc, + get_texts, setup_logger, str2bool, - get_texts, ) -from icefall.decode import one_best_decoding LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -94,7 +94,10 @@ def get_parser(): ) parser.add_argument( - "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -112,7 +115,10 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", type=int, default=20, help="Number of epochs to train.", + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", ) parser.add_argument( @@ -255,7 +261,18 @@ def get_parser(): ) parser.add_argument( - "--otc-token", type=str, default="_", help="OTC token", + "--otc-token", + type=str, + default="_", + help="OTC token", + ) + + parser.add_argument( + "--otc-granularity", + type=str, + choices=["word", "subword"], + default="word", + help="OTC granularity", ) parser.add_argument( @@ -374,7 +391,7 @@ def get_params() -> AttributeDict: "log_interval": 1, "reset_interval": 200, "valid_interval": 800, # For the 100h subset, use 800 - "alignment_interval": 100, + "alignment_interval": 25, # parameters for conformer "feature_dim": 768, "subsampling_factor": 2, @@ -585,9 +602,14 @@ def compute_loss( allow_self_loop_arc=params.allow_self_loop_arc, bypass_weight=bypass_weight, self_loop_weight=self_loop_weight, + otc_granularity=params.otc_granularity, ) - dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,) + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=3, + ) otc_loss = k2.ctc_loss( decoding_graph=decoding_graph, @@ -627,18 +649,22 @@ def compute_loss( utt_id = utt_ids[index] lattice = k2.intersect_dense( - decoding_graph, dense_fsa_vec, params.beam_size, + decoding_graph, + dense_fsa_vec, + params.beam_size, ) best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores, + lattice=lattice, + use_double_scores=params.use_double_scores, ) hyp_ids = get_texts(best_path)[index] hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids] - hyp_text = " ".join(hyp_text_list) + hyp_text = "".join(hyp_text_list).replace("▁", " ") logging.info(f"[utterance id]: {utt_id}") logging.info(f"[verbatim text]: {verbatim_text}") logging.info(f"[best alignment]: {hyp_text}") + logging.info(bypass_weight) return loss, info @@ -770,7 +796,9 @@ def train_one_epoch( and params.batch_idx_train % params.average_period == 0 ): update_averaged_model( - params=params, model_cur=model, model_avg=model_avg, + params=params, + model_cur=model, + model_avg=model_avg, ) if ( @@ -790,7 +818,9 @@ def train_one_epoch( rank=rank, ) remove_checkpoints( - out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, ) if batch_idx % params.log_interval == 0: diff --git a/egs/librispeech/WSASR/conformer_ctc2/transformer.py b/egs/librispeech/WSASR/conformer_ctc2/transformer.py index be71a1b49..41e6cd357 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/WSASR/conformer_ctc2/transformer.py @@ -909,7 +909,9 @@ class Noam(object): def encoder_padding_mask( - max_len: int, subsampling_factor: Optional[int] = 4, supervisions: Optional[Supervisions] = None + max_len: int, + subsampling_factor: Optional[int] = 4, + supervisions: Optional[Supervisions] = None, ) -> Optional[torch.Tensor]: """Make mask tensor containing indexes of padded part. diff --git a/egs/librispeech/WSASR/local/compute_ssl_librispeech.py b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py index 47ba4f49c..f405c468c 100755 --- a/egs/librispeech/WSASR/local/compute_ssl_librispeech.py +++ b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py @@ -29,7 +29,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, S3PRLSSL, S3PRLSSLConfig, NumpyFilesWriter +from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -71,9 +71,7 @@ def compute_ssl_librispeech(): dataset_parts, ) - extractor = S3PRLSSL( - S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda") - ) + extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -95,9 +93,7 @@ def compute_ssl_librispeech(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/WSASR/local/get_words_from_lexicon.py b/egs/librispeech/WSASR/local/get_words_from_lexicon.py index fc8e2e7d1..0cc740b36 100755 --- a/egs/librispeech/WSASR/local/get_words_from_lexicon.py +++ b/egs/librispeech/WSASR/local/get_words_from_lexicon.py @@ -18,7 +18,9 @@ def get_args(): ) parser.add_argument( - "--otc-token", type=str, help="OTC token to be added to words.txt", + "--otc-token", + type=str, + help="OTC token to be added to words.txt", ) return parser.parse_args() diff --git a/egs/librispeech/WSASR/local/make_error_cutset.py b/egs/librispeech/WSASR/local/make_error_cutset.py index 50ac41bd3..8463a380e 100755 --- a/egs/librispeech/WSASR/local/make_error_cutset.py +++ b/egs/librispeech/WSASR/local/make_error_cutset.py @@ -7,10 +7,11 @@ import random from pathlib import Path from typing import List -from icefall.utils import str2bool from lhotse import CutSet, load_manifest from lhotse.cut.base import Cut +from icefall.utils import str2bool + def get_args(): parser = argparse.ArgumentParser( @@ -23,23 +24,36 @@ def get_args(): ) parser.add_argument( - "--words-file", type=str, help="words.txt file", + "--words-file", + type=str, + help="words.txt file", ) parser.add_argument( - "--otc-token", type=str, help="OTC token in words.txt", + "--otc-token", + type=str, + help="OTC token in words.txt", ) parser.add_argument( - "--sub-error-rate", type=float, default=0.0, help="Substitution error rate", + "--sub-error-rate", + type=float, + default=0.0, + help="Substitution error rate", ) parser.add_argument( - "--ins-error-rate", type=float, default=0.0, help="Insertion error rate", + "--ins-error-rate", + type=float, + default=0.0, + help="Insertion error rate", ) parser.add_argument( - "--del-error-rate", type=float, default=0.0, help="Deletion error rate", + "--del-error-rate", + type=float, + default=0.0, + help="Deletion error rate", ) parser.add_argument( diff --git a/egs/librispeech/WSASR/local/prepare_lang.py b/egs/librispeech/WSASR/local/prepare_lang.py index 60905133e..d913756a1 100755 --- a/egs/librispeech/WSASR/local/prepare_lang.py +++ b/egs/librispeech/WSASR/local/prepare_lang.py @@ -324,7 +324,9 @@ def lexicon_to_fst( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py index c9e22c426..415bdff6f 100755 --- a/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py +++ b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py @@ -109,7 +109,9 @@ def lexicon_to_fst_no_sil( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -126,7 +128,10 @@ def lexicon_to_fst_no_sil( def generate_otc_lexicon( - model_file: str, words: List[str], oov: str, otc_token: str, + model_file: str, + words: List[str], + oov: str, + otc_token: str, ) -> Tuple[Lexicon, Dict[str, int]]: """Generate a lexicon from a BPE model. @@ -188,7 +193,10 @@ def get_args(): ) parser.add_argument( - "--otc-token", type=str, default="", help="The OTC token in lexicon.", + "--otc-token", + type=str, + default="", + help="The OTC token in lexicon.", ) parser.add_argument( @@ -256,7 +264,9 @@ def main(): write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, token2id=token_sym_table, word2id=word_sym_table, + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( diff --git a/icefall/otc_graph_compiler.py b/icefall/otc_graph_compiler.py index 3b823becb..c7bc79ea0 100644 --- a/icefall/otc_graph_compiler.py +++ b/icefall/otc_graph_compiler.py @@ -38,7 +38,6 @@ class OtcTrainingGraphCompiler(object): initial_self_loop_weight: float = 0.0, bypass_weight_decay: float = 0.0, self_loop_weight_decay: float = 0.0, - ) -> None: """ Args: @@ -93,7 +92,11 @@ class OtcTrainingGraphCompiler(object): return max_token_id def make_arc( - self, from_state: int, to_state: int, symbol: Union[str, int], weight: float, + self, + from_state: int, + to_state: int, + symbol: Union[str, int], + weight: float, ): return f"{from_state} {to_state} {symbol} {weight}" @@ -132,7 +135,7 @@ class OtcTrainingGraphCompiler(object): Whether to add bypass arc to training graph for substitution and insertion errors (wrong or extra words in the transcript). allow_self_loop_arc: - Whether to add self-loop arc to training graph for deletion + Whether to add self-loop arc to training graph for deletion errors (missing words in the transcript). bypass_weight: Weight associated with bypass arc. @@ -140,7 +143,7 @@ class OtcTrainingGraphCompiler(object): Weight associated with self-loop arc. otc_granularity: Use OTC token to model word or subword. - + Return: Return an FsaVec, which is the result of composing a CTC topology with OTC FSAs constructed from the given texts. @@ -161,7 +164,9 @@ class OtcTrainingGraphCompiler(object): fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop) graph = k2.compose( - self.ctc_topo, fsa_with_self_loop, treat_epsilons_specially=False, + self.ctc_topo, + fsa_with_self_loop, + treat_epsilons_specially=False, ) assert graph.requires_grad is False @@ -201,7 +206,10 @@ class OtcTrainingGraphCompiler(object): if allow_self_loop_arc: self_loop_arc = self.make_arc( - cur_state, cur_state, otc_token_id, self_loop_weight, + cur_state, + cur_state, + otc_token_id, + self_loop_weight, ) arcs.append(self_loop_arc) @@ -225,7 +233,10 @@ class OtcTrainingGraphCompiler(object): if allow_self_loop_arc: self_loop_arc = self.make_arc( - cur_state, cur_state, otc_token_id, self_loop_weight, + cur_state, + cur_state, + otc_token_id, + self_loop_weight, ) arcs.append(self_loop_arc) diff --git a/icefall/utils.py b/icefall/utils.py index 2671bb8c5..16148449d 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -262,6 +262,7 @@ def get_texts( else: return aux_labels.tolist() + def encode_supervisions_otc( supervisions: dict, subsampling_factor: int,