seame zipformer-hat-lid recipe

This commit is contained in:
AmirHussein96 2024-04-05 09:58:02 -04:00
parent c45e9fecfb
commit 77d0f68b62
83 changed files with 17148 additions and 0 deletions

23
egs/seame/ASR/README.md Normal file
View File

@ -0,0 +1,23 @@
# Introduction
This recipe includes ASR models (zipformer, zipformer-hat, zipformer-hat-lid) trained and evaluated on SEAME dataset.
The SEAME corpora is Singaporean Codeswitched English and Mandarin.
This corpus comes defined with a training split and two development splits:
train -- A mix of codeswitched, Mandarin and Singaporean English
dev_sge -- A set of primarily Singaporean English though there is codeswitching
dev_man -- A set of primarily Mandarin though there is also some codeswitching
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Zipformer-hat
Zipformer with hybrid autoregressive transducer (HAT) loss https://arxiv.org/abs/2003.07705
see https://github.com/k2-fsa/icefall/pull/1291
# Zipformer-hat-lid
Zipformer-hat with auxiliary LID encoder and blank sharing for synchronization between ASR and LID as described here (will be shared soon)

168
egs/seame/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,168 @@
## Results
#### Zipformer
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 21.87 | 29.04 | --epoch 25, --avg 5, --max-duration 500 |
The training command:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 25 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-asr-seame \
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--prune-range 10 \
--max-duration 500
```
The decoding command:
```
./zipformer/decode.py \
--epoch 25 \
--avg 5 \
--beam-size 10
--exp-dir ./zipformer/exp-asr-seame \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--decoding-method modified_beam_search
```
The pretrained model is available at: <https://huggingface.co/AmirHussein/zipformer-seame>
### Zipformer-HAT
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 22.00 | 29.92 | --epoch 20, --avg 5, --max-duration 500 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
./zipformer_hat/train.py \
--world-size 4 \
--num-epochs 25 \
--start-epoch 1 \
--base-lr 0.045 \
--lr-epochs 6 \
--use-fp16 1 \
--exp-dir zipformer_hat/exp \
--causal 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--prune-range 10 \
--max-duration 500
```
The decoding command is:
```
## modified beam search
./zipformer_hat/decode.py \
--epoch 25 --avg 5 --use-averaged-model True \
--beam-size 10 \
--causal 0 \
--exp-dir zipformer_hat/exp \
--bpe-model data_seame/lang_bpe_4000/bpe.model \
--max-duration 1000 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--decoding-method modified_beam_search
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/AmirHussein/zipformer-hat-seame>
### Zipformer-HAT-LID
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| modified beam search | 20.04 | 26.91 | --epoch 15, --avg 5, --max-duration 500 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
./zipformer_hat_lid/train.py \
--world-size 4 \
--lid True \
--num-epochs 25 \
--start-epoch 1 \
--base-lr 0.045 \
--use-fp16 1 \
--lid-loss-scale 0.3 \
--exp-dir zipformer_hat_lid/exp \
--causal 0 \
--lid-output-layer 3 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--lids "<en>,<zh>" \
--prune-range 10 \
--freeze-main-model False \
--use-lid-encoder True \
--use-lid-joiner True \
--lid-num-encoder-layers 2,2,2 \
--lid-downsampling-factor 2,4,2 \
--lid-feedforward-dim 256,256,256 \
--lid-num-heads 4,4,4 \
--lid-encoder-dim 256,256,256 \
--lid-encoder-unmasked-dim 128,128,128 \
--lid-cnn-module-kernel 31,15,31 \
--max-duration 500
```
The decoding command is:
```
## modified beam search
python zipformer_hat_lid/decode.py \
--epoch $epoch --avg 5 --use-averaged-model True \
--beam-size 10 \
--lid False \
--lids "<en>,<zh>" \
--exp-dir zipformer_hat_lid/exp \
--bpe-model data_seame/lang_bpe_4000/bpe.model \
--max-duration 800 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--decoding-method modified_beam_search \
--lid-output-layer 3 \
--use-lid-encoder True \
--use-lid-joiner True \
--lid-num-encoder-layers 2,2,2 \
--lid-downsampling-factor 2,4,2 \
--lid-feedforward-dim 256,256,256 \
--lid-num-heads 4,4,4 \
--lid-encoder-dim 256,256,256 \
--lid-encoder-unmasked-dim 128,128,128 \
--lid-cnn-module-kernel 31,15,31
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/AmirHussein/zipformer-hat-lid-seame>

View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
"""
This file cer from icefall decoded "recogs" file:
id [ref] xxx
id [hyp] yxy
"""
import argparse
import jiwer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dec-file",
type=str,
help="Decoded icefall recogs file"
)
return parser
def cer_(file):
hyp = []
ref = []
cer_results = 0
ref_lens = 0
with open(file, 'r', encoding='utf-8') as dec:
for line in dec:
id, target = line.split('\t')
id = id[0:-2]
target, txt = target.split("=")
if target == 'ref':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
ref.append(" ".join(word_list))
elif target == 'hyp':
words = txt.strip().strip('[]').split(', ')
word_list = [word.strip("'") for word in words]
hyp.append(" ".join(word_list))
for h, r in zip(hyp, ref):
if r:
cer_results += (jiwer.cer(r, h)*len(r))
ref_lens += len(r)
print(cer_results / ref_lens)
def main():
parse = get_args()
args = parse.parse_args()
cer_(args.dec_file)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features.
It looks for manifests in the directory data_seame/manifests.
The generated fbank features are saved in data_seame/fbank.
"""
import logging
import os
from pathlib import Path
import argparse
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=5,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
return parser.parse_args()
def compute_fbank_gpu(args):
src_dir = Path("data_seame/manifests")
output_dir = Path("data_seame/fbank")
num_jobs = min(os.cpu_count(),10)
num_mel_bins = 80
sampling_rate = 16000
sr = 16000
logging.info(f"Cpus {num_jobs}")
dataset_parts = (
"valid",
"dev_man",
"train",
"dev_sge",
)
prefix = ""
suffix = "jsonl.gz"
breakpoint
manifests = read_manifests_if_cached(
prefix=prefix, dataset_parts=dataset_parts, output_dir=src_dir,suffix=suffix,
)
assert manifests is not None
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / f"{cuts_filename}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info("About to split cuts into smaller chunks.")
if sr != None:
logging.info(f"Resampling to {sr}")
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .2 and c.duration <= 30)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
manifest_path=f"{src_dir}/{cuts_filename}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
else:
logging.info(f"Processing {partition}")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_gpu(args)

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features for seame with different data size.
It looks for manifests in the directory data_seame/manifests.
The generated fbank features are saved in data_seame/fbank.
"""
import logging
import os
from pathlib import Path
import argparse
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=5,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
return parser.parse_args()
def compute_fbank_gpu(args):
src_dir = Path("data_seame/manifests")
output_dir = Path("data_seame/fbank")
num_jobs = min(os.cpu_count(),10)
num_mel_bins = 80
sampling_rate = 16000
sr = 16000
logging.info(f"Cpus {num_jobs}")
dataset_parts = (
"train10",
"train50",
"train30",
)
prefix = ""
suffix = "jsonl.gz"
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for part in dataset_parts:
cuts_filename = f"cuts_{part}.{suffix}"
cut_set = CutSet.from_file(src_dir / cuts_filename)
logging.info(f"Processing {part}")
logging.info("About to split cuts into smaller chunks.")
if sr != None:
logging.info(f"Resampling to {sr}")
cut_set = cut_set.resample(sr)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
keep_all_channels=False)
cut_set = cut_set.filter(lambda c: c.duration >= .5 and c.duration <= 30)
if "train" in part:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{part}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
else:
logging.info(f"Processing {part}")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{part}",
manifest_path=f"{src_dir}/{cuts_filename}",
batch_duration=2000,
num_workers=num_jobs,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cut_set.to_file(output_dir / f"cuts_{part}.jsonl.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_gpu(args)

View File

@ -0,0 +1,109 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(30, os.cpu_count())
num_mel_bins = 80
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -0,0 +1,97 @@
#!/usr/bin/python
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sup",
type=str,
default="",
help="Supervisions file",
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Recordings file",
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--savecut",
type=str,
default="",
help="name of the cutset to be saved",
)
return parser
def valid_asr(cut):
tol = 2e-3
i=0
total_dur = 0
for c in cut:
if c.supervisions != []:
if c.supervisions[0].end > c.duration + tol:
logging.info(f"Supervision beyond the cut. Cut number: {i}")
total_dur += c.duration
logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}")
elif c.supervisions[0].start < -tol:
logging.info(f"Supervision starts before the cut. Cut number: {i}")
logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}")
else:
continue
else:
logging.info("Empty supervision")
logging.info(f"id: {c.id}")
i += 1
logging.info(f"filtered duration: {total_dur}")
def main():
parser = get_parser()
args = parser.parse_args()
if args.cut != "":
cuts = CutSet.from_file(args.cut)
else:
recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup)
# breakpoint()
logging.info("Example from supervisions:")
logging.info(supervisions[0])
logging.info("Example from recordings")
logging.info("Fixing manifests")
recordings, supervisions = fix_manifests(recordings, supervisions)
logging.info("Validating manifests")
validate_recordings_and_supervisions(recordings, supervisions)
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
cuts.describe()
logging.info("Example from cut:")
logging.info(cuts[100])
logging.info("Validating manifests for ASR")
valid_asr(cuts)
if args.savecut != "":
cuts.to_file(args.savecut)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,414 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
from icefall.utils import str2bool
Lexicon = List[Tuple[str, List[str]]]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
""",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
# the next un-allocated state, will be incremented as we go.
next_state = 3
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(lang_dir / "tokens.txt", token2id)
write_mapping(lang_dir / "words.txt", word2id)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,255 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/bpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from icefall.utils import str2bool
import pdb
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(
model_file: str, words: List[str]
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
Args:
model_file:
Path to a sentencepiece model.
words:
A list of strings representing words.
Returns:
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
word pieces.
- A dict representing the token symbol, mapping from tokens to IDs.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
words_pieces: List[List[str]] = sp.encode(words, out_type=str)
lexicon = []
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
# The OOV word is <UNK>
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = dict()
for i in range(sp.vocab_size()):
token2id[sp.id_to_piece(i)] = i
return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "test/test_bpe_lexicon.py" for usage.
""",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright 2023 Johns Hopkins University (Amir Hussein)
# Apache 2.0
# This script prepares givel a column of words lexicon.
import argparse
def get_args():
parser = argparse.ArgumentParser(
description="""Creates the list of characters and words in lexicon"""
)
parser.add_argument("input", type=str, help="""Input list of words file""")
parser.add_argument("output", type=str, help="""output graphemic lexicon""")
args = parser.parse_args()
return args
def main():
lex = {}
args = get_args()
with open(args.input, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
characters = list(line)
characters = " ".join(
["V" if char == "*" else char for char in characters]
)
lex[line] = characters
with open(args.output, "w", encoding="utf-8") as fp:
for key in sorted(lex):
fp.write(key + " " + lex[key] + "\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
# Copyright 2023 Johns Hopkins University (Amir Hussein)
#!/usr/bin/python
"""
This script prepares transcript_words.txt from cutset
"""
from lhotse import CutSet
import argparse
import logging
import pdb
from pathlib import Path
import os
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--langdir",
type=str,
default="",
help="name of the lang-dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
logging.info("Reading the cuts")
cuts = CutSet.from_file(args.cut)
langdir = Path(args.langdir)
if not os.path.exists(langdir):
os.makedirs(langdir)
with open(langdir / "transcript_words.txt", 'w') as txt:
for c in cuts:
#breakpoint()
text = c.supervisions[0].text
txt.write(text + '\n')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,90 @@
#!/usr/bin/python
# Johns Hopkins University (authors: Amir Hussein)
"""
Sample data given duration in seconds.
"""
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse
import logging
import os
from pathlib import Path
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sup",
type=str,
default="",
help="Supervisions file",
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Recordings file",
)
parser.add_argument(
"--cut",
type=str,
default="",
help="Cutset file",
)
parser.add_argument(
"--outcut",
type=str,
default="",
help="name of the cutset to be saved",
)
parser.add_argument(
"--dur",
type=float,
default=10.0,
help="duration of the cut in seconds",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
if args.cut != "":
logging.info(f"Loading {args.cut}")
cuts = CutSet.from_file(args.cut)
outdir = Path(os.path.dirname(args.cut))
else:
outdir = Path(os.path.dirname(args.sup))
logging.info(f"Loading supervisions")
recordings = RecordingSet.from_file(args.rec)
supervisions = SupervisionSet.from_file(args.sup)
logging.info("Fixing manifests")
cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,)
cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
shuffled = cuts.shuffle()
total_dur = 0
cuts_list = []
for c in shuffled:
if total_dur < args.dur and "_sp" not in c.id:
total_dur += c.duration
cuts_list.append(c.id)
else:
break
cuts = cuts.filter(lambda c: c.id in cuts_list)
cuts.describe()
logging.info(f"Saving {args.outcut}")
cuts.to_file(outdir / args.outcut)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,120 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
from typing import Dict
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
parser.add_argument(
"--predef-symbols",
type=str,
help="User predefined symbols",
default=None,
)
return parser.parse_args()
def generate_tokens(lang_dir: Path):
"""
Generate the tokens.txt from a bpe model.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
for sym, i in token2id.items():
f.write(f"{sym} {i}\n")
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
predef_sym = args.predef_symbols
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
if predef_sym:
syms = predef_sym.split(',')
for i in syms:
user_defined_symbols.append(i)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Johns Hopkins University (authors: Amir Hussein)
"""
Compute WER per language
"""
import sys, codecs, math, pickle, unicodedata, re
from collections import Counter
import argparse
from kaldialign import align
from collections import defaultdict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--rec",
type=str,
default="",
help="Cut ref file",
)
return parser
lids = "en,zh"
lids_dict = {lid:id+1 for id, lid in enumerate(lids.split(","))}
id2lang = {id+1: lid for id, lid in enumerate(lids.split(","))}
bad_id = []
def extract_info(line, info):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(info)
ref_end = rest.find(']', ref_start)
ref = rest[ref_start+len(info):ref_end].replace("'", "").split(', ')
# Extract 'lid'
if 'lid=' in rest:
lid_start = rest.find('lid=[')
lid_end = rest.find(']', lid_start)
lid = rest[lid_start+len('lid=['):lid_end].split(', ')
else:
lid = ['']
if lid[0]=='':
bad_id.append(id_part)
if " ".join(lid):
lid = [int(i) for i in lid] # Convert each element to integer
return id_part.strip(), ref, lid
def is_English(c):
"""check character is in English"""
return ord(c.lower()) >= ord("a") and ord(c.lower()) <= ord("z")
def get_en(text):
res = []
for w in text:
if w:
if is_English(w[0]):
res.append(w)
else:
continue
return res
def get_zh(text):
res = []
for w in text:
if w:
if is_English(w[0]):
continue
else:
res.append(w)
return res
def extract_info_lid(line, tag):
# Split the line at the first colon to separate the ID
id_part, rest = line.split(':', 1)
# Extract 'ref' by finding its start and end
ref_start = rest.find(tag)
ref_end = rest.find(']', ref_start)
ref = rest[ref_start+len(tag):ref_end].replace("'", "").split(', ')
return id_part.strip(), ref
def align_lid2(labels_a, labels_b, a, b):
# Alignment
EPS = '*'
ali = align(a, b, EPS, sclite_mode=True)
a2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(a,labels_a))}
b2idx = {(i,idx):j for idx,(i,j) in enumerate(zip(b,labels_b))}
# Comparing labels of aligned elements
idx_a = 0
idx_b = 0
ali_idx=0
aligned_a = []
aligned_b = []
while idx_a <len(a) and idx_b <len(b) and ali_idx < len(ali):
elem_a, elem_b = ali[ali_idx]
if elem_a == EPS:
idx_b += 1
elif elem_b == EPS:
idx_a += 1
elif elem_a != EPS and elem_b != EPS:
label_a = a2idx[(elem_a,idx_a)]
label_b = b2idx[(elem_b,idx_b)]
aligned_a.append(label_a)
aligned_b.append(label_b)
idx_b += 1
idx_a += 1
ali_idx += 1
return aligned_a, aligned_b
def align_lid(labels_a, labels_b):
# Alignment
res_a, res_b = [], []
EPS = '*'
ali = align(labels_a, labels_b, EPS, sclite_mode=True)
# Comparing labels of aligned elements
for val_a, val_b in ali:
res_a.append(val_a)
res_b.append(val_b)
return res_a, res_b
def read_file(infile, tag):
""""returns list of dict (id, lid, text)"""
res = []
with open(infile, 'r') as file:
for line in file:
_, rest = line.split(':', 1)
if tag in rest:
_id, text = extract_info_lid(line, tag)
res.append((_id, text))
return res
def wer(results, sclite_mode=False):
subs = defaultdict(int)
ins = defaultdict(int)
dels = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp in results:
ali = align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
print(f"%WER = {tot_err_rate}")
return tot_err_rate
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
ref_data = read_file(args.rec, 'ref=[')
ref_data = sorted(ref_data)
hyp_data = read_file(args.rec, 'hyp=[')
hyp_data = sorted(hyp_data)
results = defaultdict(list)
for (ref, hyp) in zip(ref_data, hyp_data):
assert ref[0] == hyp[0], f"ref_id: {ref[0]} != hyp_id: {hyp[0]}"
_, text_ref = ref
_, hyp_text = hyp
if ref:
ref_text_en = get_en(text_ref)
ref_text_zh = get_zh(text_ref)
if hyp:
hyp_text_en = get_en(hyp_text)
hyp_text_zh = get_zh(hyp_text)
results['en'].append((ref[0],ref_text_en, hyp_text_en))
results['zh'].append((ref[0],ref_text_zh, hyp_text_zh))
for key, val in results.items():
print(key)
res = wer(val)

115
egs/seame/ASR/prepare.sh Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env bash
# Copyright 2024 Johns Hopkins University (Amir Hussein)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
set -eou pipefail
nj=20
stage=-1
stop_stage=6
# We assume dl_dir (download dir) contains the following
# directories and files.
#
# - $dl_dir/mgb2
#
# You can download the data from
#
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
dl_dir=download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
4000
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage 0: Download data"
# Note: Download SEAME from https://catalog.ldc.upenn.edu/LDC2015S04
#
# downlaod the splits https://github.com/zengzp0912/SEAME-dev-set.git to $dl_dir
#
# If you have pre-downloaded it to /path/to/seame,
# you can create a symlink
#
# ln -sfv /path/to/seame $dl_dir/seame
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
fbank=data_seame/fbank
manifests=data_seame/manifests
mkdir -p $manifests
if [ $stage -le 0] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare seame manifest"
# We assume that you have downloaded the corpus
# to $dl_dir/seame
lhotse prepare seame $dl_dir/seame $dl_dir/SEAME-dev-set data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank features"
mkdir -p ${fbank}
python local/compute_fbank_gpu_seame.py
gunzip -c $fbank/cuts_train.jsonl.gz | shuf | gzip -c > ${fbank}/cuts_train_shuf.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data_seame/lang_bpe_${vocab_size}
mkdir -p ${lang_dir}
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate text for BPE training from data_seame/fbank/cuts_train_shuf.jsonl.gz"
python local/prepare_transcripts.py --cut ${fbank}/cuts_train_shuf.jsonl.gz --langdir ${lang_dir}
fi
source data_seame/manifests/token.man.1
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt \
--predef-symbols "$bpe_nlsyms"
done
fi

1
egs/seame/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1,423 @@
# Copyright Johns Hopkins University (Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class SeameAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data_seame/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--lid",
type=str2bool,
default=False,
help="Whether to add LID",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=10,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
lid=self.args.lid,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
lid=self.args.lid,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
lid=self.args.lid,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
lid=self.args.lid,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
lid=self.args.lid,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("Train data: About to get training cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("Dev data: About to get develop cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_valid.jsonl.gz"
)
@lru_cache()
def dev_man(self) -> CutSet:
logging.info("About to get dev_man cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_man.jsonl.gz"
)
def dev_sge(self) -> CutSet:
logging.info("About to get dev_sge cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_sge.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/beam_search.py

1054
egs/seame/ASR/zipformer/decode.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_check.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_decode.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: ./zipformer/profile.py
"""
import argparse
import logging
import sentencepiece as spm
import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm
from train import (
get_encoder_embed,
get_encoder_model,
get_joiner_model,
add_model_arguments,
get_params,
)
from zipformer import BypassModule
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_model_arguments(parser)
return parser
def _bias_norm_flops_compute(module, input, output):
assert len(input) == 1, len(input)
# estimate as layer_norm, see icefall/profiler.py
flops = input[0].numel() * 5
module.__flops__ += int(flops)
def _swoosh_module_flops_compute(module, input, output):
# For SwooshL and SwooshR modules
assert len(input) == 1, len(input)
# estimate as swish/silu, see icefall/profiler.py
flops = input[0].numel()
module.__flops__ += int(flops)
def _bypass_module_flops_compute(module, input, output):
# For Bypass module
assert len(input) == 2, len(input)
flops = input[0].numel() * 2
module.__flops__ += int(flops)
MODULE_HOOK_MAPPING = {
BiasNorm: _bias_norm_flops_compute,
BypassModule: _bypass_module_flops_compute,
}
class Model(nn.Module):
"""A Wrapper for encoder, encoder_embed, and encoder_proj"""
def __init__(
self,
encoder: nn.Module,
encoder_embed: nn.Module,
encoder_proj: nn.Module,
) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out)
return logits, encoder_out_lens
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
# We only profile the encoder part
model = Model(
encoder=get_encoder_model(params),
encoder_embed=get_encoder_embed(params),
encoder_proj=get_joiner_model(params).encoder_proj,
)
model.eval()
model.to(device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# for 30-second input
B, T, D = 1, 3000, 80
feature = torch.ones(B, T, D, dtype=torch.float32).to(device)
feature_lens = torch.full((B,), T, dtype=torch.int64).to(device)
flops, params = get_model_profile(
model=model,
args=(feature, feature_lens),
module_hoop_mapping=MODULE_HOOK_MAPPING,
)
logging.info(f"For the encoder part, params: {params}, flops: {flops}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_decode.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_subsampling.py

1388
egs/seame/ASR/zipformer/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -0,0 +1 @@
../zipformer/asr_datamodule.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,933 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer_hat/decode.py \
--epoch 25 --avg 5 --use-averaged-model True \
--beam-size 10 \
--causal 0 \
--exp-dir zipformer_hat_seame/exp \
--bpe-model data_seame/lang_bpe_4000/bpe.model \
--max-duration 1000 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--decoding-method greedy_search
(3) modified beam search
./zipformer_hat/decode.py \
--epoch 25 --avg 5 --use-averaged-model True \
--beam-size 10 \
--causal 0 \
--exp-dir zipformer_hat/exp \
--bpe-model data_seame/lang_bpe_4000/bpe.model \
--max-duration 1000 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,1024,1024,1024,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--decoding-method modified_beam_search
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import SeameAsrDataModule
from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
import string
import re
LOG_EPS = math.log(1e-10)
def remove_punc(text):
"""This function removes all English punctuations except the single quote (verbatim)."""
english_punctuations = string.punctuation + "¿¡"
# # Remove the single quote from the punctuations as it is verbatim
# english_punctuations = english_punctuations.replace("'", "")
# Create a translation table that maps each punctuation to a space.
translator = str.maketrans(english_punctuations, ' ' * len(english_punctuations))
# Translate the text using the translation table
text = text.translate(translator)
return text
def clean(text):
text = remove_punc(text)
text = text.lower()
text = re.sub(r'\s+', ' ', text)
text = text.rstrip()
return text
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data_seame/lang_bpe_4000/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data_seame/lang_bpe_4000",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--clean",
type=bool,
default=True,
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- modified_beam_search_LODR
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding-method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--subtract-ilm",
type=str2bool,
default=False,
help="""Subtract the ILME LM score from the NN LM score.
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--ilm-scale",
type=float,
default=0.1,
help="""The scale of the ILME LM that will be subtracted.""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
hyps = []
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
# context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
subtract_ilm=params.subtract_ilm,
ilm_scale=params.ilm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.05 * i for i in range(4, 10)]
hyp_tokens = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table,
batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
if params.clean:
tmp_hyp = " ".join(hyp_words)
tmp_hyp = clean(tmp_hyp)
ref_text = clean(ref_text)
hyp_words = tmp_hyp.split()
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
SeameAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"modified_beam_search",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
# only load the neural network LM if required
if (
params.use_shallow_fusion
or params.decoding_method in (
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
"modified_beam_search_lm_rescore_LODR",)
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
seame = SeameAsrDataModule(args)
dev_man = seame.dev_man()
dev_sge = seame.dev_sge()
dev_man_dl = seame.test_dataloaders(dev_man)
dev_sge_dl = seame.test_dataloaders(dev_sge)
test_sets = ["dev_man", "dev_sge"]
test_dl = [dev_man_dl, dev_sge_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../zipformer/decoder.py

View File

@ -0,0 +1 @@
../zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../zipformer/export.py

View File

@ -0,0 +1 @@
../zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../zipformer/jit_pretrained.py

View File

@ -0,0 +1 @@
../zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1 @@
../zipformer/joiner.py

View File

@ -0,0 +1,358 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=True,
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/pretrained.py

View File

@ -0,0 +1 @@
../zipformer/pretrained_ctc.py

View File

@ -0,0 +1 @@
../zipformer/profile.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

View File

@ -0,0 +1 @@
../zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../zipformer/subsampling.py

View File

@ -0,0 +1 @@
../zipformer/test_scaling.py

View File

@ -0,0 +1 @@
../zipformer/test_subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/zipformer.py

View File

@ -0,0 +1 @@
../zipformer/asr_datamodule.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/decoder.py

View File

@ -0,0 +1,526 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data_libri/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../zipformer/generate_averaged_model.py

View File

@ -0,0 +1,72 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear
from typing import Optional
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
encoder_lid: Optional[int] = None,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
if encoder_lid:
self.lid_proj = ScaledLinear(encoder_lid, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
lid_out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
elif lid_out is not None:
logit = encoder_out + decoder_out + lid_out
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1,426 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
# 2023 Johns Hopkins University (author: Amir Hussein)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
lid_joiner: Optional[nn.Module] = None,
lid_encoder: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.lid_joiner = lid_joiner
self.lid_encoder = lid_encoder
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, lid_output = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
if self.lid_encoder is not None:
lid_encoder_out, lid_encoder_out_lens = self.lid_encoder(
lid_output, x_lens, src_key_padding_mask
)
assert encoder_out_lens.eq(lid_encoder_out_lens).all()
lid_encoder_out = lid_encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
else:
lid_encoder_out = None
return encoder_out, encoder_out_lens, lid_encoder_out
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
lid_encoder_out: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
y_lid: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
lid_pruned, _ = k2.do_rnnt_pruning(
am=self.joiner.lid_proj(lid_encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
if self.lid_joiner is not None:
lid_am_pruned, lid_lm_pruned = k2.do_rnnt_pruning(
am=self.lid_joiner.encoder_proj(lid_encoder_out),
lm=self.lid_joiner.decoder_proj(decoder_out),
ranges=ranges,
)
lid_logits = self.lid_joiner(
lid_am_pruned, lid_lm_pruned, project_input=False)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False, lid_out=lid_pruned)
# Add blank logits to lid_logits
logits = torch.cat((lid_logits[..., 0].unsqueeze(-1), logits), dim=-1)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=True,
)
# Compute HAT loss for auxiliary lm joiner
if self.lid_joiner is not None:
with torch.cuda.amp.autocast(enabled=False):
pruned_lid_loss = k2.rnnt_loss_pruned(
logits=lid_logits.float(),
symbols=y_lid.pad(mode="constant", padding_value=blank_id).to(torch.int64),
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
use_hat_loss=True,
)
return simple_loss, pruned_loss, pruned_lid_loss
else:
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lid: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert y_lid.num_axes == 2, y_lid.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs
if self.lid_joiner != None:
encoder_out, encoder_out_lens, lid_encoder_out = self.forward_encoder(x, x_lens)
else:
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
if self.lid_joiner != None:
simple_loss, pruned_loss, pruned_loss_lm = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
lid_encoder_out=lid_encoder_out,
y=y.to(x.device),
y_lens=y_lens,
y_lid=y_lid,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
pruned_loss_lm = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, pruned_loss_lm

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

View File

@ -0,0 +1 @@
../zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../zipformer/subsampling.py

View File

@ -0,0 +1 @@
../zipformer/test_scaling.py

View File

@ -0,0 +1 @@
../zipformer/test_subsampling.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff