mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
Refactoring.
This commit is contained in:
parent
59b7140766
commit
6f5d63492a
@ -114,7 +114,10 @@ class Transformer(nn.Module):
|
|||||||
norm=encoder_norm,
|
norm=encoder_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_output_layer = nn.Linear(d_model, num_classes)
|
# TODO(fangjun): remove dropout
|
||||||
|
self.encoder_output_layer = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
if num_decoder_layers > 0:
|
if num_decoder_layers > 0:
|
||||||
self.decoder_num_class = (
|
self.decoder_num_class = (
|
||||||
@ -325,6 +328,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# The common part between this function and decoder_forward could be
|
# The common part between this function and decoder_forward could be
|
||||||
# extracted as a separate function.
|
# extracted as a separate function.
|
||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||||
|
354
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
354
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
|
from lhotse.dataset import (
|
||||||
|
BucketingSampler,
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SingleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.dataset.datamodule import DataModule
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class LibriSpeechAsrDataModule(DataModule):
|
||||||
|
"""
|
||||||
|
DataModule for k2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
super().add_arguments(parser)
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--full-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, use 960h LibriSpeech. "
|
||||||
|
"Otherwise, use 100h subset.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--feature-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the BucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--concatenate-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
|
"to minimize the amount of padding.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--duration-factor",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
|
"relative to the duration of the longest cut in a batch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gap",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The amount of padding (in seconds) inserted between "
|
||||||
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
|
"noise augmentation is used.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(self) -> DataLoader:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
cuts_train = self.train_cuts()
|
||||||
|
|
||||||
|
logging.info("About to get Musan cuts")
|
||||||
|
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||||
|
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
logging.info(
|
||||||
|
f"Using cut concatenation with duration factor "
|
||||||
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
|
)
|
||||||
|
# Cut concatenation should be the first transform in the list,
|
||||||
|
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||||
|
# different utterances.
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
|
input_transforms = [
|
||||||
|
SpecAugment(
|
||||||
|
num_frame_masks=2,
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
train = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
# NOTE: the PerturbSpeed transform should be added only if we
|
||||||
|
# remove it from data prep stage.
|
||||||
|
# Add on-the-fly speed perturbation; since originally it would
|
||||||
|
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||||
|
# 3x more epochs.
|
||||||
|
# Speed perturbation probably should come first before
|
||||||
|
# concatenation, but in principle the transforms order doesn't have
|
||||||
|
# to be strict (e.g. could be randomized)
|
||||||
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
|
# Drop feats to be on the safe side.
|
||||||
|
train = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using BucketingSampler.")
|
||||||
|
train_sampler = BucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
bucket_method="equal_duration",
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SingleCutSampler.")
|
||||||
|
train_sampler = SingleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self) -> DataLoader:
|
||||||
|
logging.info("About to get dev cuts")
|
||||||
|
cuts_valid = self.valid_cuts()
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
validate = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = SingleCutSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create dev dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
|
cuts = self.test_cuts()
|
||||||
|
is_list = isinstance(cuts, list)
|
||||||
|
test_loaders = []
|
||||||
|
if not is_list:
|
||||||
|
cuts = [cuts]
|
||||||
|
|
||||||
|
for cuts_test in cuts:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
test = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
)
|
||||||
|
if self.args.on_the_fly_feats
|
||||||
|
else PrecomputedFeatures(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
sampler = SingleCutSampler(
|
||||||
|
cuts_test, max_duration=self.args.max_duration
|
||||||
|
)
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test, batch_size=None, sampler=sampler, num_workers=1
|
||||||
|
)
|
||||||
|
test_loaders.append(test_dl)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
return test_loaders
|
||||||
|
else:
|
||||||
|
return test_loaders[0]
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
cuts_train = load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||||
|
)
|
||||||
|
if self.args.full_libri:
|
||||||
|
cuts_train = (
|
||||||
|
cuts_train
|
||||||
|
+ load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||||
|
)
|
||||||
|
+ load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cuts_train
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev cuts")
|
||||||
|
cuts_valid = load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||||
|
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||||
|
return cuts_valid
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_cuts(self) -> List[CutSet]:
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
cuts = []
|
||||||
|
for test_set in test_sets:
|
||||||
|
logging.debug("About to get test cuts")
|
||||||
|
cuts.append(
|
||||||
|
load_manifest(
|
||||||
|
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cuts
|
@ -1,7 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
# Apache 2.0
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -43,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
is_espnet_structure: bool = False,
|
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
@ -70,12 +82,10 @@ class Conformer(Transformer):
|
|||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
is_espnet_structure,
|
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
self.is_espnet_structure = is_espnet_structure
|
if self.normalize_before:
|
||||||
if self.normalize_before and self.is_espnet_structure:
|
|
||||||
self.after_norm = nn.LayerNorm(d_model)
|
self.after_norm = nn.LayerNorm(d_model)
|
||||||
else:
|
else:
|
||||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||||
@ -88,7 +98,7 @@ class Conformer(Transformer):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The model input. Its shape is [N, T, C].
|
The model input. Its shape is (N, T, C).
|
||||||
supervisions:
|
supervisions:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -110,7 +120,7 @@ class Conformer(Transformer):
|
|||||||
mask = mask.to(x.device)
|
mask = mask.to(x.device)
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||||
|
|
||||||
if self.normalize_before and self.is_espnet_structure:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
@ -144,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
is_espnet_structure: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
|
d_model, nhead, dropout=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
@ -394,7 +403,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x.size(1)
|
||||||
+ 1 : self.pe.size(1) // 2
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x.size(1),
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
@ -421,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
is_espnet_structure: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RelPositionMultiheadAttention, self).__init__()
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -444,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
self.is_espnet_structure = is_espnet_structure
|
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
@ -675,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
if not self.is_espnet_structure:
|
|
||||||
q = q * scaling
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
attn_mask.dtype == torch.float32
|
attn_mask.dtype == torch.float32
|
||||||
@ -770,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
) # (batch, head, time1, 2*time1-1)
|
) # (batch, head, time1, 2*time1-1)
|
||||||
matrix_bd = self.rel_shift(matrix_bd)
|
matrix_bd = self.rel_shift(matrix_bd)
|
||||||
|
|
||||||
if not self.is_espnet_structure:
|
attn_output_weights = (
|
||||||
attn_output_weights = (
|
matrix_ac + matrix_bd
|
||||||
matrix_ac + matrix_bd
|
) * scaling # (batch, head, time1, time2)
|
||||||
) # (batch, head, time1, time2)
|
|
||||||
else:
|
|
||||||
attn_output_weights = (
|
|
||||||
matrix_ac + matrix_bd
|
|
||||||
) * scaling # (batch, head, time1, time2)
|
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
|
@ -1,8 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# (still working in progress)
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -13,14 +25,15 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
@ -32,6 +45,7 @@ from icefall.utils import (
|
|||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,51 +58,111 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=9,
|
default=34,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=20,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="attention-decoder",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||||
|
decoding result.
|
||||||
|
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||||
|
with the highest score is the decoding result.
|
||||||
|
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
|
is the decoding result.
|
||||||
|
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||||
|
lattice, the path with the highest score is the decoding result.
|
||||||
|
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
|
rescoring method.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for n-best based decoding method.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lattice-score-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||||
|
A smaller value results in more unique paths.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--export",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, the averaged model is saved to
|
||||||
|
conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||||
|
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||||
|
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_mmi/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_mmi/exp"),
|
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
|
# parameters for conformer
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"vgg_frontend": False,
|
||||||
|
"use_feat_batchnorm": True,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
"vgg_frontend": False,
|
# parameters for decoding
|
||||||
"is_espnet_structure": True,
|
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
"search_beam": 20,
|
"search_beam": 20,
|
||||||
"output_beam": 8,
|
"output_beam": 8,
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
# Possible values for method:
|
|
||||||
# - 1best
|
|
||||||
# - nbest
|
|
||||||
# - nbest-rescoring
|
|
||||||
# - whole-lattice-rescoring
|
|
||||||
# - attention-decoder
|
|
||||||
# "method": "whole-lattice-rescoring",
|
|
||||||
"method": "1best",
|
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
|
||||||
# and attention-decoder
|
|
||||||
"num_paths": 100,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -99,7 +173,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
lexicon: Lexicon,
|
word_table: k2.SymbolTable,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
@ -133,8 +207,8 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
word_table:
|
||||||
It contains word symbol table.
|
The word symbol table.
|
||||||
sos_id:
|
sos_id:
|
||||||
The token ID of the SOS.
|
The token ID of the SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
@ -151,12 +225,12 @@ def decode_one_batch(
|
|||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
# at entry, feature is [N, T, C]
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, T, C]
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -178,6 +252,24 @@ def decode_one_batch(
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is only slightly worse than that of rescored lattices.
|
||||||
|
best_path = nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
|
oov="<UNK>",
|
||||||
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
if params.method == "1best":
|
if params.method == "1best":
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
@ -189,11 +281,12 @@ def decode_one_batch(
|
|||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-{params.num_paths}"
|
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
assert params.method in [
|
assert params.method in [
|
||||||
@ -202,7 +295,8 @@ def decode_one_batch(
|
|||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
if params.method == "nbest-rescoring":
|
||||||
@ -211,16 +305,23 @@ def decode_one_batch(
|
|||||||
G=G,
|
G=G,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
elif params.method == "whole-lattice-rescoring":
|
elif params.method == "whole-lattice-rescoring":
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
)
|
)
|
||||||
elif params.method == "attention-decoder":
|
elif params.method == "attention-decoder":
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=None,
|
||||||
)
|
)
|
||||||
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||||
|
# `rescore_with_attention_decoder`
|
||||||
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
lattice=rescored_lattice,
|
lattice=rescored_lattice,
|
||||||
@ -230,15 +331,20 @@ def decode_one_batch(
|
|||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
if best_path_dict is not None:
|
||||||
hyps = get_texts(best_path)
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
hyps = get_texts(best_path)
|
||||||
ans[lm_scale_str] = hyps
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
ans[lm_scale_str] = hyps
|
||||||
|
else:
|
||||||
|
for lm_scale in lm_scale_list:
|
||||||
|
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +353,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
lexicon: Lexicon,
|
word_table: k2.SymbolTable,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
@ -263,8 +369,8 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
HLG:
|
HLG:
|
||||||
The decoding graph.
|
The decoding graph.
|
||||||
lexicon:
|
word_table:
|
||||||
It contains word symbol table.
|
It is the word symbol table.
|
||||||
sos_id:
|
sos_id:
|
||||||
The token ID for SOS.
|
The token ID for SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
@ -283,7 +389,11 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -294,7 +404,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
lexicon=lexicon,
|
word_table=word_table,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
@ -312,10 +422,10 @@ def decode_dataset(
|
|||||||
num_cuts += len(batch["supervisions"]["text"])
|
num_cuts += len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -374,8 +484,10 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params.exp_dir = Path(params.exp_dir)
|
||||||
|
params.lang_dir = Path(params.lang_dir)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -389,7 +501,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
graph_compiler = BpeMmiTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
device=device,
|
device=device,
|
||||||
sos_token="<sos/eos>",
|
sos_token="<sos/eos>",
|
||||||
@ -398,7 +510,9 @@ def main():
|
|||||||
sos_id = graph_compiler.sos_id
|
sos_id = graph_compiler.sos_id
|
||||||
eos_id = graph_compiler.eos_id
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -429,7 +543,7 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
@ -453,7 +567,6 @@ def main():
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
is_espnet_structure=params.is_espnet_structure,
|
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -468,6 +581,13 @@ def main():
|
|||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
@ -487,7 +607,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
lexicon=lexicon,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
@ -1,15 +1,26 @@
|
|||||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
# Apache 2.0
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
from icefall.utils import get_texts
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
@ -72,8 +83,8 @@ class Transformer(nn.Module):
|
|||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape [N, T, num_classes]
|
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||||
# to the shape [N, T//subsampling_factor, d_model].
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_classes -> d_model
|
# (2) embedding: num_classes -> d_model
|
||||||
@ -103,10 +114,15 @@ class Transformer(nn.Module):
|
|||||||
norm=encoder_norm,
|
norm=encoder_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_output_layer = nn.Linear(d_model, num_classes)
|
# TODO(fangjun): remove dropout
|
||||||
|
self.encoder_output_layer = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||||
|
)
|
||||||
|
|
||||||
if num_decoder_layers > 0:
|
if num_decoder_layers > 0:
|
||||||
self.decoder_num_class = self.num_classes
|
self.decoder_num_class = (
|
||||||
|
self.num_classes
|
||||||
|
) # bpe model already has sos/eos symbol
|
||||||
|
|
||||||
self.decoder_embed = nn.Embedding(
|
self.decoder_embed = nn.Embedding(
|
||||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||||
@ -146,7 +162,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The input tensor. Its shape is [N, T, C].
|
The input tensor. Its shape is (N, T, C).
|
||||||
supervision:
|
supervision:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -155,17 +171,17 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 3 tensors:
|
Return a tuple containing 3 tensors:
|
||||||
- CTC output for ctc decoding. Its shape is [N, T, C]
|
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
- Encoder output with shape (T, N, C). It can be used as key and
|
||||||
value for the decoder.
|
value for the decoder.
|
||||||
- Encoder output padding mask. It can be used as
|
- Encoder output padding mask. It can be used as
|
||||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||||
It is None if `supervision` is None.
|
It is None if `supervision` is None.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
if self.use_feat_batchnorm:
|
||||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x, supervision
|
x, supervision
|
||||||
)
|
)
|
||||||
@ -179,7 +195,7 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The model input. Its shape is [N, T, C].
|
The model input. Its shape is (N, T, C).
|
||||||
supervisions:
|
supervisions:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -190,8 +206,8 @@ class Transformer(nn.Module):
|
|||||||
padding mask for the decoder.
|
padding mask for the decoder.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple with two tensors:
|
Return a tuple with two tensors:
|
||||||
- The encoder output, with shape [T, N, C]
|
- The encoder output, with shape (T, N, C)
|
||||||
- encoder padding mask, with shape [N, T].
|
- encoder padding mask, with shape (N, T).
|
||||||
The mask is None if `supervisions` is None.
|
The mask is None if `supervisions` is None.
|
||||||
It is used as memory key padding mask in the decoder.
|
It is used as memory key padding mask in the decoder.
|
||||||
"""
|
"""
|
||||||
@ -209,11 +225,11 @@ class Transformer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The output tensor from the transformer encoder.
|
The output tensor from the transformer encoder.
|
||||||
Its shape is [T, N, C]
|
Its shape is (T, N, C)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor that can be used for CTC decoding.
|
Return a tensor that can be used for CTC decoding.
|
||||||
Its shape is [N, T, C]
|
Its shape is (N, T, C)
|
||||||
"""
|
"""
|
||||||
x = self.encoder_output_layer(x)
|
x = self.encoder_output_layer(x)
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -231,7 +247,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
memory:
|
memory:
|
||||||
It's the output of the encoder with shape [T, N, C]
|
It's the output of the encoder with shape (T, N, C)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
token_ids:
|
token_ids:
|
||||||
@ -296,7 +312,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
memory:
|
memory:
|
||||||
It's the output of the encoder with shape [T, N, C]
|
It's the output of the encoder with shape (T, N, C)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
token_ids:
|
token_ids:
|
||||||
@ -312,6 +328,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# The common part between this function and decoder_forward could be
|
# The common part between this function and decoder_forward could be
|
||||||
# extracted as a separate function.
|
# extracted as a separate function.
|
||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||||
@ -329,6 +346,9 @@ class Transformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
|
# TODO: Use length information to create the decoder padding mask
|
||||||
|
# We set the first column to False since the first column in ys_in_pad
|
||||||
|
# contains sos_id, which is the same as eos_id in our current setting.
|
||||||
tgt_key_padding_mask[:, 0] = False
|
tgt_key_padding_mask[:, 0] = False
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||||
@ -634,13 +654,13 @@ class PositionalEncoding(nn.Module):
|
|||||||
def extend_pe(self, x: torch.Tensor) -> None:
|
def extend_pe(self, x: torch.Tensor) -> None:
|
||||||
"""Extend the time t in the positional encoding if required.
|
"""Extend the time t in the positional encoding if required.
|
||||||
|
|
||||||
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||||
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||||
to [N, T, d_model]. Otherwise, nothing is done.
|
to (N, T, d_model). Otherwise, nothing is done.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
It is a tensor of shape [N, T, C].
|
It is a tensor of shape (N, T, C).
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
@ -658,7 +678,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
pe = pe.unsqueeze(0)
|
pe = pe.unsqueeze(0)
|
||||||
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -667,10 +687,10 @@ class PositionalEncoding(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
Its shape is [N, T, C]
|
Its shape is (N, T, C)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape [N, T, C]
|
Return a tensor of shape (N, T, C)
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||||
@ -766,7 +786,8 @@ class Noam(object):
|
|||||||
|
|
||||||
class LabelSmoothingLoss(nn.Module):
|
class LabelSmoothingLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
|
Label-smoothing loss. KL-divergence between
|
||||||
|
q_{smoothed ground truth prob.}(w)
|
||||||
and p_{prob. computed by model}(w) is minimized.
|
and p_{prob. computed by model}(w) is minimized.
|
||||||
Modified from
|
Modified from
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||||
@ -851,7 +872,8 @@ def encoder_padding_mask(
|
|||||||
frames, before subsampling)
|
frames, before subsampling)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
|
Tensor: Mask tensor of dimension (batch_size, input_length),
|
||||||
|
True denote the masked indices.
|
||||||
"""
|
"""
|
||||||
if supervisions is None:
|
if supervisions is None:
|
||||||
return None
|
return None
|
||||||
|
@ -8,8 +8,8 @@ for LM training with the help of a lexicon.
|
|||||||
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
||||||
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
||||||
|
|
||||||
If a word has multiple pronunciations, the one that appears last in the lexicon
|
If a word has multiple pronunciations, the one that appears first in the lexicon
|
||||||
is used.
|
is kept; others are removed.
|
||||||
|
|
||||||
If the input transcript is:
|
If the input transcript is:
|
||||||
|
|
||||||
@ -20,8 +20,8 @@ If the input transcript is:
|
|||||||
and if the lexicon is
|
and if the lexicon is
|
||||||
|
|
||||||
<UNK> SPN
|
<UNK> SPN
|
||||||
hello h e l l o
|
|
||||||
hello h e l l o 2
|
hello h e l l o 2
|
||||||
|
hello h e l l o
|
||||||
world w o r l d
|
world w o r l d
|
||||||
zoo z o o
|
zoo z o o
|
||||||
|
|
||||||
@ -32,10 +32,11 @@ Then the output is
|
|||||||
SPN z o o w o r l d SPN
|
SPN z o o w o r l d SPN
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from generate_unique_lexicon import filter_multiple_pronunications
|
||||||
|
|
||||||
from icefall.lexicon import read_lexicon
|
from icefall.lexicon import read_lexicon
|
||||||
|
|
||||||
@ -57,7 +58,9 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None:
|
def process_line(
|
||||||
|
lexicon: Dict[str, List[str]], line: str, oov_token: str
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lexicon:
|
lexicon:
|
||||||
@ -86,7 +89,11 @@ def main():
|
|||||||
assert Path(args.transcript).is_file()
|
assert Path(args.transcript).is_file()
|
||||||
assert len(args.oov) > 0
|
assert len(args.oov) > 0
|
||||||
|
|
||||||
lexicon = dict(read_lexicon(args.lexicon))
|
# Only the first pronunciation of a word is kept
|
||||||
|
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
|
||||||
|
|
||||||
|
lexicon = dict(lexicon)
|
||||||
|
|
||||||
assert args.oov in lexicon
|
assert args.oov in lexicon
|
||||||
|
|
||||||
oov_token = lexicon[args.oov]
|
oov_token = lexicon[args.oov]
|
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file takes as input a lexicon.txt and output a new lexicon,
|
||||||
|
in which each word has a unique pronunciation.
|
||||||
|
|
||||||
|
The way to do this is to keep only the first pronunciation of a word
|
||||||
|
in lexicon.txt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from icefall.lexicon import read_lexicon, write_lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
This file will generate a new file uniq_lexicon.txt
|
||||||
|
in it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def filter_multiple_pronunications(
|
||||||
|
lexicon: List[Tuple[str, List[str]]]
|
||||||
|
) -> List[Tuple[str, List[str]]]:
|
||||||
|
"""Remove multiple pronunciations of words from a lexicon.
|
||||||
|
|
||||||
|
If a word has more than one pronunciation in the lexicon, only
|
||||||
|
the first one is kept, while other pronunciations are removed
|
||||||
|
from the lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
|
||||||
|
where "p1, p2, ..., pn" are the pronunciations of the "word".
|
||||||
|
Returns:
|
||||||
|
Return a new lexicon where each word has a unique pronunciation.
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
if word in seen:
|
||||||
|
continue
|
||||||
|
seen.add(word)
|
||||||
|
ans.append((word, tokens))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
|
||||||
|
in_lexicon = read_lexicon(lexicon_filename)
|
||||||
|
|
||||||
|
out_lexicon = filter_multiple_pronunications(in_lexicon)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
|
||||||
|
|
||||||
|
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
|
||||||
|
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
|||||||
|
|
||||||
5. Generate L_disambig.pt, in k2 format.
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -42,10 +43,37 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from icefall.lexicon import read_lexicon, write_lexicon
|
from icefall.lexicon import read_lexicon, write_lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
Lexicon = List[Tuple[str, List[str]]]
|
Lexicon = List[Tuple[str, List[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
Generated files by this script are saved into this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
"""Write a symbol to ID mapping to a file.
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
@ -315,8 +343,9 @@ def lexicon_to_fst(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
out_dir = Path("data/lang_phone")
|
args = get_args()
|
||||||
lexicon_filename = out_dir / "lexicon.txt"
|
lang_dir = Path(args.lang_dir)
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
sil_token = "SIL"
|
sil_token = "SIL"
|
||||||
sil_prob = 0.5
|
sil_prob = 0.5
|
||||||
|
|
||||||
@ -344,9 +373,9 @@ def main():
|
|||||||
token2id = generate_id_map(tokens)
|
token2id = generate_id_map(tokens)
|
||||||
word2id = generate_id_map(words)
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
write_mapping(out_dir / "tokens.txt", token2id)
|
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||||
write_mapping(out_dir / "words.txt", word2id)
|
write_mapping(lang_dir / "words.txt", word2id)
|
||||||
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
L = lexicon_to_fst(
|
L = lexicon_to_fst(
|
||||||
lexicon,
|
lexicon,
|
||||||
@ -364,17 +393,20 @@ def main():
|
|||||||
sil_prob=sil_prob,
|
sil_prob=sil_prob,
|
||||||
need_self_loops=True,
|
need_self_loops=True,
|
||||||
)
|
)
|
||||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
if False:
|
if args.debug:
|
||||||
# Just for debugging, will remove it
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
|
||||||
L_disambig.labels_sym = L.labels_sym
|
L.labels_sym = labels_sym
|
||||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
L.aux_labels_sym = aux_labels_sym
|
||||||
L.draw(out_dir / "L.png", title="L")
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -49,6 +49,8 @@ from prepare_lang import (
|
|||||||
write_mapping,
|
write_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def lexicon_to_fst_no_sil(
|
def lexicon_to_fst_no_sil(
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
@ -169,6 +171,20 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
|
||||||
|
See "test/test_bpe_lexicon.py" for usage.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -221,6 +237,18 @@ def main():
|
|||||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
# You can install sentencepiece via:
|
# You can install sentencepiece via:
|
||||||
#
|
#
|
||||||
# pip install sentencepiece
|
# pip install sentencepiece
|
||||||
@ -37,10 +38,17 @@ def get_args():
|
|||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Input and output directory.
|
help="""Input and output directory.
|
||||||
It should contain the training corpus: train.txt.
|
It should contain the training corpus: transcript_words.txt.
|
||||||
The generated bpe.model is saved to this directory.
|
The generated bpe.model is saved to this directory.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript",
|
||||||
|
type=str,
|
||||||
|
help="Training transcript.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--vocab-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -58,7 +66,7 @@ def main():
|
|||||||
model_type = "unigram"
|
model_type = "unigram"
|
||||||
|
|
||||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||||
train_text = f"{lang_dir}/train.txt"
|
train_text = args.transcript
|
||||||
character_coverage = 1.0
|
character_coverage = 1.0
|
||||||
input_sentence_size = 100000000
|
input_sentence_size = 100000000
|
||||||
|
|
||||||
|
@ -40,9 +40,9 @@ dl_dir=$PWD/download
|
|||||||
# It will generate data/lang_bpe_xxx,
|
# It will generate data/lang_bpe_xxx,
|
||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
5000
|
# 5000
|
||||||
2000
|
# 2000
|
||||||
1000
|
# 1000
|
||||||
500
|
500
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -116,14 +116,15 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Prepare phone based lang"
|
log "Stage 5: Prepare phone based lang"
|
||||||
mkdir -p data/lang_phone
|
lang_dir=data/lang_phone
|
||||||
|
mkdir -p $lang_dir
|
||||||
|
|
||||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||||
sort | uniq > data/lang_phone/lexicon.txt
|
sort | uniq > $lang_dir/lexicon.txt
|
||||||
|
|
||||||
if [ ! -f data/lang_phone/L_disambig.pt ]; then
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_lang.py
|
./local/prepare_lang.py --lang-dir $lang_dir
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -138,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
# so that the two can share G.pt later.
|
# so that the two can share G.pt later.
|
||||||
cp data/lang_phone/words.txt $lang_dir
|
cp data/lang_phone/words.txt $lang_dir
|
||||||
|
|
||||||
if [ ! -f $lang_dir/train.txt ]; then
|
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||||
log "Generate data for BPE training"
|
log "Generate data for BPE training"
|
||||||
files=$(
|
files=$(
|
||||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||||
@ -147,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
)
|
)
|
||||||
for f in ${files[@]}; do
|
for f in ${files[@]}; do
|
||||||
cat $f | cut -d " " -f 2-
|
cat $f | cut -d " " -f 2-
|
||||||
done > $lang_dir/train.txt
|
done > $lang_dir/transcript_words.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
./local/train_bpe_model.py \
|
./local/train_bpe_model.py \
|
||||||
--lang-dir $lang_dir \
|
--lang-dir $lang_dir \
|
||||||
--vocab-size $vocab_size
|
--vocab-size $vocab_size \
|
||||||
|
--transcript $lang_dir/transcript_words.txt
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||||
@ -166,18 +168,18 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
|||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
|
||||||
if [ ! -f $lang_dir/corpus.txt ]; then
|
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
|
||||||
./local/convert_transcript_to_corpus.py \
|
./local/convert_transcript_words_to_tokens.py \
|
||||||
--lexicon data/lang_bpe/lexicon.txt \
|
--lexicon $lang_dir/lexicon.txt \
|
||||||
--transcript data/lang_bpe/train.txt \
|
--transcript $lang_dir/transcript_words.txt \
|
||||||
--oov "<UNK>" \
|
--oov "<UNK>" \
|
||||||
> $lang_dir/corpus.txt
|
> $lang_dir/transcript_tokens.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_dir/P.arpa ]; then
|
if [ ! -f $lang_dir/P.arpa ]; then
|
||||||
./shared/make_kn_lm.py \
|
./shared/make_kn_lm.py \
|
||||||
-ngram-order 2 \
|
-ngram-order 2 \
|
||||||
-text $lang_dir/corpus.txt \
|
-text $lang_dir/transcript_tokens.txt \
|
||||||
-lm $lang_dir/P.arpa
|
-lm $lang_dir/P.arpa
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -226,4 +228,4 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
cd data && ln -sfv lang_bpe_500 lang_bpe
|
||||||
|
@ -34,14 +34,10 @@ class BpeCtcTrainingGraphCompiler(object):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
This directory is expected to contain the following files::
|
This directory is expected to contain the following files:
|
||||||
|
|
||||||
- bpe.model
|
- bpe.model
|
||||||
- words.txt
|
- words.txt
|
||||||
|
|
||||||
The above files are produced by the script `prepare.sh`. You
|
|
||||||
should have run that before running the training code.
|
|
||||||
|
|
||||||
device:
|
device:
|
||||||
It indicates CPU or CUDA.
|
It indicates CPU or CUDA.
|
||||||
sos_token:
|
sos_token:
|
||||||
|
@ -1,178 +0,0 @@
|
|||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
|
|
||||||
|
|
||||||
class BpeMmiTrainingGraphCompiler(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
lang_dir: Path,
|
|
||||||
device: Union[str, torch.device] = "cpu",
|
|
||||||
sos_token: str = "<sos/eos>",
|
|
||||||
eos_token: str = "<sos/eos>",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
lang_dir:
|
|
||||||
Path to the lang directory. It is expected to contain the
|
|
||||||
following files::
|
|
||||||
|
|
||||||
- tokens.txt
|
|
||||||
- words.txt
|
|
||||||
- bpe.model
|
|
||||||
- P.fst.txt
|
|
||||||
|
|
||||||
The above files are generated by the script `prepare.sh`. You
|
|
||||||
should have run it before running the training code.
|
|
||||||
|
|
||||||
device:
|
|
||||||
It indicates CPU or CUDA.
|
|
||||||
sos_token:
|
|
||||||
The word piece that represents sos.
|
|
||||||
eos_token:
|
|
||||||
The word piece that represents eos.
|
|
||||||
"""
|
|
||||||
self.lang_dir = Path(lang_dir)
|
|
||||||
self.lexicon = Lexicon(lang_dir)
|
|
||||||
self.device = device
|
|
||||||
self.load_sentence_piece_model()
|
|
||||||
self.build_ctc_topo_P()
|
|
||||||
|
|
||||||
self.sos_id = self.sp.piece_to_id(sos_token)
|
|
||||||
self.eos_id = self.sp.piece_to_id(eos_token)
|
|
||||||
|
|
||||||
assert self.sos_id != self.sp.unk_id()
|
|
||||||
assert self.eos_id != self.sp.unk_id()
|
|
||||||
|
|
||||||
def load_sentence_piece_model(self) -> None:
|
|
||||||
"""Load the pre-trained sentencepiece model
|
|
||||||
from self.lang_dir/bpe.model.
|
|
||||||
"""
|
|
||||||
model_file = self.lang_dir / "bpe.model"
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(str(model_file))
|
|
||||||
self.sp = sp
|
|
||||||
|
|
||||||
def build_ctc_topo_P(self):
|
|
||||||
"""Built ctc_topo_P, the composition result of
|
|
||||||
ctc_topo and P, where P is a pre-trained bigram
|
|
||||||
word piece LM.
|
|
||||||
"""
|
|
||||||
# Note: there is no need to save a pre-compiled P and ctc_topo
|
|
||||||
# as it is very fast to generate them.
|
|
||||||
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
|
||||||
with open(self.lang_dir / "P.fst.txt") as f:
|
|
||||||
# P is not an acceptor because there is
|
|
||||||
# a back-off state, whose incoming arcs
|
|
||||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
|
||||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
||||||
|
|
||||||
first_token_disambig_id = self.lexicon.token_table["#0"]
|
|
||||||
|
|
||||||
# P.aux_labels is not needed in later computations, so
|
|
||||||
# remove it here.
|
|
||||||
del P.aux_labels
|
|
||||||
# CAUTION: The following line is crucial.
|
|
||||||
# Arcs entering the back-off state have label equal to #0.
|
|
||||||
# We have to change it to 0 here.
|
|
||||||
P.labels[P.labels >= first_token_disambig_id] = 0
|
|
||||||
|
|
||||||
P = k2.remove_epsilon(P)
|
|
||||||
P = k2.arc_sort(P)
|
|
||||||
P = P.to(self.device)
|
|
||||||
# Add epsilon self-loops to P because we want the
|
|
||||||
# following operation "k2.intersect" to run on GPU.
|
|
||||||
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
|
||||||
|
|
||||||
max_token_id = max(self.lexicon.tokens)
|
|
||||||
logging.info(
|
|
||||||
f"Building modified ctc_topo. max_token_id: {max_token_id}"
|
|
||||||
)
|
|
||||||
# CAUTION: We have to use a modifed version of CTC topo.
|
|
||||||
# Otherwise, the resulting ctc_topo_P is so large that it gets
|
|
||||||
# stuck in k2.intersect_dense_pruned() or it gets OOM in
|
|
||||||
# k2.intersect_dense()
|
|
||||||
ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device)
|
|
||||||
|
|
||||||
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
|
||||||
|
|
||||||
logging.info("Building ctc_topo_P")
|
|
||||||
ctc_topo_P = k2.intersect(
|
|
||||||
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
|
||||||
).invert()
|
|
||||||
|
|
||||||
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
|
||||||
|
|
||||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
|
||||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts:
|
|
||||||
A list of transcripts. Within a transcript words are
|
|
||||||
separated by spaces. An example input is::
|
|
||||||
|
|
||||||
['HELLO ICEFALL', 'HELLO k2']
|
|
||||||
Returns:
|
|
||||||
Return a list-of-list of piece IDs.
|
|
||||||
"""
|
|
||||||
return self.sp.encode(texts, out_type=int)
|
|
||||||
|
|
||||||
def compile(
|
|
||||||
self, texts: List[str], replicate_den: bool = True
|
|
||||||
) -> Tuple[k2.Fsa, k2.Fsa]:
|
|
||||||
"""Create numerator and denominator graphs from transcripts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts:
|
|
||||||
A list of transcripts. Within a transcript words are
|
|
||||||
separated by spaces. An example input is::
|
|
||||||
|
|
||||||
["HELLO icefall", "HALLO WELT"]
|
|
||||||
|
|
||||||
replicate_den:
|
|
||||||
If True, the returned den_graph is replicated to match the number
|
|
||||||
of FSAs in the returned num_graph; if False, the returned den_graph
|
|
||||||
contains only a single FSA
|
|
||||||
Returns:
|
|
||||||
A tuple (num_graphs, den_graphs), where
|
|
||||||
|
|
||||||
- `num_graphs` is the numerator graph. It is an FsaVec with
|
|
||||||
shape `(len(texts), None, None)`.
|
|
||||||
|
|
||||||
- `den_graphs` is the denominator graph. It is an FsaVec with the
|
|
||||||
same shape of the `num_graph` if replicate_den is True;
|
|
||||||
otherwise, it is an FsaVec containing only a single FSA.
|
|
||||||
"""
|
|
||||||
token_ids = self.texts_to_ids(texts)
|
|
||||||
token_fsas = k2.linear_fsa(token_ids, device=self.device)
|
|
||||||
|
|
||||||
token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas)
|
|
||||||
|
|
||||||
# NOTE: Use treat_epsilons_specially=False so that k2.compose
|
|
||||||
# can be run on GPU
|
|
||||||
num_graphs = k2.compose(
|
|
||||||
self.ctc_topo_P,
|
|
||||||
token_fsas_with_self_loops,
|
|
||||||
treat_epsilons_specially=False,
|
|
||||||
)
|
|
||||||
# num_graphs may not be connected and
|
|
||||||
# not be topologically sorted after k2.compose
|
|
||||||
num_graphs = k2.connect(num_graphs)
|
|
||||||
num_graphs = k2.top_sort(num_graphs)
|
|
||||||
|
|
||||||
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()])
|
|
||||||
if replicate_den:
|
|
||||||
indexes = torch.zeros(
|
|
||||||
len(texts), dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes)
|
|
||||||
else:
|
|
||||||
den_graphs = ctc_topo_P_vec
|
|
||||||
|
|
||||||
return num_graphs, den_graphs
|
|
@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
|||||||
f.write(f"{word} {' '.join(tokens)}\n")
|
f.write(f"{word} {' '.join(tokens)}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lexicon_to_ragged(
|
||||||
|
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||||
|
) -> k2.RaggedTensor:
|
||||||
|
"""Read a lexicon and convert it to a ragged tensor.
|
||||||
|
|
||||||
|
The ragged tensor has two axes: [word][token].
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
We assume that each word has a unique pronunciation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the lexicon. It has a format that can be read
|
||||||
|
by :func:`read_lexicon`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
token_table:
|
||||||
|
The token symbol table.
|
||||||
|
Returns:
|
||||||
|
A k2 ragged tensor with two axes [word][token].
|
||||||
|
"""
|
||||||
|
disambig_id = word_table["#0"]
|
||||||
|
# We reuse the same words.txt from the phone based lexicon
|
||||||
|
# so that we can share the same G.fst. Here, we have to
|
||||||
|
# exclude some words present only in the phone based lexicon.
|
||||||
|
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||||
|
|
||||||
|
# epsilon is not a word, but it occupies a position
|
||||||
|
#
|
||||||
|
row_splits = [0]
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
lexicon_tmp = read_lexicon(filename)
|
||||||
|
lexicon = dict(lexicon_tmp)
|
||||||
|
if len(lexicon_tmp) != len(lexicon):
|
||||||
|
raise RuntimeError(
|
||||||
|
"It's assumed that each word has a unique pronunciation"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(disambig_id):
|
||||||
|
w = word_table[i]
|
||||||
|
if w in excluded_words:
|
||||||
|
row_splits.append(row_splits[-1])
|
||||||
|
continue
|
||||||
|
tokens = lexicon[w]
|
||||||
|
token_ids = [token_table[k] for k in tokens]
|
||||||
|
|
||||||
|
row_splits.append(row_splits[-1] + len(token_ids))
|
||||||
|
token_ids_list.extend(token_ids)
|
||||||
|
|
||||||
|
cached_tot_size = row_splits[-1]
|
||||||
|
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||||
|
|
||||||
|
shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits,
|
||||||
|
None,
|
||||||
|
cached_tot_size,
|
||||||
|
)
|
||||||
|
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||||
|
|
||||||
|
return k2.RaggedTensor(shape, values)
|
||||||
|
|
||||||
|
|
||||||
class Lexicon(object):
|
class Lexicon(object):
|
||||||
"""Phone based lexicon."""
|
"""Phone based lexicon."""
|
||||||
|
|
||||||
@ -96,12 +159,10 @@ class Lexicon(object):
|
|||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
Path to the lang directory. It is expected to contain the following
|
Path to the lang directory. It is expected to contain the following
|
||||||
files::
|
files:
|
||||||
|
|
||||||
- tokens.txt
|
- tokens.txt
|
||||||
- words.txt
|
- words.txt
|
||||||
- L.pt
|
- L.pt
|
||||||
|
|
||||||
The above files are produced by the script `prepare.sh`. You
|
The above files are produced by the script `prepare.sh`. You
|
||||||
should have run that before running the training code.
|
should have run that before running the training code.
|
||||||
disambig_pattern:
|
disambig_pattern:
|
||||||
@ -121,7 +182,7 @@ class Lexicon(object):
|
|||||||
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||||
|
|
||||||
# We save L_inv instead of L because it will be used to intersect with
|
# We save L_inv instead of L because it will be used to intersect with
|
||||||
# transcript, both of whose labels are word IDs.
|
# transcript FSAs, both of whose labels are word IDs.
|
||||||
self.L_inv = L_inv
|
self.L_inv = L_inv
|
||||||
self.disambig_pattern = disambig_pattern
|
self.disambig_pattern = disambig_pattern
|
||||||
|
|
||||||
@ -144,69 +205,66 @@ class Lexicon(object):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class BpeLexicon(Lexicon):
|
class UniqLexicon(Lexicon):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
lang_dir: Path,
|
lang_dir: Path,
|
||||||
|
uniq_filename: str = "uniq_lexicon.txt",
|
||||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Refer to the help information in Lexicon.__init__.
|
Refer to the help information in Lexicon.__init__.
|
||||||
|
|
||||||
|
uniq_filename: It is assumed to be inside the given `lang_dir`.
|
||||||
|
|
||||||
|
Each word in the lexicon is assumed to have a unique pronunciation.
|
||||||
"""
|
"""
|
||||||
|
lang_dir = Path(lang_dir)
|
||||||
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
|
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
|
||||||
|
|
||||||
self.ragged_lexicon = self.convert_lexicon_to_ragged(
|
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||||
lang_dir / "lexicon.txt"
|
filename=lang_dir / uniq_filename,
|
||||||
|
word_table=self.word_table,
|
||||||
|
token_table=self.token_table,
|
||||||
)
|
)
|
||||||
|
# TODO: should we move it to a certain device ?
|
||||||
|
|
||||||
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
|
def texts_to_token_ids(
|
||||||
"""Read a BPE lexicon from file and convert it to a
|
self, texts: List[str], oov: str = "<UNK>"
|
||||||
k2 ragged tensor.
|
) -> k2.RaggedTensor:
|
||||||
|
|
||||||
Args:
|
|
||||||
filename:
|
|
||||||
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
|
|
||||||
Returns:
|
|
||||||
A k2 ragged tensor with two axes [word_id]
|
|
||||||
"""
|
"""
|
||||||
disambig_id = self.word_table["#0"]
|
Args:
|
||||||
# We reuse the same words.txt from the phone based lexicon
|
texts:
|
||||||
# so that we can share the same G.fst. Here, we have to
|
A list of transcripts. Each transcript contains space(s)
|
||||||
# exclude some words present only in the phone based lexicon.
|
separated words. An example texts is::
|
||||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
|
||||||
|
|
||||||
# epsilon is not a word, but it occupies on position
|
['HELLO k2', 'HELLO icefall']
|
||||||
#
|
oov:
|
||||||
row_splits = [0]
|
The OOV word. If a word in `texts` is not in the lexicon, it is
|
||||||
token_ids = []
|
replaced with `oov`.
|
||||||
|
Returns:
|
||||||
|
Return a ragged int tensor with 2 axes [utterance][token_id]
|
||||||
|
"""
|
||||||
|
oov_id = self.word_table[oov]
|
||||||
|
|
||||||
lexicon = read_lexicon(filename)
|
word_ids_list = []
|
||||||
lexicon = dict(lexicon)
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||||
|
ans = self.ragged_lexicon.index(ragged_indexes)
|
||||||
|
ans = ans.remove_axis(ans.num_axes - 2)
|
||||||
|
return ans
|
||||||
|
|
||||||
for i in range(disambig_id):
|
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||||
w = self.word_table[i]
|
"""Convert a list of words to a ragged tensor containing token IDs.
|
||||||
if w in excluded_words:
|
|
||||||
row_splits.append(row_splits[-1])
|
|
||||||
continue
|
|
||||||
pieces = lexicon[w]
|
|
||||||
piece_ids = [self.token_table[k] for k in pieces]
|
|
||||||
|
|
||||||
row_splits.append(row_splits[-1] + len(piece_ids))
|
We assume there are no OOVs in "words".
|
||||||
token_ids.extend(piece_ids)
|
|
||||||
|
|
||||||
cached_tot_size = row_splits[-1]
|
|
||||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
|
||||||
|
|
||||||
shape = k2.ragged.create_ragged_shape2(
|
|
||||||
row_splits=row_splits, cached_tot_size=cached_tot_size
|
|
||||||
)
|
|
||||||
values = torch.tensor(token_ids, dtype=torch.int32)
|
|
||||||
|
|
||||||
return k2.RaggedTensor(shape, values)
|
|
||||||
|
|
||||||
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
|
|
||||||
"""Convert a list of words to a ragged tensor contained
|
|
||||||
word piece IDs.
|
|
||||||
"""
|
"""
|
||||||
word_ids = [self.word_table[w] for w in words]
|
word_ids = [self.word_table[w] for w in words]
|
||||||
word_ids = torch.tensor(word_ids, dtype=torch.int32)
|
word_ids = torch.tensor(word_ids, dtype=torch.int32)
|
||||||
|
216
icefall/mmi_graph_compiler.py
Normal file
216
icefall/mmi_graph_compiler.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import UniqLexicon
|
||||||
|
|
||||||
|
|
||||||
|
class MmiTrainingGraphCompiler(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lang_dir: Path,
|
||||||
|
uniq_filename: str = "uniq_lexicon.txt",
|
||||||
|
device: Union[str, torch.device] = "cpu",
|
||||||
|
oov: str = "<UNK>",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
Path to the lang directory. It is expected to contain the
|
||||||
|
following files::
|
||||||
|
|
||||||
|
- tokens.txt
|
||||||
|
- words.txt
|
||||||
|
- P.fst.txt
|
||||||
|
|
||||||
|
The above files are generated by the script `prepare.sh`. You
|
||||||
|
should have run it before running the training code.
|
||||||
|
uniq_filename:
|
||||||
|
File name to the lexicon in which every word has exactly one
|
||||||
|
pronunciation. We assume this file is inside the given `lang_dir`.
|
||||||
|
|
||||||
|
device:
|
||||||
|
It indicates CPU or CUDA.
|
||||||
|
oov:
|
||||||
|
Out of vocabulary word. When a word in the transcript
|
||||||
|
does not exist in the lexicon, it is replaced with `oov`.
|
||||||
|
"""
|
||||||
|
self.lang_dir = Path(lang_dir)
|
||||||
|
self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
self.L_inv = self.lexicon.L_inv.to(self.device)
|
||||||
|
|
||||||
|
self.oov_id = self.lexicon.word_table[oov]
|
||||||
|
|
||||||
|
self.build_ctc_topo_P()
|
||||||
|
|
||||||
|
def build_ctc_topo_P(self):
|
||||||
|
"""Built ctc_topo_P, the composition result of
|
||||||
|
ctc_topo and P, where P is a pre-trained bigram
|
||||||
|
word piece LM.
|
||||||
|
"""
|
||||||
|
# Note: there is no need to save a pre-compiled P and ctc_topo
|
||||||
|
# as it is very fast to generate them.
|
||||||
|
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
||||||
|
with open(self.lang_dir / "P.fst.txt") as f:
|
||||||
|
# P is not an acceptor because there is
|
||||||
|
# a back-off state, whose incoming arcs
|
||||||
|
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||||
|
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
|
||||||
|
first_token_disambig_id = self.lexicon.token_table["#0"]
|
||||||
|
|
||||||
|
# P.aux_labels is not needed in later computations, so
|
||||||
|
# remove it here.
|
||||||
|
del P.aux_labels
|
||||||
|
# CAUTION: The following line is crucial.
|
||||||
|
# Arcs entering the back-off state have label equal to #0.
|
||||||
|
# We have to change it to 0 here.
|
||||||
|
P.labels[P.labels >= first_token_disambig_id] = 0
|
||||||
|
|
||||||
|
P = k2.remove_epsilon(P)
|
||||||
|
P = k2.arc_sort(P)
|
||||||
|
P = P.to(self.device)
|
||||||
|
# Add epsilon self-loops to P because we want the
|
||||||
|
# following operation "k2.intersect" to run on GPU.
|
||||||
|
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
||||||
|
|
||||||
|
max_token_id = max(self.lexicon.tokens)
|
||||||
|
logging.info(
|
||||||
|
f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
|
||||||
|
)
|
||||||
|
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
|
||||||
|
|
||||||
|
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
||||||
|
|
||||||
|
logging.info("Building ctc_topo_P")
|
||||||
|
ctc_topo_P = k2.intersect(
|
||||||
|
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
||||||
|
).invert()
|
||||||
|
|
||||||
|
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self, texts: Iterable[str], replicate_den: bool = True
|
||||||
|
) -> Tuple[k2.Fsa, k2.Fsa]:
|
||||||
|
"""Create numerator and denominator graphs from transcripts
|
||||||
|
and the bigram phone LM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts. Within a transcript, words are
|
||||||
|
separated by spaces. An example `texts` is given below::
|
||||||
|
|
||||||
|
["Hello icefall", "LF-MMI training with icefall using k2"]
|
||||||
|
|
||||||
|
replicate_den:
|
||||||
|
If True, the returned den_graph is replicated to match the number
|
||||||
|
of FSAs in the returned num_graph; if False, the returned den_graph
|
||||||
|
contains only a single FSA
|
||||||
|
Returns:
|
||||||
|
A tuple (num_graph, den_graph), where
|
||||||
|
|
||||||
|
- `num_graph` is the numerator graph. It is an FsaVec with
|
||||||
|
shape `(len(texts), None, None)`.
|
||||||
|
|
||||||
|
- `den_graph` is the denominator graph. It is an FsaVec
|
||||||
|
with the same shape of the `num_graph` if replicate_den is
|
||||||
|
True; otherwise, it is an FsaVec containing only a single FSA.
|
||||||
|
"""
|
||||||
|
transcript_fsa = self.build_transcript_fsa(texts)
|
||||||
|
|
||||||
|
# remove word IDs from transcript_fsa since it is not needed
|
||||||
|
del transcript_fsa.aux_labels
|
||||||
|
# NOTE: You can comment out the above statement
|
||||||
|
# if you want to run test/test_mmi_graph_compiler.py
|
||||||
|
|
||||||
|
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||||
|
transcript_fsa
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_fsa_with_self_loops = k2.arc_sort(
|
||||||
|
transcript_fsa_with_self_loops
|
||||||
|
)
|
||||||
|
|
||||||
|
num = k2.compose(
|
||||||
|
self.ctc_topo_P,
|
||||||
|
transcript_fsa_with_self_loops,
|
||||||
|
treat_epsilons_specially=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CAUTION: Due to the presence of P,
|
||||||
|
# the resulting `num` may not be connected
|
||||||
|
num = k2.connect(num)
|
||||||
|
|
||||||
|
num = k2.arc_sort(num)
|
||||||
|
|
||||||
|
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
|
||||||
|
if replicate_den:
|
||||||
|
indexes = torch.zeros(
|
||||||
|
len(texts), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
den = k2.index_fsa(ctc_topo_P_vec, indexes)
|
||||||
|
else:
|
||||||
|
den = ctc_topo_P_vec
|
||||||
|
|
||||||
|
return num, den
|
||||||
|
|
||||||
|
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||||
|
"""Convert transcripts to an FsaVec with the help of a lexicon
|
||||||
|
and word symbol table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
Each element is a transcript containing words separated by space(s).
|
||||||
|
For instance, it may be 'HELLO icefall', which contains
|
||||||
|
two words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return an FST (FsaVec) corresponding to the transcript.
|
||||||
|
Its `labels` is token IDs and `aux_labels` is word IDs.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.lexicon.word_table:
|
||||||
|
word_ids.append(self.lexicon.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
|
||||||
|
fsa = k2.linear_fsa(word_ids_list, self.device)
|
||||||
|
fsa = k2.add_epsilon_self_loops(fsa)
|
||||||
|
|
||||||
|
# The reason to use `invert_()` at the end is as follows:
|
||||||
|
#
|
||||||
|
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
|
||||||
|
# (2) `fsa.labels` is word IDs
|
||||||
|
# (3) after intersection, the `labels` is still word IDs
|
||||||
|
# (4) after `invert_()`, the `labels` is token IDs
|
||||||
|
# and `aux_labels` is word IDs
|
||||||
|
transcript_fsa = k2.intersect(
|
||||||
|
self.L_inv, fsa, treat_epsilons_specially=False
|
||||||
|
).invert_()
|
||||||
|
transcript_fsa = k2.arc_sort(transcript_fsa)
|
||||||
|
return transcript_fsa
|
||||||
|
|
||||||
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
|
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
It is a list of strings. Each string consists of space(s)
|
||||||
|
separated words. An example containing two strings is given below:
|
||||||
|
|
||||||
|
['HELLO ICEFALL', 'HELLO k2']
|
||||||
|
We assume it contains no OOVs. Otherwise, it will raise an
|
||||||
|
exception.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs.
|
||||||
|
"""
|
||||||
|
return self.lexicon.texts_to_token_ids(texts).tolist()
|
@ -19,14 +19,16 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldialign
|
import kaldialign
|
||||||
|
import lhotse
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@ -132,17 +134,82 @@ def setup_logger(
|
|||||||
logging.getLogger("").addHandler(console)
|
logging.getLogger("").addHandler(console)
|
||||||
|
|
||||||
|
|
||||||
def get_env_info():
|
def get_git_sha1():
|
||||||
"""
|
git_commit = (
|
||||||
TODO:
|
subprocess.run(
|
||||||
"""
|
["git", "rev-parse", "--short", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
dirty_commit = (
|
||||||
|
len(
|
||||||
|
subprocess.run(
|
||||||
|
["git", "diff", "--shortstat"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
> 0
|
||||||
|
)
|
||||||
|
git_commit = (
|
||||||
|
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||||
|
)
|
||||||
|
return git_commit
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_date():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch_name():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_info() -> Dict[str, Any]:
|
||||||
|
"""Get the environment information."""
|
||||||
return {
|
return {
|
||||||
"k2-git-sha1": None,
|
"k2-version": k2.version.__version__,
|
||||||
"k2-version": None,
|
"k2-build-type": k2.version.__build_type__,
|
||||||
"lhotse-version": None,
|
"k2-with-cuda": k2.with_cuda,
|
||||||
"torch-version": None,
|
"k2-git-sha1": k2.version.__git_sha1__,
|
||||||
"icefall-sha1": None,
|
"k2-git-date": k2.version.__git_date__,
|
||||||
"icefall-version": None,
|
"lhotse-version": lhotse.__version__,
|
||||||
|
"torch-cuda-available": torch.cuda.is_available(),
|
||||||
|
"torch-cuda-version": torch.version.cuda,
|
||||||
|
"python-version": sys.version[:3],
|
||||||
|
"icefall-git-branch": get_git_branch_name(),
|
||||||
|
"icefall-git-sha1": get_git_sha1(),
|
||||||
|
"icefall-git-date": get_git_date(),
|
||||||
|
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||||
|
"k2-path": str(Path(k2.__file__).resolve()),
|
||||||
|
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,20 +19,21 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import BpeLexicon
|
from icefall.lexicon import UniqLexicon
|
||||||
|
|
||||||
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
lang_dir = Path("data/lang/bpe")
|
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
|
||||||
if not lang_dir.is_dir():
|
if not lang_dir.is_dir():
|
||||||
return
|
return
|
||||||
# TODO: generate data for testing
|
|
||||||
|
|
||||||
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
|
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
|
||||||
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
|
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
|
||||||
compiler.compile(ids)
|
compiler.compile(ids)
|
||||||
|
|
||||||
lexicon = BpeLexicon(lang_dir)
|
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
|
||||||
ids0 = lexicon.words_to_piece_ids(["HELLO"])
|
ids0 = lexicon.words_to_piece_ids(["HELLO"])
|
||||||
assert ids[0] == ids0.values().tolist()
|
assert ids[0] == ids0.values().tolist()
|
||||||
|
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
|
||||||
|
|
||||||
|
|
||||||
def test_bpe_mmi_graph_compiler():
|
|
||||||
lang_dir = Path("data/lang_bpe")
|
|
||||||
if lang_dir.is_dir() is False:
|
|
||||||
return
|
|
||||||
device = torch.device("cpu")
|
|
||||||
compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device)
|
|
||||||
|
|
||||||
texts = ["HELLO WORLD", "MMI TRAINING"]
|
|
||||||
|
|
||||||
num_graphs, den_graphs = compiler.compile(texts)
|
|
||||||
num_graphs.labels_sym = compiler.lexicon.token_table
|
|
||||||
num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table)
|
|
||||||
num_graphs.aux_labels_sym._id2sym[0] = "<eps>"
|
|
||||||
num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD")
|
|
||||||
num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD")
|
|
||||||
print(den_graphs.shape)
|
|
||||||
print(den_graphs[0].shape)
|
|
||||||
print(den_graphs[0].num_arcs)
|
|
173
test/test_lexicon.py
Normal file → Executable file
173
test/test_lexicon.py
Normal file → Executable file
@ -14,80 +14,135 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
You can run this file in one of the two ways:
|
||||||
|
|
||||||
|
(1) cd icefall; pytest test/test_lexicon.py
|
||||||
|
(2) cd icefall; ./test/test_lexicon.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import pytest
|
import sentencepiece as spm
|
||||||
import torch
|
|
||||||
|
|
||||||
from icefall.lexicon import BpeLexicon, Lexicon
|
from icefall.lexicon import UniqLexicon
|
||||||
|
|
||||||
|
TMP_DIR = "/tmp/icefall-test-lexicon"
|
||||||
|
USING_PYTEST = "pytest" in sys.modules
|
||||||
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def generate_test_data():
|
||||||
def lang_dir(tmp_path):
|
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||||
phone2id = """
|
sentences = """
|
||||||
<eps> 0
|
cat tac cat cat
|
||||||
a 1
|
at
|
||||||
b 2
|
tac at ta at at
|
||||||
f 3
|
at cat ct ct ta
|
||||||
o 4
|
cat cat cat cat
|
||||||
r 5
|
at at at at at at at
|
||||||
z 6
|
|
||||||
SPN 7
|
|
||||||
#0 8
|
|
||||||
"""
|
|
||||||
word2id = """
|
|
||||||
<eps> 0
|
|
||||||
foo 1
|
|
||||||
bar 2
|
|
||||||
baz 3
|
|
||||||
<UNK> 4
|
|
||||||
#0 5
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
L = k2.Fsa.from_str(
|
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||||
"""
|
with open(transcript, "w") as f:
|
||||||
0 0 7 4 0
|
for line in sentences.strip().split("\n"):
|
||||||
0 7 -1 -1 0
|
f.write(f"{line}\n")
|
||||||
0 1 3 1 0
|
|
||||||
0 3 2 2 0
|
words = """
|
||||||
0 5 2 3 0
|
<eps> 0
|
||||||
1 2 4 0 0
|
<UNK> 1
|
||||||
2 0 4 0 0
|
at 2
|
||||||
3 4 1 0 0
|
cat 3
|
||||||
4 0 5 0 0
|
ct 4
|
||||||
5 6 1 0 0
|
ta 5
|
||||||
6 0 6 0 0
|
tac 6
|
||||||
7
|
#0 7
|
||||||
""",
|
<s> 8
|
||||||
num_aux_labels=1,
|
</s> 9
|
||||||
|
"""
|
||||||
|
word_txt = Path(TMP_DIR) / "words.txt"
|
||||||
|
with open(word_txt, "w") as f:
|
||||||
|
for line in words.strip().split("\n"):
|
||||||
|
f.write(f"{line}\n")
|
||||||
|
|
||||||
|
vocab_size = 8
|
||||||
|
|
||||||
|
os.system(
|
||||||
|
f"""
|
||||||
|
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||||
|
|
||||||
|
./local/train_bpe_model.py \
|
||||||
|
--lang-dir {TMP_DIR} \
|
||||||
|
--vocab-size {vocab_size} \
|
||||||
|
--transcript {transcript}
|
||||||
|
|
||||||
|
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(tmp_path / "tokens.txt", "w") as f:
|
|
||||||
f.write(phone2id)
|
|
||||||
with open(tmp_path / "words.txt", "w") as f:
|
|
||||||
f.write(word2id)
|
|
||||||
|
|
||||||
torch.save(L.as_dict(), tmp_path / "L.pt")
|
def delete_test_data():
|
||||||
|
shutil.rmtree(TMP_DIR)
|
||||||
return tmp_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_lexicon(lang_dir):
|
def uniq_lexicon_test():
|
||||||
lexicon = Lexicon(lang_dir)
|
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt")
|
||||||
assert lexicon.tokens == list(range(1, 8))
|
|
||||||
|
# case 1: No OOV
|
||||||
|
texts = ["cat cat", "at ct", "at tac cat"]
|
||||||
|
token_ids = lexicon.texts_to_token_ids(texts)
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(f"{TMP_DIR}/bpe.model")
|
||||||
|
|
||||||
|
expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int)
|
||||||
|
assert token_ids.tolist() == expected_token_ids
|
||||||
|
|
||||||
|
# case 2: With OOV
|
||||||
|
texts = ["ca"]
|
||||||
|
token_ids = lexicon.texts_to_token_ids(texts)
|
||||||
|
expected_token_ids = sp.encode(texts, out_type=int)
|
||||||
|
assert token_ids.tolist() != expected_token_ids
|
||||||
|
# Note: sentencepiece breaks "ca" into "_ c a"
|
||||||
|
# But there is no word "ca" in the lexicon, so our
|
||||||
|
# implementation returns the id of "<UNK>"
|
||||||
|
print(token_ids, expected_token_ids)
|
||||||
|
assert token_ids.tolist() == [[sp.unk_id()]]
|
||||||
|
|
||||||
|
# case 3: With OOV
|
||||||
|
texts = ["foo"]
|
||||||
|
token_ids = lexicon.texts_to_token_ids(texts)
|
||||||
|
expected_token_ids = sp.encode(texts, out_type=int)
|
||||||
|
print(token_ids)
|
||||||
|
print(expected_token_ids)
|
||||||
|
|
||||||
|
# test ragged lexicon
|
||||||
|
ragged_lexicon = lexicon.ragged_lexicon.tolist()
|
||||||
|
word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
for i in range(2, word_disambig_id):
|
||||||
|
piece_id = ragged_lexicon[i]
|
||||||
|
word = lexicon.word_table[i]
|
||||||
|
assert word == sp.decode(piece_id)
|
||||||
|
assert piece_id == sp.encode(word)
|
||||||
|
|
||||||
|
|
||||||
def test_bpe_lexicon():
|
def test_main():
|
||||||
lang_dir = Path("data/lang/bpe")
|
generate_test_data()
|
||||||
if not lang_dir.is_dir():
|
|
||||||
return
|
|
||||||
# TODO: Generate test data for BpeLexicon
|
|
||||||
|
|
||||||
lexicon = BpeLexicon(lang_dir)
|
uniq_lexicon_test()
|
||||||
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
|
|
||||||
ids = lexicon.words_to_piece_ids(words)
|
if USING_PYTEST:
|
||||||
print(ids)
|
delete_test_data()
|
||||||
print([lexicon.token_table[i] for i in ids.values().tolist()])
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__" and not USING_PYTEST:
|
||||||
|
main()
|
||||||
|
196
test/test_mmi_graph_compiler.py
Executable file
196
test/test_mmi_graph_compiler.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
You can run this file in one of the two ways:
|
||||||
|
|
||||||
|
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
|
||||||
|
(2) cd icefall; ./test/test_mmi_graph_compiler.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
|
|
||||||
|
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
|
||||||
|
USING_PYTEST = "pytest" in sys.modules
|
||||||
|
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_data():
|
||||||
|
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||||
|
sentences = """
|
||||||
|
cat tac cat cat
|
||||||
|
at at cat at cat cat
|
||||||
|
tac at ta at at
|
||||||
|
at cat ct ct ta ct ct cat tac
|
||||||
|
cat cat cat cat
|
||||||
|
at at at at at at at
|
||||||
|
"""
|
||||||
|
|
||||||
|
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||||
|
with open(transcript, "w") as f:
|
||||||
|
for line in sentences.strip().split("\n"):
|
||||||
|
f.write(f"{line}\n")
|
||||||
|
|
||||||
|
words = """
|
||||||
|
<eps> 0
|
||||||
|
<UNK> 1
|
||||||
|
at 2
|
||||||
|
cat 3
|
||||||
|
ct 4
|
||||||
|
ta 5
|
||||||
|
tac 6
|
||||||
|
#0 7
|
||||||
|
<s> 8
|
||||||
|
</s> 9
|
||||||
|
"""
|
||||||
|
word_txt = Path(TMP_DIR) / "words.txt"
|
||||||
|
with open(word_txt, "w") as f:
|
||||||
|
for line in words.strip().split("\n"):
|
||||||
|
f.write(f"{line}\n")
|
||||||
|
|
||||||
|
vocab_size = 8
|
||||||
|
|
||||||
|
os.system(
|
||||||
|
f"""
|
||||||
|
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||||
|
|
||||||
|
./local/train_bpe_model.py \
|
||||||
|
--lang-dir {TMP_DIR} \
|
||||||
|
--vocab-size {vocab_size} \
|
||||||
|
--transcript {transcript}
|
||||||
|
|
||||||
|
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0
|
||||||
|
|
||||||
|
./local/convert_transcript_words_to_tokens.py \
|
||||||
|
--lexicon {TMP_DIR}/lexicon.txt \
|
||||||
|
--transcript {transcript} \
|
||||||
|
--oov "<UNK>" \
|
||||||
|
> {TMP_DIR}/transcript_tokens.txt
|
||||||
|
|
||||||
|
./shared/make_kn_lm.py \
|
||||||
|
-ngram-order 2 \
|
||||||
|
-text {TMP_DIR}/transcript_tokens.txt \
|
||||||
|
-lm {TMP_DIR}/P.arpa
|
||||||
|
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="{TMP_DIR}/tokens.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=2 \
|
||||||
|
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_test_data():
|
||||||
|
shutil.rmtree(TMP_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def mmi_graph_compiler_test():
|
||||||
|
# Caution:
|
||||||
|
# You have to uncomment
|
||||||
|
# del transcript_fsa.aux_labels
|
||||||
|
# in mmi_graph_compiler.py
|
||||||
|
# to see the correct aux_labels in *.svg
|
||||||
|
graph_compiler = MmiTrainingGraphCompiler(
|
||||||
|
lang_dir=TMP_DIR, uniq_filename="lexicon.txt"
|
||||||
|
)
|
||||||
|
print(graph_compiler.device)
|
||||||
|
L_inv = graph_compiler.L_inv
|
||||||
|
L = k2.invert(L_inv)
|
||||||
|
|
||||||
|
L.labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
L.aux_labels_sym = graph_compiler.lexicon.word_table
|
||||||
|
L.draw(f"{TMP_DIR}/L.svg", title="L")
|
||||||
|
|
||||||
|
L_inv.labels_sym = graph_compiler.lexicon.word_table
|
||||||
|
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
|
||||||
|
|
||||||
|
ctc_topo_P = graph_compiler.ctc_topo_P
|
||||||
|
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
|
||||||
|
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
|
||||||
|
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
|
||||||
|
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
|
||||||
|
|
||||||
|
print(ctc_topo_P.num_arcs)
|
||||||
|
print(k2.connect(ctc_topo_P).num_arcs)
|
||||||
|
|
||||||
|
with open(str(TMP_DIR) + "/P.fst.txt") as f:
|
||||||
|
# P is not an acceptor because there is
|
||||||
|
# a back-off state, whose incoming arcs
|
||||||
|
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||||
|
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
P.labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
P.draw(f"{TMP_DIR}/P.svg", title="P")
|
||||||
|
|
||||||
|
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
|
||||||
|
ctc_topo.labels_sym = ctc_topo_P.labels_sym
|
||||||
|
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||||
|
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
|
||||||
|
print("p num arcs", P.num_arcs)
|
||||||
|
print("ctc_topo num arcs", ctc_topo.num_arcs)
|
||||||
|
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
|
||||||
|
|
||||||
|
texts = ["cat at ct", "at ta", "cat tac"]
|
||||||
|
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
|
||||||
|
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct")
|
||||||
|
transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta")
|
||||||
|
transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac")
|
||||||
|
|
||||||
|
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||||
|
num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct")
|
||||||
|
num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta")
|
||||||
|
num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac")
|
||||||
|
|
||||||
|
den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct")
|
||||||
|
den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(f"{TMP_DIR}/bpe.model")
|
||||||
|
|
||||||
|
texts = ["cat at cat", "at tac"]
|
||||||
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
|
expected_token_ids = sp.encode(texts)
|
||||||
|
assert token_ids == expected_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
generate_test_data()
|
||||||
|
|
||||||
|
mmi_graph_compiler_test()
|
||||||
|
|
||||||
|
if USING_PYTEST:
|
||||||
|
delete_test_data()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__" and not USING_PYTEST:
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user