mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
iwslt_ta ST recipe
This commit is contained in:
parent
f0eb710163
commit
78ddda4296
@ -227,15 +227,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_ta_1000/bpe.model",
|
default="data/lang_bpe_1000/bpe.model",
|
||||||
help="Path to source data BPE model",
|
help="Path to source data BPE model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-tgt-model",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_en_1000/bpe.model",
|
|
||||||
help="Path to target data BPE model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
@ -617,7 +611,6 @@ def compute_loss(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
sp_tgt: spm.SentencePieceProcessor,
|
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
@ -655,11 +648,8 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
#pdb.set_trace()
|
#pdb.set_trace()
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
tgt_texts = batch["supervisions"]["tgt_text"]
|
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y_tgt = sp_tgt.encode(tgt_texts, out_type=int)
|
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
y_tgt = k2.RaggedTensor(y_tgt).to(device)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
@ -736,7 +726,6 @@ def compute_validation_loss(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
sp_tgt: spm.SentencePieceProcessor,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> MetricsTracker:
|
) -> MetricsTracker:
|
||||||
@ -750,7 +739,6 @@ def compute_validation_loss(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
@ -774,7 +762,6 @@ def train_one_epoch(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
sp_tgt: spm.SentencePieceProcessor,
|
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
@ -834,7 +821,6 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=(
|
warmup=(
|
||||||
@ -927,7 +913,6 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
@ -1007,9 +992,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp_tgt = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
sp_tgt.load(params.bpe_tgt_model)
|
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
@ -1139,7 +1122,6 @@ def run(rank, world_size, args):
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
params=params,
|
params=params,
|
||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
@ -1167,7 +1149,6 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
@ -1236,7 +1217,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
sp_tgt: spm.SentencePieceProcessor,
|
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
warmup: float,
|
warmup: float,
|
||||||
):
|
):
|
||||||
@ -1258,7 +1238,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
sp_tgt=sp_tgt,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
28
egs/iwslt22_ta/ST/README.md
Normal file
28
egs/iwslt22_ta/ST/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# IWSLT_Ta
|
||||||
|
|
||||||
|
The IWSLT Tunisian dataset is a 3-way parallel dataset consisting of approximately 160 hours
|
||||||
|
and 200,000 lines of aligned audio, Tunisian transcripts, and English translations. This dataset
|
||||||
|
comprises conversational telephone speech recorded at a sampling rate of 8kHz. The train, dev,
|
||||||
|
and test1 splits of the iwslt2022 shared task correspond to catalog number LDC2022E01. Please
|
||||||
|
note that access to this data requires an LDC subscription from your institution.To obtain this
|
||||||
|
dataset, you should download the predefined splits by running the following command:
|
||||||
|
git clone https://github.com/kevinduh/iwslt22-dialect.git. For more detailed information about
|
||||||
|
the shared task, please refer to the task paper available at this link:
|
||||||
|
https://aclanthology.org/2022.iwslt-1.10/.
|
||||||
|
|
||||||
|
## Stateless Pruned Transducer Performance Record (after 20 epochs)
|
||||||
|
|
||||||
|
| Decoding method | dev Bleu | test Bleu | comment |
|
||||||
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 10, beam(10), pruned range 5 |
|
||||||
|
|
||||||
|
## Zipformer Performance Record (after 20 epochs)
|
||||||
|
|
||||||
|
| Decoding method | dev Bleu | test Bleu | comment |
|
||||||
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 10, beam(10),pruned range 5 |
|
||||||
|
| modified beam search | 15.5 | 13 | --epoch 20, --avg 10, beam(20),pruned range 5 |
|
||||||
|
| modified beam search | 17.6 | 14.8 | --epoch 20, --avg 10, beam(10), pruned range 10 |
|
||||||
|
|
||||||
|
|
||||||
|
See [RESULTS](/egs/iwslt_ta/ST/RESULTS.md) for details.
|
123
egs/iwslt22_ta/ST/RESULTS.md
Normal file
123
egs/iwslt22_ta/ST/RESULTS.md
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# Results
|
||||||
|
|
||||||
|
|
||||||
|
### IWSLT Tunisian training results (Stateless Pruned Transducer)
|
||||||
|
|
||||||
|
#### 2023-06-01
|
||||||
|
|
||||||
|
|
||||||
|
| Decoding method | dev Bleu | test Bleu | comment |
|
||||||
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 10, beam(10), pruned range 5 |
|
||||||
|
|
||||||
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
./pruned_transducer_stateless5/train_st.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
|
--max-duration 300 \
|
||||||
|
--bucketing-sampler 1\
|
||||||
|
--num-buckets 50
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
https://tensorboard.dev/experiment/YnzQNCVDSxCvP1onrCzg9A/
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```
|
||||||
|
for method in modified_beam_search; do
|
||||||
|
for epoch in 15 20; do
|
||||||
|
./pruned_transducer_stateless5/decode_st.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--beam-size 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless5/exp_st_single_task2 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--dim-feedforward 1024 \
|
||||||
|
--nhead 8 \
|
||||||
|
--encoder-dim 256 \
|
||||||
|
--decoder-dim 256 \
|
||||||
|
--joiner-dim 256 \
|
||||||
|
--use-averaged-model true
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
### IWSLT Tunisian training results (Zipformer)
|
||||||
|
|
||||||
|
#### 2023-06-01
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
| Decoding method | dev Bleu | test Bleu | comment |
|
||||||
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 10, beam(10),pruned range 5 |
|
||||||
|
| modified beam search | 15.5 | 13 | --epoch 20, --avg 10, beam(20),pruned range 5 |
|
||||||
|
| modified beam search | 17.6 | 14.8 | --epoch 20, --avg 10, beam(10), pruned range 10 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
To reproduce the above result, use the following commands for training:
|
||||||
|
|
||||||
|
# Note: the model was trained on V-100 32GB GPU
|
||||||
|
# ST medium model 42.5M prune-range 10
|
||||||
|
```
|
||||||
|
|
||||||
|
./zipformer/train_st.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer/exp-st-medium-prun10 \
|
||||||
|
--causal 0 \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||||
|
--encoder-dim 192,256,384,512,384,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--context-size 2 \
|
||||||
|
--prune-range 10
|
||||||
|
--prune-range 10
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
https://tensorboard.dev/experiment/4sa4M1mRQyKjOE4o95mWUw/
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
|
||||||
|
```
|
||||||
|
for method in modified_beam_search; do
|
||||||
|
for epoch in 15 20; do
|
||||||
|
./zipformer/decode_st.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--beam-size 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir ./zipformer/exp-st-medium-prun10 \
|
||||||
|
--max-duration 800 \
|
||||||
|
--decoding-method $method \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||||
|
--encoder-dim 192,256,384,512,384,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||||
|
--context-size 2 \
|
||||||
|
--use-averaged-model true
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
0
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_streaming.py → egs/iwslt22_ta/ST/local/__init__.py
Normal file → Executable file
0
egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_streaming.py → egs/iwslt22_ta/ST/local/__init__.py
Normal file → Executable file
58
egs/iwslt22_ta/ST/local/cer.py
Normal file
58
egs/iwslt22_ta/ST/local/cer.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script computes CER for the decodings generated by icefall recipe
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import jiwer
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dec-file",
|
||||||
|
type=str,
|
||||||
|
help="file with decoded text"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def cer_(file):
|
||||||
|
hyp = []
|
||||||
|
ref = []
|
||||||
|
cer_results = 0
|
||||||
|
ref_lens = 0
|
||||||
|
with open(file, 'r', encoding='utf-8') as dec:
|
||||||
|
|
||||||
|
for line in dec:
|
||||||
|
id, target = line.split('\t')
|
||||||
|
id = id[0:-2]
|
||||||
|
target, txt = target.split("=")
|
||||||
|
if target == 'ref':
|
||||||
|
words = txt.strip().strip('[]').split(', ')
|
||||||
|
word_list = [word.strip("'") for word in words]
|
||||||
|
ref.append(" ".join(word_list))
|
||||||
|
elif target == 'hyp':
|
||||||
|
words = txt.strip().strip('[]').split(', ')
|
||||||
|
word_list = [word.strip("'") for word in words]
|
||||||
|
hyp.append(" ".join(word_list))
|
||||||
|
for h, r in zip(hyp, ref):
|
||||||
|
#breakpoint()
|
||||||
|
cer_results += (jiwer.cer(r, h)*len(r))
|
||||||
|
ref_lens += len(r)
|
||||||
|
print(os.path.basename(file))
|
||||||
|
print(cer_results/ref_lens)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parse = get_args()
|
||||||
|
args = parse.parse_args()
|
||||||
|
cer_(args.dec_file)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
159
egs/iwslt22_ta/ST/local/compile_hlg.py
Executable file
159
egs/iwslt22_ta/ST/local/compile_hlg.py
Executable file
@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input lang_dir and generates HLG from
|
||||||
|
|
||||||
|
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||||
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
|
The generated HLG is saved in $lang_dir/HLG.pt
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An FSA representing HLG.
|
||||||
|
"""
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
|
H = k2.ctc_topo(max_token_id)
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
|
if Path("data/lm/G_3_gram.pt").is_file():
|
||||||
|
logging.info("Loading pre-compiled G_3_gram")
|
||||||
|
d = torch.load("data/lm/G_3_gram.pt")
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
else:
|
||||||
|
logging.info("Loading G_3_gram.fst.txt")
|
||||||
|
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||||
|
|
||||||
|
first_token_disambig_id = lexicon.token_table["#0"]
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
logging.info("Intersecting L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
logging.info(f"LG shape: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
logging.info("Determinizing LG")
|
||||||
|
|
||||||
|
LG = k2.determinize(LG)
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
|
||||||
|
logging.info("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
logging.info("Removing disambiguation symbols on LG")
|
||||||
|
|
||||||
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set LG.properties to None
|
||||||
|
LG.__dict__["_properties"] = None
|
||||||
|
|
||||||
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
logging.info("Composing H and LG")
|
||||||
|
# CAUTION: The name of the inner_labels is fixed
|
||||||
|
# to `tokens`. If you want to change it, please
|
||||||
|
# also change other places in icefall that are using
|
||||||
|
# it.
|
||||||
|
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
HLG = k2.connect(HLG)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
HLG = k2.arc_sort(HLG)
|
||||||
|
logging.info(f"HLG.shape: {HLG.shape}")
|
||||||
|
|
||||||
|
return HLG
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
if (lang_dir / "HLG.pt").is_file():
|
||||||
|
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
|
HLG = compile_HLG(lang_dir)
|
||||||
|
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||||
|
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
173
egs/iwslt22_ta/ST/local/compute_fbank_gpu.py
Executable file
173
egs/iwslt22_ta/ST/local/compute_fbank_gpu.py
Executable file
@ -0,0 +1,173 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Johns Hopkins University (authors: Amir Hussein)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the MGB2 dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
from lhotse.features.kaldifeat import (
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-splits",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of splits for the train set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Start index of the train set split.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Stop index of the train set split.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, only compute features for the dev and val set.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_gpu(args):
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
num_jobs = os.cpu_count()
|
||||||
|
num_mel_bins = 80
|
||||||
|
sampling_rate = 16000
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
dataset_parts = (
|
||||||
|
"train",
|
||||||
|
"test1",
|
||||||
|
"dev",
|
||||||
|
)
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
prefix="iwslt", dataset_parts=dataset_parts, output_dir=src_dir
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
extractor = KaldifeatFbank(
|
||||||
|
KaldifeatFbankConfig(
|
||||||
|
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||||
|
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for partition, m in manifests.items():
|
||||||
|
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("About to split cuts into smaller chunks.")
|
||||||
|
if sr != None:
|
||||||
|
logging.info(f"Resampling to {sr}")
|
||||||
|
cut_set = cut_set.resample(sr)
|
||||||
|
|
||||||
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
|
keep_overlapping=False,
|
||||||
|
keep_all_channels=False)
|
||||||
|
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
|
||||||
|
if "train" in partition:
|
||||||
|
cut_set = (
|
||||||
|
cut_set
|
||||||
|
+ cut_set.perturb_speed(0.9)
|
||||||
|
+ cut_set.perturb_speed(1.1)
|
||||||
|
)
|
||||||
|
cut_set = cut_set.to_eager()
|
||||||
|
chunk_size = len(cut_set) // args.num_splits
|
||||||
|
cut_sets = cut_set.split_lazy(
|
||||||
|
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
||||||
|
chunk_size=chunk_size,)
|
||||||
|
start = args.start
|
||||||
|
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
|
||||||
|
num_digits = len(str(args.num_splits))
|
||||||
|
|
||||||
|
for i in range(start, stop):
|
||||||
|
idx = f"{i + 1}".zfill(num_digits)
|
||||||
|
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
|
||||||
|
logging.info(f"Processing train split {i}")
|
||||||
|
cs = cut_sets[i].compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / f"feats_train_{idx}",
|
||||||
|
batch_duration=1000,
|
||||||
|
num_workers=8,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
cs.to_file(cuts_train_idx_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / f"feats_{partition}",
|
||||||
|
batch_duration=1000,
|
||||||
|
num_workers=10,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
compute_fbank_gpu(args)
|
109
egs/iwslt22_ta/ST/local/compute_fbank_musan.py
Executable file
109
egs/iwslt22_ta/ST/local/compute_fbank_musan.py
Executable file
@ -0,0 +1,109 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the musan dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cut_long(c: MonoCut) -> bool:
|
||||||
|
return c.duration > 5
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_musan():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
num_jobs = min(30, os.cpu_count())
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
dataset_parts = (
|
||||||
|
"music",
|
||||||
|
"speech",
|
||||||
|
"noise",
|
||||||
|
)
|
||||||
|
prefix = "musan"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
|
||||||
|
|
||||||
|
if musan_cuts_path.is_file():
|
||||||
|
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Extracting features for Musan")
|
||||||
|
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
# create chunks of Musan with duration 5 - 10 seconds
|
||||||
|
musan_cuts = (
|
||||||
|
CutSet.from_manifests(
|
||||||
|
recordings=combine(part["recordings"] for part in manifests.values())
|
||||||
|
)
|
||||||
|
.cut_into_windows(10.0)
|
||||||
|
.filter(is_cut_long)
|
||||||
|
.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/musan_feats",
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
musan_cuts.to_file(musan_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
compute_fbank_musan()
|
107
egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py
Executable file
107
egs/iwslt22_ta/ST/local/convert_transcript_words_to_tokens.py
Executable file
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
"""
|
||||||
|
Convert a transcript file containing words to a corpus file containing tokens
|
||||||
|
for LM training with the help of a lexicon.
|
||||||
|
|
||||||
|
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
||||||
|
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
||||||
|
|
||||||
|
If a word has multiple pronunciations, the one that appears first in the lexicon
|
||||||
|
is kept; others are removed.
|
||||||
|
|
||||||
|
If the input transcript is:
|
||||||
|
|
||||||
|
hello zoo world hello
|
||||||
|
world zoo
|
||||||
|
foo zoo world hellO
|
||||||
|
|
||||||
|
and if the lexicon is
|
||||||
|
|
||||||
|
<UNK> SPN
|
||||||
|
hello h e l l o 2
|
||||||
|
hello h e l l o
|
||||||
|
world w o r l d
|
||||||
|
zoo z o o
|
||||||
|
|
||||||
|
Then the output is
|
||||||
|
|
||||||
|
h e l l o 2 z o o w o r l d h e l l o 2
|
||||||
|
w o r l d z o o
|
||||||
|
SPN z o o w o r l d SPN
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from generate_unique_lexicon import filter_multiple_pronunications
|
||||||
|
|
||||||
|
from icefall.lexicon import read_lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript",
|
||||||
|
type=str,
|
||||||
|
help="The input transcript file."
|
||||||
|
"We assume that the transcript file consists of "
|
||||||
|
"lines. Each line consists of space separated words.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--oov", type=str, default="<UNK>", help="The OOV word."
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def process_line(
|
||||||
|
lexicon: Dict[str, List[str]], line: str, oov_token: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
A dict containing pronunciations. Its keys are words and values
|
||||||
|
are pronunciations (i.e., tokens).
|
||||||
|
line:
|
||||||
|
A line of transcript consisting of space(s) separated words.
|
||||||
|
oov_token:
|
||||||
|
The pronunciation of the oov word if a word in `line` is not present
|
||||||
|
in the lexicon.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
s = ""
|
||||||
|
words = line.strip().split()
|
||||||
|
for i, w in enumerate(words):
|
||||||
|
tokens = lexicon.get(w, oov_token)
|
||||||
|
s += " ".join(tokens)
|
||||||
|
s += " "
|
||||||
|
print(s.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert Path(args.lexicon).is_file()
|
||||||
|
assert Path(args.transcript).is_file()
|
||||||
|
assert len(args.oov) > 0
|
||||||
|
|
||||||
|
# Only the first pronunciation of a word is kept
|
||||||
|
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
|
||||||
|
|
||||||
|
lexicon = dict(lexicon)
|
||||||
|
|
||||||
|
assert args.oov in lexicon
|
||||||
|
|
||||||
|
oov_token = lexicon[args.oov]
|
||||||
|
|
||||||
|
with open(args.transcript) as f:
|
||||||
|
for line in f:
|
||||||
|
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
109
egs/iwslt22_ta/ST/local/cuts_validate.py
Executable file
109
egs/iwslt22_ta/ST/local/cuts_validate.py
Executable file
@ -0,0 +1,109 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script helps validating the prepared manifests (recordings, supervisions)
|
||||||
|
and CutSets
|
||||||
|
|
||||||
|
"""
|
||||||
|
from lhotse import RecordingSet, SupervisionSet, CutSet
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sup",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Supervisions file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rec",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Recordings file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Cutset file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--savecut",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="name of the cutset to be saved",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def valid_asr(cut):
|
||||||
|
tol = 2e-3
|
||||||
|
i=0
|
||||||
|
total_dur = 0
|
||||||
|
for c in cut:
|
||||||
|
if c.supervisions != []:
|
||||||
|
if c.supervisions[0].end > c.duration + tol:
|
||||||
|
|
||||||
|
logging.info(f"Supervision beyond the cut. Cut number: {i}")
|
||||||
|
total_dur += c.duration
|
||||||
|
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||||
|
elif c.supervisions[0].start < -tol:
|
||||||
|
logging.info(f"Supervision starts before the cut. Cut number: {i}")
|
||||||
|
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logging.info("Empty supervision")
|
||||||
|
logging.info(f"id: {c.id}")
|
||||||
|
i += 1
|
||||||
|
logging.info(f"filtered duration: {total_dur}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.cut != "":
|
||||||
|
cuts = CutSet.from_file(args.cut)
|
||||||
|
else:
|
||||||
|
|
||||||
|
recordings = RecordingSet.from_file(args.rec)
|
||||||
|
supervisions = SupervisionSet.from_file(args.sup)
|
||||||
|
logging.info("Example from supervisions:")
|
||||||
|
logging.info(supervisions[0])
|
||||||
|
logging.info("Example from recordings")
|
||||||
|
print(recordings[0])
|
||||||
|
logging.info("Fixing manifests")
|
||||||
|
recordings, supervisions = fix_manifests(recordings, supervisions)
|
||||||
|
|
||||||
|
logging.info("Validating manifests")
|
||||||
|
validate_recordings_and_supervisions(recordings, supervisions)
|
||||||
|
|
||||||
|
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
|
||||||
|
|
||||||
|
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
|
||||||
|
logging.info("Example from cut:")
|
||||||
|
print(cuts[100])
|
||||||
|
breakpoint()
|
||||||
|
cuts.describe()
|
||||||
|
logging.info("Validating manifests for ASR")
|
||||||
|
valid_asr(cuts)
|
||||||
|
if args.savecut != "":
|
||||||
|
cuts.to_file(args.savecut)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
97
egs/iwslt22_ta/ST/local/display_manifest_statistics.py
Executable file
97
egs/iwslt22_ta/ST/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file displays duration statistics of utterances in a manifest.
|
||||||
|
You can use the displayed value to choose minimum/maximum duration
|
||||||
|
to remove short and long utterances during the training.
|
||||||
|
|
||||||
|
See the function `remove_short_and_long_utt()` in transducer/train.py
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# path = "./data/fbank/cuts_train.jsonl.gz"
|
||||||
|
path = "./data/fbank/cuts_dev.jsonl.gz"
|
||||||
|
# path = "./data/fbank/cuts_test.jsonl.gz"
|
||||||
|
|
||||||
|
cuts = load_manifest(path)
|
||||||
|
cuts.describe()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
# train
|
||||||
|
|
||||||
|
Cuts count: 1125309
|
||||||
|
Total duration (hours): 3403.9
|
||||||
|
Speech duration (hours): 3403.9 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 10.9
|
||||||
|
std 10.1
|
||||||
|
min 0.2
|
||||||
|
25% 5.2
|
||||||
|
50% 7.8
|
||||||
|
75% 12.7
|
||||||
|
99% 52.0
|
||||||
|
99.5% 65.1
|
||||||
|
99.9% 99.5
|
||||||
|
max 228.9
|
||||||
|
|
||||||
|
|
||||||
|
# test
|
||||||
|
Cuts count: 5365
|
||||||
|
Total duration (hours): 9.6
|
||||||
|
Speech duration (hours): 9.6 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.4
|
||||||
|
std 1.5
|
||||||
|
min 1.6
|
||||||
|
25% 5.3
|
||||||
|
50% 6.5
|
||||||
|
75% 7.6
|
||||||
|
99% 9.5
|
||||||
|
99.5% 9.7
|
||||||
|
99.9% 10.3
|
||||||
|
max 12.4
|
||||||
|
|
||||||
|
# dev
|
||||||
|
Cuts count: 5002
|
||||||
|
Total duration (hours): 8.5
|
||||||
|
Speech duration (hours): 8.5 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.1
|
||||||
|
std 1.7
|
||||||
|
min 1.5
|
||||||
|
25% 4.8
|
||||||
|
50% 6.2
|
||||||
|
75% 7.4
|
||||||
|
99% 9.5
|
||||||
|
99.5% 9.7
|
||||||
|
99.9% 10.1
|
||||||
|
max 20.3
|
||||||
|
|
||||||
|
"""
|
97
egs/iwslt22_ta/ST/local/download_lm.py
Executable file
97
egs/iwslt22_ta/ST/local/download_lm.py
Executable file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file downloads the following LibriSpeech LM files:
|
||||||
|
|
||||||
|
- 3-gram.pruned.1e-7.arpa.gz
|
||||||
|
- 4-gram.arpa.gz
|
||||||
|
- librispeech-vocab.txt
|
||||||
|
- librispeech-lexicon.txt
|
||||||
|
|
||||||
|
from http://www.openslr.org/resources/11
|
||||||
|
and save them in the user provided directory.
|
||||||
|
|
||||||
|
Files are not re-downloaded if they already exist.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
./local/download_lm.py --out-dir ./download/lm
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lhotse.utils import urlretrieve_progress
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--out-dir", type=str, help="Output directory.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(out_dir: str):
|
||||||
|
url = "http://www.openslr.org/resources/11"
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
|
||||||
|
files_to_download = (
|
||||||
|
"3-gram.pruned.1e-7.arpa.gz",
|
||||||
|
"4-gram.arpa.gz",
|
||||||
|
"librispeech-vocab.txt",
|
||||||
|
"librispeech-lexicon.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
|
||||||
|
filename = out_dir / f
|
||||||
|
if filename.is_file() is False:
|
||||||
|
urlretrieve_progress(
|
||||||
|
f"{url}/{f}",
|
||||||
|
filename=filename,
|
||||||
|
desc=f"Downloading {filename}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"{filename} already exists - skipping")
|
||||||
|
|
||||||
|
if ".gz" in str(filename):
|
||||||
|
unzipped = Path(os.path.splitext(filename)[0])
|
||||||
|
if unzipped.is_file() is False:
|
||||||
|
with gzip.open(filename, "rb") as f_in:
|
||||||
|
with open(unzipped, "wb") as f_out:
|
||||||
|
shutil.copyfileobj(f_in, f_out)
|
||||||
|
else:
|
||||||
|
logging.info(f"{unzipped} already exist - skipping")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
logging.info(f"out_dir: {args.out_dir}")
|
||||||
|
|
||||||
|
main(out_dir=args.out_dir)
|
100
egs/iwslt22_ta/ST/local/generate_unique_lexicon.py
Executable file
100
egs/iwslt22_ta/ST/local/generate_unique_lexicon.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file takes as input a lexicon.txt and output a new lexicon,
|
||||||
|
in which each word has a unique pronunciation.
|
||||||
|
|
||||||
|
The way to do this is to keep only the first pronunciation of a word
|
||||||
|
in lexicon.txt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from icefall.lexicon import read_lexicon, write_lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
This file will generate a new file uniq_lexicon.txt
|
||||||
|
in it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def filter_multiple_pronunications(
|
||||||
|
lexicon: List[Tuple[str, List[str]]]
|
||||||
|
) -> List[Tuple[str, List[str]]]:
|
||||||
|
"""Remove multiple pronunciations of words from a lexicon.
|
||||||
|
|
||||||
|
If a word has more than one pronunciation in the lexicon, only
|
||||||
|
the first one is kept, while other pronunciations are removed
|
||||||
|
from the lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
|
||||||
|
where "p1, p2, ..., pn" are the pronunciations of the "word".
|
||||||
|
Returns:
|
||||||
|
Return a new lexicon where each word has a unique pronunciation.
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
if word in seen:
|
||||||
|
continue
|
||||||
|
seen.add(word)
|
||||||
|
ans.append((word, tokens))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
|
||||||
|
in_lexicon = read_lexicon(lexicon_filename)
|
||||||
|
|
||||||
|
out_lexicon = filter_multiple_pronunications(in_lexicon)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
|
||||||
|
|
||||||
|
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
|
||||||
|
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
21
egs/iwslt22_ta/ST/local/prep_lexicon.sh
Executable file
21
egs/iwslt22_ta/ST/local/prep_lexicon.sh
Executable file
@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Copyright 2022 QCRI (author: Amir Hussein)
|
||||||
|
# Apache 2.0
|
||||||
|
# This script prepares the graphemic lexicon.
|
||||||
|
|
||||||
|
dir=data/local/dict
|
||||||
|
stage=0
|
||||||
|
lang_dir_src=$1
|
||||||
|
lang_dir_tgt=$2
|
||||||
|
|
||||||
|
cat $lang_dir_src/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_src/uniq_words
|
||||||
|
cat $lang_dir_tgt/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir_tgt/uniq_words
|
||||||
|
|
||||||
|
echo "$0: processing lexicon text and creating lexicon... $(date)."
|
||||||
|
# remove vowels and rare alef wasla
|
||||||
|
cat $lang_dir_src/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir_src/words.txt
|
||||||
|
cat $lang_dir_tgt/uniq_words | sed -r '/^\s*$/d' | sort -u > $lang_dir_tgt/words.txt
|
||||||
|
|
||||||
|
|
||||||
|
echo "$0: Lexicon preparation succeeded"
|
414
egs/iwslt22_ta/ST/local/prepare_lang.py
Executable file
414
egs/iwslt22_ta/ST/local/prepare_lang.py
Executable file
@ -0,0 +1,414 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||||
|
consisting of words and tokens (i.e., phones) and does the following:
|
||||||
|
|
||||||
|
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||||
|
|
||||||
|
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||||
|
|
||||||
|
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||||
|
|
||||||
|
4. Generate L.pt, in k2 format. It can be loaded by
|
||||||
|
|
||||||
|
d = torch.load("L.pt")
|
||||||
|
lexicon = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import read_lexicon, write_lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
Lexicon = List[Tuple[str, List[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
Generated files by this script are saved into this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
No need to implement `read_mapping` as it can be done
|
||||||
|
through :func:`k2.SymbolTable.from_file`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename to save the mapping.
|
||||||
|
sym2id:
|
||||||
|
A dict mapping symbols to IDs.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for sym, i in sym2id.items():
|
||||||
|
f.write(f"{sym} {i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get tokens from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique tokens.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
ans.update(tokens)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def get_words(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get words from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique words.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for word, _ in lexicon:
|
||||||
|
ans.add(word)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||||
|
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||||
|
at the ends of tokens to ensure that all pronunciations are different,
|
||||||
|
and that none is a prefix of another.
|
||||||
|
|
||||||
|
See also add_lex_disambig.pl from kaldi.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is returned by :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a tuple with two elements:
|
||||||
|
|
||||||
|
- The output lexicon with disambiguation symbols
|
||||||
|
- The ID of the max disambiguation symbol that appears
|
||||||
|
in the lexicon
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (1) Work out the count of each token-sequence in the
|
||||||
|
# lexicon.
|
||||||
|
count = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
count[" ".join(tokens)] += 1
|
||||||
|
|
||||||
|
# (2) For each left sub-sequence of each token-sequence, note down
|
||||||
|
# that it exists (for identifying prefixes of longer strings).
|
||||||
|
issubseq = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
tokens = tokens.copy()
|
||||||
|
tokens.pop()
|
||||||
|
while tokens:
|
||||||
|
issubseq[" ".join(tokens)] = 1
|
||||||
|
tokens.pop()
|
||||||
|
|
||||||
|
# (3) For each entry in the lexicon:
|
||||||
|
# if the token sequence is unique and is not a
|
||||||
|
# prefix of another word, no disambig symbol.
|
||||||
|
# Else output #1, or #2, #3, ... if the same token-seq
|
||||||
|
# has already been assigned a disambig symbol.
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
# We start with #1 since #0 has its own purpose
|
||||||
|
first_allowed_disambig = 1
|
||||||
|
max_disambig = first_allowed_disambig - 1
|
||||||
|
last_used_disambig_symbol_of = defaultdict(int)
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
tokenseq = " ".join(tokens)
|
||||||
|
assert tokenseq != ""
|
||||||
|
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||||
|
ans.append((word, tokens))
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||||
|
if cur_disambig == 0:
|
||||||
|
cur_disambig = first_allowed_disambig
|
||||||
|
else:
|
||||||
|
cur_disambig += 1
|
||||||
|
|
||||||
|
if cur_disambig > max_disambig:
|
||||||
|
max_disambig = cur_disambig
|
||||||
|
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||||
|
tokenseq += f" #{cur_disambig}"
|
||||||
|
ans.append((word, tokenseq.split()))
|
||||||
|
return ans, max_disambig
|
||||||
|
|
||||||
|
|
||||||
|
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||||
|
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols:
|
||||||
|
A list of unique symbols.
|
||||||
|
Returns:
|
||||||
|
A dict containing the mapping between symbols and IDs.
|
||||||
|
"""
|
||||||
|
return {sym: i for i, sym in enumerate(symbols)}
|
||||||
|
|
||||||
|
|
||||||
|
def add_self_loops(
|
||||||
|
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||||
|
) -> List[List[Any]]:
|
||||||
|
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||||
|
through it. They are added on each state with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state.
|
||||||
|
|
||||||
|
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||||
|
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||||
|
This function uses k2 style FSTs and it does not need to add self-loops
|
||||||
|
to the final state.
|
||||||
|
|
||||||
|
The input label of a self-loop is `disambig_token`, while the output
|
||||||
|
label is `disambig_word`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arcs:
|
||||||
|
A list-of-list. The sublist contains
|
||||||
|
`[src_state, dest_state, label, aux_label, score]`
|
||||||
|
disambig_token:
|
||||||
|
It is the token ID of the symbol `#0`.
|
||||||
|
disambig_word:
|
||||||
|
It is the word ID of the symbol `#0`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Return new `arcs` containing self-loops.
|
||||||
|
"""
|
||||||
|
states_needs_self_loops = set()
|
||||||
|
for arc in arcs:
|
||||||
|
src, dst, ilabel, olabel, score = arc
|
||||||
|
if olabel != 0:
|
||||||
|
states_needs_self_loops.add(src)
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for s in states_needs_self_loops:
|
||||||
|
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||||
|
|
||||||
|
return arcs + ans
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
sil_token: str = "SIL",
|
||||||
|
sil_prob: float = 0.5,
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||||
|
the beginning and end of each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
sil_token:
|
||||||
|
The silence token.
|
||||||
|
sil_prob:
|
||||||
|
The probability for adding a silence at the beginning and end
|
||||||
|
of the word.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||||
|
# CAUTION: we use score, i.e, negative cost.
|
||||||
|
sil_score = math.log(sil_prob)
|
||||||
|
no_sil_score = math.log(1.0 - sil_prob)
|
||||||
|
|
||||||
|
start_state = 0
|
||||||
|
loop_state = 1 # words enter and leave from here
|
||||||
|
sil_state = 2 # words terminate here when followed by silence; this state
|
||||||
|
# has a silence transition to loop_state.
|
||||||
|
# the next un-allocated state, will be incremented as we go.
|
||||||
|
next_state = 3
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
assert token2id["<eps>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
sil_token = token2id[sil_token]
|
||||||
|
|
||||||
|
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||||
|
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||||
|
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
tokens = [token2id[i] for i in tokens]
|
||||||
|
|
||||||
|
for i in range(len(tokens) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last token of this word
|
||||||
|
# It has two out-going arcs, one to the loop state,
|
||||||
|
# the other one to the sil_state.
|
||||||
|
i = len(tokens) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||||
|
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
sil_token = "SIL"
|
||||||
|
sil_prob = 0.5
|
||||||
|
|
||||||
|
lexicon = read_lexicon(lexicon_filename)
|
||||||
|
tokens = get_tokens(lexicon)
|
||||||
|
words = get_words(lexicon)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in tokens
|
||||||
|
tokens.append(f"#{i}")
|
||||||
|
|
||||||
|
assert "<eps>" not in tokens
|
||||||
|
tokens = ["<eps>"] + tokens
|
||||||
|
|
||||||
|
assert "<eps>" not in words
|
||||||
|
assert "#0" not in words
|
||||||
|
assert "<s>" not in words
|
||||||
|
assert "</s>" not in words
|
||||||
|
|
||||||
|
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
token2id = generate_id_map(tokens)
|
||||||
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||||
|
write_mapping(lang_dir / "words.txt", word2id)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst(
|
||||||
|
lexicon,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
255
egs/iwslt22_ta/ST/local/prepare_lang_bpe.py
Executable file
255
egs/iwslt22_ta/ST/local/prepare_lang_bpe.py
Executable file
@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
This script takes as input `lang_dir`, which should contain::
|
||||||
|
|
||||||
|
- lang_dir/bpe.model,
|
||||||
|
- lang_dir/words.txt
|
||||||
|
|
||||||
|
and generates the following files in the directory `lang_dir`:
|
||||||
|
|
||||||
|
- lexicon.txt
|
||||||
|
- lexicon_disambig.txt
|
||||||
|
- L.pt
|
||||||
|
- L_disambig.pt
|
||||||
|
- tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from prepare_lang import (
|
||||||
|
Lexicon,
|
||||||
|
add_disambig_symbols,
|
||||||
|
add_self_loops,
|
||||||
|
write_lexicon,
|
||||||
|
write_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst_no_sil(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
loop_state = 0 # words enter and leave from here
|
||||||
|
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||||
|
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||||
|
assert token2id["<blk>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
for word, pieces in lexicon:
|
||||||
|
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
pieces = [token2id[i] for i in pieces]
|
||||||
|
|
||||||
|
for i in range(len(pieces) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last piece of this word
|
||||||
|
i = len(pieces) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon(
|
||||||
|
model_file: str, words: List[str]
|
||||||
|
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||||
|
"""Generate a lexicon from a BPE model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file:
|
||||||
|
Path to a sentencepiece model.
|
||||||
|
words:
|
||||||
|
A list of strings representing words.
|
||||||
|
Returns:
|
||||||
|
Return a tuple with two elements:
|
||||||
|
- A dict whose keys are words and values are the corresponding
|
||||||
|
word pieces.
|
||||||
|
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||||
|
"""
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(str(model_file))
|
||||||
|
|
||||||
|
words_pieces: List[List[str]] = sp.encode(words, out_type=str)
|
||||||
|
|
||||||
|
lexicon = []
|
||||||
|
for word, pieces in zip(words, words_pieces):
|
||||||
|
lexicon.append((word, pieces))
|
||||||
|
|
||||||
|
# The OOV word is <UNK>
|
||||||
|
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())]))
|
||||||
|
|
||||||
|
token2id: Dict[str, int] = dict()
|
||||||
|
for i in range(sp.vocab_size()):
|
||||||
|
token2id[sp.id_to_piece(i)] = i
|
||||||
|
|
||||||
|
return lexicon, token2id
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain the bpe.model and words.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
|
||||||
|
See "test/test_bpe_lexicon.py" for usage.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
model_file = lang_dir / "bpe.model"
|
||||||
|
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
words = word_sym_table.symbols
|
||||||
|
|
||||||
|
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||||
|
for w in excluded:
|
||||||
|
if w in words:
|
||||||
|
words.remove(w)
|
||||||
|
|
||||||
|
lexicon, token_sym_table = generate_lexicon(model_file, words)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
next_token_id = max(token_sym_table.values()) + 1
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in token_sym_table
|
||||||
|
token_sym_table[disambig] = next_token_id
|
||||||
|
next_token_id += 1
|
||||||
|
|
||||||
|
word_sym_table.add("#0")
|
||||||
|
word_sym_table.add("<s>")
|
||||||
|
word_sym_table.add("</s>")
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst_no_sil(
|
||||||
|
lexicon,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst_no_sil(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
39
egs/iwslt22_ta/ST/local/prepare_lexicon.py
Executable file
39
egs/iwslt22_ta/ST/local/prepare_lexicon.py
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
# This script prepares givel a column of words lexicon.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="""Creates the list of characters and words in lexicon"""
|
||||||
|
)
|
||||||
|
parser.add_argument("input", type=str, help="""Input list of words file""")
|
||||||
|
parser.add_argument("output", type=str, help="""output graphemic lexicon""")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
lex = {}
|
||||||
|
args = get_args()
|
||||||
|
with open(args.input, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
characters = list(line)
|
||||||
|
characters = " ".join(
|
||||||
|
["V" if char == "*" else char for char in characters]
|
||||||
|
)
|
||||||
|
lex[line] = characters
|
||||||
|
|
||||||
|
with open(args.output, "w", encoding="utf-8") as fp:
|
||||||
|
for key in sorted(lex):
|
||||||
|
fp.write(key + " " + lex[key] + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
67
egs/iwslt22_ta/ST/local/prepare_transcripts.py
Executable file
67
egs/iwslt22_ta/ST/local/prepare_transcripts.py
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2023 Johns Hopkins University (Amir Hussein)
|
||||||
|
|
||||||
|
#!/usr/bin/python
|
||||||
|
"""
|
||||||
|
This script prepares transcript_words.txt from cutset
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lhotse import CutSet
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import pdb
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Cutset file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--src-langdir",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="name of the source lang-dir",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tgt-langdir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="name of the target lang-dir",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.info("Reading the cuts")
|
||||||
|
cuts = CutSet.from_file(args.cut)
|
||||||
|
if args.tgt_langdir != None:
|
||||||
|
logging.info("Target dir is not None")
|
||||||
|
langdirs = [Path(args.src_langdir), Path(args.tgt_langdir)]
|
||||||
|
else:
|
||||||
|
langdirs = [Path(args.src_langdir)]
|
||||||
|
|
||||||
|
for langdir in langdirs:
|
||||||
|
if not os.path.exists(langdir):
|
||||||
|
os.makedirs(langdir)
|
||||||
|
|
||||||
|
with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt:
|
||||||
|
for c in cuts:
|
||||||
|
#breakpoint()
|
||||||
|
src_txt = c.supervisions[0].text
|
||||||
|
tgt_txt = c.supervisions[0].custom['tgt_text']
|
||||||
|
src.write(src_txt + '\n')
|
||||||
|
tgt.write(tgt_txt + '\n')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
106
egs/iwslt22_ta/ST/local/test_prepare_lang.py
Executable file
106
egs/iwslt22_ta/ST/local/test_prepare_lang.py
Executable file
@ -0,0 +1,106 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import k2
|
||||||
|
from prepare_lang import (
|
||||||
|
add_disambig_symbols,
|
||||||
|
generate_id_map,
|
||||||
|
get_phones,
|
||||||
|
get_words,
|
||||||
|
lexicon_to_fst,
|
||||||
|
read_lexicon,
|
||||||
|
write_lexicon,
|
||||||
|
write_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon_file() -> str:
|
||||||
|
fd, filename = tempfile.mkstemp()
|
||||||
|
os.close(fd)
|
||||||
|
s = """
|
||||||
|
!SIL SIL
|
||||||
|
<SPOKEN_NOISE> SPN
|
||||||
|
<UNK> SPN
|
||||||
|
f f
|
||||||
|
a a
|
||||||
|
foo f o o
|
||||||
|
bar b a r
|
||||||
|
bark b a r k
|
||||||
|
food f o o d
|
||||||
|
food2 f o o d
|
||||||
|
fo f o
|
||||||
|
""".strip()
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_lexicon(filename: str):
|
||||||
|
lexicon = read_lexicon(filename)
|
||||||
|
phones = get_phones(lexicon)
|
||||||
|
words = get_words(lexicon)
|
||||||
|
print(lexicon)
|
||||||
|
print(phones)
|
||||||
|
print(words)
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
print(lexicon_disambig)
|
||||||
|
print("max disambig:", f"#{max_disambig}")
|
||||||
|
|
||||||
|
phones = ["<eps>", "SIL", "SPN"] + phones
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
phones.append(f"#{i}")
|
||||||
|
words = ["<eps>"] + words
|
||||||
|
|
||||||
|
phone2id = generate_id_map(phones)
|
||||||
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
|
print(phone2id)
|
||||||
|
print(word2id)
|
||||||
|
|
||||||
|
write_mapping("phones.txt", phone2id)
|
||||||
|
write_mapping("words.txt", word2id)
|
||||||
|
|
||||||
|
write_lexicon("a.txt", lexicon)
|
||||||
|
write_lexicon("a_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
|
||||||
|
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
|
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
|
fsa.draw("L.pdf", title="L")
|
||||||
|
|
||||||
|
fsa_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig, phone2id=phone2id, word2id=word2id
|
||||||
|
)
|
||||||
|
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
|
||||||
|
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
|
||||||
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
filename = generate_lexicon_file()
|
||||||
|
test_read_lexicon(filename)
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
98
egs/iwslt22_ta/ST/local/train_bpe_model.py
Executable file
98
egs/iwslt22_ta/ST/local/train_bpe_model.py
Executable file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# You can install sentencepiece via:
|
||||||
|
#
|
||||||
|
# pip install sentencepiece
|
||||||
|
#
|
||||||
|
# Due to an issue reported in
|
||||||
|
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||||
|
#
|
||||||
|
# Please install a version >=0.1.96
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain the training corpus: transcript_words.txt.
|
||||||
|
The generated bpe.model is saved to this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript",
|
||||||
|
type=str,
|
||||||
|
help="Training transcript.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab-size",
|
||||||
|
type=int,
|
||||||
|
help="Vocabulary size for BPE training",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
vocab_size = args.vocab_size
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
model_type = "unigram"
|
||||||
|
|
||||||
|
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||||
|
train_text = args.transcript
|
||||||
|
character_coverage = 1.0
|
||||||
|
input_sentence_size = 100000000
|
||||||
|
|
||||||
|
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||||
|
unk_id = len(user_defined_symbols)
|
||||||
|
# Note: unk_id is fixed to 2.
|
||||||
|
# If you change it, you should also change other
|
||||||
|
# places that are using it.
|
||||||
|
|
||||||
|
model_file = Path(model_prefix + ".model")
|
||||||
|
if not model_file.is_file():
|
||||||
|
spm.SentencePieceTrainer.train(
|
||||||
|
input=train_text,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
model_type=model_type,
|
||||||
|
model_prefix=model_prefix,
|
||||||
|
input_sentence_size=input_sentence_size,
|
||||||
|
character_coverage=character_coverage,
|
||||||
|
user_defined_symbols=user_defined_symbols,
|
||||||
|
unk_id=unk_id,
|
||||||
|
bos_id=-1,
|
||||||
|
eos_id=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,645 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Johns Hopkins (authors: Amir Hussein)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
(1) greedy search
|
|
||||||
./pruned_transducer_stateless5/decode.py \
|
|
||||||
--epoch 22 \
|
|
||||||
--avg 5 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
|
||||||
--max-duration 200 \
|
|
||||||
--decoding-method greedy_search
|
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
|
||||||
./pruned_transducer_stateless5/decode.py \
|
|
||||||
--epoch 22 \
|
|
||||||
--avg 5 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
|
||||||
--max-duration 200 \
|
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 10
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./pruned_transducer_stateless5/decode.py \
|
|
||||||
--epoch 22 \
|
|
||||||
--avg 5 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
|
||||||
--beam-size 10
|
|
||||||
|
|
||||||
(4) fast beam search
|
|
||||||
./pruned_transducer_stateless5/decode.py \
|
|
||||||
--epoch 22 \
|
|
||||||
--avg 5 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
|
||||||
--max-duration 200 \
|
|
||||||
--decoding-method fast_beam_search \
|
|
||||||
--beam-size 10 \
|
|
||||||
--max-contexts 4 \
|
|
||||||
--max-states 8
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import pdb
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import MGB2AsrDataModule
|
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
average_checkpoints,
|
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=15,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-averaged-model",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
|
||||||
"`epoch` are loaded for averaging. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="pruned_transducer_stateless5/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-model",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_2000/bpe.model",
|
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoding-method",
|
|
||||||
type=str,
|
|
||||||
default="greedy_search",
|
|
||||||
help="""Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
|
||||||
frame. Used only when --decoding-method is beam_search or
|
|
||||||
modified_beam_search.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=4,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --decoding-method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-sym-per-frame",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Maximum number of symbols per frame.
|
|
||||||
Used only when --decoding_method is greedy_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
batch: dict,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[List[str]]]:
|
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if greedy_search is used, it would be "greedy_search"
|
|
||||||
If beam search with a beam size of 7 is used, it would be
|
|
||||||
"beam_7"
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It's the return value of :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search.
|
|
||||||
Returns:
|
|
||||||
Return the decoding result. See above description for the format of
|
|
||||||
the returned dict.
|
|
||||||
"""
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
feature = batch["inputs"]
|
|
||||||
assert feature.ndim == 3
|
|
||||||
|
|
||||||
feature = feature.to(device)
|
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
|
||||||
x=feature, x_lens=feature_lens
|
|
||||||
)
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
|
||||||
params.decoding_method == "greedy_search"
|
|
||||||
and params.max_sym_per_frame == 1
|
|
||||||
):
|
|
||||||
# pdb.set_trace()
|
|
||||||
hyp_tokens = greedy_search_batch(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
# pdb.set_trace()
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
|
||||||
hyp_tokens = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
# fmt: off
|
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
|
||||||
# fmt: on
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
hyp = greedy_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "beam_search":
|
|
||||||
hyp = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
return {"greedy_search": hyps}
|
|
||||||
elif params.decoding_method == "fast_beam_search":
|
|
||||||
return {
|
|
||||||
(
|
|
||||||
f"beam_{params.beam}_"
|
|
||||||
f"max_contexts_{params.max_contexts}_"
|
|
||||||
f"max_states_{params.max_states}"
|
|
||||||
): hyps
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
|
||||||
dl: torch.utils.data.DataLoader,
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
num_cuts = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
num_batches = len(dl)
|
|
||||||
except TypeError:
|
|
||||||
num_batches = "?"
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
log_interval = 50
|
|
||||||
else:
|
|
||||||
log_interval = 20
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
|
||||||
texts = batch["supervisions"]["text"]
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
|
||||||
logging.info(f"Decoding {batch_idx}-th batch")
|
|
||||||
hyps_dict = decode_one_batch(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
# pdb.set_trace()
|
|
||||||
for name, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
|
||||||
assert len(hyps) == len(texts)
|
|
||||||
for cut_ids, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
|
||||||
|
|
||||||
ref_words = ref_text.split()
|
|
||||||
this_batch.append((cut_ids, ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
|
||||||
# pdb.set_trace()
|
|
||||||
num_cuts += len(texts)
|
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
|
||||||
params: AttributeDict,
|
|
||||||
test_set_name: str,
|
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
|
||||||
):
|
|
||||||
test_set_wers = dict()
|
|
||||||
for key, results in results_dict.items():
|
|
||||||
recog_path = (
|
|
||||||
params.res_dir /
|
|
||||||
f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# pdb.set_trace()
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
|
||||||
# ref/hyp pairs.
|
|
||||||
errs_filename = (
|
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
|
||||||
)
|
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = (
|
|
||||||
params.res_dir
|
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
MGB2AsrDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
assert params.decoding_method in (
|
|
||||||
"greedy_search",
|
|
||||||
"beam_search",
|
|
||||||
"fast_beam_search",
|
|
||||||
"modified_beam_search",
|
|
||||||
)
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
|
||||||
else:
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam}"
|
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
|
||||||
elif "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += (
|
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
params.suffix += f"-context-{params.context_size}"
|
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
|
||||||
|
|
||||||
if params.use_averaged_model:
|
|
||||||
params.suffix += "-use-averaged-model"
|
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
|
||||||
|
|
||||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(
|
|
||||||
params.exp_dir, iteration=-params.iter
|
|
||||||
)[: params.avg]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints(filenames, device=device))
|
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if i >= 1:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints(filenames, device=device))
|
|
||||||
else:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(
|
|
||||||
params.exp_dir, iteration=-params.iter
|
|
||||||
)[: params.avg + 1]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg + 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
filename_start = filenames[-1]
|
|
||||||
filename_end = filenames[0]
|
|
||||||
logging.info(
|
|
||||||
"Calculating the averaged model over iteration checkpoints"
|
|
||||||
f" from {filename_start} (excluded) to {filename_end}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert params.avg > 0, params.avg
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
else:
|
|
||||||
decoding_graph = None
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
# we need cut ids to display recognition results.
|
|
||||||
args.return_cuts = True
|
|
||||||
MGB2 = MGB2AsrDataModule(args)
|
|
||||||
|
|
||||||
test_cuts = MGB2.test_cuts()
|
|
||||||
dev_cuts = MGB2.dev_cuts()
|
|
||||||
|
|
||||||
test_dl = MGB2.test_dataloaders(test_cuts)
|
|
||||||
dev_dl = MGB2.test_dataloaders(dev_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test", "dev"]
|
|
||||||
test_all_dl = [test_dl, dev_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_all_dl):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(
|
|
||||||
params=params,
|
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,834 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
(1) greedy search
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method greedy_search
|
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(4) fast beam search (one best)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
|
|
||||||
(5) fast beam search (nbest)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64 \
|
|
||||||
--num-paths 200 \
|
|
||||||
--nbest-scale 0.5
|
|
||||||
|
|
||||||
(6) fast beam search (nbest oracle WER)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest_oracle \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64 \
|
|
||||||
--num-paths 200 \
|
|
||||||
--nbest-scale 0.5
|
|
||||||
|
|
||||||
(7) fast beam search (with LG)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest_LG \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import IWSLTDialectSTDataModule
|
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG,
|
|
||||||
fast_beam_search_nbest_oracle,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
average_checkpoints,
|
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
make_pad_mask,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=15,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-averaged-model",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
|
||||||
"`epoch` are loaded for averaging. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="zipformer/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-model",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_500/bpe.model",
|
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=Path,
|
|
||||||
default="data/lang_bpe_500",
|
|
||||||
help="The lang dir containing word table and LG graph",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoding-method",
|
|
||||||
type=str,
|
|
||||||
default="greedy_search",
|
|
||||||
help="""Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
- fast_beam_search_nbest
|
|
||||||
- fast_beam_search_nbest_oracle
|
|
||||||
- fast_beam_search_nbest_LG
|
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
|
||||||
frame. Used only when --decoding-method is beam_search or
|
|
||||||
modified_beam_search.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=20.0,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --decoding-method is fast_beam_search,
|
|
||||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.01,
|
|
||||||
help="""
|
|
||||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
|
||||||
It specifies the scale for n-gram LM scores.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-sym-per-frame",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Maximum number of symbols per frame.
|
|
||||||
Used only when --decoding_method is greedy_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-paths",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="""Number of paths for nbest decoding.
|
|
||||||
Used only when the decoding method is fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--nbest-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="""Scale applied to lattice scores when computing nbest paths.
|
|
||||||
Used only when the decoding method is fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
batch: dict,
|
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[List[str]]]:
|
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if greedy_search is used, it would be "greedy_search"
|
|
||||||
If beam search with a beam size of 7 is used, it would be
|
|
||||||
"beam_7"
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It's the return value of :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
|
||||||
Return the decoding result. See above description for the format of
|
|
||||||
the returned dict.
|
|
||||||
"""
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
feature = batch["inputs"]
|
|
||||||
assert feature.ndim == 3
|
|
||||||
|
|
||||||
feature = feature.to(device)
|
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
|
||||||
|
|
||||||
if params.causal:
|
|
||||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
|
||||||
pad_len = 30
|
|
||||||
feature_lens += pad_len
|
|
||||||
feature = torch.nn.functional.pad(
|
|
||||||
feature,
|
|
||||||
pad=(0, 0, 0, pad_len),
|
|
||||||
value=LOG_EPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
x, x_lens = model.encoder_embed(feature, feature_lens)
|
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
|
|
||||||
hyps = []
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
|
||||||
hyp_tokens = fast_beam_search_nbest_LG(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append([word_table[i] for i in hyp])
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
|
||||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif (
|
|
||||||
params.decoding_method == "greedy_search"
|
|
||||||
and params.max_sym_per_frame == 1
|
|
||||||
):
|
|
||||||
hyp_tokens = greedy_search_batch(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
|
||||||
hyp_tokens = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
# fmt: off
|
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
|
||||||
# fmt: on
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
hyp = greedy_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "beam_search":
|
|
||||||
hyp = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
return {"greedy_search": hyps}
|
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
|
||||||
key = f"beam_{params.beam}_"
|
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
|
||||||
key += f"max_states_{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
key += f"_num_paths_{params.num_paths}_"
|
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
|
||||||
|
|
||||||
return {key: hyps}
|
|
||||||
else:
|
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
|
||||||
dl: torch.utils.data.DataLoader,
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
num_cuts = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
num_batches = len(dl)
|
|
||||||
except TypeError:
|
|
||||||
num_batches = "?"
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
|
||||||
log_interval = 50
|
|
||||||
else:
|
|
||||||
log_interval = 20
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
|
||||||
texts = batch["supervisions"]["text"]
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
word_table=word_table,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
|
||||||
assert len(hyps) == len(texts)
|
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
|
||||||
ref_words = ref_text.split()
|
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
|
||||||
|
|
||||||
num_cuts += len(texts)
|
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
|
||||||
params: AttributeDict,
|
|
||||||
test_set_name: str,
|
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
|
||||||
):
|
|
||||||
test_set_wers = dict()
|
|
||||||
for key, results in results_dict.items():
|
|
||||||
recog_path = (
|
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
|
||||||
# ref/hyp pairs.
|
|
||||||
errs_filename = (
|
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
|
||||||
)
|
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = (
|
|
||||||
params.res_dir
|
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
assert params.decoding_method in (
|
|
||||||
"greedy_search",
|
|
||||||
"beam_search",
|
|
||||||
"fast_beam_search",
|
|
||||||
"fast_beam_search_nbest",
|
|
||||||
"fast_beam_search_nbest_LG",
|
|
||||||
"fast_beam_search_nbest_oracle",
|
|
||||||
"modified_beam_search",
|
|
||||||
)
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
|
||||||
else:
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
|
|
||||||
if params.causal:
|
|
||||||
assert (
|
|
||||||
"," not in params.chunk_size
|
|
||||||
), "chunk_size should be one value in decoding."
|
|
||||||
assert (
|
|
||||||
"," not in params.left_context_frames
|
|
||||||
), "left_context_frames should be one value in decoding."
|
|
||||||
params.suffix += f"-chunk-{params.chunk_size}"
|
|
||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam}"
|
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
|
||||||
elif "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += (
|
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
params.suffix += f"-context-{params.context_size}"
|
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
|
||||||
|
|
||||||
if params.use_averaged_model:
|
|
||||||
params.suffix += "-use-averaged-model"
|
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
|
||||||
|
|
||||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(
|
|
||||||
params.exp_dir, iteration=-params.iter
|
|
||||||
)[: params.avg]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if i >= 1:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
else:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(
|
|
||||||
params.exp_dir, iteration=-params.iter
|
|
||||||
)[: params.avg + 1]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg + 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
filename_start = filenames[-1]
|
|
||||||
filename_end = filenames[0]
|
|
||||||
logging.info(
|
|
||||||
"Calculating the averaged model over iteration checkpoints"
|
|
||||||
f" from {filename_start} (excluded) to {filename_end}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert params.avg > 0, params.avg
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
word_table = lexicon.word_table
|
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
|
||||||
logging.info(f"Loading {lg_filename}")
|
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
|
||||||
torch.load(lg_filename, map_location=device)
|
|
||||||
)
|
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
|
||||||
else:
|
|
||||||
word_table = None
|
|
||||||
decoding_graph = k2.trivial_graph(
|
|
||||||
params.vocab_size - 1, device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
decoding_graph = None
|
|
||||||
word_table = None
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
|
||||||
args.return_cuts = True
|
|
||||||
librispeech = IWSLTDialectSTDataModule(args)
|
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
word_table=word_table,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(
|
|
||||||
params=params,
|
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Mingshuang Luo,
|
# Mingshuang Luo,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao)
|
||||||
|
# 2023 Johns Hopkins University (authors: Amir Hussein)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -488,6 +489,57 @@ def store_transcripts_and_timestamps(
|
|||||||
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def store_translations(
|
||||||
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]],
|
||||||
|
lowercase: bool = True) -> None:
|
||||||
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the reference translation
|
||||||
|
and the fourth element is the predicted result.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
bleu = BLEU(lowercase=lowercase)
|
||||||
|
hyp_list = []
|
||||||
|
ref_list = []
|
||||||
|
dir_ = os.path.dirname(filename)
|
||||||
|
reftgt = os.path.join(dir_, "reftgt-" + str(os.path.basename(filename)))
|
||||||
|
refsrc = os.path.join(dir_, "refsrc-"+str(os.path.basename(filename)))
|
||||||
|
hyp = os.path.join(dir_, "hyp-"+str( os.path.basename(filename)))
|
||||||
|
bleu_file = os.path.join(dir_, "bleu-"+str( os.path.basename(filename)))
|
||||||
|
with open(filename, "w") as f, open(reftgt, "w") as f_tgt, open(hyp, "w") as f_hyp, open(refsrc, "w") as f_src:
|
||||||
|
for cut_id, ref, ref_tgt, hyp in texts:
|
||||||
|
ref = " ".join(ref)
|
||||||
|
ref_tgt = " ".join(ref_tgt)
|
||||||
|
hyp = " ".join(hyp)
|
||||||
|
print(f"{cut_id}: ref {ref}", file=f)
|
||||||
|
print(f"{cut_id}: ref_tgt {ref_tgt}", file=f)
|
||||||
|
print(f"{cut_id}: hyp {hyp}", file=f)
|
||||||
|
print("\n", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"{ref}", file=f_src)
|
||||||
|
print(f"{ref_tgt}", file=f_tgt)
|
||||||
|
print(f"{hyp}", file=f_hyp)
|
||||||
|
|
||||||
|
hyp_list.append(hyp)
|
||||||
|
ref_list.append(ref_tgt)
|
||||||
|
|
||||||
|
with open(bleu_file, 'w') as b:
|
||||||
|
print(str(bleu.corpus_score(hyp_list, [ref_list])), file=b)
|
||||||
|
print(f"BLEU signiture: {str(bleu.get_signature())}", file=b)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"[{bleu.corpus_score(hyp_list, [ref_list])}] "
|
||||||
|
f"BLEU signiture: {str(bleu.get_signature())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
f: TextIO,
|
f: TextIO,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user