mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add OTC related scripts using phone as units instead of BPEs (#1602)
* add otc related scripts using phone instead of bpe
This commit is contained in:
parent
25cabb7663
commit
9a17f4ce41
592
egs/librispeech/WSASR/conformer_ctc2/decode_phone.py
Executable file
592
egs/librispeech/WSASR/conformer_ctc2/decode_phone.py
Executable file
@ -0,0 +1,592 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||||
|
# Fangjun Kuang,
|
||||||
|
# Quandong Wang)
|
||||||
|
# 2023 Johns Hopkins University (Author: Dongji Gao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
|
load_averaged_model,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
default="<star>",
|
||||||
|
help="OTC token",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-bias",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="bias (log-prob) added to blank token during decoding",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="ctc-greedy-search",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (0) 1best. Extract the best path from the decoding lattice as the
|
||||||
|
decoding result.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""Number of decoder layer of transformer decoder.
|
||||||
|
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_phone",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lm",
|
||||||
|
help="""The n-gram LM dir.
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# parameters for conformer
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"feature_dim": 80,
|
||||||
|
"nhead": 8,
|
||||||
|
"dim_feedforward": 2048,
|
||||||
|
"encoder_dim": 512,
|
||||||
|
"num_encoder_layers": 12,
|
||||||
|
# parameters for decoding
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
"env_info": get_env_info(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||||
|
new_hyp: List[int] = []
|
||||||
|
cur = 0
|
||||||
|
while cur < len(hyp):
|
||||||
|
if hyp[cur] != 0:
|
||||||
|
new_hyp.append(hyp[cur])
|
||||||
|
prev = cur
|
||||||
|
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||||
|
cur += 1
|
||||||
|
return new_hyp
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: k2.Fsa,
|
||||||
|
batch: dict,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if no rescoring is used, the key is the string `no_rescore`.
|
||||||
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||||
|
where `xxx` is the value of `lm_scale`. An example key is
|
||||||
|
`lm_scale_0.7`
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||||
|
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
|
"""
|
||||||
|
device = HLG.device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
nnet_output[:, :, 0] += params.blank_bias
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
torch.div(
|
||||||
|
supervisions["start_frame"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="trunc",
|
||||||
|
),
|
||||||
|
torch.div(
|
||||||
|
supervisions["num_frames"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="trunc",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
decoding_graph = HLG
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor + 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method in ["1best"]:
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
key = "no_rescore"
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: k2.Fsa,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
G=G,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hyps_dict is not None:
|
||||||
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
else:
|
||||||
|
assert len(results) > 0, "It should not decode to empty in the first batch!"
|
||||||
|
this_batch = []
|
||||||
|
hyp_words = []
|
||||||
|
for ref_text in texts:
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
for lm_scale in results.keys():
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
if params.method in ("attention-decoder", "rnn-lm"):
|
||||||
|
# Set it to False since there are too many logs.
|
||||||
|
enable_log = False
|
||||||
|
else:
|
||||||
|
enable_log = True
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
if enable_log:
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
args.lm_dir = Path(args.lm_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
# remove otc_token from decoding units
|
||||||
|
max_token_id = len(lexicon.tokens) - 1
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
params.num_classes = num_classes
|
||||||
|
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
G = None
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
d_model=params.encoder_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
word_table=lexicon.word_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1124
egs/librispeech/WSASR/conformer_ctc2/train_phone.py
Executable file
1124
egs/librispeech/WSASR/conformer_ctc2/train_phone.py
Executable file
File diff suppressed because it is too large
Load Diff
146
egs/librispeech/WSASR/local/download_lm.py
Executable file
146
egs/librispeech/WSASR/local/download_lm.py
Executable file
@ -0,0 +1,146 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file downloads the following LibriSpeech LM files:
|
||||||
|
|
||||||
|
- 3-gram.pruned.1e-7.arpa.gz
|
||||||
|
- 4-gram.arpa.gz
|
||||||
|
- librispeech-vocab.txt
|
||||||
|
- librispeech-lexicon.txt
|
||||||
|
- librispeech-lm-norm.txt.gz
|
||||||
|
|
||||||
|
from http://www.openslr.org/resources/11
|
||||||
|
and save them in the user provided directory.
|
||||||
|
|
||||||
|
Files are not re-downloaded if they already exist.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
./local/download_lm.py --out-dir ./download/lm
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# This function is copied from lhotse
|
||||||
|
def tqdm_urlretrieve_hook(t):
|
||||||
|
"""Wraps tqdm instance.
|
||||||
|
Don't forget to close() or __exit__()
|
||||||
|
the tqdm instance once you're done with it (easiest using `with` syntax).
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> from urllib.request import urlretrieve
|
||||||
|
>>> with tqdm(...) as t:
|
||||||
|
... reporthook = tqdm_urlretrieve_hook(t)
|
||||||
|
... urlretrieve(..., reporthook=reporthook)
|
||||||
|
|
||||||
|
Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
|
||||||
|
"""
|
||||||
|
last_b = [0]
|
||||||
|
|
||||||
|
def update_to(b=1, bsize=1, tsize=None):
|
||||||
|
"""
|
||||||
|
b : int, optional
|
||||||
|
Number of blocks transferred so far [default: 1].
|
||||||
|
bsize : int, optional
|
||||||
|
Size of each block (in tqdm units) [default: 1].
|
||||||
|
tsize : int, optional
|
||||||
|
Total size (in tqdm units). If [default: None] or -1,
|
||||||
|
remains unchanged.
|
||||||
|
"""
|
||||||
|
if tsize not in (None, -1):
|
||||||
|
t.total = tsize
|
||||||
|
displayed = t.update((b - last_b[0]) * bsize)
|
||||||
|
last_b[0] = b
|
||||||
|
return displayed
|
||||||
|
|
||||||
|
return update_to
|
||||||
|
|
||||||
|
|
||||||
|
# This function is copied from lhotse
|
||||||
|
def urlretrieve_progress(url, filename=None, data=None, desc=None):
|
||||||
|
"""
|
||||||
|
Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to
|
||||||
|
display a progress bar of the download.
|
||||||
|
Use "desc" argument to display a user-readable string that informs what is
|
||||||
|
being downloaded.
|
||||||
|
"""
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
|
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t:
|
||||||
|
reporthook = tqdm_urlretrieve_hook(t)
|
||||||
|
return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--out-dir", type=str, help="Output directory.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(out_dir: str):
|
||||||
|
url = "http://www.openslr.org/resources/11"
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
|
||||||
|
files_to_download = (
|
||||||
|
"3-gram.pruned.1e-7.arpa.gz",
|
||||||
|
"4-gram.arpa.gz",
|
||||||
|
"librispeech-vocab.txt",
|
||||||
|
"librispeech-lexicon.txt",
|
||||||
|
"librispeech-lm-norm.txt.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
|
||||||
|
filename = out_dir / f
|
||||||
|
if filename.is_file() is False:
|
||||||
|
urlretrieve_progress(
|
||||||
|
f"{url}/{f}",
|
||||||
|
filename=filename,
|
||||||
|
desc=f"Downloading {filename}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"{filename} already exists - skipping")
|
||||||
|
|
||||||
|
if ".gz" in str(filename):
|
||||||
|
unzipped = Path(os.path.splitext(filename)[0])
|
||||||
|
if unzipped.is_file() is False:
|
||||||
|
with gzip.open(filename, "rb") as f_in:
|
||||||
|
with open(unzipped, "wb") as f_out:
|
||||||
|
shutil.copyfileobj(f_in, f_out)
|
||||||
|
else:
|
||||||
|
logging.info(f"{unzipped} already exist - skipping")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
logging.info(f"out_dir: {args.out_dir}")
|
||||||
|
|
||||||
|
main(out_dir=args.out_dir)
|
469
egs/librispeech/WSASR/local/prepare_otc_lang.py
Executable file
469
egs/librispeech/WSASR/local/prepare_otc_lang.py
Executable file
@ -0,0 +1,469 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# 2024 Johns Hopkins University (author: Dongji Gao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||||
|
consisting of words and tokens (i.e., phones) and does the following:
|
||||||
|
|
||||||
|
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||||
|
|
||||||
|
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||||
|
|
||||||
|
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||||
|
|
||||||
|
4. Generate L.pt, in k2 format. It can be loaded by
|
||||||
|
|
||||||
|
d = torch.load("L.pt")
|
||||||
|
lexicon = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import write_lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
Lexicon = List[Tuple[str, List[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
Generated files by this script are saved into this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
default="<star>",
|
||||||
|
help="The OTC token in lexicon",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def read_lexicon(
|
||||||
|
filename: str,
|
||||||
|
) -> List[Tuple[str, List[str]]]:
|
||||||
|
"""Read a lexicon from `filename`.
|
||||||
|
|
||||||
|
Each line in the lexicon contains "word p1 p2 p3 ...".
|
||||||
|
That is, the first field is a word and the remaining
|
||||||
|
fields are tokens. Fields are separated by space(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Path to the lexicon.txt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
whitespace = re.compile("[ \t]+")
|
||||||
|
for line in f:
|
||||||
|
a = whitespace.split(line.strip(" \t\r\n"))
|
||||||
|
if len(a) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(a) < 2:
|
||||||
|
logging.info(f"Found bad line {line} in lexicon file {filename}")
|
||||||
|
logging.info("Every line is expected to contain at least 2 fields")
|
||||||
|
continue
|
||||||
|
word = a[0]
|
||||||
|
if word == "<eps>":
|
||||||
|
logging.info(f"Found bad line {line} in lexicon file {filename}")
|
||||||
|
logging.info("<eps> should not be a valid word")
|
||||||
|
continue
|
||||||
|
|
||||||
|
tokens = a[1:]
|
||||||
|
ans.append((word, tokens))
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
No need to implement `read_mapping` as it can be done
|
||||||
|
through :func:`k2.SymbolTable.from_file`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename to save the mapping.
|
||||||
|
sym2id:
|
||||||
|
A dict mapping symbols to IDs.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for sym, i in sym2id.items():
|
||||||
|
f.write(f"{sym} {i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get tokens from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique tokens.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
ans.update(tokens)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def get_words(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get words from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique words.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for word, _ in lexicon:
|
||||||
|
ans.add(word)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||||
|
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||||
|
at the ends of tokens to ensure that all pronunciations are different,
|
||||||
|
and that none is a prefix of another.
|
||||||
|
|
||||||
|
See also add_lex_disambig.pl from kaldi.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is returned by :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a tuple with two elements:
|
||||||
|
|
||||||
|
- The output lexicon with disambiguation symbols
|
||||||
|
- The ID of the max disambiguation symbol that appears
|
||||||
|
in the lexicon
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (1) Work out the count of each token-sequence in the
|
||||||
|
# lexicon.
|
||||||
|
count = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
count[" ".join(tokens)] += 1
|
||||||
|
|
||||||
|
# (2) For each left sub-sequence of each token-sequence, note down
|
||||||
|
# that it exists (for identifying prefixes of longer strings).
|
||||||
|
issubseq = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
tokens = tokens.copy()
|
||||||
|
tokens.pop()
|
||||||
|
while tokens:
|
||||||
|
issubseq[" ".join(tokens)] = 1
|
||||||
|
tokens.pop()
|
||||||
|
|
||||||
|
# (3) For each entry in the lexicon:
|
||||||
|
# if the token sequence is unique and is not a
|
||||||
|
# prefix of another word, no disambig symbol.
|
||||||
|
# Else output #1, or #2, #3, ... if the same token-seq
|
||||||
|
# has already been assigned a disambig symbol.
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
# We start with #1 since #0 has its own purpose
|
||||||
|
first_allowed_disambig = 1
|
||||||
|
max_disambig = first_allowed_disambig - 1
|
||||||
|
last_used_disambig_symbol_of = defaultdict(int)
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
tokenseq = " ".join(tokens)
|
||||||
|
assert tokenseq != ""
|
||||||
|
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||||
|
ans.append((word, tokens))
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||||
|
if cur_disambig == 0:
|
||||||
|
cur_disambig = first_allowed_disambig
|
||||||
|
else:
|
||||||
|
cur_disambig += 1
|
||||||
|
|
||||||
|
if cur_disambig > max_disambig:
|
||||||
|
max_disambig = cur_disambig
|
||||||
|
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||||
|
tokenseq += f" #{cur_disambig}"
|
||||||
|
ans.append((word, tokenseq.split()))
|
||||||
|
return ans, max_disambig
|
||||||
|
|
||||||
|
|
||||||
|
def generate_id_map(
|
||||||
|
symbols: List[str],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols:
|
||||||
|
A list of unique symbols.
|
||||||
|
Returns:
|
||||||
|
A dict containing the mapping between symbols and IDs.
|
||||||
|
"""
|
||||||
|
return {sym: i for i, sym in enumerate(symbols)}
|
||||||
|
|
||||||
|
|
||||||
|
def add_self_loops(
|
||||||
|
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||||
|
) -> List[List[Any]]:
|
||||||
|
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||||
|
through it. They are added on each state with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state.
|
||||||
|
|
||||||
|
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||||
|
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||||
|
This function uses k2 style FSTs and it does not need to add self-loops
|
||||||
|
to the final state.
|
||||||
|
|
||||||
|
The input label of a self-loop is `disambig_token`, while the output
|
||||||
|
label is `disambig_word`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arcs:
|
||||||
|
A list-of-list. The sublist contains
|
||||||
|
`[src_state, dest_state, label, aux_label, score]`
|
||||||
|
disambig_token:
|
||||||
|
It is the token ID of the symbol `#0`.
|
||||||
|
disambig_word:
|
||||||
|
It is the word ID of the symbol `#0`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Return new `arcs` containing self-loops.
|
||||||
|
"""
|
||||||
|
states_needs_self_loops = set()
|
||||||
|
for arc in arcs:
|
||||||
|
src, dst, ilabel, olabel, score = arc
|
||||||
|
if olabel != 0:
|
||||||
|
states_needs_self_loops.add(src)
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for s in states_needs_self_loops:
|
||||||
|
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||||
|
|
||||||
|
return arcs + ans
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
sil_token: str = "SIL",
|
||||||
|
sil_prob: float = 0.5,
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||||
|
the beginning and end of each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
sil_token:
|
||||||
|
The silence token.
|
||||||
|
sil_prob:
|
||||||
|
The probability for adding a silence at the beginning and end
|
||||||
|
of the word.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||||
|
# CAUTION: we use score, i.e, negative cost.
|
||||||
|
sil_score = math.log(sil_prob)
|
||||||
|
no_sil_score = math.log(1.0 - sil_prob)
|
||||||
|
|
||||||
|
start_state = 0
|
||||||
|
loop_state = 1 # words enter and leave from here
|
||||||
|
sil_state = 2 # words terminate here when followed by silence; this state
|
||||||
|
# has a silence transition to loop_state.
|
||||||
|
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
assert token2id["<eps>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
sil_token = token2id[sil_token]
|
||||||
|
|
||||||
|
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||||
|
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||||
|
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
tokens = [token2id[i] for i in tokens]
|
||||||
|
|
||||||
|
for i in range(len(tokens) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last token of this word
|
||||||
|
# It has two out-going arcs, one to the loop state,
|
||||||
|
# the other one to the sil_state.
|
||||||
|
i = len(tokens) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||||
|
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
otc_token = args.otc_token
|
||||||
|
sil_token = "SIL"
|
||||||
|
sil_prob = 0.5
|
||||||
|
|
||||||
|
lexicon = read_lexicon(lexicon_filename)
|
||||||
|
tokens = get_tokens(lexicon)
|
||||||
|
words = get_words(lexicon)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
lexicon.append((otc_token, [otc_token]))
|
||||||
|
tokens.append(otc_token)
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in tokens
|
||||||
|
tokens.append(f"#{i}")
|
||||||
|
|
||||||
|
assert "<eps>" not in tokens
|
||||||
|
tokens = ["<eps>"] + tokens
|
||||||
|
|
||||||
|
assert "<eps>" not in words
|
||||||
|
assert "#0" not in words
|
||||||
|
assert "<s>" not in words
|
||||||
|
assert "</s>" not in words
|
||||||
|
|
||||||
|
words = ["<eps>"] + words + [otc_token, "#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
token2id = generate_id_map(tokens)
|
||||||
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||||
|
write_mapping(lang_dir / "words.txt", word2id)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst(
|
||||||
|
lexicon,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -30,7 +30,8 @@ stop_stage=100
|
|||||||
# - librispeech-lm-norm.txt.gz
|
# - librispeech-lm-norm.txt.gz
|
||||||
#
|
#
|
||||||
otc_token="<star>"
|
otc_token="<star>"
|
||||||
feature_type="ssl"
|
# ssl or fbank
|
||||||
|
feature_type="fbank"
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
manifests_dir="data/manifests"
|
manifests_dir="data/manifests"
|
||||||
@ -40,9 +41,6 @@ lm_dir="data/lm"
|
|||||||
|
|
||||||
perturb_speed=false
|
perturb_speed=false
|
||||||
|
|
||||||
# ssl or fbank
|
|
||||||
|
|
||||||
. ./cmd.sh
|
|
||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
# vocab size for sentence piece models.
|
# vocab size for sentence piece models.
|
||||||
@ -192,7 +190,23 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Prepare G"
|
log "Stage 5: Prepare phone based lang"
|
||||||
|
lang_dir="data/lang_phone"
|
||||||
|
mkdir -p ${lang_dir}
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/lexicon.txt ]; then
|
||||||
|
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||||
|
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||||
|
sort | uniq > $lang_dir/lexicon.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
|
./local/prepare_otc_lang.py --lang-dir $lang_dir
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Prepare G"
|
||||||
# We assume you have installed kaldilm, if not, please install
|
# We assume you have installed kaldilm, if not, please install
|
||||||
# it using: pip install kaldilm
|
# it using: pip install kaldilm
|
||||||
|
|
||||||
@ -216,18 +230,30 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
log "Stage 6: Compile HLG"
|
log "Stage 7: Compile HLG"
|
||||||
# Note If ./local/compile_hlg.py throws OOM,
|
# Note If ./local/compile_hlg.py throws OOM,
|
||||||
# please switch to the following command
|
# please switch to the following command
|
||||||
#
|
#
|
||||||
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
bpe_lang_dir="data/lang_bpe_${vocab_size}"
|
lang_dir="data/lang_bpe_${vocab_size}"
|
||||||
echo "LM DIR: ${lm_dir}"
|
echo "LM DIR: ${lm_dir}"
|
||||||
./local/compile_hlg.py \
|
./local/compile_hlg.py \
|
||||||
--lm-dir "${lm_dir}" \
|
--lm-dir "${lm_dir}" \
|
||||||
--lang-dir "${bpe_lang_dir}"
|
--lang-dir "${bpe_lang_dir}"
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
|
log "Stage 7: Compile HLG"
|
||||||
|
# Note If ./local/compile_hlg.py throws OOM,
|
||||||
|
# please switch to the following command
|
||||||
|
#
|
||||||
|
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
||||||
|
lang_dir="data/lang_phone"
|
||||||
|
echo "LM DIR: ${lm_dir}"
|
||||||
|
./local/compile_hlg.py \
|
||||||
|
--lm-dir "${lm_dir}" \
|
||||||
|
--lang-dir "${lang_dir}"
|
||||||
|
fi
|
||||||
|
1
egs/librispeech/WSASR/shared
Symbolic link
1
egs/librispeech/WSASR/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared/
|
232
icefall/otc_phone_graph_compiler.py
Normal file
232
icefall/otc_phone_graph_compiler.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# 2023 Johns Hopkins University (author: Dongji Gao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class OtcPhoneTrainingGraphCompiler(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
otc_token: str,
|
||||||
|
oov: str = "<UNK>",
|
||||||
|
device: Union[str, torch.device] = "cpu",
|
||||||
|
initial_bypass_weight: float = 0.0,
|
||||||
|
initial_self_loop_weight: float = 0.0,
|
||||||
|
bypass_weight_decay: float = 0.0,
|
||||||
|
self_loop_weight_decay: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is built from `data/lang/lexicon.txt`.
|
||||||
|
otc_token:
|
||||||
|
The special token in OTC that represent all non-blank tokens
|
||||||
|
device:
|
||||||
|
It indicates CPU or CUDA.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
L_inv = lexicon.L_inv.to(self.device)
|
||||||
|
assert L_inv.requires_grad is False
|
||||||
|
assert oov in lexicon.word_table
|
||||||
|
|
||||||
|
self.L_inv = k2.arc_sort(L_inv)
|
||||||
|
self.oov_id = lexicon.word_table[oov]
|
||||||
|
self.otc_id = lexicon.word_table[otc_token]
|
||||||
|
self.word_table = lexicon.word_table
|
||||||
|
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
|
||||||
|
self.ctc_topo = ctc_topo.to(self.device)
|
||||||
|
self.max_token_id = max_token_id
|
||||||
|
|
||||||
|
self.initial_bypass_weight = initial_bypass_weight
|
||||||
|
self.initial_self_loop_weight = initial_self_loop_weight
|
||||||
|
self.bypass_weight_decay = bypass_weight_decay
|
||||||
|
self.self_loop_weight_decay = self_loop_weight_decay
|
||||||
|
|
||||||
|
def get_max_token_id(self):
|
||||||
|
return self.max_token_id
|
||||||
|
|
||||||
|
def make_arc(
|
||||||
|
self,
|
||||||
|
from_state: int,
|
||||||
|
to_state: int,
|
||||||
|
symbol: Union[str, int],
|
||||||
|
weight: float,
|
||||||
|
):
|
||||||
|
return f"{from_state} {to_state} {symbol} {weight}"
|
||||||
|
|
||||||
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
|
"""Convert a list of texts to a list-of-list of word IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
It is a list of strings. Each string consists of space(s)
|
||||||
|
separated words. An example containing two strings is given below:
|
||||||
|
|
||||||
|
['HELLO ICEFALL', 'HELLO k2']
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of word IDs.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
return word_ids_list
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
allow_bypass_arc: str2bool = True,
|
||||||
|
allow_self_loop_arc: str2bool = True,
|
||||||
|
bypass_weight: float = 0.0,
|
||||||
|
self_loop_weight: float = 0.0,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Build a OTC graph from a texts (list of words).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of strings. Each string contains a sentence for an utterance.
|
||||||
|
A sentence consists of spaces separated words. An example `texts`
|
||||||
|
looks like:
|
||||||
|
['hello icefall', 'CTC training with k2']
|
||||||
|
allow_bypass_arc:
|
||||||
|
Whether to add bypass arc to training graph for substitution
|
||||||
|
and insertion errors (wrong or extra words in the transcript).
|
||||||
|
allow_self_loop_arc:
|
||||||
|
Whether to add self-loop arc to training graph for deletion
|
||||||
|
errors (missing words in the transcript).
|
||||||
|
bypass_weight:
|
||||||
|
Weight associated with bypass arc.
|
||||||
|
self_loop_weight:
|
||||||
|
Weight associated with self-loop arc.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Return an FsaVec, which is the result of composing a
|
||||||
|
CTC topology with OTC FSAs constructed from the given texts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
transcript_fsa = self.convert_transcript_to_fsa(
|
||||||
|
texts,
|
||||||
|
allow_bypass_arc,
|
||||||
|
allow_self_loop_arc,
|
||||||
|
bypass_weight,
|
||||||
|
self_loop_weight,
|
||||||
|
)
|
||||||
|
fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
|
||||||
|
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
|
||||||
|
|
||||||
|
graph = k2.compose(
|
||||||
|
self.ctc_topo,
|
||||||
|
fsa_with_self_loop,
|
||||||
|
treat_epsilons_specially=False,
|
||||||
|
)
|
||||||
|
assert graph.requires_grad is False
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def convert_transcript_to_fsa(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
allow_bypass_arc: str2bool = True,
|
||||||
|
allow_self_loop_arc: str2bool = True,
|
||||||
|
bypass_weight: float = 0.0,
|
||||||
|
self_loop_weight: float = 0.0,
|
||||||
|
):
|
||||||
|
|
||||||
|
word_fsa_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
|
||||||
|
arcs = []
|
||||||
|
start_state = 0
|
||||||
|
cur_state = start_state
|
||||||
|
next_state = 1
|
||||||
|
|
||||||
|
for word_id in word_ids:
|
||||||
|
if allow_self_loop_arc:
|
||||||
|
self_loop_arc = self.make_arc(
|
||||||
|
cur_state,
|
||||||
|
cur_state,
|
||||||
|
self.otc_id,
|
||||||
|
self_loop_weight,
|
||||||
|
)
|
||||||
|
arcs.append(self_loop_arc)
|
||||||
|
|
||||||
|
arc = self.make_arc(cur_state, next_state, word_id, 0.0)
|
||||||
|
arcs.append(arc)
|
||||||
|
|
||||||
|
if allow_bypass_arc:
|
||||||
|
bypass_arc = self.make_arc(
|
||||||
|
cur_state,
|
||||||
|
next_state,
|
||||||
|
self.otc_id,
|
||||||
|
bypass_weight,
|
||||||
|
)
|
||||||
|
arcs.append(bypass_arc)
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
if allow_self_loop_arc:
|
||||||
|
self_loop_arc = self.make_arc(
|
||||||
|
cur_state,
|
||||||
|
cur_state,
|
||||||
|
self.otc_id,
|
||||||
|
self_loop_weight,
|
||||||
|
)
|
||||||
|
arcs.append(self_loop_arc)
|
||||||
|
|
||||||
|
# Deal with final state
|
||||||
|
final_state = next_state
|
||||||
|
final_arc = self.make_arc(cur_state, final_state, -1, 0.0)
|
||||||
|
arcs.append(final_arc)
|
||||||
|
arcs.append(f"{final_state}")
|
||||||
|
sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0]))
|
||||||
|
|
||||||
|
word_fsa = k2.Fsa.from_str("\n".join(sorted_arcs))
|
||||||
|
word_fsa = k2.arc_sort(word_fsa)
|
||||||
|
word_fsa_list.append(word_fsa)
|
||||||
|
|
||||||
|
word_fsa_vec = k2.create_fsa_vec(word_fsa_list).to(self.device)
|
||||||
|
word_fsa_vec_with_self_loop = k2.add_epsilon_self_loops(word_fsa_vec)
|
||||||
|
|
||||||
|
fsa = k2.intersect(
|
||||||
|
self.L_inv, word_fsa_vec_with_self_loop, treat_epsilons_specially=False
|
||||||
|
)
|
||||||
|
ans_fsa = fsa.invert_()
|
||||||
|
return k2.arc_sort(ans_fsa)
|
Loading…
x
Reference in New Issue
Block a user