From 6f5d63492a32f4e48b60ae6d71dc529ddefef22d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Sep 2021 19:45:08 +0800 Subject: [PATCH] Refactoring. --- .../ASR/conformer_ctc/transformer.py | 6 +- .../ASR/conformer_mmi/asr_datamodule.py | 354 ++++++++++++++++++ .../ASR/conformer_mmi/conformer.py | 48 ++- egs/librispeech/ASR/conformer_mmi/decode.py | 222 ++++++++--- .../ASR/conformer_mmi/transformer.py | 84 +++-- ... => convert_transcript_words_to_tokens.py} | 23 +- .../ASR/local/generate_unique_lexicon.py | 100 +++++ egs/librispeech/ASR/local/prepare_lang.py | 62 ++- egs/librispeech/ASR/local/prepare_lang_bpe.py | 28 ++ egs/librispeech/ASR/local/train_bpe_model.py | 12 +- egs/librispeech/ASR/prepare.sh | 36 +- icefall/bpe_graph_compiler.py | 6 +- icefall/bpe_mmi_graph_compiler.py | 178 --------- icefall/lexicon.py | 158 +++++--- icefall/mmi_graph_compiler.py | 216 +++++++++++ icefall/utils.py | 89 ++++- test/test_bpe_graph_compiler.py | 9 +- test/test_bpe_mmi_graph_compiler.py | 30 -- test/test_lexicon.py | 173 ++++++--- test/test_mmi_graph_compiler.py | 196 ++++++++++ 20 files changed, 1543 insertions(+), 487 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_mmi/asr_datamodule.py rename egs/librispeech/ASR/local/{convert_transcript_to_corpus.py => convert_transcript_words_to_tokens.py} (83%) create mode 100755 egs/librispeech/ASR/local/generate_unique_lexicon.py delete mode 100644 icefall/bpe_mmi_graph_compiler.py create mode 100644 icefall/mmi_graph_compiler.py delete mode 100644 test/test_bpe_mmi_graph_compiler.py mode change 100644 => 100755 test/test_lexicon.py create mode 100755 test/test_mmi_graph_compiler.py diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index f1d7cbbbc..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -114,7 +114,10 @@ class Transformer(nn.Module): norm=encoder_norm, ) - self.encoder_output_layer = nn.Linear(d_model, num_classes) + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) if num_decoder_layers > 0: self.decoder_num_class = ( @@ -325,6 +328,7 @@ class Transformer(nn.Module): """ # The common part between this function and decoder_forward could be # extracted as a separate function. + ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) diff --git a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py new file mode 100644 index 000000000..8290e71d1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py @@ -0,0 +1,354 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class LibriSpeechAsrDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") + + logging.info("About to create train dataset") + transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [ + SpecAugment( + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ] + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = SingleCutSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = SingleCutSampler( + cuts_test, max_duration=self.args.max_duration + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, batch_size=None, sampler=sampler, num_workers=1 + ) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index ac49b7b1c..b19b94db1 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -1,7 +1,20 @@ #!/usr/bin/env python3 - # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings @@ -43,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - is_espnet_structure: bool = False, use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( @@ -70,12 +82,10 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, - is_espnet_structure, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: self.after_norm = nn.LayerNorm(d_model) else: # Note: TorchScript detects that self.after_norm could be used inside forward() @@ -88,7 +98,7 @@ class Conformer(Transformer): """ Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -110,7 +120,7 @@ class Conformer(Transformer): mask = mask.to(x.device) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: x = self.after_norm(x) return x, mask @@ -144,11 +154,10 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( @@ -394,7 +403,7 @@ class RelPositionalEncoding(torch.nn.Module): :, self.pe.size(1) // 2 - x.size(1) - + 1 : self.pe.size(1) // 2 + + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) @@ -421,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -444,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -675,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -770,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 6030d13e1..dc2e449c2 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -1,8 +1,20 @@ #!/usr/bin/env python3 - # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# (still working in progress) import argparse import logging @@ -13,14 +25,15 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, + nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, @@ -32,6 +45,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -44,51 +58,111 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=9, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=1, + default=20, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_mmi/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, "attention_dim": 512, - "subsampling_factor": 4, "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "use_feat_batchnorm": True, + # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - # - attention-decoder - # "method": "whole-lattice-rescoring", - "method": "1best", - # num_paths is used when method is "nbest", "nbest-rescoring", - # and attention-decoder - "num_paths": 100, } ) return params @@ -99,7 +173,7 @@ def decode_one_batch( model: nn.Module, HLG: k2.Fsa, batch: dict, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -133,8 +207,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - lexicon: - It contains word symbol table. + word_table: + The word symbol table. sos_id: The token ID of the SOS. eos_id: @@ -151,12 +225,12 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervision_segments = torch.stack( ( @@ -178,6 +252,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + lattice_score_scale=params.lattice_score_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + return {key: hyps} + if params.method in ["1best", "nbest"]: if params.method == "1best": best_path = one_best_decoding( @@ -189,11 +281,12 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + lattice_score_scale=params.lattice_score_scale, ) - key = f"no_rescore-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + hyps = [[word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method in [ @@ -202,7 +295,8 @@ def decode_one_batch( "attention-decoder", ] - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": @@ -211,16 +305,23 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, ) elif params.method == "attention-decoder": # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, @@ -230,15 +331,20 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, + lattice_score_scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + for lm_scale in lm_scale_list: + ans[lm_scale_str] = [[] * lattice.shape[0]] return ans @@ -247,7 +353,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, HLG: k2.Fsa, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -263,8 +369,8 @@ def decode_dataset( The neural model. HLG: The decoding graph. - lexicon: - It contains word symbol table. + word_table: + It is the word symbol table. sos_id: The token ID for SOS. eos_id: @@ -283,7 +389,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -294,7 +404,7 @@ def decode_dataset( model=model, HLG=HLG, batch=batch, - lexicon=lexicon, + word_table=word_table, G=G, sos_id=sos_id, eos_id=eos_id, @@ -312,10 +422,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results @@ -374,8 +484,10 @@ def main(): params = get_params() params.update(vars(args)) + params.exp_dir = Path(params.exp_dir) + params.lang_dir = Path(params.lang_dir) - setup_logger(f"{params.exp_dir}/log/log-decode") + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") logging.info(params) @@ -389,7 +501,7 @@ def main(): logging.info(f"device: {device}") - graph_compiler = BpeMmiTrainingGraphCompiler( + graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, device=device, sos_token="", @@ -398,7 +510,9 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -429,7 +543,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: @@ -453,7 +567,6 @@ def main(): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, - is_espnet_structure=params.is_espnet_structure, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -468,6 +581,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) @@ -487,7 +607,7 @@ def main(): params=params, model=model, HLG=HLG, - lexicon=lexicon, + word_table=lexicon.word_table, G=G, sos_id=sos_id, eos_id=eos_id, diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index fd1a082e7..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -1,15 +1,26 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import Dict, List, Optional, Tuple -import k2 import torch import torch.nn as nn from subsampling import Conv2dSubsampling, VggSubsampling - -from icefall.utils import get_texts from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. @@ -72,8 +83,8 @@ class Transformer(nn.Module): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model @@ -103,10 +114,15 @@ class Transformer(nn.Module): norm=encoder_norm, ) - self.encoder_output_layer = nn.Linear(d_model, num_classes) + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol self.decoder_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model @@ -146,7 +162,7 @@ class Transformer(nn.Module): """ Args: x: - The input tensor. Its shape is [N, T, C]. + The input tensor. Its shape is (N, T, C). supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -155,17 +171,17 @@ class Transformer(nn.Module): Returns: Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and value for the decoder. - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) @@ -179,7 +195,7 @@ class Transformer(nn.Module): Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -190,8 +206,8 @@ class Transformer(nn.Module): padding mask for the decoder. Returns: Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ @@ -209,11 +225,11 @@ class Transformer(nn.Module): Args: x: The output tensor from the transformer encoder. - Its shape is [T, N, C] + Its shape is (T, N, C) Returns: Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] + Its shape is (N, T, C) """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -231,7 +247,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -296,7 +312,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -312,6 +328,7 @@ class Transformer(nn.Module): """ # The common part between this function and decoder_forward could be # extracted as a separate function. + ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) @@ -329,6 +346,9 @@ class Transformer(nn.Module): ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. tgt_key_padding_mask[:, 0] = False tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) @@ -634,13 +654,13 @@ class PositionalEncoding(nn.Module): def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. Args: x: - It is a tensor of shape [N, T, C]. + It is a tensor of shape (N, T, C). Returns: Return None. """ @@ -658,7 +678,7 @@ class PositionalEncoding(nn.Module): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) + # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -667,10 +687,10 @@ class PositionalEncoding(nn.Module): Args: x: - Its shape is [N, T, C] + Its shape is (N, T, C) Returns: - Return a tensor of shape [N, T, C] + Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] @@ -766,7 +786,8 @@ class Noam(object): class LabelSmoothingLoss(nn.Module): """ - Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + Label-smoothing loss. KL-divergence between + q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa @@ -851,7 +872,8 @@ def encoder_padding_mask( frames, before subsampling) Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. """ if supervisions is None: return None diff --git a/egs/librispeech/ASR/local/convert_transcript_to_corpus.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py similarity index 83% rename from egs/librispeech/ASR/local/convert_transcript_to_corpus.py rename to egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index bb02dac58..133499c8b 100755 --- a/egs/librispeech/ASR/local/convert_transcript_to_corpus.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -8,8 +8,8 @@ for LM training with the help of a lexicon. If the lexicon contains phones, the resulting LM will be a phone LM; If the lexicon contains word pieces, the resulting LM will be a word piece LM. -If a word has multiple pronunciations, the one that appears last in the lexicon -is used. +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. If the input transcript is: @@ -20,8 +20,8 @@ If the input transcript is: and if the lexicon is SPN - hello h e l l o hello h e l l o 2 + hello h e l l o world w o r l d zoo z o o @@ -32,10 +32,11 @@ Then the output is SPN z o o w o r l d SPN """ -from pathlib import Path -from typing import Dict - import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications from icefall.lexicon import read_lexicon @@ -57,7 +58,9 @@ def get_args(): return parser.parse_args() -def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: @@ -86,7 +89,11 @@ def main(): assert Path(args.transcript).is_file() assert len(args.oov) > 0 - lexicon = dict(read_lexicon(args.lexicon)) + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + assert args.oov in lexicon oov_token = lexicon[args.oov] diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py new file mode 100755 index 000000000..566c0743d --- /dev/null +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index 0880019b3..d913756a1 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following: 5. Generate L_disambig.pt, in k2 format. """ +import argparse import math from collections import defaultdict from pathlib import Path @@ -42,10 +43,37 @@ import k2 import torch from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool Lexicon = List[Tuple[str, List[str]]] +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: """Write a symbol to ID mapping to a file. @@ -315,8 +343,9 @@ def lexicon_to_fst( def main(): - out_dir = Path("data/lang_phone") - lexicon_filename = out_dir / "lexicon.txt" + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 @@ -344,9 +373,9 @@ def main(): token2id = generate_id_map(tokens) word2id = generate_id_map(words) - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst( lexicon, @@ -364,17 +393,20 @@ def main(): sil_prob=sil_prob, need_self_loops=True, ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index 39d347661..cf32f308d 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -49,6 +49,8 @@ from prepare_lang import ( write_mapping, ) +from icefall.utils import str2bool + def lexicon_to_fst_no_sil( lexicon: Lexicon, @@ -169,6 +171,20 @@ def get_args(): """, ) + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + return parser.parse_args() @@ -221,6 +237,18 @@ def main(): torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 3c3ecdcae..bc5812810 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + # You can install sentencepiece via: # # pip install sentencepiece @@ -37,10 +38,17 @@ def get_args(): "--lang-dir", type=str, help="""Input and output directory. - It should contain the training corpus: train.txt. + It should contain the training corpus: transcript_words.txt. The generated bpe.model is saved to this directory. """, ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + parser.add_argument( "--vocab-size", type=int, @@ -58,7 +66,7 @@ def main(): model_type = "unigram" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = f"{lang_dir}/train.txt" + train_text = args.transcript character_coverage = 1.0 input_sentence_size = 100000000 diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 564f0d067..1965dc491 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -40,9 +40,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -116,14 +116,15 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" - mkdir -p data/lang_phone + lang_dir=data/lang_phone + mkdir -p $lang_dir (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | cat - $dl_dir/lm/librispeech-lexicon.txt | - sort | uniq > data/lang_phone/lexicon.txt + sort | uniq > $lang_dir/lexicon.txt - if [ ! -f data/lang_phone/L_disambig.pt ]; then - ./local/prepare_lang.py + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir fi fi @@ -138,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then # so that the two can share G.pt later. cp data/lang_phone/words.txt $lang_dir - if [ ! -f $lang_dir/train.txt ]; then + if [ ! -f $lang_dir/transcript_words.txt ]; then log "Generate data for BPE training" files=$( find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" @@ -147,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then ) for f in ${files[@]}; do cat $f | cut -d " " -f 2- - done > $lang_dir/train.txt + done > $lang_dir/transcript_words.txt fi ./local/train_bpe_model.py \ --lang-dir $lang_dir \ - --vocab-size $vocab_size + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang_bpe.py --lang-dir $lang_dir @@ -166,18 +168,18 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - if [ ! -f $lang_dir/corpus.txt ]; then - ./local/convert_transcript_to_corpus.py \ - --lexicon data/lang_bpe/lexicon.txt \ - --transcript data/lang_bpe/train.txt \ + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ --oov "" \ - > $lang_dir/corpus.txt + > $lang_dir/transcript_tokens.txt fi if [ ! -f $lang_dir/P.arpa ]; then ./shared/make_kn_lm.py \ -ngram-order 2 \ - -text $lang_dir/corpus.txt \ + -text $lang_dir/transcript_tokens.txt \ -lm $lang_dir/P.arpa fi @@ -226,4 +228,4 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then done fi -cd data && ln -sfv lang_bpe_5000 lang_bpe +cd data && ln -sfv lang_bpe_500 lang_bpe diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py index 813b15f76..e76b7ea32 100644 --- a/icefall/bpe_graph_compiler.py +++ b/icefall/bpe_graph_compiler.py @@ -34,14 +34,10 @@ class BpeCtcTrainingGraphCompiler(object): """ Args: lang_dir: - This directory is expected to contain the following files:: + This directory is expected to contain the following files: - bpe.model - words.txt - - The above files are produced by the script `prepare.sh`. You - should have run that before running the training code. - device: It indicates CPU or CUDA. sos_token: diff --git a/icefall/bpe_mmi_graph_compiler.py b/icefall/bpe_mmi_graph_compiler.py deleted file mode 100644 index 83bc9846f..000000000 --- a/icefall/bpe_mmi_graph_compiler.py +++ /dev/null @@ -1,178 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Tuple, Union - -import k2 -import sentencepiece as spm -import torch - -from icefall.lexicon import Lexicon - - -class BpeMmiTrainingGraphCompiler(object): - def __init__( - self, - lang_dir: Path, - device: Union[str, torch.device] = "cpu", - sos_token: str = "", - eos_token: str = "", - ) -> None: - """ - Args: - lang_dir: - Path to the lang directory. It is expected to contain the - following files:: - - - tokens.txt - - words.txt - - bpe.model - - P.fst.txt - - The above files are generated by the script `prepare.sh`. You - should have run it before running the training code. - - device: - It indicates CPU or CUDA. - sos_token: - The word piece that represents sos. - eos_token: - The word piece that represents eos. - """ - self.lang_dir = Path(lang_dir) - self.lexicon = Lexicon(lang_dir) - self.device = device - self.load_sentence_piece_model() - self.build_ctc_topo_P() - - self.sos_id = self.sp.piece_to_id(sos_token) - self.eos_id = self.sp.piece_to_id(eos_token) - - assert self.sos_id != self.sp.unk_id() - assert self.eos_id != self.sp.unk_id() - - def load_sentence_piece_model(self) -> None: - """Load the pre-trained sentencepiece model - from self.lang_dir/bpe.model. - """ - model_file = self.lang_dir / "bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - self.sp = sp - - def build_ctc_topo_P(self): - """Built ctc_topo_P, the composition result of - ctc_topo and P, where P is a pre-trained bigram - word piece LM. - """ - # Note: there is no need to save a pre-compiled P and ctc_topo - # as it is very fast to generate them. - logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}") - with open(self.lang_dir / "P.fst.txt") as f: - # P is not an acceptor because there is - # a back-off state, whose incoming arcs - # have label #0 and aux_label 0 (i.e., ). - P = k2.Fsa.from_openfst(f.read(), acceptor=False) - - first_token_disambig_id = self.lexicon.token_table["#0"] - - # P.aux_labels is not needed in later computations, so - # remove it here. - del P.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - P.labels[P.labels >= first_token_disambig_id] = 0 - - P = k2.remove_epsilon(P) - P = k2.arc_sort(P) - P = P.to(self.device) - # Add epsilon self-loops to P because we want the - # following operation "k2.intersect" to run on GPU. - P_with_self_loops = k2.add_epsilon_self_loops(P) - - max_token_id = max(self.lexicon.tokens) - logging.info( - f"Building modified ctc_topo. max_token_id: {max_token_id}" - ) - # CAUTION: We have to use a modifed version of CTC topo. - # Otherwise, the resulting ctc_topo_P is so large that it gets - # stuck in k2.intersect_dense_pruned() or it gets OOM in - # k2.intersect_dense() - ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device) - - ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) - - logging.info("Building ctc_topo_P") - ctc_topo_P = k2.intersect( - ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False - ).invert() - - self.ctc_topo_P = k2.arc_sort(ctc_topo_P) - - def texts_to_ids(self, texts: List[str]) -> List[List[int]]: - """Convert a list of texts to a list-of-list of piece IDs. - - Args: - texts: - A list of transcripts. Within a transcript words are - separated by spaces. An example input is:: - - ['HELLO ICEFALL', 'HELLO k2'] - Returns: - Return a list-of-list of piece IDs. - """ - return self.sp.encode(texts, out_type=int) - - def compile( - self, texts: List[str], replicate_den: bool = True - ) -> Tuple[k2.Fsa, k2.Fsa]: - """Create numerator and denominator graphs from transcripts. - - Args: - texts: - A list of transcripts. Within a transcript words are - separated by spaces. An example input is:: - - ["HELLO icefall", "HALLO WELT"] - - replicate_den: - If True, the returned den_graph is replicated to match the number - of FSAs in the returned num_graph; if False, the returned den_graph - contains only a single FSA - Returns: - A tuple (num_graphs, den_graphs), where - - - `num_graphs` is the numerator graph. It is an FsaVec with - shape `(len(texts), None, None)`. - - - `den_graphs` is the denominator graph. It is an FsaVec with the - same shape of the `num_graph` if replicate_den is True; - otherwise, it is an FsaVec containing only a single FSA. - """ - token_ids = self.texts_to_ids(texts) - token_fsas = k2.linear_fsa(token_ids, device=self.device) - - token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas) - - # NOTE: Use treat_epsilons_specially=False so that k2.compose - # can be run on GPU - num_graphs = k2.compose( - self.ctc_topo_P, - token_fsas_with_self_loops, - treat_epsilons_specially=False, - ) - # num_graphs may not be connected and - # not be topologically sorted after k2.compose - num_graphs = k2.connect(num_graphs) - num_graphs = k2.top_sort(num_graphs) - - ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()]) - if replicate_den: - indexes = torch.zeros( - len(texts), dtype=torch.int32, device=self.device - ) - den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes) - else: - den_graphs = ctc_topo_P_vec - - return num_graphs, den_graphs diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 1378d79fb..80bd7c1ee 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: f.write(f"{word} {' '.join(tokens)}\n") +def convert_lexicon_to_ragged( + filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable +) -> k2.RaggedTensor: + """Read a lexicon and convert it to a ragged tensor. + + The ragged tensor has two axes: [word][token]. + + Caution: + We assume that each word has a unique pronunciation. + + Args: + filename: + Filename of the lexicon. It has a format that can be read + by :func:`read_lexicon`. + word_table: + The word symbol table. + token_table: + The token symbol table. + Returns: + A k2 ragged tensor with two axes [word][token]. + """ + disambig_id = word_table["#0"] + # We reuse the same words.txt from the phone based lexicon + # so that we can share the same G.fst. Here, we have to + # exclude some words present only in the phone based lexicon. + excluded_words = ["", "!SIL", ""] + + # epsilon is not a word, but it occupies a position + # + row_splits = [0] + token_ids_list = [] + + lexicon_tmp = read_lexicon(filename) + lexicon = dict(lexicon_tmp) + if len(lexicon_tmp) != len(lexicon): + raise RuntimeError( + "It's assumed that each word has a unique pronunciation" + ) + + for i in range(disambig_id): + w = word_table[i] + if w in excluded_words: + row_splits.append(row_splits[-1]) + continue + tokens = lexicon[w] + token_ids = [token_table[k] for k in tokens] + + row_splits.append(row_splits[-1] + len(token_ids)) + token_ids_list.extend(token_ids) + + cached_tot_size = row_splits[-1] + row_splits = torch.tensor(row_splits, dtype=torch.int32) + + shape = k2.ragged.create_ragged_shape2( + row_splits, + None, + cached_tot_size, + ) + values = torch.tensor(token_ids_list, dtype=torch.int32) + + return k2.RaggedTensor(shape, values) + + class Lexicon(object): """Phone based lexicon.""" @@ -96,12 +159,10 @@ class Lexicon(object): Args: lang_dir: Path to the lang directory. It is expected to contain the following - files:: - + files: - tokens.txt - words.txt - L.pt - The above files are produced by the script `prepare.sh`. You should have run that before running the training code. disambig_pattern: @@ -121,7 +182,7 @@ class Lexicon(object): torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") # We save L_inv instead of L because it will be used to intersect with - # transcript, both of whose labels are word IDs. + # transcript FSAs, both of whose labels are word IDs. self.L_inv = L_inv self.disambig_pattern = disambig_pattern @@ -144,69 +205,66 @@ class Lexicon(object): return ans -class BpeLexicon(Lexicon): +class UniqLexicon(Lexicon): def __init__( self, lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. + + uniq_filename: It is assumed to be inside the given `lang_dir`. + + Each word in the lexicon is assumed to have a unique pronunciation. """ + lang_dir = Path(lang_dir) super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) - self.ragged_lexicon = self.convert_lexicon_to_ragged( - lang_dir / "lexicon.txt" + self.ragged_lexicon = convert_lexicon_to_ragged( + filename=lang_dir / uniq_filename, + word_table=self.word_table, + token_table=self.token_table, ) + # TODO: should we move it to a certain device ? - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: - """Read a BPE lexicon from file and convert it to a - k2 ragged tensor. - - Args: - filename: - Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt - Returns: - A k2 ragged tensor with two axes [word_id] + def texts_to_token_ids( + self, texts: List[str], oov: str = "" + ) -> k2.RaggedTensor: """ - disambig_id = self.word_table["#0"] - # We reuse the same words.txt from the phone based lexicon - # so that we can share the same G.fst. Here, we have to - # exclude some words present only in the phone based lexicon. - excluded_words = ["", "!SIL", ""] + Args: + texts: + A list of transcripts. Each transcript contains space(s) + separated words. An example texts is:: - # epsilon is not a word, but it occupies on position - # - row_splits = [0] - token_ids = [] + ['HELLO k2', 'HELLO icefall'] + oov: + The OOV word. If a word in `texts` is not in the lexicon, it is + replaced with `oov`. + Returns: + Return a ragged int tensor with 2 axes [utterance][token_id] + """ + oov_id = self.word_table[oov] - lexicon = read_lexicon(filename) - lexicon = dict(lexicon) + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(oov_id) + word_ids_list.append(word_ids) + ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32) + ans = self.ragged_lexicon.index(ragged_indexes) + ans = ans.remove_axis(ans.num_axes - 2) + return ans - for i in range(disambig_id): - w = self.word_table[i] - if w in excluded_words: - row_splits.append(row_splits[-1]) - continue - pieces = lexicon[w] - piece_ids = [self.token_table[k] for k in pieces] + def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor: + """Convert a list of words to a ragged tensor containing token IDs. - row_splits.append(row_splits[-1] + len(piece_ids)) - token_ids.extend(piece_ids) - - cached_tot_size = row_splits[-1] - row_splits = torch.tensor(row_splits, dtype=torch.int32) - - shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=cached_tot_size - ) - values = torch.tensor(token_ids, dtype=torch.int32) - - return k2.RaggedTensor(shape, values) - - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: - """Convert a list of words to a ragged tensor contained - word piece IDs. + We assume there are no OOVs in "words". """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py new file mode 100644 index 000000000..43f2a092a --- /dev/null +++ b/icefall/mmi_graph_compiler.py @@ -0,0 +1,216 @@ +import logging +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +import k2 +import torch + +from icefall.lexicon import UniqLexicon + + +class MmiTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", + device: Union[str, torch.device] = "cpu", + oov: str = "", + ): + """ + Args: + lang_dir: + Path to the lang directory. It is expected to contain the + following files:: + + - tokens.txt + - words.txt + - P.fst.txt + + The above files are generated by the script `prepare.sh`. You + should have run it before running the training code. + uniq_filename: + File name to the lexicon in which every word has exactly one + pronunciation. We assume this file is inside the given `lang_dir`. + + device: + It indicates CPU or CUDA. + oov: + Out of vocabulary word. When a word in the transcript + does not exist in the lexicon, it is replaced with `oov`. + """ + self.lang_dir = Path(lang_dir) + self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename) + self.device = torch.device(device) + + self.L_inv = self.lexicon.L_inv.to(self.device) + + self.oov_id = self.lexicon.word_table[oov] + + self.build_ctc_topo_P() + + def build_ctc_topo_P(self): + """Built ctc_topo_P, the composition result of + ctc_topo and P, where P is a pre-trained bigram + word piece LM. + """ + # Note: there is no need to save a pre-compiled P and ctc_topo + # as it is very fast to generate them. + logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}") + with open(self.lang_dir / "P.fst.txt") as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label 0 (i.e., ). + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + + first_token_disambig_id = self.lexicon.token_table["#0"] + + # P.aux_labels is not needed in later computations, so + # remove it here. + del P.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + P.labels[P.labels >= first_token_disambig_id] = 0 + + P = k2.remove_epsilon(P) + P = k2.arc_sort(P) + P = P.to(self.device) + # Add epsilon self-loops to P because we want the + # following operation "k2.intersect" to run on GPU. + P_with_self_loops = k2.add_epsilon_self_loops(P) + + max_token_id = max(self.lexicon.tokens) + logging.info( + f"Building ctc_topo (modified=False). max_token_id: {max_token_id}" + ) + ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device) + + ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + + logging.info("Building ctc_topo_P") + ctc_topo_P = k2.intersect( + ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False + ).invert() + + self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + + def compile( + self, texts: Iterable[str], replicate_den: bool = True + ) -> Tuple[k2.Fsa, k2.Fsa]: + """Create numerator and denominator graphs from transcripts + and the bigram phone LM. + + Args: + texts: + A list of transcripts. Within a transcript, words are + separated by spaces. An example `texts` is given below:: + + ["Hello icefall", "LF-MMI training with icefall using k2"] + + replicate_den: + If True, the returned den_graph is replicated to match the number + of FSAs in the returned num_graph; if False, the returned den_graph + contains only a single FSA + Returns: + A tuple (num_graph, den_graph), where + + - `num_graph` is the numerator graph. It is an FsaVec with + shape `(len(texts), None, None)`. + + - `den_graph` is the denominator graph. It is an FsaVec + with the same shape of the `num_graph` if replicate_den is + True; otherwise, it is an FsaVec containing only a single FSA. + """ + transcript_fsa = self.build_transcript_fsa(texts) + + # remove word IDs from transcript_fsa since it is not needed + del transcript_fsa.aux_labels + # NOTE: You can comment out the above statement + # if you want to run test/test_mmi_graph_compiler.py + + transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( + transcript_fsa + ) + + transcript_fsa_with_self_loops = k2.arc_sort( + transcript_fsa_with_self_loops + ) + + num = k2.compose( + self.ctc_topo_P, + transcript_fsa_with_self_loops, + treat_epsilons_specially=False, + ) + + # CAUTION: Due to the presence of P, + # the resulting `num` may not be connected + num = k2.connect(num) + + num = k2.arc_sort(num) + + ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) + if replicate_den: + indexes = torch.zeros( + len(texts), dtype=torch.int32, device=self.device + ) + den = k2.index_fsa(ctc_topo_P_vec, indexes) + else: + den = ctc_topo_P_vec + + return num, den + + def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa: + """Convert transcripts to an FsaVec with the help of a lexicon + and word symbol table. + + Args: + texts: + Each element is a transcript containing words separated by space(s). + For instance, it may be 'HELLO icefall', which contains + two words. + + Returns: + Return an FST (FsaVec) corresponding to the transcript. + Its `labels` is token IDs and `aux_labels` is word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.lexicon.word_table: + word_ids.append(self.lexicon.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + + fsa = k2.linear_fsa(word_ids_list, self.device) + fsa = k2.add_epsilon_self_loops(fsa) + + # The reason to use `invert_()` at the end is as follows: + # + # (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs + # (2) `fsa.labels` is word IDs + # (3) after intersection, the `labels` is still word IDs + # (4) after `invert_()`, the `labels` is token IDs + # and `aux_labels` is word IDs + transcript_fsa = k2.intersect( + self.L_inv, fsa, treat_epsilons_specially=False + ).invert_() + transcript_fsa = k2.arc_sort(transcript_fsa) + return transcript_fsa + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + We assume it contains no OOVs. Otherwise, it will raise an + exception. + Returns: + Return a list-of-list of token IDs. + """ + return self.lexicon.texts_to_token_ids(texts).tolist() diff --git a/icefall/utils.py b/icefall/utils.py index 23b4dd6c7..1c4dceb0b 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,14 +19,16 @@ import argparse import logging import os import subprocess +import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 import kaldialign +import lhotse import torch import torch.distributed as dist @@ -132,17 +134,82 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_env_info(): - """ - TODO: - """ +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), } diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index e58c4f1c6..6c9073c4c 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -19,20 +19,21 @@ from pathlib import Path from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.lexicon import BpeLexicon +from icefall.lexicon import UniqLexicon + +ICEFALL_DIR = Path(__file__).resolve().parent.parent def test(): - lang_dir = Path("data/lang/bpe") + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe" if not lang_dir.is_dir(): return - # TODO: generate data for testing compiler = BpeCtcTrainingGraphCompiler(lang_dir) ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) compiler.compile(ids) - lexicon = BpeLexicon(lang_dir) + lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt") ids0 = lexicon.words_to_piece_ids(["HELLO"]) assert ids[0] == ids0.values().tolist() diff --git a/test/test_bpe_mmi_graph_compiler.py b/test/test_bpe_mmi_graph_compiler.py deleted file mode 100644 index c6009d69b..000000000 --- a/test/test_bpe_mmi_graph_compiler.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python3 - -import copy -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler - - -def test_bpe_mmi_graph_compiler(): - lang_dir = Path("data/lang_bpe") - if lang_dir.is_dir() is False: - return - device = torch.device("cpu") - compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device) - - texts = ["HELLO WORLD", "MMI TRAINING"] - - num_graphs, den_graphs = compiler.compile(texts) - num_graphs.labels_sym = compiler.lexicon.token_table - num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table) - num_graphs.aux_labels_sym._id2sym[0] = "" - num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD") - num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD") - print(den_graphs.shape) - print(den_graphs[0].shape) - print(den_graphs[0].num_arcs) diff --git a/test/test_lexicon.py b/test/test_lexicon.py old mode 100644 new mode 100755 index 6801b3a89..2a16db226 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -14,80 +14,135 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +You can run this file in one of the two ways: + + (1) cd icefall; pytest test/test_lexicon.py + (2) cd icefall; ./test/test_lexicon.py +""" +import os +import shutil +import sys from pathlib import Path import k2 -import pytest -import torch +import sentencepiece as spm -from icefall.lexicon import BpeLexicon, Lexicon +from icefall.lexicon import UniqLexicon + +TMP_DIR = "/tmp/icefall-test-lexicon" +USING_PYTEST = "pytest" in sys.modules +ICEFALL_DIR = Path(__file__).resolve().parent.parent -@pytest.fixture -def lang_dir(tmp_path): - phone2id = """ - 0 - a 1 - b 2 - f 3 - o 4 - r 5 - z 6 - SPN 7 - #0 8 - """ - word2id = """ - 0 - foo 1 - bar 2 - baz 3 - 4 - #0 5 +def generate_test_data(): + Path(TMP_DIR).mkdir(exist_ok=True) + sentences = """ +cat tac cat cat +at +tac at ta at at +at cat ct ct ta +cat cat cat cat +at at at at at at at """ - L = k2.Fsa.from_str( - """ - 0 0 7 4 0 - 0 7 -1 -1 0 - 0 1 3 1 0 - 0 3 2 2 0 - 0 5 2 3 0 - 1 2 4 0 0 - 2 0 4 0 0 - 3 4 1 0 0 - 4 0 5 0 0 - 5 6 1 0 0 - 6 0 6 0 0 - 7 - """, - num_aux_labels=1, + transcript = Path(TMP_DIR) / "transcript_words.txt" + with open(transcript, "w") as f: + for line in sentences.strip().split("\n"): + f.write(f"{line}\n") + + words = """ + 0 + 1 +at 2 +cat 3 +ct 4 +ta 5 +tac 6 +#0 7 + 8 + 9 +""" + word_txt = Path(TMP_DIR) / "words.txt" + with open(word_txt, "w") as f: + for line in words.strip().split("\n"): + f.write(f"{line}\n") + + vocab_size = 8 + + os.system( + f""" +cd {ICEFALL_DIR}/egs/librispeech/ASR + +./local/train_bpe_model.py \ + --lang-dir {TMP_DIR} \ + --vocab-size {vocab_size} \ + --transcript {transcript} + +./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1 +""" ) - with open(tmp_path / "tokens.txt", "w") as f: - f.write(phone2id) - with open(tmp_path / "words.txt", "w") as f: - f.write(word2id) - torch.save(L.as_dict(), tmp_path / "L.pt") - - return tmp_path +def delete_test_data(): + shutil.rmtree(TMP_DIR) -def test_lexicon(lang_dir): - lexicon = Lexicon(lang_dir) - assert lexicon.tokens == list(range(1, 8)) +def uniq_lexicon_test(): + lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt") + + # case 1: No OOV + texts = ["cat cat", "at ct", "at tac cat"] + token_ids = lexicon.texts_to_token_ids(texts) + + sp = spm.SentencePieceProcessor() + sp.load(f"{TMP_DIR}/bpe.model") + + expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int) + assert token_ids.tolist() == expected_token_ids + + # case 2: With OOV + texts = ["ca"] + token_ids = lexicon.texts_to_token_ids(texts) + expected_token_ids = sp.encode(texts, out_type=int) + assert token_ids.tolist() != expected_token_ids + # Note: sentencepiece breaks "ca" into "_ c a" + # But there is no word "ca" in the lexicon, so our + # implementation returns the id of "" + print(token_ids, expected_token_ids) + assert token_ids.tolist() == [[sp.unk_id()]] + + # case 3: With OOV + texts = ["foo"] + token_ids = lexicon.texts_to_token_ids(texts) + expected_token_ids = sp.encode(texts, out_type=int) + print(token_ids) + print(expected_token_ids) + + # test ragged lexicon + ragged_lexicon = lexicon.ragged_lexicon.tolist() + word_disambig_id = lexicon.word_table["#0"] + for i in range(2, word_disambig_id): + piece_id = ragged_lexicon[i] + word = lexicon.word_table[i] + assert word == sp.decode(piece_id) + assert piece_id == sp.encode(word) -def test_bpe_lexicon(): - lang_dir = Path("data/lang/bpe") - if not lang_dir.is_dir(): - return - # TODO: Generate test data for BpeLexicon +def test_main(): + generate_test_data() - lexicon = BpeLexicon(lang_dir) - words = ["", "HELLO", "ZZZZ", "WORLD"] - ids = lexicon.words_to_piece_ids(words) - print(ids) - print([lexicon.token_table[i] for i in ids.values().tolist()]) + uniq_lexicon_test() + + if USING_PYTEST: + delete_test_data() + + +def main(): + test_main() + + +if __name__ == "__main__" and not USING_PYTEST: + main() diff --git a/test/test_mmi_graph_compiler.py b/test/test_mmi_graph_compiler.py new file mode 100755 index 000000000..653c57b59 --- /dev/null +++ b/test/test_mmi_graph_compiler.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +You can run this file in one of the two ways: + + (1) cd icefall; pytest test/test_mmi_graph_compiler.py + (2) cd icefall; ./test/test_mmi_graph_compiler.py +""" + +import copy +import os +import shutil +import sys +from pathlib import Path + +import k2 +import sentencepiece as spm + +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler + +TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler" +USING_PYTEST = "pytest" in sys.modules +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def generate_test_data(): + Path(TMP_DIR).mkdir(exist_ok=True) + sentences = """ +cat tac cat cat +at at cat at cat cat +tac at ta at at +at cat ct ct ta ct ct cat tac +cat cat cat cat +at at at at at at at + """ + + transcript = Path(TMP_DIR) / "transcript_words.txt" + with open(transcript, "w") as f: + for line in sentences.strip().split("\n"): + f.write(f"{line}\n") + + words = """ + 0 + 1 +at 2 +cat 3 +ct 4 +ta 5 +tac 6 +#0 7 + 8 + 9 +""" + word_txt = Path(TMP_DIR) / "words.txt" + with open(word_txt, "w") as f: + for line in words.strip().split("\n"): + f.write(f"{line}\n") + + vocab_size = 8 + + os.system( + f""" +cd {ICEFALL_DIR}/egs/librispeech/ASR + +./local/train_bpe_model.py \ + --lang-dir {TMP_DIR} \ + --vocab-size {vocab_size} \ + --transcript {transcript} + +./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0 + +./local/convert_transcript_words_to_tokens.py \ +--lexicon {TMP_DIR}/lexicon.txt \ +--transcript {transcript} \ +--oov "" \ +> {TMP_DIR}/transcript_tokens.txt + +./shared/make_kn_lm.py \ +-ngram-order 2 \ +-text {TMP_DIR}/transcript_tokens.txt \ +-lm {TMP_DIR}/P.arpa + +python3 -m kaldilm \ +--read-symbol-table="{TMP_DIR}/tokens.txt" \ +--disambig-symbol='#0' \ +--max-order=2 \ +{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt +""" + ) + + +def delete_test_data(): + shutil.rmtree(TMP_DIR) + + +def mmi_graph_compiler_test(): + # Caution: + # You have to uncomment + # del transcript_fsa.aux_labels + # in mmi_graph_compiler.py + # to see the correct aux_labels in *.svg + graph_compiler = MmiTrainingGraphCompiler( + lang_dir=TMP_DIR, uniq_filename="lexicon.txt" + ) + print(graph_compiler.device) + L_inv = graph_compiler.L_inv + L = k2.invert(L_inv) + + L.labels_sym = graph_compiler.lexicon.token_table + L.aux_labels_sym = graph_compiler.lexicon.word_table + L.draw(f"{TMP_DIR}/L.svg", title="L") + + L_inv.labels_sym = graph_compiler.lexicon.word_table + L_inv.aux_labels_sym = graph_compiler.lexicon.token_table + L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L") + + ctc_topo_P = graph_compiler.ctc_topo_P + ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table) + ctc_topo_P.labels_sym._id2sym[0] = "" + ctc_topo_P.labels_sym._sym2id[""] = 0 + ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table + ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P") + + print(ctc_topo_P.num_arcs) + print(k2.connect(ctc_topo_P).num_arcs) + + with open(str(TMP_DIR) + "/P.fst.txt") as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label 0 (i.e., ). + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + P.labels_sym = graph_compiler.lexicon.token_table + P.aux_labels_sym = graph_compiler.lexicon.token_table + P.draw(f"{TMP_DIR}/P.svg", title="P") + + ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False) + ctc_topo.labels_sym = ctc_topo_P.labels_sym + ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table + ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo") + print("p num arcs", P.num_arcs) + print("ctc_topo num arcs", ctc_topo.num_arcs) + print("ctc_topo_P num arcs", ctc_topo_P.num_arcs) + + texts = ["cat at ct", "at ta", "cat tac"] + transcript_fsa = graph_compiler.build_transcript_fsa(texts) + transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct") + transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta") + transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac") + + num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) + num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct") + num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta") + num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac") + + den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct") + den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac") + + sp = spm.SentencePieceProcessor() + sp.load(f"{TMP_DIR}/bpe.model") + + texts = ["cat at cat", "at tac"] + token_ids = graph_compiler.texts_to_ids(texts) + expected_token_ids = sp.encode(texts) + assert token_ids == expected_token_ids + + +def test_main(): + generate_test_data() + + mmi_graph_compiler_test() + + if USING_PYTEST: + delete_test_data() + + +def main(): + test_main() + + +if __name__ == "__main__" and not USING_PYTEST: + main()