mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Merge 74200583be92b96726de086023db2a0f45a1d401 into 8eff7a4642da89bace5ea60ea9c5cae7ef5043b0
This commit is contained in:
commit
b056e55ead
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
634
egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
Executable file
634
egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
Executable file
@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/decoder.py
|
2181
egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
Normal file
2181
egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/joiner.py
Symbolic link
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/joiner.py
|
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/model.py
Symbolic link
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/model.py
|
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/optim.py
Symbolic link
1
egs/librispeech/ASR/conv_emformer_transducer_stateless/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/optim.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/scaling.py
|
@ -0,0 +1,585 @@
|
||||
import torch
|
||||
|
||||
|
||||
def test_rel_positional_encoding():
|
||||
from emformer import RelPositionalEncoding
|
||||
|
||||
D = 256
|
||||
pos_enc = RelPositionalEncoding(D, dropout=0.1)
|
||||
pos_len = 100
|
||||
neg_len = 100
|
||||
x = torch.randn(2, D)
|
||||
x, pos_emb = pos_enc(x, pos_len, neg_len)
|
||||
assert pos_emb.shape == (pos_len + neg_len - 1, D)
|
||||
|
||||
|
||||
def test_emformer_attention_forward():
|
||||
from emformer import EmformerAttention
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
attention = EmformerAttention(
|
||||
embed_dim=D,
|
||||
nhead=8,
|
||||
chunk_length=chunk_length,
|
||||
right_context_length=right_context_length,
|
||||
)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
summary = torch.randn(S, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
PE = 2 * U + right_context_length - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
probs_memory,
|
||||
probs_frames,
|
||||
) = attention(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
assert probs_memory.shape == (B, U)
|
||||
assert probs_frames.shape == (B, U)
|
||||
|
||||
|
||||
def test_emformer_attention_infer():
|
||||
from emformer import EmformerAttention
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 1
|
||||
U = chunk_length * num_chunks
|
||||
R = right_context_length * num_chunks
|
||||
L = 3
|
||||
attention = EmformerAttention(
|
||||
embed_dim=D,
|
||||
nhead=8,
|
||||
chunk_length=chunk_length,
|
||||
right_context_length=right_context_length,
|
||||
)
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S, M = 1, 3
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
summary = torch.randn(S, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
left_context_key = torch.randn(L, B, D)
|
||||
left_context_val = torch.randn(L, B, D)
|
||||
PE = (
|
||||
2 * U
|
||||
+ right_context_length
|
||||
- 1
|
||||
+ (M * chunk_length if M > 0 else L)
|
||||
)
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = attention.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
left_context_key,
|
||||
left_context_val,
|
||||
pos_emb,
|
||||
)
|
||||
assert output_right_context_utterance.shape == (R + U, B, D)
|
||||
assert output_memory.shape == (S, B, D)
|
||||
assert next_key.shape == (L + U, B, D)
|
||||
assert next_val.shape == (L + U, B, D)
|
||||
|
||||
|
||||
def test_convolution_module_forward():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
cache = torch.randn(B, D, kernel_size - 1)
|
||||
|
||||
utterance, right_context, new_cache = conv_module(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_convolution_module_infer():
|
||||
from emformer import ConvolutionModule
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
kernel_size = 31
|
||||
conv_module = ConvolutionModule(
|
||||
chunk_length,
|
||||
right_context_length,
|
||||
D,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
right_context = torch.randn(R, B, D)
|
||||
cache = torch.randn(B, D, kernel_size - 1)
|
||||
|
||||
utterance, right_context, new_cache = conv_module.infer(
|
||||
utterance, right_context, cache
|
||||
)
|
||||
assert utterance.shape == (U, B, D)
|
||||
assert right_context.shape == (R, B, D)
|
||||
assert new_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_forward():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
Q, KV = R + U + S, M + R + U
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
attention_mask = torch.rand(Q, KV) >= 0.5
|
||||
PE = 2 * U + right_context_length - 1
|
||||
pos_emb = torch.randn(PE, D)
|
||||
|
||||
output_utterance, output_right_context, output_memory = layer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
pos_emb,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
assert output_memory.shape == (M, B, D)
|
||||
|
||||
|
||||
def test_emformer_encoder_layer_infer():
|
||||
from emformer import EmformerEncoderLayer
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 8
|
||||
right_context_length = 2
|
||||
left_context_length = 8
|
||||
kernel_size = 31
|
||||
num_chunks = 1
|
||||
U = num_chunks * chunk_length
|
||||
R = num_chunks * right_context_length
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
max_memory_size = 3
|
||||
M = 1
|
||||
else:
|
||||
max_memory_size = 0
|
||||
M = 0
|
||||
|
||||
layer = EmformerEncoderLayer(
|
||||
d_model=D,
|
||||
nhead=8,
|
||||
dim_feedforward=1024,
|
||||
chunk_length=chunk_length,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
|
||||
utterance = torch.randn(U, B, D)
|
||||
lengths = torch.randint(1, U + 1, (B,))
|
||||
lengths[0] = U
|
||||
right_context = torch.randn(R, B, D)
|
||||
memory = torch.randn(M, B, D)
|
||||
state = None
|
||||
PE = (
|
||||
2 * U
|
||||
+ right_context_length
|
||||
- 1
|
||||
+ (
|
||||
max_memory_size * chunk_length
|
||||
if max_memory_size > 0
|
||||
else left_context_length
|
||||
)
|
||||
)
|
||||
pos_emb = torch.randn(PE, D)
|
||||
conv_cache = None
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state,
|
||||
conv_cache,
|
||||
) = layer.infer(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
memory,
|
||||
pos_emb,
|
||||
state,
|
||||
conv_cache,
|
||||
)
|
||||
assert output_utterance.shape == (U, B, D)
|
||||
assert output_right_context.shape == (R, B, D)
|
||||
if use_memory:
|
||||
assert output_memory.shape == (1, B, D)
|
||||
else:
|
||||
assert output_memory.shape == (0, B, D)
|
||||
assert len(output_state) == 4
|
||||
assert output_state[0].shape == (max_memory_size, B, D)
|
||||
assert output_state[1].shape == (left_context_length, B, D)
|
||||
assert output_state[2].shape == (left_context_length, B, D)
|
||||
assert output_state[3].shape == (1, B)
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
kernel_size = 31
|
||||
num_encoder_layers = 2
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
S = num_chunks
|
||||
M = S - 1
|
||||
else:
|
||||
S, M = 0, 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=M,
|
||||
)
|
||||
|
||||
x = torch.randn(U + right_context_length, B, D)
|
||||
lengths = torch.randint(1, U + right_context_length + 1, (B,))
|
||||
lengths[0] = U + right_context_length
|
||||
|
||||
output, output_lengths = encoder(x, lengths)
|
||||
assert output.shape == (U, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths, torch.clamp(lengths - right_context_length, min=0)
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_encoder_infer():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
B, D = 2, 256
|
||||
num_encoder_layers = 2
|
||||
chunk_length = 4
|
||||
right_context_length = 2
|
||||
left_context_length = 2
|
||||
num_chunks = 3
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
max_memory_size = 3
|
||||
else:
|
||||
max_memory_size = 0
|
||||
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(chunk_length + right_context_length, B, D)
|
||||
lengths = torch.randint(
|
||||
1, chunk_length + right_context_length + 1, (B,)
|
||||
)
|
||||
lengths[0] = chunk_length + right_context_length
|
||||
output, output_lengths, states, conv_caches = encoder.infer(
|
||||
x, lengths, states, conv_caches
|
||||
)
|
||||
assert output.shape == (chunk_length, B, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(lengths - right_context_length, min=0),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (max_memory_size, B, D)
|
||||
assert state[1].shape == (left_context_length, B, D)
|
||||
assert state[2].shape == (left_context_length, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
(chunk_idx + 1) * chunk_length * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
def test_emformer_encoder_forward_infer_consistency():
|
||||
from emformer import EmformerEncoder
|
||||
|
||||
chunk_length = 4
|
||||
num_chunks = 3
|
||||
U = chunk_length * num_chunks
|
||||
left_context_length, right_context_length = 1, 2
|
||||
D = 256
|
||||
num_encoder_layers = 3
|
||||
kernel_size = 31
|
||||
memory_sizes = [0, 3]
|
||||
|
||||
for max_memory_size in memory_sizes:
|
||||
encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length,
|
||||
d_model=D,
|
||||
dim_feedforward=1024,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
encoder.eval()
|
||||
|
||||
x = torch.randn(U + right_context_length, 1, D)
|
||||
lengths = torch.tensor([U + right_context_length])
|
||||
|
||||
# training mode with full utterance
|
||||
forward_output, forward_output_lengths = encoder(x, lengths)
|
||||
|
||||
# streaming inference mode with individual chunks
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
start_idx = chunk_idx * chunk_length
|
||||
end_idx = start_idx + chunk_length
|
||||
chunk = x[start_idx : end_idx + right_context_length] # noqa
|
||||
chunk_length = torch.tensor([chunk_length])
|
||||
(
|
||||
infer_output_chunk,
|
||||
infer_output_lengths,
|
||||
states,
|
||||
conv_caches,
|
||||
) = encoder.infer(chunk, chunk_length, states, conv_caches)
|
||||
forward_output_chunk = forward_output[start_idx:end_idx]
|
||||
assert torch.allclose(
|
||||
infer_output_chunk,
|
||||
forward_output_chunk,
|
||||
atol=1e-4,
|
||||
rtol=0.0,
|
||||
), (
|
||||
infer_output_chunk - forward_output_chunk
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_forward():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 16
|
||||
right_context_length = 8
|
||||
left_context_length = 8
|
||||
num_chunks = 3
|
||||
U = num_chunks * chunk_length
|
||||
B, D = 2, 256
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
max_memory_size = 3
|
||||
else:
|
||||
max_memory_size = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths = model(x, x_lens)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_emformer_infer():
|
||||
from emformer import Emformer
|
||||
|
||||
num_features = 80
|
||||
chunk_length = 8
|
||||
U = chunk_length
|
||||
left_context_length, right_context_length = 32, 4
|
||||
B, D = 2, 256
|
||||
num_chunks = 3
|
||||
num_encoder_layers = 2
|
||||
kernel_size = 31
|
||||
|
||||
for use_memory in [True, False]:
|
||||
if use_memory:
|
||||
max_memory_size = 32
|
||||
else:
|
||||
max_memory_size = 0
|
||||
model = Emformer(
|
||||
num_features=num_features,
|
||||
chunk_length=chunk_length,
|
||||
subsampling_factor=4,
|
||||
d_model=D,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
cnn_module_kernel=kernel_size,
|
||||
left_context_length=left_context_length,
|
||||
right_context_length=right_context_length,
|
||||
max_memory_size=max_memory_size,
|
||||
)
|
||||
states = None
|
||||
conv_caches = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
x = torch.randn(B, U + right_context_length + 3, num_features)
|
||||
x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,))
|
||||
x_lens[0] = U + right_context_length + 3
|
||||
output, output_lengths, states, conv_caches = model.infer(
|
||||
x, x_lens, states, conv_caches
|
||||
)
|
||||
assert output.shape == (B, U // 4, D)
|
||||
assert torch.equal(
|
||||
output_lengths,
|
||||
torch.clamp(
|
||||
((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4,
|
||||
min=0,
|
||||
),
|
||||
)
|
||||
assert len(states) == num_encoder_layers
|
||||
for state in states:
|
||||
assert len(state) == 4
|
||||
assert state[0].shape == (max_memory_size, B, D)
|
||||
assert state[1].shape == (left_context_length // 4, B, D)
|
||||
assert state[2].shape == (left_context_length // 4, B, D)
|
||||
assert torch.equal(
|
||||
state[3],
|
||||
U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]),
|
||||
)
|
||||
for conv_cache in conv_caches:
|
||||
assert conv_cache.shape == (B, D, kernel_size - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rel_positional_encoding()
|
||||
test_emformer_attention_forward()
|
||||
test_emformer_attention_infer()
|
||||
test_convolution_module_forward()
|
||||
test_convolution_module_infer()
|
||||
test_emformer_encoder_layer_forward()
|
||||
test_emformer_encoder_layer_infer()
|
||||
test_emformer_encoder_forward()
|
||||
test_emformer_encoder_infer()
|
||||
test_emformer_encoder_forward_infer_consistency()
|
||||
test_emformer_forward()
|
||||
test_emformer_infer()
|
1121
egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
Executable file
1121
egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user