diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..118e7e717 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -0,0 +1,602 @@ +#!/usr/bin/env python3 + +# This is just at the very beginning ... + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional + +import k2 +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from conformer import Conformer +from transformer import Noam + +from lhotse.utils import fix_random_seed +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_value_ +from torch.optim.lr_scheduler import StepLR +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.dist import cleanup_dist, setup_dist +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + # TODO: add extra arguments and support DDP training. + # Currently, only single GPU training is implemented. Will add + # DDP training once single GPU training is finished. + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lang_dir: It contains language related input files such as + "lexicon.txt" + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - num_epochs: Number of epochs to train. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + """ + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang/bpe"), + "feature_dim": 80, + "weight_decay": 0.0, + "subsampling_factor": 4, + "start_epoch": 0, + "num_epochs": 10, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "valid_interval": 1000, + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # + "accum_grad": 1, + "att_rate": 0.7, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "lr_factor": 5.0, + "warm_step": 80000, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, model=model, optimizer=optimizer, scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +): + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + + # train_frames and valid_frames are used for printing. + if is_training: + params.train_frames = supervision_segments[:, 2].sum().item() + else: + params.valid_frames = supervision_segments[:, 2].sum().item() + + assert loss.requires_grad == is_training + + return loss + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> None: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = 0.0 + tot_frames = 0.0 + for batch_idx, batch in enumerate(valid_dl): + loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + + loss_cpu = loss.detach().cpu().item() + tot_loss += loss_cpu + tot_frames += params.valid_frames + + if world_size > 1: + s = torch.tensor([tot_loss, tot_frames], device=loss.device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + s = s.cpu().tolist() + tot_loss = s[0] + tot_frames = s[1] + + params.valid_loss = tot_loss / tot_frames + + if params.valid_loss < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = params.valid_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = 0.0 # sum of losses over all batches + tot_frames = 0.0 # sum of frames over all batches + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_value_(model.parameters(), 5.0) + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + + tot_frames += params.train_frames + tot_loss += loss_cpu + tot_avg_loss = tot_loss / tot_frames + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg loss {loss_cpu/params.train_frames:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " + f"best valid epoch: {params.best_valid_epoch}" + ) + + params.train_loss = tot_loss / tot_frames + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, model=model, optimizer=optimizer, rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index fc748a252..1df16e346 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -189,6 +189,8 @@ class Transformer(nn.Module): supervision: Supervisions = None, graph_compiler: object = None, token_ids: List[int] = None, + sos_id: Optional[int] = None, + eos_id: Optional[int] = None, ) -> Tensor: """ Args: @@ -197,6 +199,8 @@ class Transformer(nn.Module): supervision: Supervison in lhotse format, get from batch['supervisions'] graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) , graph_compiler.words and graph_compiler.oov + sos_id: sos token id + eos_id: eos token id Returns: Tensor: Decoder loss. @@ -206,18 +210,9 @@ class Transformer(nn.Module): supervision, graph_compiler.lexicon.words, graph_compiler.oov ) ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, - graph_compiler.L_inv, - self.decoder_num_class - 1, - self.decoder_num_class - 1, + batch_text, graph_compiler.L_inv, sos_id, eos_id, ) elif token_ids is not None: - # speical token ids: - # 0 - # 1 - # self.decoder_num_class - 1 - sos_id = self.decoder_num_class - 1 - eos_id = self.decoder_num_class - 1 _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) ys_in = [ @@ -259,7 +254,12 @@ class Transformer(nn.Module): return decoder_loss def decoder_nll( - self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None + self, + x: Tensor, + encoder_mask: Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, ) -> Tensor: """ Args: @@ -273,12 +273,6 @@ class Transformer(nn.Module): # The common part between this fuction and decoder_forward could be # extracted as a seperated function. if token_ids is not None: - # speical token ids: - # 0 - # 1 - # self.decoder_num_class - 1 - sos_id = self.decoder_num_class - 1 - eos_id = self.decoder_num_class - 1 _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) ys_in = [ @@ -866,7 +860,8 @@ class LabelSmoothingLoss(nn.Module): target = target.masked_fill(ignore, 0) # avoid -1 index true_dist.scatter_(1, target.unsqueeze(1), self.confidence) kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) - denom = total if self.normalize_length else batch_size + # denom = total if self.normalize_length else batch_size + denom = total if self.normalize_length else 1 return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom @@ -983,8 +978,8 @@ def generate_square_subsequent_mask(sz: int) -> Tensor: def add_sos_eos( ys: List[List[int]], lexicon: k2.Fsa, - sos: int, - eos: int, + sos_id: int, + eos_id: int, ignore_id: int = -1, ) -> Tuple[Tensor, Tensor]: """Add and labels. @@ -992,8 +987,8 @@ def add_sos_eos( Args: ys: batch of unpadded target sequences lexicon: Its labels are words, while its aux_labels are phones. - sos: index of - eos: index of + sos_id: index of + eos_id: index of ignore_id: index of padding Returns: @@ -1001,8 +996,8 @@ def add_sos_eos( Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length). """ - _sos = torch.tensor([sos]) - _eos = torch.tensor([eos]) + _sos = torch.tensor([sos_id]) + _eos = torch.tensor([eos_id]) ys = get_hierarchical_targets(ys, lexicon) ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index b962be552..605d72dae 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -3,7 +3,7 @@ """ This script compiles HLG from - - H, the ctc topology, built from phones contained in lexicon.txt + - H, the ctc topology, built from tokens contained in lexicon.txt - L, the lexicon, built from L_disambig.pt Caution: We use a lexicon that contains disambiguation symbols @@ -13,6 +13,7 @@ This script compiles HLG from The generated HLG is saved in data/lm/HLG.pt (phone based) or data/lm/HLG_bpe.pt (BPE based) """ +import logging from pathlib import Path import k2 @@ -32,44 +33,44 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) - print(f"Building ctc_topo. max_token_id: {max_token_id}") + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): - print("Loading pre-compiled G_3_gram") + logging.info("Loading pre-compiled G_3_gram") d = torch.load("data/lm/G_3_gram.pt") G = k2.Fsa.from_dict(d) else: - print("Loading G_3_gram.fst.txt") + logging.info("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) torch.save(G.as_dict(), "G_3_gram.pt") - first_token_disambig_id = lexicon.phones["#0"] - first_word_disambig_id = lexicon.words["#0"] + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] L = k2.arc_sort(L) G = k2.arc_sort(G) - print("Intersecting L and G") + logging.info("Intersecting L and G") LG = k2.compose(L, G) - print(f"LG shape: {LG.shape}") + logging.info(f"LG shape: {LG.shape}") - print("Connecting LG") + logging.info("Connecting LG") LG = k2.connect(LG) - print(f"LG shape after k2.connect: {LG.shape}") + logging.info(f"LG shape after k2.connect: {LG.shape}") - print(type(LG.aux_labels)) - print("Determinizing LG") + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") LG = k2.determinize(LG) - print(type(LG.aux_labels)) + logging.info(type(LG.aux_labels)) - print("Connecting LG after k2.determinize") + logging.info("Connecting LG after k2.determinize") LG = k2.connect(LG) - print("Removing disambiguation symbols on LG") + logging.info("Removing disambiguation symbols on LG") LG.labels[LG.labels >= first_token_disambig_id] = 0 @@ -77,27 +78,27 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) - print(f"LG shape after k2.remove_epsilon: {LG.shape}") + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) - print("Arc sorting LG") + logging.info("Arc sorting LG") LG = k2.arc_sort(LG) - print("Composing H and LG") + logging.info("Composing H and LG") # CAUTION: The name of the inner_labels is fixed # to `tokens`. If you want to change it, please # also change other places in icefall that are using # it. HLG = k2.compose(H, LG, inner_labels="tokens") - print("Connecting LG") + logging.info("Connecting LG") HLG = k2.connect(HLG) - print("Arc sorting LG") + logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) - print(f"HLG.shape: {HLG.shape}") + logging.info(f"HLG.shape: {HLG.shape}") return HLG @@ -106,10 +107,10 @@ def phone_based_HLG(): if Path("data/lm/HLG.pt").is_file(): return - print("Compiling phone based HLG") + logging.info("Compiling phone based HLG") HLG = compile_HLG("data/lang") - print("Saving HLG.pt to data/lm") + logging.info("Saving HLG.pt to data/lm") torch.save(HLG.as_dict(), "data/lm/HLG.pt") @@ -117,9 +118,9 @@ def bpe_based_HLG(): if Path("data/lm/HLG_bpe.pt").is_file(): return - print("Compiling BPE based HLG") + logging.info("Compiling BPE based HLG") HLG = compile_HLG("data/lang/bpe") - print("Saving HLG_bpe.pt to data/lm") + logging.info("Saving HLG_bpe.pt to data/lm") torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") @@ -129,4 +130,10 @@ def main(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index 9945a5006..b9d13f5bb 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -4,13 +4,13 @@ """ This script takes as input a lexicon file "data/lang/lexicon.txt" -consisting of words and phones and does the following: +consisting of words and tokens (i.e., phones) and does the following: 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt -2. Generate phones.txt, the phones table mapping a phone to a unique integer. +2. Generate tokens.txt, the token table mapping a token to a unique integer. -3. Generate words.txt, the words table mapping a word to a unique integer. +3. Generate words.txt, the word table mapping a word to a unique integer. 4. Generate L.pt, in k2 format. It can be loaded by @@ -29,62 +29,11 @@ from typing import Any, Dict, List, Tuple import k2 import torch +from icefall.lexicon import read_lexicon, write_lexicon + Lexicon = List[Tuple[str, List[str]]] -def read_lexicon(filename: str) -> Lexicon: - """Read a lexicon.txt in `filename`. - - Each line in the lexicon contains "word p1 p2 p3 ...". - That is, the first field is a word and the remaining - fields are phones. Fields are separated by space(s). - - Args: - filename: - Path to the lexicon.txt - - Returns: - A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])] - """ - ans = [] - - with open(filename, "r", encoding="utf-8") as f: - whitespace = re.compile("[ \t]+") - for line in f: - a = whitespace.split(line.strip(" \t\r\n")) - if len(a) == 0: - continue - - if len(a) < 2: - print(f"Found bad line {line} in lexicon file {filename}") - print("Every line is expected to contain at least 2 fields") - sys.exit(1) - word = a[0] - if word == "": - print(f"Found bad line {line} in lexicon file {filename}") - print(" should not be a valid word") - sys.exit(1) - - prons = a[1:] - ans.append((word, prons)) - - return ans - - -def write_lexicon(filename: str, lexicon: Lexicon) -> None: - """Write a lexicon to a file. - - Args: - filename: - Path to the lexicon file to be generated. - lexicon: - It can be the return value of :func:`read_lexicon`. - """ - with open(filename, "w", encoding="utf-8") as f: - for word, prons in lexicon: - f.write(f"{word} {' '.join(prons)}\n") - - def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: """Write a symbol to ID mapping to a file. @@ -105,18 +54,18 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: f.write(f"{sym} {i}\n") -def get_phones(lexicon: Lexicon) -> List[str]: - """Get phones from a lexicon. +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. Args: lexicon: It is the return value of :func:`read_lexicon`. Returns: - Return a list of unique phones. + Return a list of unique tokens. """ ans = set() - for _, prons in lexicon: - ans.update(prons) + for _, tokens in lexicon: + ans.update(tokens) sorted_ans = sorted(list(ans)) return sorted_ans @@ -138,8 +87,8 @@ def get_words(lexicon: Lexicon) -> List[str]: def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-phone disambiguation symbols #1, #2 and so on - at the ends of phones to ensure that all pronunciations are different, + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, and that none is a prefix of another. See also add_lex_disambig.pl from kaldi. @@ -151,30 +100,30 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: Return a tuple with two elements: - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbols that appears + - The ID of the max disambiguation symbol that appears in the lexicon """ - # (1) Work out the count of each phone-sequence in the + # (1) Work out the count of each token-sequence in the # lexicon. count = defaultdict(int) - for _, prons in lexicon: - count[" ".join(prons)] += 1 + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 - # (2) For each left sub-sequence of each phone-sequence, note down + # (2) For each left sub-sequence of each token-sequence, note down # that it exists (for identifying prefixes of longer strings). issubseq = defaultdict(int) - for _, prons in lexicon: - prons = prons.copy() - prons.pop() - while prons: - issubseq[" ".join(prons)] = 1 - prons.pop() + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() # (3) For each entry in the lexicon: - # if the phone sequence is unique and is not a + # if the token sequence is unique and is not a # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same phone-seq + # Else output #1, or #2, #3, ... if the same token-seq # has already been assigned a disambig symbol. ans = [] @@ -183,14 +132,14 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: max_disambig = first_allowed_disambig - 1 last_used_disambig_symbol_of = defaultdict(int) - for word, prons in lexicon: - phnseq = " ".join(prons) - assert phnseq != "" - if issubseq[phnseq] == 0 and count[phnseq] == 1: - ans.append((word, prons)) + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) continue - cur_disambig = last_used_disambig_symbol_of[phnseq] + cur_disambig = last_used_disambig_symbol_of[tokenseq] if cur_disambig == 0: cur_disambig = first_allowed_disambig else: @@ -198,9 +147,9 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: if cur_disambig > max_disambig: max_disambig = cur_disambig - last_used_disambig_symbol_of[phnseq] = cur_disambig - phnseq += f" #{cur_disambig}" - ans.append((word, phnseq.split())) + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) return ans, max_disambig @@ -217,7 +166,7 @@ def generate_id_map(symbols: List[str]) -> Dict[str, int]: def add_self_loops( - arcs: List[List[Any]], disambig_phone: int, disambig_word: int + arcs: List[List[Any]], disambig_token: int, disambig_word: int ) -> List[List[Any]]: """Adds self-loops to states of an FST to propagate disambiguation symbols through it. They are added on each state with non-epsilon output symbols @@ -228,12 +177,15 @@ def add_self_loops( This function uses k2 style FSTs and it does not need to add self-loops to the final state. + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + Args: arcs: A list-of-list. The sublist contains `[src_state, dest_state, label, aux_label, score]` - disambig_phone: - It is the phone ID of the symbol `#0`. + disambig_token: + It is the token ID of the symbol `#0`. disambig_word: It is the word ID of the symbol `#0`. @@ -248,37 +200,38 @@ def add_self_loops( ans = [] for s in states_needs_self_loops: - ans.append([s, s, disambig_phone, disambig_word, 0]) + ans.append([s, s, disambig_token, disambig_word, 0]) return arcs + ans def lexicon_to_fst( lexicon: Lexicon, - phone2id: Dict[str, int], + token2id: Dict[str, int], word2id: Dict[str, int], - sil_phone: str = "SIL", + sil_token: str = "SIL", sil_prob: float = 0.5, need_self_loops: bool = False, ) -> k2.Fsa: """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of the word. + the beginning and end of each word. Args: lexicon: The input lexicon. See also :func:`read_lexicon` - phone2id: - A dict mapping phones to IDs. + token2id: + A dict mapping tokens to IDs. word2id: A dict mapping words to IDs. - sil_phone: - The silence phone. + sil_token: + The silence token. sil_prob: The probability for adding a silence at the beginning and end of the word. need_self_loops: If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. Returns: Return an instance of `k2.Fsa` representing the given lexicon. """ @@ -294,48 +247,44 @@ def lexicon_to_fst( next_state = 3 # the next un-allocated state, will be incremented as we go. arcs = [] - assert phone2id[""] == 0 + assert token2id[""] == 0 assert word2id[""] == 0 eps = 0 - sil_phone = phone2id[sil_phone] + sil_token = token2id[sil_token] arcs.append([start_state, loop_state, eps, eps, no_sil_score]) arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_phone, eps, 0]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) - for word, prons in lexicon: - assert len(prons) > 0, f"{word} has no pronunciations" + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" cur_state = loop_state word = word2id[word] - prons = [phone2id[i] for i in prons] + tokens = [token2id[i] for i in tokens] - for i in range(len(prons) - 1): - if i == 0: - arcs.append([cur_state, next_state, prons[i], word, 0]) - else: - arcs.append([cur_state, next_state, prons[i], eps, 0]) + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) cur_state = next_state next_state += 1 - # now for the last phone of this word + # now for the last token of this word # It has two out-going arcs, one to the loop state, # the other one to the sil_state. - i = len(prons) - 1 + i = len(tokens) - 1 w = word if i == 0 else eps - arcs.append([cur_state, loop_state, prons[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, prons[i], w, sil_score]) + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) if need_self_loops: - disambig_phone = phone2id["#0"] + disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, - disambig_phone=disambig_phone, - disambig_word=disambig_word, + arcs, disambig_token=disambig_token, disambig_word=disambig_word, ) final_state = next_state @@ -354,22 +303,22 @@ def lexicon_to_fst( def main(): out_dir = Path("data/lang") lexicon_filename = out_dir / "lexicon.txt" - sil_phone = "SIL" + sil_token = "SIL" sil_prob = 0.5 lexicon = read_lexicon(lexicon_filename) - phones = get_phones(lexicon) + tokens = get_tokens(lexicon) words = get_words(lexicon) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) for i in range(max_disambig + 1): disambig = f"#{i}" - assert disambig not in phones - phones.append(f"#{i}") + assert disambig not in tokens + tokens.append(f"#{i}") - assert "" not in phones - phones = [""] + phones + assert "" not in tokens + tokens = [""] + tokens assert "" not in words assert "#0" not in words @@ -378,26 +327,26 @@ def main(): words = [""] + words + ["#0", "", ""] - phone2id = generate_id_map(phones) + token2id = generate_id_map(tokens) word2id = generate_id_map(words) - write_mapping(out_dir / "phones.txt", phone2id) + write_mapping(out_dir / "tokens.txt", token2id) write_mapping(out_dir / "words.txt", word2id) write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst( lexicon, - phone2id=phone2id, + token2id=token2id, word2id=word2id, - sil_phone=sil_phone, + sil_token=sil_token, sil_prob=sil_prob, ) L_disambig = lexicon_to_fst( lexicon_disambig, - phone2id=phone2id, + token2id=token2id, word2id=word2id, - sil_phone=sil_phone, + sil_token=sil_token, sil_prob=sil_prob, need_self_loops=True, ) @@ -406,7 +355,7 @@ def main(): if False: # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt") + L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") L_disambig.labels_sym = L.labels_sym L_disambig.aux_labels_sym = L.aux_labels_sym diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index f70279cf4..0c3e9ede5 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -3,9 +3,9 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as inputs the following files: +This script takes as inputs the following two files: + - data/lang/bpe/bpe.model, - - data/lang/bpe/tokens.txt (will remove it), - data/lang/bpe/words.txt and generates the following files in the directory data/lang/bpe: @@ -14,11 +14,11 @@ and generates the following files in the directory data/lang/bpe: - lexicon_disambig.txt - L.pt - L_disambig.pt - - phones.txt + - tokens.txt """ from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Tuple import k2 import sentencepiece as spm @@ -28,6 +28,7 @@ from prepare_lang import ( add_disambig_symbols, add_self_loops, write_lexicon, + write_mapping, ) @@ -48,48 +49,46 @@ def lexicon_to_fst_no_sil( A dict mapping words to IDs. need_self_loops: If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. Returns: Return an instance of `k2.Fsa` representing the given lexicon. """ loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go. + next_state = 1 # the next un-allocated state, will be incremented as we go arcs = [] - assert token2id[""] == 0 + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 assert word2id[""] == 0 eps = 0 - for word, prons in lexicon: - assert len(prons) > 0, f"{word} has no pronunciations" + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" cur_state = loop_state word = word2id[word] - prons = [token2id[i] for i in prons] + pieces = [token2id[i] for i in pieces] - for i in range(len(prons) - 1): - if i == 0: - arcs.append([cur_state, next_state, prons[i], word, 0]) - else: - arcs.append([cur_state, next_state, prons[i], eps, 0]) + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) cur_state = next_state next_state += 1 - # now for the last phone of this word - i = len(prons) - 1 + # now for the last piece of this word + i = len(pieces) - 1 w = word if i == 0 else eps - arcs.append([cur_state, loop_state, prons[i], w, 0]) + arcs.append([cur_state, loop_state, pieces[i], w, 0]) if need_self_loops: - disambig_phone = token2id["#0"] + disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, - disambig_phone=disambig_phone, - disambig_word=disambig_word, + arcs, disambig_token=disambig_token, disambig_word=disambig_word, ) final_state = next_state @@ -105,7 +104,9 @@ def lexicon_to_fst_no_sil( return fsa -def generate_lexicon(model_file: str, words: List[str]) -> Lexicon: +def generate_lexicon( + model_file: str, words: List[str] +) -> Tuple[Lexicon, Dict[str, int]]: """Generate a lexicon from a BPE model. Args: @@ -114,8 +115,10 @@ def generate_lexicon(model_file: str, words: List[str]) -> Lexicon: words: A list of strings representing words. Returns: - Return a dict whose keys are words and values are the corresponding - word pieces. + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. """ sp = spm.SentencePieceProcessor() sp.load(str(model_file)) @@ -126,8 +129,14 @@ def generate_lexicon(model_file: str, words: List[str]) -> Lexicon: for word, pieces in zip(words, words_pieces): lexicon.append((word, pieces)) - lexicon.append(("", [""])) - return lexicon + # The OOV word is + lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = dict() + for i in range(sp.vocab_size()): + token2id[sp.id_to_piece(i)] = i + + return lexicon, token2id def main(): @@ -143,34 +152,28 @@ def main(): if w in words: words.remove(w) - lexicon = generate_lexicon(model_file, words) - - # TODO(fangjun): Remove tokens.txt and generate it from the model directly. - # - # We are using it since the IDs we are using in tokens.txt is - # different from the one contained in the model - token_sym_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + lexicon, token_sym_table = generate_lexicon(model_file, words) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + next_token_id = max(token_sym_table.values()) + 1 for i in range(max_disambig + 1): disambig = f"#{i}" assert disambig not in token_sym_table - token_sym_table.add(f"#{i}") + token_sym_table[disambig] = next_token_id + next_token_id += 1 word_sym_table.add("#0") word_sym_table.add("") word_sym_table.add("") - token_sym_table.to_file(lang_dir / "phones.txt") + write_mapping(lang_dir / "tokens.txt", token_sym_table) write_lexicon(lang_dir / "lexicon.txt", lexicon) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, + lexicon, token2id=token_sym_table, word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( @@ -184,7 +187,7 @@ def main(): if False: # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(lang_dir / "phones.txt") + L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") L_disambig.labels_sym = L.labels_sym L_disambig.aux_labels_sym = L.aux_labels_sym diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..b5c6c7541 --- /dev/null +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +""" +This script takes as input "data/lang/bpe/train.txt" +and generates "data/lang/bpe/bep.model". +""" + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +from pathlib import Path + +import sentencepiece as spm + +import shutil + + +def main(): + model_type = "unigram" + vocab_size = 5000 + model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" + train_text = "data/lang/bpe/train.txt" + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + + sp = spm.SentencePieceProcessor(model_file=str(model_file)) + vocab_size = sp.vocab_size() + + shutil.copyfile(model_file, "data/lang/bpe/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index b73d0e71f..406527b71 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -10,14 +10,20 @@ stop_stage=100 mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - echo "stage -1: Download LM" + log "stage -1: Download LM" mkdir -p data/lm ./local/download_lm.py fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - echo "stage 0: Download data" + log "stage 0: Download data" # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink @@ -49,7 +55,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - echo "Stage 1: Prepare librispeech manifest" + log "Stage 1: Prepare librispeech manifest" # We assume that you have downloaded the librispeech corpus # to data/LibriSpeech mkdir -p data/manifests @@ -57,7 +63,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - echo "Stage 2: Prepare musan manifest" + log "Stage 2: Prepare musan manifest" # We assume that you have downloaded the musan corpus # to data/musan mkdir -p data/manifests @@ -65,19 +71,19 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - echo "Stage 3: Compute fbank for librispeech" + log "Stage 3: Compute fbank for librispeech" mkdir -p data/fbank ./local/compute_fbank_librispeech.py fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - echo "Stage 4: Compute fbank for musan" + log "Stage 4: Compute fbank for musan" mkdir -p data/fbank ./local/compute_fbank_musan.py fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - echo "Stage 5: Prepare phone based lang" + log "Stage 5: Prepare phone based lang" # TODO: add BPE based lang mkdir -p data/lang @@ -85,21 +91,37 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then cat - data/lm/librispeech-lexicon.txt | sort | uniq > data/lang/lexicon.txt - ./local/prepare_lang.py + if [ ! -f data/lang/L_disambig.pt ]; then + ./local/prepare_lang.py + fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - echo "State 6: Prepare BPE based lang" + log "State 6: Prepare BPE based lang" mkdir -p data/lang/bpe cp data/lang/words.txt data/lang/bpe/ + if [ ! -f data/lang/bpe/train.txt ]; then + log "Generate data for BPE training" + files=$( + find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "data/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > data/lang/bpe/train.txt + fi + + python3 ./local/train_bpe_model.py + if [ ! -f data/lang/bpe/L_disambig.pt ]; then ./local/prepare_lang_bpe.py fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - echo "Stage 7: Prepare G" + log "Stage 7: Prepare G" # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm @@ -123,6 +145,6 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - echo "Stage 8: Compile HLG" + log "Stage 8: Compile HLG" python3 ./local/compile_hlg.py fi diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2a29190c9..2c45b4e31 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -72,7 +72,7 @@ def get_params() -> AttributeDict: # - nbest # - nbest-rescoring # - whole-lattice-rescoring - "method": "whole-lattice-rescoring", + "method": "1best", # num_paths is used when method is "nbest" and "nbest-rescoring" "num_paths": 30, } @@ -173,7 +173,7 @@ def decode_one_batch( ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) - hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] @@ -196,7 +196,7 @@ def decode_one_batch( ans = dict() for lm_scale_str, best_path in best_path_dict.items(): hyps = get_texts(best_path) - hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] ans[lm_scale_str] = hyps return ans diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py new file mode 100644 index 000000000..e22cf4edc --- /dev/null +++ b/icefall/bpe_graph_compiler.py @@ -0,0 +1,74 @@ +from pathlib import Path +from typing import List, Union + +import k2 +import sentencepiece as spm +import torch + + +class BpeCtcTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + device: Union[str, torch.device] = "cpu", + sos_token: str = "", + eos_token: str = "", + ) -> None: + """ + Args: + lang_dir: + This directory is expected to contain the following files: + + - bpe.model + - words.txt + device: + It indicates CPU or CUDA. + sos_token: + The word piece that represents sos. + eos_token: + The word piece that represents eos. + """ + lang_dir = Path(lang_dir) + model_file = lang_dir / "bpe.model" + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + self.sp = sp + self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + self.device = device + + self.sos_id = self.sp.piece_to_id(sos_token) + self.eos_id = self.sp.piece_to_id(eos_token) + + assert self.sos_id != self.sp.unk_id() + assert self.eos_id != self.sp.unk_id() + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of piece IDs. + """ + return self.sp.encode(texts, out_type=int) + + def compile( + self, piece_ids: List[List[int]], modified: bool = False, + ) -> k2.Fsa: + """Build a ctc graph from a list-of-list piece IDs. + + Args: + piece_ids: + It is a list-of-list integer IDs. + modified: + See :func:`k2.ctc_graph` for its meaning. + Return: + Return an FsaVec, which is the result of composing a + CTC topology with linear FSAs constructed from the given + piece IDs. + """ + return k2.ctc_graph(piece_ids, modified=modified, device=self.device) diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 8d5d136b7..f7ba3cdaf 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -8,10 +8,7 @@ from icefall.lexicon import Lexicon class CtcTrainingGraphCompiler(object): def __init__( - self, - lexicon: Lexicon, - device: torch.device, - oov: str = "", + self, lexicon: Lexicon, device: torch.device, oov: str = "", ): """ Args: @@ -26,11 +23,11 @@ class CtcTrainingGraphCompiler(object): L_inv = lexicon.L_inv.to(device) assert L_inv.requires_grad is False - assert oov in lexicon.words + assert oov in lexicon.word_table self.L_inv = k2.arc_sort(L_inv) - self.oov_id = lexicon.words[oov] - self.words = lexicon.words + self.oov_id = lexicon.word_table[oov] + self.word_table = lexicon.word_table max_token_id = max(lexicon.tokens) ctc_topo = k2.ctc_topo(max_token_id, modified=False) @@ -90,8 +87,8 @@ class CtcTrainingGraphCompiler(object): for text in texts: word_ids = [] for word in text.split(" "): - if word in self.words: - word_ids.append(self.words[word]) + if word in self.word_table: + word_ids.append(self.word_table[word]) else: word_ids.append(self.oov_id) word_ids_list.append(word_ids) diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 46cea1941..3b52c70c9 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -1,12 +1,65 @@ import logging import re from pathlib import Path -from typing import List +from typing import List, Tuple, Union import k2 import torch +def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: + """Read a lexicon from `filename`. + + Each line in the lexicon contains "word p1 p2 p3 ...". + That is, the first field is a word and the remaining + fields are tokens. Fields are separated by space(s). + + Args: + filename: + Path to the lexicon.txt + + Returns: + A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])] + """ + ans = [] + + with open(filename, "r", encoding="utf-8") as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) == 0: + continue + + if len(a) < 2: + print(f"Found bad line {line} in lexicon file {filename}") + print("Every line is expected to contain at least 2 fields") + sys.exit(1) + word = a[0] + if word == "": + print(f"Found bad line {line} in lexicon file {filename}") + print(" should not be a valid word") + sys.exit(1) + + tokens = a[1:] + ans.append((word, tokens)) + + return ans + + +def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: + """Write a lexicon to a file. + + Args: + filename: + Path to the lexicon file to be generated. + lexicon: + It can be the return value of :func:`read_lexicon`. + """ + with open(filename, "w", encoding="utf-8") as f: + for word, tokens in lexicon: + f.write(f"{word} {' '.join(tokens)}\n") + + class Lexicon(object): """Phone based lexicon. @@ -14,14 +67,14 @@ class Lexicon(object): """ def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$") + self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Args: lang_dir: Path to the lang director. It is expected to contain the following files: - - phones.txt + - tokens.txt - words.txt - L.pt The above files are produced by the script `prepare.sh`. You @@ -30,11 +83,11 @@ class Lexicon(object): It contains the pattern for disambiguation symbols. """ lang_dir = Path(lang_dir) - self.phones = k2.SymbolTable.from_file(lang_dir / "phones.txt") - self.words = k2.SymbolTable.from_file(lang_dir / "words.txt") + self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt") if (lang_dir / "Linv.pt").exists(): - logging.info("Loading pre-compiled Linv.pt") + logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt") L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt")) else: logging.info("Converting L.pt to Linv.pt") @@ -49,18 +102,92 @@ class Lexicon(object): @property def tokens(self) -> List[int]: - """Return a list of phone IDs excluding those from + """Return a list of token IDs excluding those from disambiguation symbols. Caution: - 0 is not a phone ID so it is excluded from the return value. + 0 is not a token ID so it is excluded from the return value. """ - symbols = self.phones.symbols + symbols = self.token_table.symbols ans = [] for s in symbols: if not self.disambig_pattern.match(s): - ans.append(self.phones[s]) + ans.append(self.token_table[s]) if 0 in ans: ans.remove(0) ans.sort() return ans + + +class BpeLexicon(Lexicon): + def __init__( + self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + ): + """ + Refer to the help information in Lexicon.__init__. + """ + super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) + + self.ragged_lexicon = self.convert_lexicon_to_ragged( + lang_dir / "lexicon.txt" + ) + + def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: + """Read a BPE lexicon from file and convert it to a + k2 ragged tensor. + + Args: + filename: + Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt + Returns: + A k2 ragged tensor with two axes [word_id] + """ + disambig_id = self.word_table["#0"] + # We reuse the same words.txt from the phone based lexicon + # so that we can share the same G.fst. Here, we have to + # exclude some words present only in the phone based lexicon. + excluded_words = ["", "!SIL", ""] + + # epsilon is not a word, but it occupies on position + # + row_splits = [0] + token_ids = [] + + lexicon = read_lexicon(filename) + lexicon = dict(lexicon) + + for i in range(disambig_id): + w = self.word_table[i] + if w in excluded_words: + row_splits.append(row_splits[-1]) + continue + pieces = lexicon[w] + piece_ids = [self.token_table[k] for k in pieces] + + row_splits.append(row_splits[-1] + len(piece_ids)) + token_ids.extend(piece_ids) + + cached_tot_size = row_splits[-1] + row_splits = torch.tensor(row_splits, dtype=torch.int32) + + shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=cached_tot_size + ) + values = torch.tensor(token_ids, dtype=torch.int32) + + return k2.RaggedInt(shape, values) + + def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: + """Convert a list of words to a ragged tensor contained + word piece IDs. + """ + word_ids = [self.word_table[w] for w in words] + word_ids = torch.tensor(word_ids, dtype=torch.int32) + + ragged, _ = k2.ragged.index( + self.ragged_lexicon, + indexes=word_ids, + need_value_indexes=False, + axis=0, + ) + return ragged diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py new file mode 100755 index 000000000..7b941e5a7 --- /dev/null +++ b/test/test_bpe_graph_compiler.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.lexicon import BpeLexicon +from pathlib import Path + + +def test(): + lang_dir = Path("data/lang/bpe") + if not lang_dir.is_dir(): + return + # TODO: generate data for testing + + compiler = BpeCtcTrainingGraphCompiler(lang_dir) + ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) + fsa = compiler.compile(ids) + + lexicon = BpeLexicon(lang_dir) + ids0 = lexicon.words_to_piece_ids(["HELLO"]) + assert ids[0] == ids0.values().tolist() + + ids1 = lexicon.words_to_piece_ids(["WORLD", "ZZZ"]) + assert ids[1] == ids1.values().tolist() diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py index 7894dc61f..343768957 100644 --- a/test/test_checkpoint.py +++ b/test/test_checkpoint.py @@ -41,7 +41,8 @@ def test_load_checkpoints(checkpoints1): m.p2 = nn.Parameter(torch.Tensor([0, 0])) params = load_checkpoint(checkpoints1, m) assert torch.allclose(m.p1, torch.Tensor([10.0, 20])) - assert params == {"a": 10, "b": 20} + assert params["a"] == 10 + assert params["b"] == 20 def test_average_checkpoints(checkpoints1, checkpoints2): diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py index a053d5f4d..4083d79ac 100644 --- a/test/test_graph_compiler.py +++ b/test/test_graph_compiler.py @@ -81,8 +81,8 @@ def lexicon(): """ ) ans = Lexicon.__new__(Lexicon) - ans.phones = L.labels_sym - ans.words = L.aux_labels_sym + ans.token_table = L.labels_sym + ans.word_table = L.aux_labels_sym ans.L_inv = k2.arc_sort(L.invert_()) ans.disambig_pattern = re.compile(r"^#\d+$") @@ -107,11 +107,11 @@ class TestCtcTrainingGraphCompiler(object): aux_labels1 = fsa[1].aux_labels[:-1] aux_labels1 = aux_labels1[aux_labels1 != 0].tolist() - labels0 = [lexicon.phones[i] for i in labels0] - labels1 = [lexicon.phones[i] for i in labels1] + labels0 = [lexicon.token_table[i] for i in labels0] + labels1 = [lexicon.token_table[i] for i in labels1] - aux_labels0 = [lexicon.words[i] for i in aux_labels0] - aux_labels1 = [lexicon.words[i] for i in aux_labels1] + aux_labels0 = [lexicon.word_table[i] for i in aux_labels0] + aux_labels1 = [lexicon.word_table[i] for i in aux_labels1] assert labels0 == ["b", "a", "r", "f", "o", "o"] assert aux_labels0 == ["bar", "foo"] @@ -129,11 +129,11 @@ class TestCtcTrainingGraphCompiler(object): input2 = ["b", "b", "a", "a", "a", "", "", "z", "z"] input2 += ["", "", "SPN", "SPN", "", ""] - lexicon.phones._id2sym[0] == "" - lexicon.phones._sym2id[""] = 0 + lexicon.token_table._id2sym[0] == "" + lexicon.token_table._sym2id[""] = 0 - input1 = [lexicon.phones[i] for i in input1] - input2 = [lexicon.phones[i] for i in input2] + input1 = [lexicon.token_table[i] for i in input1] + input2 = [lexicon.token_table[i] for i in input2] fsa1 = k2.linear_fsa(input1) fsa2 = k2.linear_fsa(input2) @@ -147,14 +147,14 @@ class TestCtcTrainingGraphCompiler(object): aux_labels0 = lattice[0].aux_labels[:-1] aux_labels0 = aux_labels0[aux_labels0 != 0].tolist() - aux_labels0 = [lexicon.words[i] for i in aux_labels0] + aux_labels0 = [lexicon.word_table[i] for i in aux_labels0] assert aux_labels0 == ["bar", "foo"] aux_labels1 = lattice[1].aux_labels[:-1] aux_labels1 = aux_labels1[aux_labels1 != 0].tolist() - aux_labels1 = [lexicon.words[i] for i in aux_labels1] + aux_labels1 = [lexicon.word_table[i] for i in aux_labels1] assert aux_labels1 == ["baz", ""] texts = get_texts(lattice) - texts = [[lexicon.words[i] for i in words] for words in texts] + texts = [[lexicon.word_table[i] for i in words] for words in texts] assert texts == [["bar", "foo"], ["baz", ""]] diff --git a/test/test_lexicon.py b/test/test_lexicon.py index b1b823f98..b1284d98a 100644 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 +from pathlib import Path + import k2 import pytest import torch -from icefall.lexicon import Lexicon +from icefall.lexicon import BpeLexicon, Lexicon @pytest.fixture @@ -47,7 +49,7 @@ def lang_dir(tmp_path): num_aux_labels=1, ) - with open(tmp_path / "phones.txt", "w") as f: + with open(tmp_path / "tokens.txt", "w") as f: f.write(phone2id) with open(tmp_path / "words.txt", "w") as f: f.write(word2id) @@ -60,3 +62,16 @@ def lang_dir(tmp_path): def test_lexicon(lang_dir): lexicon = Lexicon(lang_dir) assert lexicon.tokens == list(range(1, 8)) + + +def test_bpe_lexicon(): + lang_dir = Path("data/lang/bpe") + if not lang_dir.is_dir(): + return + # TODO: Generate test data for BpeLexicon + + lexicon = BpeLexicon(lang_dir) + words = ["", "HELLO", "ZZZZ", "WORLD"] + ids = lexicon.words_to_piece_ids(words) + print(ids) + print([lexicon.token_table[i] for i in ids.values().tolist()])