WIP: Add BPE training code.

This commit is contained in:
Fangjun Kuang 2021-07-29 20:23:52 +08:00
parent bd69e4be32
commit acc63a9172
15 changed files with 1144 additions and 267 deletions

View File

@ -0,0 +1,602 @@
#!/usr/bin/env python3
# This is just at the very beginning ...
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from conformer import Conformer
from transformer import Noam
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"feature_dim": 80,
"weight_decay": 0.0,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
#
"accum_grad": 1,
"att_rate": 0.7,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"lr_factor": 5.0,
"warm_step": 80000,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename, model=model, optimizer=optimizer, scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
):
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training
return loss
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
params.train_loss = tot_loss / tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params, model=model, optimizer=optimizer, rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -189,6 +189,8 @@ class Transformer(nn.Module):
supervision: Supervisions = None,
graph_compiler: object = None,
token_ids: List[int] = None,
sos_id: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tensor:
"""
Args:
@ -197,6 +199,8 @@ class Transformer(nn.Module):
supervision: Supervison in lhotse format, get from batch['supervisions']
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
, graph_compiler.words and graph_compiler.oov
sos_id: sos token id
eos_id: eos token id
Returns:
Tensor: Decoder loss.
@ -206,18 +210,9 @@ class Transformer(nn.Module):
supervision, graph_compiler.lexicon.words, graph_compiler.oov
)
ys_in_pad, ys_out_pad = add_sos_eos(
batch_text,
graph_compiler.L_inv,
self.decoder_num_class - 1,
self.decoder_num_class - 1,
batch_text, graph_compiler.L_inv, sos_id, eos_id,
)
elif token_ids is not None:
# speical token ids:
# <blank> 0
# <UNK> 1
# <sos/eos> self.decoder_num_class - 1
sos_id = self.decoder_num_class - 1
eos_id = self.decoder_num_class - 1
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
@ -259,7 +254,12 @@ class Transformer(nn.Module):
return decoder_loss
def decoder_nll(
self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
self,
x: Tensor,
encoder_mask: Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> Tensor:
"""
Args:
@ -273,12 +273,6 @@ class Transformer(nn.Module):
# The common part between this fuction and decoder_forward could be
# extracted as a seperated function.
if token_ids is not None:
# speical token ids:
# <blank> 0
# <UNK> 1
# <sos/eos> self.decoder_num_class - 1
sos_id = self.decoder_num_class - 1
eos_id = self.decoder_num_class - 1
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
@ -866,7 +860,8 @@ class LabelSmoothingLoss(nn.Module):
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
@ -983,8 +978,8 @@ def generate_square_subsequent_mask(sz: int) -> Tensor:
def add_sos_eos(
ys: List[List[int]],
lexicon: k2.Fsa,
sos: int,
eos: int,
sos_id: int,
eos_id: int,
ignore_id: int = -1,
) -> Tuple[Tensor, Tensor]:
"""Add <sos> and <eos> labels.
@ -992,8 +987,8 @@ def add_sos_eos(
Args:
ys: batch of unpadded target sequences
lexicon: Its labels are words, while its aux_labels are phones.
sos: index of <sos>
eos: index of <eos>
sos_id: index of <sos>
eos_id: index of <eos>
ignore_id: index of padding
Returns:
@ -1001,8 +996,8 @@ def add_sos_eos(
Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length).
"""
_sos = torch.tensor([sos])
_eos = torch.tensor([eos])
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys = get_hierarchical_targets(ys, lexicon)
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]

View File

@ -3,7 +3,7 @@
"""
This script compiles HLG from
- H, the ctc topology, built from phones contained in lexicon.txt
- H, the ctc topology, built from tokens contained in lexicon.txt
- L, the lexicon, built from L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
@ -13,6 +13,7 @@ This script compiles HLG from
The generated HLG is saved in data/lm/HLG.pt (phone based)
or data/lm/HLG_bpe.pt (BPE based)
"""
import logging
from pathlib import Path
import k2
@ -32,44 +33,44 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
print(f"Building ctc_topo. max_token_id: {max_token_id}")
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
print("Loading pre-compiled G_3_gram")
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
print("Loading G_3_gram.fst.txt")
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt")
first_token_disambig_id = lexicon.phones["#0"]
first_word_disambig_id = lexicon.words["#0"]
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
print("Intersecting L and G")
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
print(f"LG shape: {LG.shape}")
logging.info(f"LG shape: {LG.shape}")
print("Connecting LG")
logging.info("Connecting LG")
LG = k2.connect(LG)
print(f"LG shape after k2.connect: {LG.shape}")
logging.info(f"LG shape after k2.connect: {LG.shape}")
print(type(LG.aux_labels))
print("Determinizing LG")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
print(type(LG.aux_labels))
logging.info(type(LG.aux_labels))
print("Connecting LG after k2.determinize")
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
print("Removing disambiguation symbols on LG")
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
@ -77,27 +78,27 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
print("Arc sorting LG")
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
print("Composing H and LG")
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
print("Connecting LG")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
print("Arc sorting LG")
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
print(f"HLG.shape: {HLG.shape}")
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
@ -106,10 +107,10 @@ def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
print("Compiling phone based HLG")
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
print("Saving HLG.pt to data/lm")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
@ -117,9 +118,9 @@ def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
print("Compiling BPE based HLG")
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
print("Saving HLG_bpe.pt to data/lm")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
@ -129,4 +130,10 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -4,13 +4,13 @@
"""
This script takes as input a lexicon file "data/lang/lexicon.txt"
consisting of words and phones and does the following:
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate phones.txt, the phones table mapping a phone to a unique integer.
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the words table mapping a word to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
@ -29,62 +29,11 @@ from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def read_lexicon(filename: str) -> Lexicon:
"""Read a lexicon.txt in `filename`.
Each line in the lexicon contains "word p1 p2 p3 ...".
That is, the first field is a word and the remaining
fields are phones. Fields are separated by space(s).
Args:
filename:
Path to the lexicon.txt
Returns:
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
"""
ans = []
with open(filename, "r", encoding="utf-8") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) == 0:
continue
if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}")
print("Every line is expected to contain at least 2 fields")
sys.exit(1)
word = a[0]
if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}")
print("<eps> should not be a valid word")
sys.exit(1)
prons = a[1:]
ans.append((word, prons))
return ans
def write_lexicon(filename: str, lexicon: Lexicon) -> None:
"""Write a lexicon to a file.
Args:
filename:
Path to the lexicon file to be generated.
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w", encoding="utf-8") as f:
for word, prons in lexicon:
f.write(f"{word} {' '.join(prons)}\n")
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
@ -105,18 +54,18 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
f.write(f"{sym} {i}\n")
def get_phones(lexicon: Lexicon) -> List[str]:
"""Get phones from a lexicon.
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique phones.
Return a list of unique tokens.
"""
ans = set()
for _, prons in lexicon:
ans.update(prons)
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
@ -138,8 +87,8 @@ def get_words(lexicon: Lexicon) -> List[str]:
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-phone disambiguation symbols #1, #2 and so on
at the ends of phones to ensure that all pronunciations are different,
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
@ -151,30 +100,30 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbols that appears
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each phone-sequence in the
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, prons in lexicon:
count[" ".join(prons)] += 1
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each phone-sequence, note down
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, prons in lexicon:
prons = prons.copy()
prons.pop()
while prons:
issubseq[" ".join(prons)] = 1
prons.pop()
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the phone sequence is unique and is not a
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same phone-seq
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
@ -183,14 +132,14 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, prons in lexicon:
phnseq = " ".join(prons)
assert phnseq != ""
if issubseq[phnseq] == 0 and count[phnseq] == 1:
ans.append((word, prons))
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[phnseq]
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
@ -198,9 +147,9 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[phnseq] = cur_disambig
phnseq += f" #{cur_disambig}"
ans.append((word, phnseq.split()))
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
@ -217,7 +166,7 @@ def generate_id_map(symbols: List[str]) -> Dict[str, int]:
def add_self_loops(
arcs: List[List[Any]], disambig_phone: int, disambig_word: int
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
@ -228,12 +177,15 @@ def add_self_loops(
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_phone:
It is the phone ID of the symbol `#0`.
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
@ -248,37 +200,38 @@ def add_self_loops(
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_phone, disambig_word, 0])
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
phone2id: Dict[str, int],
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_phone: str = "SIL",
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of the word.
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
phone2id:
A dict mapping phones to IDs.
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_phone:
The silence phone.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state.
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
@ -294,48 +247,44 @@ def lexicon_to_fst(
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert phone2id["<eps>"] == 0
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_phone = phone2id[sil_phone]
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_phone, eps, 0])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, prons in lexicon:
assert len(prons) > 0, f"{word} has no pronunciations"
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
prons = [phone2id[i] for i in prons]
tokens = [token2id[i] for i in tokens]
for i in range(len(prons) - 1):
if i == 0:
arcs.append([cur_state, next_state, prons[i], word, 0])
else:
arcs.append([cur_state, next_state, prons[i], eps, 0])
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last phone of this word
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(prons) - 1
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, prons[i], w, no_sil_score])
arcs.append([cur_state, sil_state, prons[i], w, sil_score])
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_phone = phone2id["#0"]
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_phone=disambig_phone,
disambig_word=disambig_word,
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
)
final_state = next_state
@ -354,22 +303,22 @@ def lexicon_to_fst(
def main():
out_dir = Path("data/lang")
lexicon_filename = out_dir / "lexicon.txt"
sil_phone = "SIL"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
phones = get_phones(lexicon)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in phones
phones.append(f"#{i}")
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in phones
phones = ["<eps>"] + phones
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
@ -378,26 +327,26 @@ def main():
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
phone2id = generate_id_map(phones)
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "phones.txt", phone2id)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
phone2id=phone2id,
token2id=token2id,
word2id=word2id,
sil_phone=sil_phone,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
phone2id=phone2id,
token2id=token2id,
word2id=word2id,
sil_phone=sil_phone,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
@ -406,7 +355,7 @@ def main():
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym

View File

@ -3,9 +3,9 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as inputs the following files:
This script takes as inputs the following two files:
- data/lang/bpe/bpe.model,
- data/lang/bpe/tokens.txt (will remove it),
- data/lang/bpe/words.txt
and generates the following files in the directory data/lang/bpe:
@ -14,11 +14,11 @@ and generates the following files in the directory data/lang/bpe:
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- phones.txt
- tokens.txt
"""
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
@ -28,6 +28,7 @@ from prepare_lang import (
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
@ -48,48 +49,46 @@ def lexicon_to_fst_no_sil(
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state.
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go.
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
assert token2id["<blank>"] == 0
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, prons in lexicon:
assert len(prons) > 0, f"{word} has no pronunciations"
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
prons = [token2id[i] for i in prons]
pieces = [token2id[i] for i in pieces]
for i in range(len(prons) - 1):
if i == 0:
arcs.append([cur_state, next_state, prons[i], word, 0])
else:
arcs.append([cur_state, next_state, prons[i], eps, 0])
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last phone of this word
i = len(prons) - 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, prons[i], w, 0])
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_phone = token2id["#0"]
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_phone=disambig_phone,
disambig_word=disambig_word,
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
)
final_state = next_state
@ -105,7 +104,9 @@ def lexicon_to_fst_no_sil(
return fsa
def generate_lexicon(model_file: str, words: List[str]) -> Lexicon:
def generate_lexicon(
model_file: str, words: List[str]
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
Args:
@ -114,8 +115,10 @@ def generate_lexicon(model_file: str, words: List[str]) -> Lexicon:
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
word pieces.
- A dict representing the token symbol, mapping from tokens to IDs.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
@ -126,8 +129,14 @@ def generate_lexicon(model_file: str, words: List[str]) -> Lexicon:
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
lexicon.append(("<UNK>", ["<UNK>"]))
return lexicon
# The OOV word is <UNK>
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = dict()
for i in range(sp.vocab_size()):
token2id[sp.id_to_piece(i)] = i
return lexicon, token2id
def main():
@ -143,34 +152,28 @@ def main():
if w in words:
words.remove(w)
lexicon = generate_lexicon(model_file, words)
# TODO(fangjun): Remove tokens.txt and generate it from the model directly.
#
# We are using it since the IDs we are using in tokens.txt is
# different from the one contained in the model
token_sym_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
lexicon, token_sym_table = generate_lexicon(model_file, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table.add(f"#{i}")
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
token_sym_table.to_file(lang_dir / "phones.txt")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
lexicon, token2id=token_sym_table, word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
@ -184,7 +187,7 @@ def main():
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "phones.txt")
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""
This script takes as input "data/lang/bpe/train.txt"
and generates "data/lang/bpe/bep.model".
"""
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
from pathlib import Path
import sentencepiece as spm
import shutil
def main():
model_type = "unigram"
vocab_size = 5000
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}"
train_text = "data/lang/bpe/train.txt"
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
sp = spm.SentencePieceProcessor(model_file=str(model_file))
vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang/bpe/bpe.model")
if __name__ == "__main__":
main()

View File

@ -10,14 +10,20 @@ stop_stage=100
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "stage -1: Download LM"
log "stage -1: Download LM"
mkdir -p data/lm
./local/download_lm.py
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "stage 0: Download data"
log "stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
@ -49,7 +55,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Stage 1: Prepare librispeech manifest"
log "Stage 1: Prepare librispeech manifest"
# We assume that you have downloaded the librispeech corpus
# to data/LibriSpeech
mkdir -p data/manifests
@ -57,7 +63,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Stage 2: Prepare musan manifest"
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
@ -65,19 +71,19 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Stage 3: Compute fbank for librispeech"
log "Stage 3: Compute fbank for librispeech"
mkdir -p data/fbank
./local/compute_fbank_librispeech.py
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Stage 4: Compute fbank for musan"
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
./local/compute_fbank_musan.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Stage 5: Prepare phone based lang"
log "Stage 5: Prepare phone based lang"
# TODO: add BPE based lang
mkdir -p data/lang
@ -85,21 +91,37 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
cat - data/lm/librispeech-lexicon.txt |
sort | uniq > data/lang/lexicon.txt
if [ ! -f data/lang/L_disambig.pt ]; then
./local/prepare_lang.py
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "State 6: Prepare BPE based lang"
log "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe
cp data/lang/words.txt data/lang/bpe/
if [ ! -f data/lang/bpe/train.txt ]; then
log "Generate data for BPE training"
files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "data/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "data/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > data/lang/bpe/train.txt
fi
python3 ./local/train_bpe_model.py
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Stage 7: Prepare G"
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
@ -123,6 +145,6 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Stage 8: Compile HLG"
log "Stage 8: Compile HLG"
python3 ./local/compile_hlg.py
fi

View File

@ -72,7 +72,7 @@ def get_params() -> AttributeDict:
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "whole-lattice-rescoring",
"method": "1best",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30,
}
@ -173,7 +173,7 @@ def decode_one_batch(
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
@ -196,7 +196,7 @@ def decode_one_batch(
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans

View File

@ -0,0 +1,74 @@
from pathlib import Path
from typing import List, Union
import k2
import sentencepiece as spm
import torch
class BpeCtcTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
device: Union[str, torch.device] = "cpu",
sos_token: str = "<sos/eos>",
eos_token: str = "<sos/eos>",
) -> None:
"""
Args:
lang_dir:
This directory is expected to contain the following files:
- bpe.model
- words.txt
device:
It indicates CPU or CUDA.
sos_token:
The word piece that represents sos.
eos_token:
The word piece that represents eos.
"""
lang_dir = Path(lang_dir)
model_file = lang_dir / "bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
self.sp = sp
self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
self.device = device
self.sos_id = self.sp.piece_to_id(sos_token)
self.eos_id = self.sp.piece_to_id(eos_token)
assert self.sos_id != self.sp.unk_id()
assert self.eos_id != self.sp.unk_id()
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of piece IDs.
"""
return self.sp.encode(texts, out_type=int)
def compile(
self, piece_ids: List[List[int]], modified: bool = False,
) -> k2.Fsa:
"""Build a ctc graph from a list-of-list piece IDs.
Args:
piece_ids:
It is a list-of-list integer IDs.
modified:
See :func:`k2.ctc_graph` for its meaning.
Return:
Return an FsaVec, which is the result of composing a
CTC topology with linear FSAs constructed from the given
piece IDs.
"""
return k2.ctc_graph(piece_ids, modified=modified, device=self.device)

View File

@ -8,10 +8,7 @@ from icefall.lexicon import Lexicon
class CtcTrainingGraphCompiler(object):
def __init__(
self,
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
self, lexicon: Lexicon, device: torch.device, oov: str = "<UNK>",
):
"""
Args:
@ -26,11 +23,11 @@ class CtcTrainingGraphCompiler(object):
L_inv = lexicon.L_inv.to(device)
assert L_inv.requires_grad is False
assert oov in lexicon.words
assert oov in lexicon.word_table
self.L_inv = k2.arc_sort(L_inv)
self.oov_id = lexicon.words[oov]
self.words = lexicon.words
self.oov_id = lexicon.word_table[oov]
self.word_table = lexicon.word_table
max_token_id = max(lexicon.tokens)
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
@ -90,8 +87,8 @@ class CtcTrainingGraphCompiler(object):
for text in texts:
word_ids = []
for word in text.split(" "):
if word in self.words:
word_ids.append(self.words[word])
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)

View File

@ -1,12 +1,65 @@
import logging
import re
from pathlib import Path
from typing import List
from typing import List, Tuple, Union
import k2
import torch
def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
"""Read a lexicon from `filename`.
Each line in the lexicon contains "word p1 p2 p3 ...".
That is, the first field is a word and the remaining
fields are tokens. Fields are separated by space(s).
Args:
filename:
Path to the lexicon.txt
Returns:
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
"""
ans = []
with open(filename, "r", encoding="utf-8") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
if len(a) == 0:
continue
if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}")
print("Every line is expected to contain at least 2 fields")
sys.exit(1)
word = a[0]
if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}")
print("<eps> should not be a valid word")
sys.exit(1)
tokens = a[1:]
ans.append((word, tokens))
return ans
def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
"""Write a lexicon to a file.
Args:
filename:
Path to the lexicon file to be generated.
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w", encoding="utf-8") as f:
for word, tokens in lexicon:
f.write(f"{word} {' '.join(tokens)}\n")
class Lexicon(object):
"""Phone based lexicon.
@ -14,14 +67,14 @@ class Lexicon(object):
"""
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$")
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
lang_dir:
Path to the lang director. It is expected to contain the following
files:
- phones.txt
- tokens.txt
- words.txt
- L.pt
The above files are produced by the script `prepare.sh`. You
@ -30,11 +83,11 @@ class Lexicon(object):
It contains the pattern for disambiguation symbols.
"""
lang_dir = Path(lang_dir)
self.phones = k2.SymbolTable.from_file(lang_dir / "phones.txt")
self.words = k2.SymbolTable.from_file(lang_dir / "words.txt")
self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
self.word_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
if (lang_dir / "Linv.pt").exists():
logging.info("Loading pre-compiled Linv.pt")
logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt")
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
else:
logging.info("Converting L.pt to Linv.pt")
@ -49,18 +102,92 @@ class Lexicon(object):
@property
def tokens(self) -> List[int]:
"""Return a list of phone IDs excluding those from
"""Return a list of token IDs excluding those from
disambiguation symbols.
Caution:
0 is not a phone ID so it is excluded from the return value.
0 is not a token ID so it is excluded from the return value.
"""
symbols = self.phones.symbols
symbols = self.token_table.symbols
ans = []
for s in symbols:
if not self.disambig_pattern.match(s):
ans.append(self.phones[s])
ans.append(self.token_table[s])
if 0 in ans:
ans.remove(0)
ans.sort()
return ans
class BpeLexicon(Lexicon):
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.
"""
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = self.convert_lexicon_to_ragged(
lang_dir / "lexicon.txt"
)
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt:
"""Read a BPE lexicon from file and convert it to a
k2 ragged tensor.
Args:
filename:
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
Returns:
A k2 ragged tensor with two axes [word_id]
"""
disambig_id = self.word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies on position
#
row_splits = [0]
token_ids = []
lexicon = read_lexicon(filename)
lexicon = dict(lexicon)
for i in range(disambig_id):
w = self.word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
pieces = lexicon[w]
piece_ids = [self.token_table[k] for k in pieces]
row_splits.append(row_splits[-1] + len(piece_ids))
token_ids.extend(piece_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=cached_tot_size
)
values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedInt(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)
ragged, _ = k2.ragged.index(
self.ragged_lexicon,
indexes=word_ids,
need_value_indexes=False,
axis=0,
)
return ragged

25
test/test_bpe_graph_compiler.py Executable file
View File

@ -0,0 +1,25 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
from pathlib import Path
def test():
lang_dir = Path("data/lang/bpe")
if not lang_dir.is_dir():
return
# TODO: generate data for testing
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
fsa = compiler.compile(ids)
lexicon = BpeLexicon(lang_dir)
ids0 = lexicon.words_to_piece_ids(["HELLO"])
assert ids[0] == ids0.values().tolist()
ids1 = lexicon.words_to_piece_ids(["WORLD", "ZZZ"])
assert ids[1] == ids1.values().tolist()

