docker + ubiqus + pyonmttok

This commit is contained in:
Valentin Berkes 2022-05-11 09:32:40 +02:00
parent b3e6bf66df
commit 092f69b477
22 changed files with 4952 additions and 14 deletions

11
docker/Makefile Normal file
View File

@ -0,0 +1,11 @@
build_docker:
docker build -t icefall/pytorch1.7.1:latest -f ./Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile ./
run_docker:
docker run -it --rm --runtime=nvidia \
--gpus all \
-v /data1:/data1 \
-v /data1/merge_all_short/raw/fr_token_list/bpe_unigram5000/bpe.pyonmttok.vocab:/data/vocab \
-v /data1/merge_all_short/raw/fr_token_list/bpe_unigram5000/bpe.pyonmttok:/data/bpe.pyonmttok \
-v /nas-labs/ASR/valentin_work/icefall:/workspace/icefall \
--name val_icefall_3 icefall/pytorch1.7.1:latest bash

View File

@ -1,7 +1,10 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel
# install normal source
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-get update && \
apt-get install -y --no-install-recommends \
g++ \
@ -26,13 +29,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/*
RUN mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
mv /opt/conda/lib/libnvrtc.so.11.0 /opt/libnvrtc.so.11.1.bak && \
mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \
mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak
# cmake
RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
@ -72,20 +68,26 @@ RUN git clone https://github.com/csukuangfj/kaldifeat.git /opt/kaldifeat && \
cd -
RUN conda install pytorch torchvision torchaudio=0.11 cudatoolkit=11.3 -c pytorch
#install k2 from source
# RUN conda install -c k2-fsa -c pytorch -c conda-forge k2 cudatoolkit=11.3 pytorch=1.10.0
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd /opt/k2 && \
python3 setup.py install && \
cd -
# RUN pip install k2
# install lhotse
RUN pip install git+https://github.com/lhotse-speech/lhotse
#RUN pip install lhotse
# RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN pip install lhotse
# install icefall
RUN git clone https://github.com/k2-fsa/icefall && \
cd icefall && \
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# RUN git clone https://github.com/k2-fsa/icefall && \
# cd icefall && \
# pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH

View File

@ -0,0 +1,394 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures, AudioSamples
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class UbiqusAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("/data1/merge_all_manifest/raw"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=True,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
# if self.args.enable_musan:
# logging.info("Enable MUSAN")
# logging.info("About to get Musan cuts")
# cuts_musan = load_manifest(
# self.args.manifest_dir / "cuts_musan.json.gz"
# )
# transforms.append(
# CutMix(
# cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
# )
# )
# else:
# logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
input_strategy=AudioSamples(),
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
input_strategy=AudioSamples(),
)
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
rec = load_manifest(
self.args.manifest_dir / "train_sp/recordings.jsonl.gz"
)
sup = load_manifest(
self.args.manifest_dir / "train_sp/supervisions.jsonl.gz"
)
return CutSet.from_manifests(
recordings=rec,
supervisions=sup,
)
return load_manifest(
self.args.manifest_dir / "train_sp/supervisions.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
rec = load_manifest(self.args.manifest_dir / "dev/recordings.jsonl.gz")
sup = load_manifest(
self.args.manifest_dir / "dev/supervisions.jsonl.gz"
)
return CutSet.from_manifests(
recordings=rec,
supervisions=sup,
)
return load_manifest(
self.args.manifest_dir / "dev/supervisions.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/beam_search.py

View File

@ -0,0 +1,549 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_emformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_emformer/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_emformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/decoder.py

View File

@ -0,0 +1,271 @@
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from torchaudio.models import Emformer as _Emformer
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi])
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans
class Emformer(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.
"""
def __init__(
self,
num_features: int,
output_dim: int,
d_model: int,
nhead: int,
dim_feedforward: int,
num_encoder_layers: int,
segment_length: int,
left_context_length: int,
right_context_length: int,
max_memory_size: int = 0,
dropout: float = 0.1,
subsampling_factor: int = 4,
vgg_frontend: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment before subsampling.
left_context_length:
Number of frames in the left context before subsampling.
right_context_length:
Number of frames in the right context before subsampling.
max_memory_size:
TODO.
dropout:
Dropout in encoder.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
vgg_frontend:
True to use vgg style frontend for subsampling.
"""
super().__init__()
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
print(num_features, d_model, output_dim)
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.segment_length = segment_length
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0
assert segment_length % subsampling_factor == 0
assert left_context_length % subsampling_factor == 0
left_context_length = left_context_length // subsampling_factor
right_context_length = right_context_length // subsampling_factor
segment_length = segment_length // subsampling_factor
self.model = _Emformer(
input_dim=d_model,
num_heads=nhead,
ffn_dim=dim_feedforward,
num_layers=num_encoder_layers,
segment_length=segment_length,
dropout=dropout,
activation="relu",
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
weight_init_scale_strategy="depthwise",
tanh_on_mem=False,
negative_inf=-1e8,
)
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
Input features of shape (N, T, C).
x_lens:
A int32 tensor of shape (N,) containing valid frames in `x` before
padding. We have `x.size(1) == x_lens.max()`
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of shape (N, T', C)
- encoder_out_lens, a int32 tensor of shape (N,) containing the
valid frames in `encoder_out` before padding
"""
x = nn.functional.pad(
x,
# (left, right, top, bottom)
# left/right are for the channel dimension, i.e., axis 2
# top/bottom are for the time dimension, i.e., axis 1
(0, 0, 0, self.right_context_length),
value=LOG_EPSILON,
) # (N, T, C) -> (N, T+right_context_length, C)
x = self.encoder_embed(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
x_lens = ((x_lens - 1) // 2 - 1) // 2
emformer_out, emformer_out_lens = self.model(x, x_lens)
logits = self.encoder_output_layer(emformer_out)
return logits, emformer_out_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
):
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 2-D tensor of shap containing the number of valid frames for each
element in `x` before padding.
states:
Internal states of the model.
Returns:
Return a tuple containing 3 tensors:
- encoder_out, a 3-D tensor of shape (N, T, C)
- encoder_out_lens: a 1-D tensor of shape (N,)
- next_state, internal model states for the next chunk
"""
x = self.encoder_embed(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
x_lens = ((x_lens - 1) // 2 - 1) // 2
emformer_out, emformer_out_lens, states = self.model.infer(
x, x_lens, states
)
logits = self.encoder_output_layer(emformer_out)
return logits, emformer_out_lens, states

View File

@ -0,0 +1,289 @@
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from torchaudio.models import Emformer as _Emformer
from torchaudio.models.wav2vec2 import components
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: List[List[torch.Tensor]],
) -> List[List[List[torch.Tensor]]]:
"""Unstack the emformer state corresponding to a batch of utterances
into a list of states, were the i-th entry is the state from the i-th
utterance in the batch.
Args:
states:
A list-of-list of tensors. ``len(states)`` equals to number of
layers in the emformer. ``states[i]]`` contains the states for
the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape
``(T, N, C)`` or a 2-D tensor of shape ``(C, N)``
"""
batch_size = states[0][0].size(1)
num_layers = len(states)
ans = [None] * batch_size
for i in range(batch_size):
ans[i] = [[] for _ in range(num_layers)]
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi])
return ans
def stack_states(
state_list: List[List[List[torch.Tensor]]],
) -> List[List[torch.Tensor]]:
"""Stack list of emformer states that correspond to separate utterances
into a single emformer state so that it can be used as an input for
emformer when those utterances are formed into a batch.
Note:
It is the inverse of :func:`unstack_states`.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the emformer model for a single utterance.
Returns:
Return a new state corresponding to a batch of utterances.
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans
class EmformerRaw(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.
"""
def __init__(
self,
num_features: int,
output_dim: int,
d_model: int,
nhead: int,
dim_feedforward: int,
num_encoder_layers: int,
segment_length: int,
left_context_length: int,
right_context_length: int,
max_memory_size: int = 0,
dropout: float = 0.1,
subsampling_factor: int = 4,
vgg_frontend: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment before subsampling.
left_context_length:
Number of frames in the left context before subsampling.
right_context_length:
Number of frames in the right context before subsampling.
max_memory_size:
TODO.
dropout:
Dropout in encoder.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
vgg_frontend:
True to use vgg style frontend for subsampling.
"""
super().__init__()
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
extractor_conv_layer_config = (
[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
)
extractor_mode = "layer_norm"
extractor_conv_bias = True
self.feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
)
print(num_features, d_model, output_dim)
# if vgg_frontend:
# self.encoder_embed = VggSubsampling(num_features, d_model)
# else:
# self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.segment_length = segment_length
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0
assert segment_length % subsampling_factor == 0
assert left_context_length % subsampling_factor == 0
left_context_length = left_context_length // subsampling_factor
right_context_length = right_context_length // subsampling_factor
segment_length = segment_length // subsampling_factor
print(extractor_conv_layer_config[-1][0])
print(dim_feedforward)
self.model = _Emformer(
input_dim=extractor_conv_layer_config[-1][0],
num_heads=nhead,
ffn_dim=dim_feedforward,
num_layers=num_encoder_layers,
segment_length=segment_length,
dropout=dropout,
activation="relu",
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
weight_init_scale_strategy="depthwise",
tanh_on_mem=False,
negative_inf=-1e8,
)
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
Input features of shape (N, T, C).
x_lens:
A int32 tensor of shape (N,) containing valid frames in `x` before
padding. We have `x.size(1) == x_lens.max()`
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of shape (N, T', C)
- encoder_out_lens, a int32 tensor of shape (N,) containing the
valid frames in `encoder_out` before padding
"""
print(x.shape)
x = nn.functional.pad(
x,
# (left, right, top, bottom)
# left/right are for the channel dimension, i.e., axis 2
# top/bottom are for the time dimension, i.e., axis 1
(0, 0, 0, self.right_context_length),
value=LOG_EPSILON,
) # (N, T, C) -> (N, T+right_context_length, C)
print(x.shape, x_lens)
x, x_lens = self.feature_extractor(x.squeeze(-1), x_lens)
x_lens -= 1
print(x.shape, x_lens)
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# # Caution: We assume the subsampling factor is 4!
# x_lens = ((x_lens - 1) // 2 - 1) // 2
emformer_out, emformer_out_lens = self.model(x, x_lens)
logits = self.encoder_output_layer(emformer_out)
return logits, emformer_out_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[List[torch.Tensor]]] = None,
):
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 2-D tensor of shap containing the number of valid frames for each
element in `x` before padding.
states:
Internal states of the model.
Returns:
Return a tuple containing 3 tensors:
- encoder_out, a 3-D tensor of shape (N, T, C)
- encoder_out_lens: a 1-D tensor of shape (N,)
- next_state, internal model states for the next chunk
"""
x, x_lens = self.feature_extractor(x, x_lens)
x_lens -= 1
# Sure about that ?
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# # Caution: We assume the subsampling factor is 4!
# x_lens = ((x_lens - 1) // 2 - 1) // 2
emformer_out, emformer_out_lens, states = self.model.infer(
x, x_lens, states
)
logits = self.encoder_output_layer(emformer_out)
return logits, emformer_out_lens, states

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_emformer/export.py \
--exp-dir ./transducer_emformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_emformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_emformer/decode.py \
--exp-dir ./transducer_emformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1000 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/model.py

View File

@ -0,0 +1,104 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -0,0 +1,748 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import warnings
from pathlib import Path
from typing import List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from streaming_feature_extractor import FeatureExtractionStream
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_emformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
add_model_arguments(parser)
return parser
class StreamingAudioSamples(object):
"""This class takes as input a list of audio samples and returns
them in a streaming fashion.
"""
def __init__(self, samples: List[torch.Tensor]) -> None:
"""
Args:
samples:
A list of audio samples. Each entry is a 1-D tensor of dtype
torch.float32, containing the audio samples of an utterance.
"""
self.samples = samples
self.cur_indexes = [0] * len(self.samples)
@property
def done(self) -> bool:
"""Return True if all samples have been processed.
Return False otherwise.
"""
for i, samples in zip(self.cur_indexes, self.samples):
if i < samples.numel():
return False
return True
def get_next(self) -> List[torch.Tensor]:
"""Return a list of audio samples. Each entry may have different
lengths. It is OK if an entry contains no samples at all, which
means it reaches the end of the utterance.
"""
ans = []
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
end = start + num[i]
self.cur_indexes[i] = end
s = self.samples[i][start:end]
ans.append(s)
return ans
class StreamList(object):
def __init__(
self,
batch_size: int,
context_size: int,
decoding_method: str,
):
"""
Args:
batch_size:
Size of this batch.
context_size:
Context size of the RNN-T decoder model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
"""
self.streams = [
FeatureExtractionStream(
context_size=context_size, decoding_method=decoding_method
)
for _ in range(batch_size)
]
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
That is, no more audio samples are available for all utterances.
"""
return all(stream.done for stream in self.streams)
def accept_waveform(
self,
audio_samples: List[torch.Tensor],
sampling_rate: float,
):
"""Feed audio samples to each stream.
Args:
audio_samples:
A list of 1-D tensors containing the audio samples for each
utterance in the batch. If an entry is empty, it means
end-of-utterance has been reached.
sampling_rate:
Sampling rate of the given audio samples.
"""
assert len(audio_samples) == len(self.streams)
for stream, samples in zip(self.streams, audio_samples):
if stream.done:
assert samples.numel() == 0
continue
stream.accept_waveform(
sampling_rate=sampling_rate,
waveform=samples,
)
if samples.numel() == 0:
stream.input_finished()
def build_batch(
self,
chunk_length: int,
segment_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[List[FeatureExtractionStream]]]:
"""
Args:
chunk_length:
Number of frames for each chunk. It equals to
``segment_length + right_context_length``.
segment_length
Number of frames for each segment.
Returns:
Return a tuple containing:
- features, a 3-D tensor of shape ``(num_active_streams, T, C)``
- active_streams, a list of active streams. We say a stream is
active when it has enough feature frames to be fed into the
encoder model.
"""
feature_list = []
stream_list = []
for stream in self.streams:
if len(stream.feature_frames) >= chunk_length:
# this_chunk is a list of tensors, each of which
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
stream_list.append(stream)
if len(feature_list) == 0:
return None, None
features = torch.stack(feature_list, dim=0)
return features, stream_list
def greedy_search(
model: nn.Module,
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
):
"""
Args:
model:
The RNN-T model.
streams:
A list of stream objects.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
if streams[0].decoder_out is None:
for stream in streams:
stream.hyp = [blank_id] * context_size
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.stack(
[stream.decoder_out for stream in streams],
dim=0,
)
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).squeeze(1)
for k, stream in enumerate(streams):
result = sp.decode(stream.decoding_result())
logging.info(f"Partial result {k}:\n{result}")
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d
def modified_beam_search(
model: nn.Module,
streams: List[FeatureExtractionStream],
encoder_out: torch.Tensor,
sp: spm.SentencePieceProcessor,
beam: int = 4,
):
"""
Args:
model:
The RNN-T model.
streams:
A list of stream objects.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
sp:
The BPE model.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
batch_size = len(streams)
T = encoder_out.size(1)
for stream in streams:
if len(stream.hyps) == 0:
stream.hyps.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
# decoder_out is of shape (num_hyps, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
streams[i].hyps = B[i]
result = sp.decode(streams[i].decoding_result())
logging.info(f"Partial result {i}:\n{result}")
def process_features(
model: nn.Module,
features: torch.Tensor,
streams: List[FeatureExtractionStream],
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Process features for each stream in parallel.
Args:
model:
The RNN-T model.
features:
A 3-D tensor of shape (N, T, C).
streams:
A list of streams of size (N,).
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
assert features.ndim == 3
assert features.size(0) == len(streams)
batch_size = features.size(0)
device = model.device
features = features.to(device)
feature_lens = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
)
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None:
states = None
else:
state_list = [stream.states for stream in streams]
states = stack_states(state_list)
(encoder_out, encoder_out_lens, states,) = model.encoder.streaming_forward(
features,
feature_lens,
states,
)
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
sp=sp,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
def decode_batch(
batched_samples: List[torch.Tensor],
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> List[str]:
"""
Args:
batched_samples:
A list of 1-D tensors containing the audio samples of each utterance.
model:
The RNN-T model.
params:
It is the return value of :func:`get_params`.
sp:
The BPE model.
"""
# number of frames before subsampling
segment_length = model.encoder.segment_length
right_context_length = model.encoder.right_context_length
# We add 3 here since the subsampling method is using
# ((len - 1) // 2 - 1) // 2)
chunk_length = (segment_length + 3) + right_context_length
batch_size = len(batched_samples)
streaming_audio_samples = StreamingAudioSamples(batched_samples)
stream_list = StreamList(
batch_size=batch_size,
context_size=params.context_size,
decoding_method=params.decoding_method,
)
while not streaming_audio_samples.done:
samples = streaming_audio_samples.get_next()
stream_list.accept_waveform(
audio_samples=samples,
sampling_rate=params.sampling_rate,
)
features, active_streams = stream_list.build_batch(
chunk_length=chunk_length,
segment_length=segment_length,
)
if features is not None:
process_features(
model=model,
features=features,
streams=active_streams,
params=params,
sp=sp,
)
results = []
for stream in stream_list.streams:
text = sp.decode(stream.decoding_result())
results.append(text)
return results
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
# Note: params.decoding_method is currently not used.
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
batch_size = 3
ground_truth = []
batched_samples = []
for num, cut in enumerate(test_clean_cuts):
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
batched_samples.append(samples)
ground_truth.append(cut.supervisions[0].text)
if len(batched_samples) >= batch_size:
decoded_results = decode_batch(
batched_samples=batched_samples,
model=model,
params=params,
sp=sp,
)
s = "\n"
for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)):
s += f"hyp {i}:\n{hyp}\n"
s += f"ref {i}:\n{ref}\n\n"
logging.info(s)
batched_samples = []
ground_truth = []
# break after processing the first batch for test purposes
break
if __name__ == "__main__":
torch.manual_seed(20220410)
main()

View File

@ -0,0 +1,132 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def _create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
class FeatureExtractionStream(object):
def __init__(self, context_size: int, decoding_method: str) -> None:
"""
Args:
context_size:
Context size of the RNN-T decoder model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
"""
self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []
self.num_fetched_frames = 0
# After calling `self.input_finished()`, we set this flag to True
self._done = False
# For the emformer model, it contains the states of each
# encoder layer.
self.states: Optional[List[List[torch.Tensor]]] = None
# It use different attributes for different decoding methods.
self.context_size = context_size
self.decoding_method = decoding_method
if decoding_method == "greedy_search":
self.hyp: Optional[List[int]] = None
self.decoder_out: Optional[torch.Tensor] = None
elif decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
else:
raise ValueError(f"Unsupported decoding method: {decoding_method}")
def accept_waveform(
self,
sampling_rate: float,
waveform: torch.Tensor,
) -> None:
"""Feed audio samples to the feature extractor and compute features
if there are enough samples available.
Caution:
The range of the audio samples should match the one used in the
training. That is, if you use the range [-1, 1] in the training, then
the input audio samples should also be normalized to [-1, 1].
Args
sampling_rate:
The sampling rate of the input audio samples. It is used for sanity
check to ensure that the input sampling rate equals to the one
used in the extractor. If they are not equal, then no resampling
will be performed; instead an error will be thrown.
waveform:
A 1-D torch tensor of dtype torch.float32 containing audio samples.
It should be on CPU.
"""
self.feature_extractor.accept_waveform(
sampling_rate=sampling_rate,
waveform=waveform,
)
self._fetch_frames()
def input_finished(self) -> None:
"""Signal that no more audio samples available and the feature
extractor should flush the buffered samples to compute frames.
"""
self.feature_extractor.input_finished()
self._fetch_frames()
self._done = True
@property
def done(self) -> bool:
"""Return True if `self.input_finished()` has been invoked"""
return self._done
def _fetch_frames(self) -> None:
"""Fetch frames from the feature extractor"""
while self.num_fetched_frames < self.feature_extractor.num_frames_ready:
frame = self.feature_extractor.get_frame(self.num_fetched_frames)
self.feature_frames.append(frame)
self.num_fetched_frames += 1
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.decoding_method == "greedy_search":
return self.hyp[self.context_size :]
else:
assert self.decoding_method == "modified_beam_search"
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/subsampling.py

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_emformer.py
"""
import warnings
import torch
from emformer import Emformer, stack_states, unstack_states
def test_emformer():
N = 3
T = 300
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=20,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=False,
)
x = torch.rand(N, T, C)
x_lens = torch.randint(100, T, (N,))
x_lens[0] = T
y, y_lens = encoder(x, x_lens)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
assert (y_lens == ((x_lens - 1) // 2 - 1) // 2).all()
assert x.size(0) == x.size(0)
assert y.size(1) == max(y_lens)
assert y.size(2) == output_dim
num_param = sum([p.numel() for p in encoder.parameters()])
print(f"Number of encoder parameters: {num_param}")
def test_emformer_streaming_forward():
N = 3
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=20,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=False,
)
x = torch.rand(N, 23, C)
x_lens = torch.full((N,), 23)
y, y_lens, states = encoder.streaming_forward(x=x, x_lens=x_lens)
state_list = unstack_states(states)
states2 = stack_states(state_list)
for ss, ss2 in zip(states, states2):
for s, s2 in zip(ss, ss2):
assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}"
@torch.no_grad()
def main():
# test_emformer()
test_emformer_streaming_forward()
if __name__ == "__main__":
torch.manual_seed(20220329)
main()

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_streaming_feature_extractor.py
"""
import torch
from streaming_feature_extractor import FeatureExtractionStream
def test_streaming_feature_extractor():
stream = FeatureExtractionStream(context_size=2, blank_id=0)
samples = torch.rand(16000)
start = 0
while True:
n = torch.randint(50, 500, (1,)).item()
end = start + n
this_chunk = samples[start:end]
start = end
if len(this_chunk) == 0:
break
stream.accept_waveform(sampling_rate=16000, waveform=this_chunk)
print(len(stream.feature_frames))
stream.input_finished()
print(len(stream.feature_frames))
def main():
test_streaming_feature_extractor()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,57 @@
# sp = spm.SentencePieceProcessor()
# sp.load(params.bpe_model)
# # <blk> is defined in local/train_bpe_model.py
# params.blank_id = sp.piece_to_id("<blk>")
# params.vocab_size = sp.get_piece_size()
# sp.encode(texts, out_type=int)
from typing import List
import pyonmttok
class PyonmttokProcessor:
def __init__(self):
self.tok = None
def load(self, path: str) -> None:
args = {
"mode": "aggressive",
"joiner_annotate": True,
"preserve_placeholders": True,
"case_markup": True,
"soft_case_regions": True,
"preserve_segmented_tokens": True,
}
self.tok = pyonmttok.Tokenizer(
**args,
bpe_model_path="/data/bpe.pyonmttok",
vocabulary_path="/data/vocab"
)
self.vocab = []
self.reverse_vocab = dict()
with open("/data/vocab", "r") as f:
for i, l in enumerate(f):
word = l.rstrip("\n")
self.vocab.append(word)
self.reverse_vocab[word] = i
def piece_to_id(self, token: str) -> int:
return self.reverse_vocab.get(token, self.reverse_vocab["<unk>"])
def encode(self, texts: List[str], out_type: type = int) -> List[int]:
batch_tokens = [self.tok.tokenize(text)[0] for text in texts]
# print(texts)
# print(batch_tokens)
if out_type == str:
return batch_tokens
elif out_type == int:
return [
[self.piece_to_id(token) for token in tokens]
for tokens in batch_tokens
]
raise ValueError
def get_piece_size(self) -> int:
return len(self.vocab)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff