Implement recipe for Fluent Speech Commands dataset

Signed-off-by: Xinyuan Li <xli257@c13.clsp.jhu.edu>
This commit is contained in:
Xinyuan Li 2024-01-19 13:37:00 -05:00
parent bbb03f7962
commit d305c7cceb
35 changed files with 6697 additions and 19 deletions

View File

@ -0,0 +1,9 @@
## Fluent Speech Commands recipe
This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances.
Dataset Paper link: <https://paperswithcode.com/dataset/fluent-speech-commands>
cd icefall/egs/fluent_speech_commands/
Training: python transducer/train.py
Decoding: python transducer/decode.py

View File

@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
logging.info("Loading G.fst.txt")
with open(lang_dir / "G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,95 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the Fluent Speech Commands dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os, argparse
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a
# lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_slu(manifest_dir, fbanks_dir):
src_dir = Path(manifest_dir)
output_dir = Path(fbanks_dir)
# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23
dataset_parts = (
"train",
"valid",
"test",
)
prefix = "slu"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(cuts_file)
parser = argparse.ArgumentParser()
parser.add_argument('manifest_dir')
parser.add_argument('fbanks_dir')
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = parser.parse_args()
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_slu(args.manifest_dir, args.fbanks_dir)

View File

@ -0,0 +1,47 @@
import pandas, argparse
from tqdm import tqdm
def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
vocab_transcript = set()
vocab_frames = set()
transcripts = data['transcription'].tolist()
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
for transcript in tqdm(transcripts):
for word in transcript.split():
vocab_transcript.add(word)
for frame in tqdm(frames):
for word in frame:
vocab_frames.add('_'.join(word.split()))
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + '\n')
lexicon_transcript_file.write("<s> 2" + '\n')
lexicon_transcript_file.write("</s> 0" + '\n')
id = 3
for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
id += 1
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + '\n')
lexicon_frames_file.write("<s> 2" + '\n')
lexicon_frames_file.write("</s> 0" + '\n')
id = 3
for vocab in vocab_frames:
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
id += 1
parser = argparse.ArgumentParser()
parser.add_argument('corpus_dir')
parser.add_argument('lm_dir')
def main():
args = parser.parse_args()
generate_lexicon(args.corpus_dir, args.lm_dir)
main()