View File

@ -41,7 +41,8 @@ def test_load_checkpoints(checkpoints1):
m.p2 = nn.Parameter(torch.Tensor([0, 0]))
params = load_checkpoint(checkpoints1, m)
assert torch.allclose(m.p1, torch.Tensor([10.0, 20]))
assert params == {"a": 10, "b": 20}
assert params["a"] == 10
assert params["b"] == 20
def test_average_checkpoints(checkpoints1, checkpoints2):

View File

@ -81,8 +81,8 @@ def lexicon():
"""
)
ans = Lexicon.__new__(Lexicon)
ans.phones = L.labels_sym
ans.words = L.aux_labels_sym
ans.token_table = L.labels_sym
ans.word_table = L.aux_labels_sym
ans.L_inv = k2.arc_sort(L.invert_())
ans.disambig_pattern = re.compile(r"^#\d+$")
@ -107,11 +107,11 @@ class TestCtcTrainingGraphCompiler(object):
aux_labels1 = fsa[1].aux_labels[:-1]
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
labels0 = [lexicon.phones[i] for i in labels0]
labels1 = [lexicon.phones[i] for i in labels1]
labels0 = [lexicon.token_table[i] for i in labels0]
labels1 = [lexicon.token_table[i] for i in labels1]
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
aux_labels0 = [lexicon.word_table[i] for i in aux_labels0]
aux_labels1 = [lexicon.word_table[i] for i in aux_labels1]
assert labels0 == ["b", "a", "r", "f", "o", "o"]
assert aux_labels0 == ["bar", "foo"]
@ -129,11 +129,11 @@ class TestCtcTrainingGraphCompiler(object):
input2 = ["b", "b", "a", "a", "a", "<blk>", "<blk>", "z", "z"]
input2 += ["<blk>", "<blk>", "SPN", "SPN", "<blk>", "<blk>"]
lexicon.phones._id2sym[0] == "<blk>"
lexicon.phones._sym2id["<blk>"] = 0
lexicon.token_table._id2sym[0] == "<blk>"
lexicon.token_table._sym2id["<blk>"] = 0
input1 = [lexicon.phones[i] for i in input1]
input2 = [lexicon.phones[i] for i in input2]
input1 = [lexicon.token_table[i] for i in input1]
input2 = [lexicon.token_table[i] for i in input2]
fsa1 = k2.linear_fsa(input1)
fsa2 = k2.linear_fsa(input2)
@ -147,14 +147,14 @@ class TestCtcTrainingGraphCompiler(object):
aux_labels0 = lattice[0].aux_labels[:-1]
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
aux_labels0 = [lexicon.word_table[i] for i in aux_labels0]
assert aux_labels0 == ["bar", "foo"]
aux_labels1 = lattice[1].aux_labels[:-1]
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
aux_labels1 = [lexicon.word_table[i] for i in aux_labels1]
assert aux_labels1 == ["baz", "<UNK>"]
texts = get_texts(lattice)
texts = [[lexicon.words[i] for i in words] for words in texts]
texts = [[lexicon.word_table[i] for i in words] for words in texts]
assert texts == [["bar", "foo"], ["baz", "<UNK>"]]

View File

@ -1,10 +1,12 @@
#!/usr/bin/env python3
from pathlib import Path
import k2
import pytest
import torch
from icefall.lexicon import Lexicon
from icefall.lexicon import BpeLexicon, Lexicon
@pytest.fixture
@ -47,7 +49,7 @@ def lang_dir(tmp_path):
num_aux_labels=1,
)
with open(tmp_path / "phones.txt", "w") as f:
with open(tmp_path / "tokens.txt", "w") as f:
f.write(phone2id)
with open(tmp_path / "words.txt", "w") as f:
f.write(word2id)
@ -60,3 +62,16 @@ def lang_dir(tmp_path):
def test_lexicon(lang_dir):
lexicon = Lexicon(lang_dir)
assert lexicon.tokens == list(range(1, 8))
def test_bpe_lexicon():
lang_dir = Path("data/lang/bpe")
if not lang_dir.is_dir():
return
# TODO: Generate test data for BpeLexicon
lexicon = BpeLexicon(lang_dir)
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
ids = lexicon.words_to_piece_ids(words)
print(ids)
print([lexicon.token_table[i] for i in ids.values().tolist()])