mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Refactoring.
This commit is contained in:
parent
59b7140766
commit
6f5d63492a
@ -114,7 +114,10 @@ class Transformer(nn.Module):
|
||||
norm=encoder_norm,
|
||||
)
|
||||
|
||||
self.encoder_output_layer = nn.Linear(d_model, num_classes)
|
||||
# TODO(fangjun): remove dropout
|
||||
self.encoder_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
self.decoder_num_class = (
|
||||
@ -325,6 +328,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
|
354
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
354
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
@ -0,0 +1,354 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(DataModule):
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. "
|
||||
"Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--feature-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
def train_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = self.train_cuts()
|
||||
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
SpecAugment(
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = self.valid_cuts()
|
||||
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = SingleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
cuts = self.test_cuts()
|
||||
is_list = isinstance(cuts, list)
|
||||
test_loaders = []
|
||||
if not is_list:
|
||||
cuts = [cuts]
|
||||
|
||||
for cuts_test in cuts:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
if is_list:
|
||||
return test_loaders
|
||||
else:
|
||||
return test_loaders[0]
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||
)
|
||||
if self.args.full_libri:
|
||||
cuts_train = (
|
||||
cuts_train
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
cuts = []
|
||||
for test_set in test_sets:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts.append(
|
||||
load_manifest(
|
||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts
|
@ -1,7 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@ -43,7 +56,6 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
is_espnet_structure: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
@ -70,12 +82,10 @@ class Conformer(Transformer):
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
is_espnet_structure,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
if self.normalize_before:
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||
@ -88,7 +98,7 @@ class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -110,7 +120,7 @@ class Conformer(Transformer):
|
||||
mask = mask.to(x.device)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
||||
return x, mask
|
||||
@ -144,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
is_espnet_structure: bool = False,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
@ -394,7 +403,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
@ -421,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_espnet_structure: bool = False,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -444,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
@ -675,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
if not self.is_espnet_structure:
|
||||
q = q * scaling
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
@ -770,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
if not self.is_espnet_structure:
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
|
@ -1,8 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# (still working in progress)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -13,14 +25,15 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
@ -32,6 +45,7 @@ from icefall.utils import (
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -44,51 +58,111 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=9,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lattice-score-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_mmi/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_mmi/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoding
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "1best",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 100,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -99,7 +173,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
@ -133,8 +207,8 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
@ -151,12 +225,12 @@ def decode_one_batch(
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -178,6 +252,24 @@ def decode_one_batch(
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
@ -189,11 +281,12 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
@ -202,7 +295,8 @@ def decode_one_batch(
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
@ -211,16 +305,23 @@ def decode_one_batch(
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
@ -230,15 +331,20 @@ def decode_one_batch(
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
@ -247,7 +353,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
@ -263,8 +369,8 @@ def decode_dataset(
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
@ -283,7 +389,11 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -294,7 +404,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
@ -312,10 +422,10 @@ def decode_dataset(
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
@ -374,8 +484,10 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.exp_dir = Path(params.exp_dir)
|
||||
params.lang_dir = Path(params.lang_dir)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
@ -389,7 +501,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeMmiTrainingGraphCompiler(
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
@ -398,7 +510,9 @@ def main():
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -429,7 +543,7 @@ def main():
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
@ -453,7 +567,6 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
@ -468,6 +581,13 @@ def main():
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
@ -487,7 +607,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
|
@ -1,15 +1,26 @@
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
from icefall.utils import get_texts
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
@ -72,8 +83,8 @@ class Transformer(nn.Module):
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape [N, T, num_classes]
|
||||
# to the shape [N, T//subsampling_factor, d_model].
|
||||
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_classes -> d_model
|
||||
@ -103,10 +114,15 @@ class Transformer(nn.Module):
|
||||
norm=encoder_norm,
|
||||
)
|
||||
|
||||
self.encoder_output_layer = nn.Linear(d_model, num_classes)
|
||||
# TODO(fangjun): remove dropout
|
||||
self.encoder_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
self.decoder_num_class = self.num_classes
|
||||
self.decoder_num_class = (
|
||||
self.num_classes
|
||||
) # bpe model already has sos/eos symbol
|
||||
|
||||
self.decoder_embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||
@ -146,7 +162,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is [N, T, C].
|
||||
The input tensor. Its shape is (N, T, C).
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -155,17 +171,17 @@ class Transformer(nn.Module):
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- CTC output for ctc decoding. Its shape is [N, T, C]
|
||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
||||
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||
- Encoder output with shape (T, N, C). It can be used as key and
|
||||
value for the decoder.
|
||||
- Encoder output padding mask. It can be used as
|
||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
@ -179,7 +195,7 @@ class Transformer(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -190,8 +206,8 @@ class Transformer(nn.Module):
|
||||
padding mask for the decoder.
|
||||
Returns:
|
||||
Return a tuple with two tensors:
|
||||
- The encoder output, with shape [T, N, C]
|
||||
- encoder padding mask, with shape [N, T].
|
||||
- The encoder output, with shape (T, N, C)
|
||||
- encoder padding mask, with shape (N, T).
|
||||
The mask is None if `supervisions` is None.
|
||||
It is used as memory key padding mask in the decoder.
|
||||
"""
|
||||
@ -209,11 +225,11 @@ class Transformer(nn.Module):
|
||||
Args:
|
||||
x:
|
||||
The output tensor from the transformer encoder.
|
||||
Its shape is [T, N, C]
|
||||
Its shape is (T, N, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor that can be used for CTC decoding.
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
"""
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -231,7 +247,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -296,7 +312,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -312,6 +328,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
@ -329,6 +346,9 @@ class Transformer(nn.Module):
|
||||
)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
@ -634,13 +654,13 @@ class PositionalEncoding(nn.Module):
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
|
||||
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
||||
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
||||
to [N, T, d_model]. Otherwise, nothing is done.
|
||||
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||
to (N, T, d_model). Otherwise, nothing is done.
|
||||
|
||||
Args:
|
||||
x:
|
||||
It is a tensor of shape [N, T, C].
|
||||
It is a tensor of shape (N, T, C).
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
@ -658,7 +678,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
||||
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -667,10 +687,10 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, T, C]
|
||||
Return a tensor of shape (N, T, C)
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||
@ -766,7 +786,8 @@ class Noam(object):
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
|
||||
Label-smoothing loss. KL-divergence between
|
||||
q_{smoothed ground truth prob.}(w)
|
||||
and p_{prob. computed by model}(w) is minimized.
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||
@ -851,7 +872,8 @@ def encoder_padding_mask(
|
||||
frames, before subsampling)
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length),
|
||||
True denote the masked indices.
|
||||
"""
|
||||
if supervisions is None:
|
||||
return None
|
||||
|
@ -8,8 +8,8 @@ for LM training with the help of a lexicon.
|
||||
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
||||
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
||||
|
||||
If a word has multiple pronunciations, the one that appears last in the lexicon
|
||||
is used.
|
||||
If a word has multiple pronunciations, the one that appears first in the lexicon
|
||||
is kept; others are removed.
|
||||
|
||||
If the input transcript is:
|
||||
|
||||
@ -20,8 +20,8 @@ If the input transcript is:
|
||||
and if the lexicon is
|
||||
|
||||
<UNK> SPN
|
||||
hello h e l l o
|
||||
hello h e l l o 2
|
||||
hello h e l l o
|
||||
world w o r l d
|
||||
zoo z o o
|
||||
|
||||
@ -32,10 +32,11 @@ Then the output is
|
||||
SPN z o o w o r l d SPN
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from generate_unique_lexicon import filter_multiple_pronunications
|
||||
|
||||
from icefall.lexicon import read_lexicon
|
||||
|
||||
@ -57,7 +58,9 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None:
|
||||
def process_line(
|
||||
lexicon: Dict[str, List[str]], line: str, oov_token: str
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lexicon:
|
||||
@ -86,7 +89,11 @@ def main():
|
||||
assert Path(args.transcript).is_file()
|
||||
assert len(args.oov) > 0
|
||||
|
||||
lexicon = dict(read_lexicon(args.lexicon))
|
||||
# Only the first pronunciation of a word is kept
|
||||
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
|
||||
|
||||
lexicon = dict(lexicon)
|
||||
|
||||
assert args.oov in lexicon
|
||||
|
||||
oov_token = lexicon[args.oov]
|
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file takes as input a lexicon.txt and output a new lexicon,
|
||||
in which each word has a unique pronunciation.
|
||||
|
||||
The way to do this is to keep only the first pronunciation of a word
|
||||
in lexicon.txt.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
This file will generate a new file uniq_lexicon.txt
|
||||
in it.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def filter_multiple_pronunications(
|
||||
lexicon: List[Tuple[str, List[str]]]
|
||||
) -> List[Tuple[str, List[str]]]:
|
||||
"""Remove multiple pronunciations of words from a lexicon.
|
||||
|
||||
If a word has more than one pronunciation in the lexicon, only
|
||||
the first one is kept, while other pronunciations are removed
|
||||
from the lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
|
||||
where "p1, p2, ..., pn" are the pronunciations of the "word".
|
||||
Returns:
|
||||
Return a new lexicon where each word has a unique pronunciation.
|
||||
"""
|
||||
seen = set()
|
||||
ans = []
|
||||
|
||||
for word, tokens in lexicon:
|
||||
if word in seen:
|
||||
continue
|
||||
seen.add(word)
|
||||
ans.append((word, tokens))
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
|
||||
in_lexicon = read_lexicon(lexicon_filename)
|
||||
|
||||
out_lexicon = filter_multiple_pronunications(in_lexicon)
|
||||
|
||||
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
|
||||
|
||||
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
|
||||
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -42,10 +43,37 @@ import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
@ -315,8 +343,9 @@ def lexicon_to_fst(
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang_phone")
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
@ -344,9 +373,9 @@ def main():
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(out_dir / "tokens.txt", token2id)
|
||||
write_mapping(out_dir / "words.txt", word2id)
|
||||
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||
write_mapping(lang_dir / "words.txt", word2id)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
@ -364,17 +393,20 @@ def main():
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.png", title="L")
|
||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,6 +49,8 @@ from prepare_lang import (
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
@ -169,6 +171,20 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -221,6 +237,18 @@ def main():
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
@ -37,10 +38,17 @@ def get_args():
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the training corpus: train.txt.
|
||||
It should contain the training corpus: transcript_words.txt.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
@ -58,7 +66,7 @@ def main():
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = f"{lang_dir}/train.txt"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
|
@ -40,9 +40,9 @@ dl_dir=$PWD/download
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
2000
|
||||
1000
|
||||
# 5000
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
@ -116,14 +116,15 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
mkdir -p data/lang_phone
|
||||
lang_dir=data/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang_phone/lexicon.txt
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f data/lang_phone/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -138,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/train.txt ]; then
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
@ -147,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > $lang_dir/train.txt
|
||||
done > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
@ -166,18 +168,18 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
|
||||
if [ ! -f $lang_dir/corpus.txt ]; then
|
||||
./local/convert_transcript_to_corpus.py \
|
||||
--lexicon data/lang_bpe/lexicon.txt \
|
||||
--transcript data/lang_bpe/train.txt \
|
||||
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
|
||||
./local/convert_transcript_words_to_tokens.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--transcript $lang_dir/transcript_words.txt \
|
||||
--oov "<UNK>" \
|
||||
> $lang_dir/corpus.txt
|
||||
> $lang_dir/transcript_tokens.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/P.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 2 \
|
||||
-text $lang_dir/corpus.txt \
|
||||
-text $lang_dir/transcript_tokens.txt \
|
||||
-lm $lang_dir/P.arpa
|
||||
fi
|
||||
|
||||
@ -226,4 +228,4 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
||||
cd data && ln -sfv lang_bpe_500 lang_bpe
|
||||
|
@ -34,14 +34,10 @@ class BpeCtcTrainingGraphCompiler(object):
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
This directory is expected to contain the following files::
|
||||
This directory is expected to contain the following files:
|
||||
|
||||
- bpe.model
|
||||
- words.txt
|
||||
|
||||
The above files are produced by the script `prepare.sh`. You
|
||||
should have run that before running the training code.
|
||||
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
sos_token:
|
||||
|
@ -1,178 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
class BpeMmiTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
sos_token: str = "<sos/eos>",
|
||||
eos_token: str = "<sos/eos>",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang directory. It is expected to contain the
|
||||
following files::
|
||||
|
||||
- tokens.txt
|
||||
- words.txt
|
||||
- bpe.model
|
||||
- P.fst.txt
|
||||
|
||||
The above files are generated by the script `prepare.sh`. You
|
||||
should have run it before running the training code.
|
||||
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
sos_token:
|
||||
The word piece that represents sos.
|
||||
eos_token:
|
||||
The word piece that represents eos.
|
||||
"""
|
||||
self.lang_dir = Path(lang_dir)
|
||||
self.lexicon = Lexicon(lang_dir)
|
||||
self.device = device
|
||||
self.load_sentence_piece_model()
|
||||
self.build_ctc_topo_P()
|
||||
|
||||
self.sos_id = self.sp.piece_to_id(sos_token)
|
||||
self.eos_id = self.sp.piece_to_id(eos_token)
|
||||
|
||||
assert self.sos_id != self.sp.unk_id()
|
||||
assert self.eos_id != self.sp.unk_id()
|
||||
|
||||
def load_sentence_piece_model(self) -> None:
|
||||
"""Load the pre-trained sentencepiece model
|
||||
from self.lang_dir/bpe.model.
|
||||
"""
|
||||
model_file = self.lang_dir / "bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
self.sp = sp
|
||||
|
||||
def build_ctc_topo_P(self):
|
||||
"""Built ctc_topo_P, the composition result of
|
||||
ctc_topo and P, where P is a pre-trained bigram
|
||||
word piece LM.
|
||||
"""
|
||||
# Note: there is no need to save a pre-compiled P and ctc_topo
|
||||
# as it is very fast to generate them.
|
||||
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
||||
with open(self.lang_dir / "P.fst.txt") as f:
|
||||
# P is not an acceptor because there is
|
||||
# a back-off state, whose incoming arcs
|
||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
first_token_disambig_id = self.lexicon.token_table["#0"]
|
||||
|
||||
# P.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del P.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
P.labels[P.labels >= first_token_disambig_id] = 0
|
||||
|
||||
P = k2.remove_epsilon(P)
|
||||
P = k2.arc_sort(P)
|
||||
P = P.to(self.device)
|
||||
# Add epsilon self-loops to P because we want the
|
||||
# following operation "k2.intersect" to run on GPU.
|
||||
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
||||
|
||||
max_token_id = max(self.lexicon.tokens)
|
||||
logging.info(
|
||||
f"Building modified ctc_topo. max_token_id: {max_token_id}"
|
||||
)
|
||||
# CAUTION: We have to use a modifed version of CTC topo.
|
||||
# Otherwise, the resulting ctc_topo_P is so large that it gets
|
||||
# stuck in k2.intersect_dense_pruned() or it gets OOM in
|
||||
# k2.intersect_dense()
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device)
|
||||
|
||||
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
||||
|
||||
logging.info("Building ctc_topo_P")
|
||||
ctc_topo_P = k2.intersect(
|
||||
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
||||
).invert()
|
||||
|
||||
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Within a transcript words are
|
||||
separated by spaces. An example input is::
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of piece IDs.
|
||||
"""
|
||||
return self.sp.encode(texts, out_type=int)
|
||||
|
||||
def compile(
|
||||
self, texts: List[str], replicate_den: bool = True
|
||||
) -> Tuple[k2.Fsa, k2.Fsa]:
|
||||
"""Create numerator and denominator graphs from transcripts.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Within a transcript words are
|
||||
separated by spaces. An example input is::
|
||||
|
||||
["HELLO icefall", "HALLO WELT"]
|
||||
|
||||
replicate_den:
|
||||
If True, the returned den_graph is replicated to match the number
|
||||
of FSAs in the returned num_graph; if False, the returned den_graph
|
||||
contains only a single FSA
|
||||
Returns:
|
||||
A tuple (num_graphs, den_graphs), where
|
||||
|
||||
- `num_graphs` is the numerator graph. It is an FsaVec with
|
||||
shape `(len(texts), None, None)`.
|
||||
|
||||
- `den_graphs` is the denominator graph. It is an FsaVec with the
|
||||
same shape of the `num_graph` if replicate_den is True;
|
||||
otherwise, it is an FsaVec containing only a single FSA.
|
||||
"""
|
||||
token_ids = self.texts_to_ids(texts)
|
||||
token_fsas = k2.linear_fsa(token_ids, device=self.device)
|
||||
|
||||
token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas)
|
||||
|
||||
# NOTE: Use treat_epsilons_specially=False so that k2.compose
|
||||
# can be run on GPU
|
||||
num_graphs = k2.compose(
|
||||
self.ctc_topo_P,
|
||||
token_fsas_with_self_loops,
|
||||
treat_epsilons_specially=False,
|
||||
)
|
||||
# num_graphs may not be connected and
|
||||
# not be topologically sorted after k2.compose
|
||||
num_graphs = k2.connect(num_graphs)
|
||||
num_graphs = k2.top_sort(num_graphs)
|
||||
|
||||
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()])
|
||||
if replicate_den:
|
||||
indexes = torch.zeros(
|
||||
len(texts), dtype=torch.int32, device=self.device
|
||||
)
|
||||
den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes)
|
||||
else:
|
||||
den_graphs = ctc_topo_P_vec
|
||||
|
||||
return num_graphs, den_graphs
|
@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
||||
f.write(f"{word} {' '.join(tokens)}\n")
|
||||
|
||||
|
||||
def convert_lexicon_to_ragged(
|
||||
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||
) -> k2.RaggedTensor:
|
||||
"""Read a lexicon and convert it to a ragged tensor.
|
||||
|
||||
The ragged tensor has two axes: [word][token].
|
||||
|
||||
Caution:
|
||||
We assume that each word has a unique pronunciation.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the lexicon. It has a format that can be read
|
||||
by :func:`read_lexicon`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
token_table:
|
||||
The token symbol table.
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word][token].
|
||||
"""
|
||||
disambig_id = word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
|
||||
# epsilon is not a word, but it occupies a position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids_list = []
|
||||
|
||||
lexicon_tmp = read_lexicon(filename)
|
||||
lexicon = dict(lexicon_tmp)
|
||||
if len(lexicon_tmp) != len(lexicon):
|
||||
raise RuntimeError(
|
||||
"It's assumed that each word has a unique pronunciation"
|
||||
)
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
tokens = lexicon[w]
|
||||
token_ids = [token_table[k] for k in tokens]
|
||||
|
||||
row_splits.append(row_splits[-1] + len(token_ids))
|
||||
token_ids_list.extend(token_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits,
|
||||
None,
|
||||
cached_tot_size,
|
||||
)
|
||||
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon."""
|
||||
|
||||
@ -96,12 +159,10 @@ class Lexicon(object):
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang directory. It is expected to contain the following
|
||||
files::
|
||||
|
||||
files:
|
||||
- tokens.txt
|
||||
- words.txt
|
||||
- L.pt
|
||||
|
||||
The above files are produced by the script `prepare.sh`. You
|
||||
should have run that before running the training code.
|
||||
disambig_pattern:
|
||||
@ -121,7 +182,7 @@ class Lexicon(object):
|
||||
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||
|
||||
# We save L_inv instead of L because it will be used to intersect with
|
||||
# transcript, both of whose labels are word IDs.
|
||||
# transcript FSAs, both of whose labels are word IDs.
|
||||
self.L_inv = L_inv
|
||||
self.disambig_pattern = disambig_pattern
|
||||
|
||||
@ -144,69 +205,66 @@ class Lexicon(object):
|
||||
return ans
|
||||
|
||||
|
||||
class BpeLexicon(Lexicon):
|
||||
class UniqLexicon(Lexicon):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Refer to the help information in Lexicon.__init__.
|
||||
|
||||
uniq_filename: It is assumed to be inside the given `lang_dir`.
|
||||
|
||||
Each word in the lexicon is assumed to have a unique pronunciation.
|
||||
"""
|
||||
lang_dir = Path(lang_dir)
|
||||
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
|
||||
|
||||
self.ragged_lexicon = self.convert_lexicon_to_ragged(
|
||||
lang_dir / "lexicon.txt"
|
||||
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||
filename=lang_dir / uniq_filename,
|
||||
word_table=self.word_table,
|
||||
token_table=self.token_table,
|
||||
)
|
||||
# TODO: should we move it to a certain device ?
|
||||
|
||||
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
|
||||
"""Read a BPE lexicon from file and convert it to a
|
||||
k2 ragged tensor.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word_id]
|
||||
def texts_to_token_ids(
|
||||
self, texts: List[str], oov: str = "<UNK>"
|
||||
) -> k2.RaggedTensor:
|
||||
"""
|
||||
disambig_id = self.word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Each transcript contains space(s)
|
||||
separated words. An example texts is::
|
||||
|
||||
# epsilon is not a word, but it occupies on position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids = []
|
||||
['HELLO k2', 'HELLO icefall']
|
||||
oov:
|
||||
The OOV word. If a word in `texts` is not in the lexicon, it is
|
||||
replaced with `oov`.
|
||||
Returns:
|
||||
Return a ragged int tensor with 2 axes [utterance][token_id]
|
||||
"""
|
||||
oov_id = self.word_table[oov]
|
||||
|
||||
lexicon = read_lexicon(filename)
|
||||
lexicon = dict(lexicon)
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||
ans = self.ragged_lexicon.index(ragged_indexes)
|
||||
ans = ans.remove_axis(ans.num_axes - 2)
|
||||
return ans
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = self.word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
pieces = lexicon[w]
|
||||
piece_ids = [self.token_table[k] for k in pieces]
|
||||
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor containing token IDs.
|
||||
|
||||
row_splits.append(row_splits[-1] + len(piece_ids))
|
||||
token_ids.extend(piece_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=cached_tot_size
|
||||
)
|
||||
values = torch.tensor(token_ids, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor contained
|
||||
word piece IDs.
|
||||
We assume there are no OOVs in "words".
|
||||
"""
|
||||
word_ids = [self.word_table[w] for w in words]
|
||||
word_ids = torch.tensor(word_ids, dtype=torch.int32)
|
||||
|
216
icefall/mmi_graph_compiler.py
Normal file
216
icefall/mmi_graph_compiler.py
Normal file
@ -0,0 +1,216 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
|
||||
class MmiTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
oov: str = "<UNK>",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang directory. It is expected to contain the
|
||||
following files::
|
||||
|
||||
- tokens.txt
|
||||
- words.txt
|
||||
- P.fst.txt
|
||||
|
||||
The above files are generated by the script `prepare.sh`. You
|
||||
should have run it before running the training code.
|
||||
uniq_filename:
|
||||
File name to the lexicon in which every word has exactly one
|
||||
pronunciation. We assume this file is inside the given `lang_dir`.
|
||||
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
oov:
|
||||
Out of vocabulary word. When a word in the transcript
|
||||
does not exist in the lexicon, it is replaced with `oov`.
|
||||
"""
|
||||
self.lang_dir = Path(lang_dir)
|
||||
self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.L_inv = self.lexicon.L_inv.to(self.device)
|
||||
|
||||
self.oov_id = self.lexicon.word_table[oov]
|
||||
|
||||
self.build_ctc_topo_P()
|
||||
|
||||
def build_ctc_topo_P(self):
|
||||
"""Built ctc_topo_P, the composition result of
|
||||
ctc_topo and P, where P is a pre-trained bigram
|
||||
word piece LM.
|
||||
"""
|
||||
# Note: there is no need to save a pre-compiled P and ctc_topo
|
||||
# as it is very fast to generate them.
|
||||
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
||||
with open(self.lang_dir / "P.fst.txt") as f:
|
||||
# P is not an acceptor because there is
|
||||
# a back-off state, whose incoming arcs
|
||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
first_token_disambig_id = self.lexicon.token_table["#0"]
|
||||
|
||||
# P.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del P.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
P.labels[P.labels >= first_token_disambig_id] = 0
|
||||
|
||||
P = k2.remove_epsilon(P)
|
||||
P = k2.arc_sort(P)
|
||||
P = P.to(self.device)
|
||||
# Add epsilon self-loops to P because we want the
|
||||
# following operation "k2.intersect" to run on GPU.
|
||||
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
||||
|
||||
max_token_id = max(self.lexicon.tokens)
|
||||
logging.info(
|
||||
f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
|
||||
)
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
|
||||
|
||||
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
||||
|
||||
logging.info("Building ctc_topo_P")
|
||||
ctc_topo_P = k2.intersect(
|
||||
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
||||
).invert()
|
||||
|
||||
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
||||
|
||||
def compile(
|
||||
self, texts: Iterable[str], replicate_den: bool = True
|
||||
) -> Tuple[k2.Fsa, k2.Fsa]:
|
||||
"""Create numerator and denominator graphs from transcripts
|
||||
and the bigram phone LM.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Within a transcript, words are
|
||||
separated by spaces. An example `texts` is given below::
|
||||
|
||||
["Hello icefall", "LF-MMI training with icefall using k2"]
|
||||
|
||||
replicate_den:
|
||||
If True, the returned den_graph is replicated to match the number
|
||||
of FSAs in the returned num_graph; if False, the returned den_graph
|
||||
contains only a single FSA
|
||||
Returns:
|
||||
A tuple (num_graph, den_graph), where
|
||||
|
||||
- `num_graph` is the numerator graph. It is an FsaVec with
|
||||
shape `(len(texts), None, None)`.
|
||||
|
||||
- `den_graph` is the denominator graph. It is an FsaVec
|
||||
with the same shape of the `num_graph` if replicate_den is
|
||||
True; otherwise, it is an FsaVec containing only a single FSA.
|
||||
"""
|
||||
transcript_fsa = self.build_transcript_fsa(texts)
|
||||
|
||||
# remove word IDs from transcript_fsa since it is not needed
|
||||
del transcript_fsa.aux_labels
|
||||
# NOTE: You can comment out the above statement
|
||||
# if you want to run test/test_mmi_graph_compiler.py
|
||||
|
||||
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
transcript_fsa
|
||||
)
|
||||
|
||||
transcript_fsa_with_self_loops = k2.arc_sort(
|
||||
transcript_fsa_with_self_loops
|
||||
)
|
||||
|
||||
num = k2.compose(
|
||||
self.ctc_topo_P,
|
||||
transcript_fsa_with_self_loops,
|
||||
treat_epsilons_specially=False,
|
||||
)
|
||||
|
||||
# CAUTION: Due to the presence of P,
|
||||
# the resulting `num` may not be connected
|
||||
num = k2.connect(num)
|
||||
|
||||
num = k2.arc_sort(num)
|
||||
|
||||
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
|
||||
if replicate_den:
|
||||
indexes = torch.zeros(
|
||||
len(texts), dtype=torch.int32, device=self.device
|
||||
)
|
||||
den = k2.index_fsa(ctc_topo_P_vec, indexes)
|
||||
else:
|
||||
den = ctc_topo_P_vec
|
||||
|
||||
return num, den
|
||||
|
||||
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert transcripts to an FsaVec with the help of a lexicon
|
||||
and word symbol table.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
Each element is a transcript containing words separated by space(s).
|
||||
For instance, it may be 'HELLO icefall', which contains
|
||||
two words.
|
||||
|
||||
Returns:
|
||||
Return an FST (FsaVec) corresponding to the transcript.
|
||||
Its `labels` is token IDs and `aux_labels` is word IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.lexicon.word_table:
|
||||
word_ids.append(self.lexicon.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
|
||||
fsa = k2.linear_fsa(word_ids_list, self.device)
|
||||
fsa = k2.add_epsilon_self_loops(fsa)
|
||||
|
||||
# The reason to use `invert_()` at the end is as follows:
|
||||
#
|
||||
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
|
||||
# (2) `fsa.labels` is word IDs
|
||||
# (3) after intersection, the `labels` is still word IDs
|
||||
# (4) after `invert_()`, the `labels` is token IDs
|
||||
# and `aux_labels` is word IDs
|
||||
transcript_fsa = k2.intersect(
|
||||
self.L_inv, fsa, treat_epsilons_specially=False
|
||||
).invert_()
|
||||
transcript_fsa = k2.arc_sort(transcript_fsa)
|
||||
return transcript_fsa
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
We assume it contains no OOVs. Otherwise, it will raise an
|
||||
exception.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
return self.lexicon.texts_to_token_ids(texts).tolist()
|
@ -19,14 +19,16 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import kaldialign
|
||||
import lhotse
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@ -132,17 +134,82 @@ def setup_logger(
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""
|
||||
TODO:
|
||||
"""
|
||||
def get_git_sha1():
|
||||
git_commit = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
dirty_commit = (
|
||||
len(
|
||||
subprocess.run(
|
||||
["git", "diff", "--shortstat"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
> 0
|
||||
)
|
||||
git_commit = (
|
||||
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||
)
|
||||
return git_commit
|
||||
|
||||
|
||||
def get_git_date():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_git_branch_name():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_env_info() -> Dict[str, Any]:
|
||||
"""Get the environment information."""
|
||||
return {
|
||||
"k2-git-sha1": None,
|
||||
"k2-version": None,
|
||||
"lhotse-version": None,
|
||||
"torch-version": None,
|
||||
"icefall-sha1": None,
|
||||
"icefall-version": None,
|
||||
"k2-version": k2.version.__version__,
|
||||
"k2-build-type": k2.version.__build_type__,
|
||||
"k2-with-cuda": k2.with_cuda,
|
||||
"k2-git-sha1": k2.version.__git_sha1__,
|
||||
"k2-git-date": k2.version.__git_date__,
|
||||
"lhotse-version": lhotse.__version__,
|
||||
"torch-cuda-available": torch.cuda.is_available(),
|
||||
"torch-cuda-version": torch.version.cuda,
|
||||
"python-version": sys.version[:3],
|
||||
"icefall-git-branch": get_git_branch_name(),
|
||||
"icefall-git-sha1": get_git_sha1(),
|
||||
"icefall-git-date": get_git_date(),
|
||||
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||
"k2-path": str(Path(k2.__file__).resolve()),
|
||||
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||
}
|
||||
|
||||
|
||||
|
@ -19,20 +19,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import BpeLexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def test():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: generate data for testing
|
||||
|
||||
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
|
||||
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
|
||||
compiler.compile(ids)
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
|
||||
ids0 = lexicon.words_to_piece_ids(["HELLO"])
|
||||
assert ids[0] == ids0.values().tolist()
|
||||
|
||||
|
@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
||||
|
||||
|
||||
def test_bpe_mmi_graph_compiler():
|
||||
lang_dir = Path("data/lang_bpe")
|
||||
if lang_dir.is_dir() is False:
|
||||
return
|
||||
device = torch.device("cpu")
|
||||
compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device)
|
||||
|
||||
texts = ["HELLO WORLD", "MMI TRAINING"]
|
||||
|
||||
num_graphs, den_graphs = compiler.compile(texts)
|
||||
num_graphs.labels_sym = compiler.lexicon.token_table
|
||||
num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table)
|
||||
num_graphs.aux_labels_sym._id2sym[0] = "<eps>"
|
||||
num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD")
|
||||
num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD")
|
||||
print(den_graphs.shape)
|
||||
print(den_graphs[0].shape)
|
||||
print(den_graphs[0].num_arcs)
|
173
test/test_lexicon.py
Normal file → Executable file
173
test/test_lexicon.py
Normal file → Executable file
@ -14,80 +14,135 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_lexicon.py
|
||||
(2) cd icefall; ./test/test_lexicon.py
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import pytest
|
||||
import torch
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.lexicon import BpeLexicon, Lexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
TMP_DIR = "/tmp/icefall-test-lexicon"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_dir(tmp_path):
|
||||
phone2id = """
|
||||
<eps> 0
|
||||
a 1
|
||||
b 2
|
||||
f 3
|
||||
o 4
|
||||
r 5
|
||||
z 6
|
||||
SPN 7
|
||||
#0 8
|
||||
"""
|
||||
word2id = """
|
||||
<eps> 0
|
||||
foo 1
|
||||
bar 2
|
||||
baz 3
|
||||
<UNK> 4
|
||||
#0 5
|
||||
def generate_test_data():
|
||||
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||
sentences = """
|
||||
cat tac cat cat
|
||||
at
|
||||
tac at ta at at
|
||||
at cat ct ct ta
|
||||
cat cat cat cat
|
||||
at at at at at at at
|
||||
"""
|
||||
|
||||
L = k2.Fsa.from_str(
|
||||
"""
|
||||
0 0 7 4 0
|
||||
0 7 -1 -1 0
|
||||
0 1 3 1 0
|
||||
0 3 2 2 0
|
||||
0 5 2 3 0
|
||||
1 2 4 0 0
|
||||
2 0 4 0 0
|
||||
3 4 1 0 0
|
||||
4 0 5 0 0
|
||||
5 6 1 0 0
|
||||
6 0 6 0 0
|
||||
7
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||
with open(transcript, "w") as f:
|
||||
for line in sentences.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
words = """
|
||||
<eps> 0
|
||||
<UNK> 1
|
||||
at 2
|
||||
cat 3
|
||||
ct 4
|
||||
ta 5
|
||||
tac 6
|
||||
#0 7
|
||||
<s> 8
|
||||
</s> 9
|
||||
"""
|
||||
word_txt = Path(TMP_DIR) / "words.txt"
|
||||
with open(word_txt, "w") as f:
|
||||
for line in words.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
os.system(
|
||||
f"""
|
||||
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir {TMP_DIR} \
|
||||
--vocab-size {vocab_size} \
|
||||
--transcript {transcript}
|
||||
|
||||
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1
|
||||
"""
|
||||
)
|
||||
|
||||
with open(tmp_path / "tokens.txt", "w") as f:
|
||||
f.write(phone2id)
|
||||
with open(tmp_path / "words.txt", "w") as f:
|
||||
f.write(word2id)
|
||||
|
||||
torch.save(L.as_dict(), tmp_path / "L.pt")
|
||||
|
||||
return tmp_path
|
||||
def delete_test_data():
|
||||
shutil.rmtree(TMP_DIR)
|
||||
|
||||
|
||||
def test_lexicon(lang_dir):
|
||||
lexicon = Lexicon(lang_dir)
|
||||
assert lexicon.tokens == list(range(1, 8))
|
||||
def uniq_lexicon_test():
|
||||
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt")
|
||||
|
||||
# case 1: No OOV
|
||||
texts = ["cat cat", "at ct", "at tac cat"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(f"{TMP_DIR}/bpe.model")
|
||||
|
||||
expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
assert token_ids.tolist() == expected_token_ids
|
||||
|
||||
# case 2: With OOV
|
||||
texts = ["ca"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
expected_token_ids = sp.encode(texts, out_type=int)
|
||||
assert token_ids.tolist() != expected_token_ids
|
||||
# Note: sentencepiece breaks "ca" into "_ c a"
|
||||
# But there is no word "ca" in the lexicon, so our
|
||||
# implementation returns the id of "<UNK>"
|
||||
print(token_ids, expected_token_ids)
|
||||
assert token_ids.tolist() == [[sp.unk_id()]]
|
||||
|
||||
# case 3: With OOV
|
||||
texts = ["foo"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
expected_token_ids = sp.encode(texts, out_type=int)
|
||||
print(token_ids)
|
||||
print(expected_token_ids)
|
||||
|
||||
# test ragged lexicon
|
||||
ragged_lexicon = lexicon.ragged_lexicon.tolist()
|
||||
word_disambig_id = lexicon.word_table["#0"]
|
||||
for i in range(2, word_disambig_id):
|
||||
piece_id = ragged_lexicon[i]
|
||||
word = lexicon.word_table[i]
|
||||
assert word == sp.decode(piece_id)
|
||||
assert piece_id == sp.encode(word)
|
||||
|
||||
|
||||
def test_bpe_lexicon():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: Generate test data for BpeLexicon
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
|
||||
ids = lexicon.words_to_piece_ids(words)
|
||||
print(ids)
|
||||
print([lexicon.token_table[i] for i in ids.values().tolist()])
|
||||
uniq_lexicon_test()
|
||||
|
||||
if USING_PYTEST:
|
||||
delete_test_data()
|
||||
|
||||
|
||||
def main():
|
||||
test_main()
|
||||
|
||||
|
||||
if __name__ == "__main__" and not USING_PYTEST:
|
||||
main()
|
||||
|
196
test/test_mmi_graph_compiler.py
Executable file
196
test/test_mmi_graph_compiler.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
|
||||
(2) cd icefall; ./test/test_mmi_graph_compiler.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
|
||||
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def generate_test_data():
|
||||
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||
sentences = """
|
||||
cat tac cat cat
|
||||
at at cat at cat cat
|
||||
tac at ta at at
|
||||
at cat ct ct ta ct ct cat tac
|
||||
cat cat cat cat
|
||||
at at at at at at at
|
||||
"""
|
||||
|
||||
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||
with open(transcript, "w") as f:
|
||||
for line in sentences.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
words = """
|
||||
<eps> 0
|
||||
<UNK> 1
|
||||
at 2
|
||||
cat 3
|
||||
ct 4
|
||||
ta 5
|
||||
tac 6
|
||||
#0 7
|
||||
<s> 8
|
||||
</s> 9
|
||||
"""
|
||||
word_txt = Path(TMP_DIR) / "words.txt"
|
||||
with open(word_txt, "w") as f:
|
||||
for line in words.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
os.system(
|
||||
f"""
|
||||
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir {TMP_DIR} \
|
||||
--vocab-size {vocab_size} \
|
||||
--transcript {transcript}
|
||||
|
||||
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0
|
||||
|
||||
./local/convert_transcript_words_to_tokens.py \
|
||||
--lexicon {TMP_DIR}/lexicon.txt \
|
||||
--transcript {transcript} \
|
||||
--oov "<UNK>" \
|
||||
> {TMP_DIR}/transcript_tokens.txt
|
||||
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 2 \
|
||||
-text {TMP_DIR}/transcript_tokens.txt \
|
||||
-lm {TMP_DIR}/P.arpa
|
||||
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="{TMP_DIR}/tokens.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=2 \
|
||||
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def delete_test_data():
|
||||
shutil.rmtree(TMP_DIR)
|
||||
|
||||
|
||||
def mmi_graph_compiler_test():
|
||||
# Caution:
|
||||
# You have to uncomment
|
||||
# del transcript_fsa.aux_labels
|
||||
# in mmi_graph_compiler.py
|
||||
# to see the correct aux_labels in *.svg
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
lang_dir=TMP_DIR, uniq_filename="lexicon.txt"
|
||||
)
|
||||
print(graph_compiler.device)
|
||||
L_inv = graph_compiler.L_inv
|
||||
L = k2.invert(L_inv)
|
||||
|
||||
L.labels_sym = graph_compiler.lexicon.token_table
|
||||
L.aux_labels_sym = graph_compiler.lexicon.word_table
|
||||
L.draw(f"{TMP_DIR}/L.svg", title="L")
|
||||
|
||||
L_inv.labels_sym = graph_compiler.lexicon.word_table
|
||||
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
|
||||
|
||||
ctc_topo_P = graph_compiler.ctc_topo_P
|
||||
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
|
||||
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
|
||||
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
|
||||
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
|
||||
|
||||
print(ctc_topo_P.num_arcs)
|
||||
print(k2.connect(ctc_topo_P).num_arcs)
|
||||
|
||||
with open(str(TMP_DIR) + "/P.fst.txt") as f:
|
||||
# P is not an acceptor because there is
|
||||
# a back-off state, whose incoming arcs
|
||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
P.labels_sym = graph_compiler.lexicon.token_table
|
||||
P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
P.draw(f"{TMP_DIR}/P.svg", title="P")
|
||||
|
||||
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
|
||||
ctc_topo.labels_sym = ctc_topo_P.labels_sym
|
||||
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
|
||||
print("p num arcs", P.num_arcs)
|
||||
print("ctc_topo num arcs", ctc_topo.num_arcs)
|
||||
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
|
||||
|
||||
texts = ["cat at ct", "at ta", "cat tac"]
|
||||
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
|
||||
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct")
|
||||
transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta")
|
||||
transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac")
|
||||
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||
num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct")
|
||||
num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta")
|
||||
num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac")
|
||||
|
||||
den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct")
|
||||
den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(f"{TMP_DIR}/bpe.model")
|
||||
|
||||
texts = ["cat at cat", "at tac"]
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
expected_token_ids = sp.encode(texts)
|
||||
assert token_ids == expected_token_ids
|
||||
|
||||
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
||||
mmi_graph_compiler_test()
|
||||
|
||||
if USING_PYTEST:
|
||||
delete_test_data()
|
||||
|
||||
|
||||
def main():
|
||||
test_main()
|
||||
|
||||
|
||||
if __name__ == "__main__" and not USING_PYTEST:
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user