View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import argparse
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "!SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
# assert token2id["<eps>"] == 0
# assert word2id["<eps>"] == 0
eps = 0
sil_token = word2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [word2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = word2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
parser = argparse.ArgumentParser()
parser.add_argument('lm_dir')
def main():
args = parser.parse_args()
out_dir = Path(args.lm_dir)
lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
names = ["frames", "transcript"]
sil_token = "!SIL"
sil_prob = 0.5
for name, lexicon_filename in zip(names, lexicon_filenames):
lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon)
words = get_words(lexicon)
new_lexicon = []
for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
tokens = ["<eps>"] + tokens
words = ['eps'] + words + ["#0", "!SIL"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
main()

View File

@ -0,0 +1,112 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=1
stop_stage=5
data_dir=path/to/fluent/speech/commands
# data_dir=$1
target_root_dir=data/
# target_root_dir=$2
# data_dir=/home/xli257/slu/fluent_speech_commands_dataset
# lang_dir=data/lang_phone
# lm_dir=data/lm
# manifest_dir=data/manifests
# fbanks_dir=data/fbanks
lang_dir=${target_root_dir}/lang_phone
lm_dir=${target_root_dir}/lm
manifest_dir=${target_root_dir}/manifests
fbanks_dir=${target_root_dir}/fbanks
. shared/parse_options.sh || exit 1
mkdir -p $lang_dir
mkdir -p $lm_dir
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "data_dir: $data_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare slu manifest"
mkdir -p $manifest_dir
lhotse prepare slu $data_dir $manifest_dir
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute fbank for yesno"
mkdir -p $fbanks_dir
python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare lang"
# NOTE: "<UNK> SIL" is added for implementation convenience
# as the graph compiler code requires that there is a OOV word
# in the lexicon.
python ./local/generate_lexicon.py $data_dir $lm_dir
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Train LM"
# We use a unigram G
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_transcript.txt \
-lm $lm_dir/G_transcript.arpa
./shared/make_kn_lm.py \
-ngram-order 1 \
-text $lm_dir/words_frames.txt \
-lm $lm_dir/G_frames.arpa
python ./local/prepare_lang.py $lm_dir
if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_transcript.txt" \
$lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
fi
if [ ! -f $lm_dir/G_frames.fst.txt ]; then
python -m kaldilm \
--read-symbol-table="$lm_dir/words_frames.txt" \
$lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
fi
mkdir -p $lm_dir/frames
mkdir -p $lm_dir/transcript
chmod -R +777 .
for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
do
j=${i//"_frames"/}
mv "$lm_dir/$i" $lm_dir/frames/$j
done
for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
do
j=${i//"_transcript"/}
mv "$lm_dir/$i" $lm_dir/transcript/$j
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compile HLG"
./local/compile_hlg.py --lang-dir $lm_dir/frames
./local/compile_hlg.py --lang-dir $lm_dir/transcript
fi

View File

@ -0,0 +1 @@
../../icefall/shared/

View File

@ -0,0 +1,292 @@
# Copyright 2021 Piotr Żelasko
# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class SluDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbanks"),
help="Path to directory with train/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=30.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=10,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to create train dataset")
transforms = []
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
FbankConfig(sampling_rate=8000, num_mel_bins=23)
),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get valid cuts")
cuts_valid = self.valid_cuts()
logging.debug("About to create valid dataset")
valid = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create valid dataloader")
valid_dl = DataLoader(
valid,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return valid_dl
def test_dataloaders(self) -> DataLoader:
logging.info("About to get test cuts")
cuts_test = self.test_cuts()
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=True,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_train.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> List[CutSet]:
logging.info("About to get valid cuts")
cuts_valid = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_valid.jsonl.gz"
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
cuts_test = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_test.jsonl.gz"
)
return cuts_test

View File

@ -0,0 +1,315 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import SluDataModule
from model import Tdnn
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=13,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn/exp/"),
"lang_dir": Path("data/lm/frames"),
"feature_dim": 23,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
word_table: k2.SymbolTable,
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
word_table:
It is the word symbol table.
Returns:
Return the decoding result. `len(ans)` == batch size.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
nnet_output = model(feature)
# nnet_output is (N, T, C)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
word_table: k2.SymbolTable,
) -> List[Tuple[str, List[str], List[str]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
word_table:
It is word symbol table.
Returns:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = []
for batch_idx, batch in enumerate(dl):
# texts = batch["supervisions"]["custom"]["frames"]
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
word_table=word_table,
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
exp_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
@torch.no_grad()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
return
model.to(device)
model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
slu = SluDataModule(args)
test_dl = slu.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
word_table=lexicon.word_table,
)
save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
This file is for exporting trained models to a checkpoint
or to a torchscript model.
(1) Generate the checkpoint tdnn/exp/pretrained.pt
./tdnn/export.py \
--epoch 14 \
--avg 2
See ./tdnn/pretrained.py for how to use the generated file.
(2) Generate torchscript model tdnn/exp/cpu_jit.pt
./tdnn/export.py \
--epoch 14 \
--avg 2 \
--jit 1
See ./tdnn/jit_pretrained.py for how to use the generated file.
"""
import argparse
import logging
import torch
from model import Tdnn
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
"""
This file is for exporting trained models to onnx.
Usage:
./tdnn/export_onnx.py \
--epoch 14 \
--avg 2
The above command generates the following two files:
- ./exp/model-epoch-14-avg-2.onnx
- ./exp/model-epoch-14-avg-2.int8.onnx
See ./tdnn/onnx_pretrained.py for how to use them.
"""
import argparse
import logging
from typing import Dict
import onnx
import torch
from model import Tdnn
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
N = 1
T = 100
C = params.feature_dim
x = torch.rand(N, T, C)
opset_version = 13
onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx"
torch.onnx.export(
model,
x,
onnx_filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["log_prob"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"log_prob": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {onnx_filename}")
meta_data = {
"model_type": "tdnn",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming tdnn for the yesno recipe",
"vocab_size": max_token_id + 1,
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=onnx_filename, meta_data=meta_data)
logging.info("Generate int8 quantization models")
onnx_filename_int8 = (
f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx"
)
quantize_dynamic(
model_input=onnx_filename,
model_output=onnx_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
logging.info(f"Saved to {onnx_filename_int8}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
This file shows how to use a torchscript model for decoding.
Usage:
./tdnn/jit_pretrained.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
from typing import List
import math
import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"num_classes": 4, # [<blk>, N, SIL, Y]
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
import torch
import torch.nn as nn
class Tdnn(nn.Module):
def __init__(self, num_features: int, num_classes: int):
"""
Args:
num_features:
Model input dimension.
num_classes:
Model output dimension
"""
super().__init__()
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape [N, T, C]
Returns:
The output tensor has shape [N, T, C]
"""
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.tdnn(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = self.output_linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x
def test_tdnn():
num_features = 23
num_classes = 4
model = Tdnn(num_features=num_features, num_classes=num_classes)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
N = 2
T = 100
C = num_features
x = torch.randn(N, T, C)
y = model(x)
print(x.shape)
print(y.shape)
if __name__ == "__main__":
test_tdnn()

View File

@ -0,0 +1,242 @@
#!/usr/bin/env python3
"""
This file shows how to use an ONNX model for decoding with onnxruntime.
Usage:
(1) Use a not quantized ONNX model, i.e., a float32 model
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
(2) Use a quantized ONNX model, i.e., an int8 model
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx,
and ./tdnn/exp/model-epoch-14-avg-2.onnx,
you can use ./export_onnx.py --epoch 14 --avg 2
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
class OnnxModel:
def __init__(self, nn_model: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
)
meta = self.model.get_modelmeta().custom_metadata_map
self.vocab_size = int(meta["vocab_size"])
def run(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor log_prob of shape (N, T, C)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)
return torch.from_numpy(out[0])
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading onnx model {params.nn_model}")
model = OnnxModel(params.nn_model)
logging.info(f"Loading HLG from {args.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model.run(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,221 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file shows how to use a checkpoint for decoding.
Usage:
./tdnn/pretrained.py \
--checkpoint ./tdnn/exp/pretrained.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/pretrained.pt,
you can use ./export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import Tdnn
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint(). "
"You can use ./tdnn/export.py to obtain it.",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"num_classes": 4, # [<blk>, N, SIL, Y]
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Tdnn(
num_features=params.feature_dim,
num_classes=params.num_classes,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,581 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import SluDataModule
from lhotse.utils import fix_random_seed
from model import Tdnn
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=100,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=14,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
"""
params = AttributeDict(
{
"exp_dir": Path("tdnn/exp"),
"lang_dir": Path("data/lm/frames"),
"lr": 1e-3,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"num_epochs": 5,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 100,
"reset_interval": 20,
"valid_interval": 300,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Tdnn in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
# texts = supervisions["custom"]["frames"]
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in supervisions["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
)
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"device: {device}")
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.SGD(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
slu = SluDataModule(args)
train_dl = slu.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = slu.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=None,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../tdnn/asr_datamodule.py

View File

@ -0,0 +1,69 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from transducer.model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
# id2word = {1: "YES", 2: "NO"}
hyp = [id2word[i] for i in hyp]
return hyp

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,349 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from transducer.asr_datamodule import SluDataModule
from transducer.beam_search import greedy_search
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.conformer import Conformer
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_id2word(params):
id2word = {}
# 0 is blank
id = 1
try:
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
id2word[id] = line.split()[0]
id += 1
except:
pass
return id2word
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=6,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory from which to load the checkpoints",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lm/frames"
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"lang_dir": Path("data/lm/frames"),
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
vocab_size = 1
with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
id2word: dict
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
Returns:
Return the decoding result. `len(ans)` == batch size.
"""
device = model.device
feature = batch["inputs"]
feature = feature.to(device)
# at entry, feature is (N, T, C)
feature_lens = batch["supervisions"]["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word)
hyps.append(hyp)
# hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> List[Tuple[List[int], List[int]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
id2word = get_id2word(params)
results = []
for batch_idx, batch in enumerate(dl):
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
batch=batch,
id2word=id2word
)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
exp_dir: Path,
test_set_name: str,
results: List[Tuple[List[int], List[int]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
It contains the reference and hypothesis results, like below::
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
def get_transducer_model(params: AttributeDict):
# encoder = Tdnn(
# num_features=params.feature_dim,
# output_dim=params.hidden_dim,
# )
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
@torch.no_grad()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
model.device = device
# we need cut ids to display recognition results.
args.return_cuts = True
slu = SluDataModule(args)
test_dl = slu.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
)
test_set_name=str(args.feature_dir).split('/')[-2]
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,92 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
num_layers: int,
hidden_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
num_layers:
Number of RNN layers.
hidden_dim:
Hidden dimension of RNN layers.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for RNN layers.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.embedding_dropout = nn.Dropout(embedding_dropout)
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.output_linear = nn.Linear(hidden_dim, hidden_dim)
def forward(
self,
y: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
y:
A 2-D tensor of shape (N, U).
states:
A tuple of two tensors containing the states information of
RNN layers in this decoder.
Returns:
Return a tuple containing:
- rnn_output, a tensor of shape (N, U, C)
- (h, c), which contain the state information for RNN layers.
Both are of shape (num_layers, N, C)
"""
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -0,0 +1,87 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
# We use a TDNN model as encoder, as it works very well with CTC training
# for this tiny dataset.
class Tdnn(nn.Module):
def __init__(self, num_features: int, output_dim: int):
"""
Args:
num_features:
Model input dimension.
ouput_dim:
Model output dimension
"""
super().__init__()
# Note: We don't use paddings inside conv layers
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=output_dim)
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape (N, T, C)
x_lens:
It contains the number of frames in each utterance in x
before padding.
Returns:
Return a tuple with 2 tensors:
- logits, a tensor of shape (N, T, C)
- logit_lens, a tensor of shape (N,)
"""
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.tdnn(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
logits = self.output_linear(x)
# the first conv layer reduces T by 3-1 frames
# the second layer reduces T by (5-1)*2 frames
# the second layer reduces T by (5-1)*4 frames
# Number of output frames is 2 + 4*2 + 4*4 = 2 + 8 + 16 = 26
x_lens = x_lens - 26
return logits, x_lens

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,55 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
return output

View File

@ -0,0 +1,120 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out, _ = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
y_padded = y.pad(mode="constant", padding_value=0)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="mean",
)
return loss

View File

@ -0,0 +1,153 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_decoder.py
"""
import torch
from transducer.decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
num_layers = 2
hidden_dim = 6
N = 3
U = 5
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
num_layers=num_layers,
hidden_dim=hidden_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
x = torch.randint(1, vocab_size, (N, U))
rnn_out, (h, c) = decoder(x)
assert rnn_out.shape == (N, U, hidden_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
rnn_out, (h, c) = decoder(x, (h, c))
assert rnn_out.shape == (N, U, hidden_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_encoder.py
"""
import torch
from transducer.encoder import Tdnn
def test_encoder():
input_dim = 10
output_dim = 20
encoder = Tdnn(input_dim, output_dim)
N = 10
T = 85
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
logits, logit_lens = encoder(x, x_lens)
assert logits.shape == (N, T - 26, output_dim)
assert torch.all(torch.eq(x_lens - 26, logit_lens))
def main():
test_encoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_joiner.py
"""
import torch
from transducer.joiner import Joiner
def test_joiner():
N = 2
T = 3
C = 4
U = 5
joiner = Joiner(C, 10)
encoder_out = torch.rand(N, T, C)
decoder_out = torch.rand(N, U, C)
joint = joiner(encoder_out, decoder_out)
assert joint.shape == (N, T, U, 10)
def main():
test_joiner()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_transducer.py
"""
import k2
import torch
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.joiner import Joiner
from transducer.model import Transducer
def test_transducer():
# encoder params
input_dim = 10
output_dim = 20
# decoder params
vocab_size = 3
blank_id = 0
embedding_dim = 128
num_layers = 2
encoder = Tdnn(input_dim, output_dim)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
num_layers=num_layers,
hidden_dim=output_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
joiner = Joiner(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0
T = 50
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
x_lens[0] = T
loss = transducer(x, x_lens, y)
print(loss)
def main():
test_transducer()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,633 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import List, Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import SluDataModule
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
# from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.conformer import Conformer
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_word2id(params):
word2id = {}
# 0 is blank
id = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
word2id[line.split()[0]] = id
id += 1
return word2id
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts.
Returns:
Return a ragged tensor containing the corresponding word ID.
"""
# blank is 0
word_ids = []
for t in texts:
words = t.split()
ids = [word2id[w] for w in words]
word_ids.append(ids)
return k2.RaggedTensor(word_ids)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=7,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory to save results",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lm/frames"
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
"""
params = AttributeDict(
{
"lr": 1e-4,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 100,
"reset_interval": 20,
"valid_interval": 3000,
"exp_dir": Path("transducer/exp"),
"lang_dir": Path("data/lm/frames"),
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
vocab_size = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
is_training: bool,
word2ids
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute RNN-T loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Tdnn in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
feature_lens = batch["supervisions"]["num_frames"].to(device)
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
labels = get_labels(texts, word2ids).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=labels)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = feature.size(0)
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
word2ids,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=False,
word2ids=word2ids
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
word2ids,
tb_writer: None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=True,
word2ids=word2ids
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
word2ids=word2ids
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def get_transducer_model(params: AttributeDict):
# encoder = Tdnn(
# num_features=params.feature_dim,
# output_dim=params.hidden_dim,
# )
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
word2ids = get_word2id(params)
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
# if args.tensorboard and rank == 0:
# tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
# else:
# tb_writer = None
tb_writer = None
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"device: {device}")
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
slu = SluDataModule(args)
train_dl = slu.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = slu.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
word2ids=word2ids
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=None,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
SluDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,416 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transducer.encoder_interface import EncoderInterface
from transducer.subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
class Transformer(EncoderInterface):
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
dropout:
Dropout in encoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
"""
super().__init__()
self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
# TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is (N, T, C)
Returns:
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -33,7 +33,7 @@ parser.add_argument(
"-ngram-order",
type=int,
default=4,
choices=[2, 3, 4, 5, 6, 7],
choices=[1, 2, 3, 4, 5, 6, 7],
help="Order of n-gram",
)
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
@ -105,7 +105,7 @@ class NgramCounts:
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
def __init__(self, ngram_order, bos_symbol="<s>", eos_symbol="</s>"):
assert ngram_order >= 2
assert ngram_order >= 1
self.ngram_order = ngram_order
self.bos_symbol = bos_symbol
@ -169,7 +169,7 @@ class NgramCounts:
with open(filename, encoding=default_encoding) as fp:
for line in fp:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
self.add_raw_counts_from_line(line.split()[0])
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print(

View File

@ -38,7 +38,7 @@ import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import average_checkpoints
@ -1125,22 +1125,22 @@ class MetricsTracker(collections.defaultdict):
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
# def write_summary(
# self,
# tb_writer: SummaryWriter,
# prefix: str,
# batch_idx: int,
# ) -> None:
# """Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# Args:
# tb_writer: a TensorBoard writer
# prefix: a prefix for the name of the loss, e.g. "train/valid_",
# or "train/current_"
# batch_idx: The current batch index, used as the x-axis of the plot.
# """
# for k, v in self.norm_items():
# tb_writer.add_scalar(prefix + k, v, batch_idx)
def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor: