Refactoring.

This commit is contained in:
Fangjun Kuang 2021-09-24 19:45:08 +08:00
parent 59b7140766
commit 6f5d63492a
20 changed files with 1543 additions and 487 deletions

View File

@ -114,7 +114,10 @@ class Transformer(nn.Module):
norm=encoder_norm, norm=encoder_norm,
) )
self.encoder_output_layer = nn.Linear(d_model, num_classes) # TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
if num_decoder_layers > 0: if num_decoder_layers > 0:
self.decoder_num_class = ( self.decoder_num_class = (
@ -325,6 +328,7 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)

View File

@ -0,0 +1,354 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

View File

@ -1,7 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0 #
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import warnings import warnings
@ -43,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
is_espnet_structure: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
@ -70,12 +82,10 @@ class Conformer(Transformer):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
is_espnet_structure,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure if self.normalize_before:
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model) self.after_norm = nn.LayerNorm(d_model)
else: else:
# Note: TorchScript detects that self.after_norm could be used inside forward() # Note: TorchScript detects that self.after_norm could be used inside forward()
@ -88,7 +98,7 @@ class Conformer(Transformer):
""" """
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -110,7 +120,7 @@ class Conformer(Transformer):
mask = mask.to(x.device) mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)
return x, mask return x, mask
@ -144,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure d_model, nhead, dropout=0.0
) )
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
@ -394,7 +403,7 @@ class RelPositionalEncoding(torch.nn.Module):
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x.size(1)
+ 1 : self.pe.size(1) // 2 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x.size(1),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -421,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None: ) -> None:
super(RelPositionMultiheadAttention, self).__init__() super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -444,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters() self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0) nn.init.constant_(self.in_proj.bias, 0.0)
@ -675,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
attn_mask.dtype == torch.float32 attn_mask.dtype == torch.float32
@ -770,11 +773,6 @@ class RelPositionMultiheadAttention(nn.Module):
) # (batch, head, time1, 2*time1-1) ) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure:
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
else:
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)

View File

@ -1,8 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# (still working in progress)
import argparse import argparse
import logging import logging
@ -13,14 +25,15 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_n_best_list, rescore_with_n_best_list,
@ -32,6 +45,7 @@ from icefall.utils import (
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -44,51 +58,111 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=9, default=34,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=1, default=20,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
) )
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_mmi/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="The lang dir",
)
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_mmi/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6, "num_decoder_layers": 6,
"vgg_frontend": False, # parameters for decoding
"is_espnet_structure": True,
"use_feat_batchnorm": True,
"search_beam": 20, "search_beam": 20,
"output_beam": 8, "output_beam": 8,
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "1best",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
} }
) )
return params return params
@ -99,7 +173,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: k2.Fsa,
batch: dict, batch: dict,
lexicon: Lexicon, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
@ -133,8 +207,8 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
lexicon: word_table:
It contains word symbol table. The word symbol table.
sos_id: sos_id:
The token ID of the SOS. The token ID of the SOS.
eos_id: eos_id:
@ -151,12 +225,12 @@ def decode_one_batch(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -178,6 +252,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
lattice_score_scale=params.lattice_score_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
if params.method == "1best": if params.method == "1best":
best_path = one_best_decoding( best_path = one_best_decoding(
@ -189,11 +281,12 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale,
) )
key = f"no_rescore-{params.num_paths}" key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps} return {key: hyps}
assert params.method in [ assert params.method in [
@ -202,7 +295,8 @@ def decode_one_batch(
"attention-decoder", "attention-decoder",
] ]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring": if params.method == "nbest-rescoring":
@ -211,16 +305,23 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
) )
elif params.method == "attention-decoder": elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice( rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
) )
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder( best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice, lattice=rescored_lattice,
@ -230,15 +331,20 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
ans = dict() ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items(): for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans return ans
@ -247,7 +353,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
@ -263,8 +369,8 @@ def decode_dataset(
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph.
lexicon: word_table:
It contains word symbol table. It is the word symbol table.
sos_id: sos_id:
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
@ -283,7 +389,11 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -294,7 +404,7 @@ def decode_dataset(
model=model, model=model,
HLG=HLG, HLG=HLG,
batch=batch, batch=batch,
lexicon=lexicon, word_table=word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
@ -312,10 +422,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"]) num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_str}, cuts processed until now is {num_cuts}"
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
) )
return results return results
@ -374,8 +484,10 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.exp_dir = Path(params.exp_dir)
params.lang_dir = Path(params.lang_dir)
setup_logger(f"{params.exp_dir}/log/log-decode") setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)
@ -389,7 +501,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
graph_compiler = BpeMmiTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
device=device, device=device,
sos_token="<sos/eos>", sos_token="<sos/eos>",
@ -398,7 +510,9 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -429,7 +543,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
@ -453,7 +567,6 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
@ -468,6 +581,13 @@ def main():
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames)) model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
@ -487,7 +607,7 @@ def main():
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
lexicon=lexicon, word_table=lexicon.word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,

View File

@ -1,15 +1,26 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0 #
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.
@ -72,8 +83,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes] # self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape [N, T//subsampling_factor, d_model]. # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model # (2) embedding: num_classes -> d_model
@ -103,10 +114,15 @@ class Transformer(nn.Module):
norm=encoder_norm, norm=encoder_norm,
) )
self.encoder_output_layer = nn.Linear(d_model, num_classes) # TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
if num_decoder_layers > 0: if num_decoder_layers > 0:
self.decoder_num_class = self.num_classes self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding( self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model num_embeddings=self.decoder_num_class, embedding_dim=d_model
@ -146,7 +162,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
x: x:
The input tensor. Its shape is [N, T, C]. The input tensor. Its shape is (N, T, C).
supervision: supervision:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -155,17 +171,17 @@ class Transformer(nn.Module):
Returns: Returns:
Return a tuple containing 3 tensors: Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C] - CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape [T, N, C]. It can be used as key and - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder. value for the decoder.
- Encoder output padding mask. It can be used as - Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T]. memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )
@ -179,7 +195,7 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -190,8 +206,8 @@ class Transformer(nn.Module):
padding mask for the decoder. padding mask for the decoder.
Returns: Returns:
Return a tuple with two tensors: Return a tuple with two tensors:
- The encoder output, with shape [T, N, C] - The encoder output, with shape (T, N, C)
- encoder padding mask, with shape [N, T]. - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None. The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder. It is used as memory key padding mask in the decoder.
""" """
@ -209,11 +225,11 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The output tensor from the transformer encoder. The output tensor from the transformer encoder.
Its shape is [T, N, C] Its shape is (T, N, C)
Returns: Returns:
Return a tensor that can be used for CTC decoding. Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C] Its shape is (N, T, C)
""" """
x = self.encoder_output_layer(x) x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -231,7 +247,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -296,7 +312,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -312,6 +328,7 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
@ -329,6 +346,9 @@ class Transformer(nn.Module):
) )
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
@ -634,13 +654,13 @@ class PositionalEncoding(nn.Module):
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe is (N, T, d_model). If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done. to (N, T, d_model). Otherwise, nothing is done.
Args: Args:
x: x:
It is a tensor of shape [N, T, C]. It is a tensor of shape (N, T, C).
Returns: Returns:
Return None. Return None.
""" """
@ -658,7 +678,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1) # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -667,10 +687,10 @@ class PositionalEncoding(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, C] Its shape is (N, T, C)
Returns: Returns:
Return a tensor of shape [N, T, C] Return a tensor of shape (N, T, C)
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :] x = x * self.xscale + self.pe[:, : x.size(1), :]
@ -766,7 +786,8 @@ class Noam(object):
class LabelSmoothingLoss(nn.Module): class LabelSmoothingLoss(nn.Module):
""" """
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized. and p_{prob. computed by model}(w) is minimized.
Modified from Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
@ -851,7 +872,8 @@ def encoder_padding_mask(
frames, before subsampling) frames, before subsampling)
Returns: Returns:
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. Tensor: Mask tensor of dimension (batch_size, input_length),
True denote the masked indices.
""" """
if supervisions is None: if supervisions is None:
return None return None

View File

@ -8,8 +8,8 @@ for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM. lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears last in the lexicon If a word has multiple pronunciations, the one that appears first in the lexicon
is used. is kept; others are removed.
If the input transcript is: If the input transcript is:
@ -20,8 +20,8 @@ If the input transcript is:
and if the lexicon is and if the lexicon is
<UNK> SPN <UNK> SPN
hello h e l l o
hello h e l l o 2 hello h e l l o 2
hello h e l l o
world w o r l d world w o r l d
zoo z o o zoo z o o
@ -32,10 +32,11 @@ Then the output is
SPN z o o w o r l d SPN SPN z o o w o r l d SPN
""" """
from pathlib import Path
from typing import Dict
import argparse import argparse
from pathlib import Path
from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon from icefall.lexicon import read_lexicon
@ -57,7 +58,9 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None: def process_line(
lexicon: Dict[str, List[str]], line: str, oov_token: str
) -> None:
""" """
Args: Args:
lexicon: lexicon:
@ -86,7 +89,11 @@ def main():
assert Path(args.transcript).is_file() assert Path(args.transcript).is_file()
assert len(args.oov) > 0 assert len(args.oov) > 0
lexicon = dict(read_lexicon(args.lexicon)) # Only the first pronunciation of a word is kept
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon assert args.oov in lexicon
oov_token = lexicon[args.oov] oov_token = lexicon[args.oov]

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input a lexicon.txt and output a new lexicon,
in which each word has a unique pronunciation.
The way to do this is to keep only the first pronunciation of a word
in lexicon.txt.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from icefall.lexicon import read_lexicon, write_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
This file will generate a new file uniq_lexicon.txt
in it.
""",
)
return parser.parse_args()
def filter_multiple_pronunications(
lexicon: List[Tuple[str, List[str]]]
) -> List[Tuple[str, List[str]]]:
"""Remove multiple pronunciations of words from a lexicon.
If a word has more than one pronunciation in the lexicon, only
the first one is kept, while other pronunciations are removed
from the lexicon.
Args:
lexicon:
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
where "p1, p2, ..., pn" are the pronunciations of the "word".
Returns:
Return a new lexicon where each word has a unique pronunciation.
"""
seen = set()
ans = []
for word, tokens in lexicon:
if word in seen:
continue
seen.add(word)
ans.append((word, tokens))
return ans
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
in_lexicon = read_lexicon(lexicon_filename)
out_lexicon = filter_multiple_pronunications(in_lexicon)
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.
""" """
import argparse
import math import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -42,10 +43,37 @@ import k2
import torch import torch
from icefall.lexicon import read_lexicon, write_lexicon from icefall.lexicon import read_lexicon, write_lexicon
from icefall.utils import str2bool
Lexicon = List[Tuple[str, List[str]]] Lexicon = List[Tuple[str, List[str]]]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
""",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file. """Write a symbol to ID mapping to a file.
@ -315,8 +343,9 @@ def lexicon_to_fst(
def main(): def main():
out_dir = Path("data/lang_phone") args = get_args()
lexicon_filename = out_dir / "lexicon.txt" lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL" sil_token = "SIL"
sil_prob = 0.5 sil_prob = 0.5
@ -344,9 +373,9 @@ def main():
token2id = generate_id_map(tokens) token2id = generate_id_map(tokens)
word2id = generate_id_map(words) word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id) write_mapping(lang_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id) write_mapping(lang_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst( L = lexicon_to_fst(
lexicon, lexicon,
@ -364,17 +393,20 @@ def main():
sil_prob=sil_prob, sil_prob=sil_prob,
need_self_loops=True, need_self_loops=True,
) )
torch.save(L.as_dict(), out_dir / "L.pt") torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if False: if args.debug:
# Just for debugging, will remove it labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym L.labels_sym = labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym L.aux_labels_sym = aux_labels_sym
L.draw(out_dir / "L.png", title="L") L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,6 +49,8 @@ from prepare_lang import (
write_mapping, write_mapping,
) )
from icefall.utils import str2bool
def lexicon_to_fst_no_sil( def lexicon_to_fst_no_sil(
lexicon: Lexicon, lexicon: Lexicon,
@ -169,6 +171,20 @@ def get_args():
""", """,
) )
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "test/test_bpe_lexicon.py" for usage.
""",
)
return parser.parse_args() return parser.parse_args()
@ -221,6 +237,18 @@ def main():
torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# You can install sentencepiece via: # You can install sentencepiece via:
# #
# pip install sentencepiece # pip install sentencepiece
@ -37,10 +38,17 @@ def get_args():
"--lang-dir", "--lang-dir",
type=str, type=str,
help="""Input and output directory. help="""Input and output directory.
It should contain the training corpus: train.txt. It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory. The generated bpe.model is saved to this directory.
""", """,
) )
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument( parser.add_argument(
"--vocab-size", "--vocab-size",
type=int, type=int,
@ -58,7 +66,7 @@ def main():
model_type = "unigram" model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = f"{lang_dir}/train.txt" train_text = args.transcript
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000

View File

@ -40,9 +40,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx, # It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
5000 # 5000
2000 # 2000
1000 # 1000
500 500
) )
@ -116,14 +116,15 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang" log "Stage 5: Prepare phone based lang"
mkdir -p data/lang_phone lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) | (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt | cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang_phone/lexicon.txt sort | uniq > $lang_dir/lexicon.txt
if [ ! -f data/lang_phone/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py ./local/prepare_lang.py --lang-dir $lang_dir
fi fi
fi fi
@ -138,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# so that the two can share G.pt later. # so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/train.txt ]; then if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
files=$( files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -147,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > $lang_dir/train.txt done > $lang_dir/transcript_words.txt
fi fi
./local/train_bpe_model.py \ ./local/train_bpe_model.py \
--lang-dir $lang_dir \ --lang-dir $lang_dir \
--vocab-size $vocab_size --vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir ./local/prepare_lang_bpe.py --lang-dir $lang_dir
@ -166,18 +168,18 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/corpus.txt ]; then if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/convert_transcript_to_corpus.py \ ./local/convert_transcript_words_to_tokens.py \
--lexicon data/lang_bpe/lexicon.txt \ --lexicon $lang_dir/lexicon.txt \
--transcript data/lang_bpe/train.txt \ --transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \ --oov "<UNK>" \
> $lang_dir/corpus.txt > $lang_dir/transcript_tokens.txt
fi fi
if [ ! -f $lang_dir/P.arpa ]; then if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \ ./shared/make_kn_lm.py \
-ngram-order 2 \ -ngram-order 2 \
-text $lang_dir/corpus.txt \ -text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa -lm $lang_dir/P.arpa
fi fi
@ -226,4 +228,4 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
done done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe cd data && ln -sfv lang_bpe_500 lang_bpe

View File

@ -34,14 +34,10 @@ class BpeCtcTrainingGraphCompiler(object):
""" """
Args: Args:
lang_dir: lang_dir:
This directory is expected to contain the following files:: This directory is expected to contain the following files:
- bpe.model - bpe.model
- words.txt - words.txt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
device: device:
It indicates CPU or CUDA. It indicates CPU or CUDA.
sos_token: sos_token:

View File

@ -1,178 +0,0 @@
import logging
from pathlib import Path
from typing import List, Tuple, Union
import k2
import sentencepiece as spm
import torch
from icefall.lexicon import Lexicon
class BpeMmiTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
device: Union[str, torch.device] = "cpu",
sos_token: str = "<sos/eos>",
eos_token: str = "<sos/eos>",
) -> None:
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the
following files::
- tokens.txt
- words.txt
- bpe.model
- P.fst.txt
The above files are generated by the script `prepare.sh`. You
should have run it before running the training code.
device:
It indicates CPU or CUDA.
sos_token:
The word piece that represents sos.
eos_token:
The word piece that represents eos.
"""
self.lang_dir = Path(lang_dir)
self.lexicon = Lexicon(lang_dir)
self.device = device
self.load_sentence_piece_model()
self.build_ctc_topo_P()
self.sos_id = self.sp.piece_to_id(sos_token)
self.eos_id = self.sp.piece_to_id(eos_token)
assert self.sos_id != self.sp.unk_id()
assert self.eos_id != self.sp.unk_id()
def load_sentence_piece_model(self) -> None:
"""Load the pre-trained sentencepiece model
from self.lang_dir/bpe.model.
"""
model_file = self.lang_dir / "bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
self.sp = sp
def build_ctc_topo_P(self):
"""Built ctc_topo_P, the composition result of
ctc_topo and P, where P is a pre-trained bigram
word piece LM.
"""
# Note: there is no need to save a pre-compiled P and ctc_topo
# as it is very fast to generate them.
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
with open(self.lang_dir / "P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = self.lexicon.token_table["#0"]
# P.aux_labels is not needed in later computations, so
# remove it here.
del P.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
P.labels[P.labels >= first_token_disambig_id] = 0
P = k2.remove_epsilon(P)
P = k2.arc_sort(P)
P = P.to(self.device)
# Add epsilon self-loops to P because we want the
# following operation "k2.intersect" to run on GPU.
P_with_self_loops = k2.add_epsilon_self_loops(P)
max_token_id = max(self.lexicon.tokens)
logging.info(
f"Building modified ctc_topo. max_token_id: {max_token_id}"
)
# CAUTION: We have to use a modifed version of CTC topo.
# Otherwise, the resulting ctc_topo_P is so large that it gets
# stuck in k2.intersect_dense_pruned() or it gets OOM in
# k2.intersect_dense()
ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device)
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
logging.info("Building ctc_topo_P")
ctc_topo_P = k2.intersect(
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
A list of transcripts. Within a transcript words are
separated by spaces. An example input is::
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of piece IDs.
"""
return self.sp.encode(texts, out_type=int)
def compile(
self, texts: List[str], replicate_den: bool = True
) -> Tuple[k2.Fsa, k2.Fsa]:
"""Create numerator and denominator graphs from transcripts.
Args:
texts:
A list of transcripts. Within a transcript words are
separated by spaces. An example input is::
["HELLO icefall", "HALLO WELT"]
replicate_den:
If True, the returned den_graph is replicated to match the number
of FSAs in the returned num_graph; if False, the returned den_graph
contains only a single FSA
Returns:
A tuple (num_graphs, den_graphs), where
- `num_graphs` is the numerator graph. It is an FsaVec with
shape `(len(texts), None, None)`.
- `den_graphs` is the denominator graph. It is an FsaVec with the
same shape of the `num_graph` if replicate_den is True;
otherwise, it is an FsaVec containing only a single FSA.
"""
token_ids = self.texts_to_ids(texts)
token_fsas = k2.linear_fsa(token_ids, device=self.device)
token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas)
# NOTE: Use treat_epsilons_specially=False so that k2.compose
# can be run on GPU
num_graphs = k2.compose(
self.ctc_topo_P,
token_fsas_with_self_loops,
treat_epsilons_specially=False,
)
# num_graphs may not be connected and
# not be topologically sorted after k2.compose
num_graphs = k2.connect(num_graphs)
num_graphs = k2.top_sort(num_graphs)
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()])
if replicate_den:
indexes = torch.zeros(
len(texts), dtype=torch.int32, device=self.device
)
den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes)
else:
den_graphs = ctc_topo_P_vec
return num_graphs, den_graphs

View File

@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
f.write(f"{word} {' '.join(tokens)}\n") f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
The ragged tensor has two axes: [word][token].
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word][token].
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object): class Lexicon(object):
"""Phone based lexicon.""" """Phone based lexicon."""
@ -96,12 +159,10 @@ class Lexicon(object):
Args: Args:
lang_dir: lang_dir:
Path to the lang directory. It is expected to contain the following Path to the lang directory. It is expected to contain the following
files:: files:
- tokens.txt - tokens.txt
- words.txt - words.txt
- L.pt - L.pt
The above files are produced by the script `prepare.sh`. You The above files are produced by the script `prepare.sh`. You
should have run that before running the training code. should have run that before running the training code.
disambig_pattern: disambig_pattern:
@ -121,7 +182,7 @@ class Lexicon(object):
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with # We save L_inv instead of L because it will be used to intersect with
# transcript, both of whose labels are word IDs. # transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv self.L_inv = L_inv
self.disambig_pattern = disambig_pattern self.disambig_pattern = disambig_pattern
@ -144,69 +205,66 @@ class Lexicon(object):
return ans return ans
class BpeLexicon(Lexicon): class UniqLexicon(Lexicon):
def __init__( def __init__(
self, self,
lang_dir: Path, lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"), disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Refer to the help information in Lexicon.__init__. Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
""" """
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = self.convert_lexicon_to_ragged( self.ragged_lexicon = convert_lexicon_to_ragged(
lang_dir / "lexicon.txt" filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
) )
# TODO: should we move it to a certain device ?
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: def texts_to_token_ids(
"""Read a BPE lexicon from file and convert it to a self, texts: List[str], oov: str = "<UNK>"
k2 ragged tensor. ) -> k2.RaggedTensor:
Args:
filename:
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
Returns:
A k2 ragged tensor with two axes [word_id]
""" """
disambig_id = self.word_table["#0"] Args:
# We reuse the same words.txt from the phone based lexicon texts:
# so that we can share the same G.fst. Here, we have to A list of transcripts. Each transcript contains space(s)
# exclude some words present only in the phone based lexicon. separated words. An example texts is::
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies on position ['HELLO k2', 'HELLO icefall']
# oov:
row_splits = [0] The OOV word. If a word in `texts` is not in the lexicon, it is
token_ids = [] replaced with `oov`.
Returns:
Return a ragged int tensor with 2 axes [utterance][token_id]
"""
oov_id = self.word_table[oov]
lexicon = read_lexicon(filename) word_ids_list = []
lexicon = dict(lexicon) for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
ans = self.ragged_lexicon.index(ragged_indexes)
ans = ans.remove_axis(ans.num_axes - 2)
return ans
for i in range(disambig_id): def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
w = self.word_table[i] """Convert a list of words to a ragged tensor containing token IDs.
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
pieces = lexicon[w]
piece_ids = [self.token_table[k] for k in pieces]
row_splits.append(row_splits[-1] + len(piece_ids)) We assume there are no OOVs in "words".
token_ids.extend(piece_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=cached_tot_size
)
values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
""" """
word_ids = [self.word_table[w] for w in words] word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32) word_ids = torch.tensor(word_ids, dtype=torch.int32)

View File

@ -0,0 +1,216 @@
import logging
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import k2
import torch
from icefall.lexicon import UniqLexicon
class MmiTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>",
):
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the
following files::
- tokens.txt
- words.txt
- P.fst.txt
The above files are generated by the script `prepare.sh`. You
should have run it before running the training code.
uniq_filename:
File name to the lexicon in which every word has exactly one
pronunciation. We assume this file is inside the given `lang_dir`.
device:
It indicates CPU or CUDA.
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
"""
self.lang_dir = Path(lang_dir)
self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
self.device = torch.device(device)
self.L_inv = self.lexicon.L_inv.to(self.device)
self.oov_id = self.lexicon.word_table[oov]
self.build_ctc_topo_P()
def build_ctc_topo_P(self):
"""Built ctc_topo_P, the composition result of
ctc_topo and P, where P is a pre-trained bigram
word piece LM.
"""
# Note: there is no need to save a pre-compiled P and ctc_topo
# as it is very fast to generate them.
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
with open(self.lang_dir / "P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = self.lexicon.token_table["#0"]
# P.aux_labels is not needed in later computations, so
# remove it here.
del P.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
P.labels[P.labels >= first_token_disambig_id] = 0
P = k2.remove_epsilon(P)
P = k2.arc_sort(P)
P = P.to(self.device)
# Add epsilon self-loops to P because we want the
# following operation "k2.intersect" to run on GPU.
P_with_self_loops = k2.add_epsilon_self_loops(P)
max_token_id = max(self.lexicon.tokens)
logging.info(
f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
)
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
logging.info("Building ctc_topo_P")
ctc_topo_P = k2.intersect(
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def compile(
self, texts: Iterable[str], replicate_den: bool = True
) -> Tuple[k2.Fsa, k2.Fsa]:
"""Create numerator and denominator graphs from transcripts
and the bigram phone LM.
Args:
texts:
A list of transcripts. Within a transcript, words are
separated by spaces. An example `texts` is given below::
["Hello icefall", "LF-MMI training with icefall using k2"]
replicate_den:
If True, the returned den_graph is replicated to match the number
of FSAs in the returned num_graph; if False, the returned den_graph
contains only a single FSA
Returns:
A tuple (num_graph, den_graph), where
- `num_graph` is the numerator graph. It is an FsaVec with
shape `(len(texts), None, None)`.
- `den_graph` is the denominator graph. It is an FsaVec
with the same shape of the `num_graph` if replicate_den is
True; otherwise, it is an FsaVec containing only a single FSA.
"""
transcript_fsa = self.build_transcript_fsa(texts)
# remove word IDs from transcript_fsa since it is not needed
del transcript_fsa.aux_labels
# NOTE: You can comment out the above statement
# if you want to run test/test_mmi_graph_compiler.py
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
transcript_fsa
)
transcript_fsa_with_self_loops = k2.arc_sort(
transcript_fsa_with_self_loops
)
num = k2.compose(
self.ctc_topo_P,
transcript_fsa_with_self_loops,
treat_epsilons_specially=False,
)
# CAUTION: Due to the presence of P,
# the resulting `num` may not be connected
num = k2.connect(num)
num = k2.arc_sort(num)
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
if replicate_den:
indexes = torch.zeros(
len(texts), dtype=torch.int32, device=self.device
)
den = k2.index_fsa(ctc_topo_P_vec, indexes)
else:
den = ctc_topo_P_vec
return num, den
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert transcripts to an FsaVec with the help of a lexicon
and word symbol table.
Args:
texts:
Each element is a transcript containing words separated by space(s).
For instance, it may be 'HELLO icefall', which contains
two words.
Returns:
Return an FST (FsaVec) corresponding to the transcript.
Its `labels` is token IDs and `aux_labels` is word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.lexicon.word_table:
word_ids.append(self.lexicon.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
fsa = k2.linear_fsa(word_ids_list, self.device)
fsa = k2.add_epsilon_self_loops(fsa)
# The reason to use `invert_()` at the end is as follows:
#
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
# (2) `fsa.labels` is word IDs
# (3) after intersection, the `labels` is still word IDs
# (4) after `invert_()`, the `labels` is token IDs
# and `aux_labels` is word IDs
transcript_fsa = k2.intersect(
self.L_inv, fsa, treat_epsilons_specially=False
).invert_()
transcript_fsa = k2.arc_sort(transcript_fsa)
return transcript_fsa
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
We assume it contains no OOVs. Otherwise, it will raise an
exception.
Returns:
Return a list-of-list of token IDs.
"""
return self.lexicon.texts_to_token_ids(texts).tolist()

View File

@ -19,14 +19,16 @@ import argparse
import logging import logging
import os import os
import subprocess import subprocess
import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import kaldialign import kaldialign
import lhotse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -132,17 +134,82 @@ def setup_logger(
logging.getLogger("").addHandler(console) logging.getLogger("").addHandler(console)
def get_env_info(): def get_git_sha1():
""" git_commit = (
TODO: subprocess.run(
""" ["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return { return {
"k2-git-sha1": None, "k2-version": k2.version.__version__,
"k2-version": None, "k2-build-type": k2.version.__build_type__,
"lhotse-version": None, "k2-with-cuda": k2.with_cuda,
"torch-version": None, "k2-git-sha1": k2.version.__git_sha1__,
"icefall-sha1": None, "k2-git-date": k2.version.__git_date__,
"icefall-version": None, "lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
} }

View File

@ -19,20 +19,21 @@
from pathlib import Path from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon from icefall.lexicon import UniqLexicon
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def test(): def test():
lang_dir = Path("data/lang/bpe") lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
if not lang_dir.is_dir(): if not lang_dir.is_dir():
return return
# TODO: generate data for testing
compiler = BpeCtcTrainingGraphCompiler(lang_dir) compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
compiler.compile(ids) compiler.compile(ids)
lexicon = BpeLexicon(lang_dir) lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
ids0 = lexicon.words_to_piece_ids(["HELLO"]) ids0 = lexicon.words_to_piece_ids(["HELLO"])
assert ids[0] == ids0.values().tolist() assert ids[0] == ids0.values().tolist()

View File

@ -1,30 +0,0 @@
#!/usr/bin/env python3
import copy
import logging
from pathlib import Path
import k2
import torch
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
def test_bpe_mmi_graph_compiler():
lang_dir = Path("data/lang_bpe")
if lang_dir.is_dir() is False:
return
device = torch.device("cpu")
compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device)
texts = ["HELLO WORLD", "MMI TRAINING"]
num_graphs, den_graphs = compiler.compile(texts)
num_graphs.labels_sym = compiler.lexicon.token_table
num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table)
num_graphs.aux_labels_sym._id2sym[0] = "<eps>"
num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD")
num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD")
print(den_graphs.shape)
print(den_graphs[0].shape)
print(den_graphs[0].num_arcs)

173
test/test_lexicon.py Normal file → Executable file
View File

@ -14,80 +14,135 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_lexicon.py
(2) cd icefall; ./test/test_lexicon.py
"""
import os
import shutil
import sys
from pathlib import Path from pathlib import Path
import k2 import k2
import pytest import sentencepiece as spm
import torch
from icefall.lexicon import BpeLexicon, Lexicon from icefall.lexicon import UniqLexicon
TMP_DIR = "/tmp/icefall-test-lexicon"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
@pytest.fixture def generate_test_data():
def lang_dir(tmp_path): Path(TMP_DIR).mkdir(exist_ok=True)
phone2id = """ sentences = """
<eps> 0 cat tac cat cat
a 1 at
b 2 tac at ta at at
f 3 at cat ct ct ta
o 4 cat cat cat cat
r 5 at at at at at at at
z 6
SPN 7
#0 8
"""
word2id = """
<eps> 0
foo 1
bar 2
baz 3
<UNK> 4
#0 5
""" """
L = k2.Fsa.from_str( transcript = Path(TMP_DIR) / "transcript_words.txt"
""" with open(transcript, "w") as f:
0 0 7 4 0 for line in sentences.strip().split("\n"):
0 7 -1 -1 0 f.write(f"{line}\n")
0 1 3 1 0
0 3 2 2 0 words = """
0 5 2 3 0 <eps> 0
1 2 4 0 0 <UNK> 1
2 0 4 0 0 at 2
3 4 1 0 0 cat 3
4 0 5 0 0 ct 4
5 6 1 0 0 ta 5
6 0 6 0 0 tac 6
7 #0 7
""", <s> 8
num_aux_labels=1, </s> 9
"""
word_txt = Path(TMP_DIR) / "words.txt"
with open(word_txt, "w") as f:
for line in words.strip().split("\n"):
f.write(f"{line}\n")
vocab_size = 8
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/train_bpe_model.py \
--lang-dir {TMP_DIR} \
--vocab-size {vocab_size} \
--transcript {transcript}
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1
"""
) )
with open(tmp_path / "tokens.txt", "w") as f:
f.write(phone2id)
with open(tmp_path / "words.txt", "w") as f:
f.write(word2id)
torch.save(L.as_dict(), tmp_path / "L.pt") def delete_test_data():
shutil.rmtree(TMP_DIR)
return tmp_path
def test_lexicon(lang_dir): def uniq_lexicon_test():
lexicon = Lexicon(lang_dir) lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt")
assert lexicon.tokens == list(range(1, 8))
# case 1: No OOV
texts = ["cat cat", "at ct", "at tac cat"]
token_ids = lexicon.texts_to_token_ids(texts)
sp = spm.SentencePieceProcessor()
sp.load(f"{TMP_DIR}/bpe.model")
expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int)
assert token_ids.tolist() == expected_token_ids
# case 2: With OOV
texts = ["ca"]
token_ids = lexicon.texts_to_token_ids(texts)
expected_token_ids = sp.encode(texts, out_type=int)
assert token_ids.tolist() != expected_token_ids
# Note: sentencepiece breaks "ca" into "_ c a"
# But there is no word "ca" in the lexicon, so our
# implementation returns the id of "<UNK>"
print(token_ids, expected_token_ids)
assert token_ids.tolist() == [[sp.unk_id()]]
# case 3: With OOV
texts = ["foo"]
token_ids = lexicon.texts_to_token_ids(texts)
expected_token_ids = sp.encode(texts, out_type=int)
print(token_ids)
print(expected_token_ids)
# test ragged lexicon
ragged_lexicon = lexicon.ragged_lexicon.tolist()
word_disambig_id = lexicon.word_table["#0"]
for i in range(2, word_disambig_id):
piece_id = ragged_lexicon[i]
word = lexicon.word_table[i]
assert word == sp.decode(piece_id)
assert piece_id == sp.encode(word)
def test_bpe_lexicon(): def test_main():
lang_dir = Path("data/lang/bpe") generate_test_data()
if not lang_dir.is_dir():
return
# TODO: Generate test data for BpeLexicon
lexicon = BpeLexicon(lang_dir) uniq_lexicon_test()
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
ids = lexicon.words_to_piece_ids(words) if USING_PYTEST:
print(ids) delete_test_data()
print([lexicon.token_table[i] for i in ids.values().tolist()])
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()

196
test/test_mmi_graph_compiler.py Executable file
View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
(2) cd icefall; ./test/test_mmi_graph_compiler.py
"""
import copy
import os
import shutil
import sys
from pathlib import Path
import k2
import sentencepiece as spm
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def generate_test_data():
Path(TMP_DIR).mkdir(exist_ok=True)
sentences = """
cat tac cat cat
at at cat at cat cat
tac at ta at at
at cat ct ct ta ct ct cat tac
cat cat cat cat
at at at at at at at
"""
transcript = Path(TMP_DIR) / "transcript_words.txt"
with open(transcript, "w") as f:
for line in sentences.strip().split("\n"):
f.write(f"{line}\n")
words = """
<eps> 0
<UNK> 1
at 2
cat 3
ct 4
ta 5
tac 6
#0 7
<s> 8
</s> 9
"""
word_txt = Path(TMP_DIR) / "words.txt"
with open(word_txt, "w") as f:
for line in words.strip().split("\n"):
f.write(f"{line}\n")
vocab_size = 8
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/train_bpe_model.py \
--lang-dir {TMP_DIR} \
--vocab-size {vocab_size} \
--transcript {transcript}
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0
./local/convert_transcript_words_to_tokens.py \
--lexicon {TMP_DIR}/lexicon.txt \
--transcript {transcript} \
--oov "<UNK>" \
> {TMP_DIR}/transcript_tokens.txt
./shared/make_kn_lm.py \
-ngram-order 2 \
-text {TMP_DIR}/transcript_tokens.txt \
-lm {TMP_DIR}/P.arpa
python3 -m kaldilm \
--read-symbol-table="{TMP_DIR}/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
"""
)
def delete_test_data():
shutil.rmtree(TMP_DIR)
def mmi_graph_compiler_test():
# Caution:
# You have to uncomment
# del transcript_fsa.aux_labels
# in mmi_graph_compiler.py
# to see the correct aux_labels in *.svg
graph_compiler = MmiTrainingGraphCompiler(
lang_dir=TMP_DIR, uniq_filename="lexicon.txt"
)
print(graph_compiler.device)
L_inv = graph_compiler.L_inv
L = k2.invert(L_inv)
L.labels_sym = graph_compiler.lexicon.token_table
L.aux_labels_sym = graph_compiler.lexicon.word_table
L.draw(f"{TMP_DIR}/L.svg", title="L")
L_inv.labels_sym = graph_compiler.lexicon.word_table
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
ctc_topo_P = graph_compiler.ctc_topo_P
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
print(ctc_topo_P.num_arcs)
print(k2.connect(ctc_topo_P).num_arcs)
with open(str(TMP_DIR) + "/P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
P.labels_sym = graph_compiler.lexicon.token_table
P.aux_labels_sym = graph_compiler.lexicon.token_table
P.draw(f"{TMP_DIR}/P.svg", title="P")
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
ctc_topo.labels_sym = ctc_topo_P.labels_sym
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
print("p num arcs", P.num_arcs)
print("ctc_topo num arcs", ctc_topo.num_arcs)
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
texts = ["cat at ct", "at ta", "cat tac"]
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct")
transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta")
transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac")
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct")
num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta")
num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac")
den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct")
den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac")
sp = spm.SentencePieceProcessor()
sp.load(f"{TMP_DIR}/bpe.model")
texts = ["cat at cat", "at tac"]
token_ids = graph_compiler.texts_to_ids(texts)
expected_token_ids = sp.encode(texts)
assert token_ids == expected_token_ids
def test_main():
generate_test_data()
mmi_graph_compiler_test()
if USING_PYTEST:
delete_test_data()
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()