Add subformer into zipformer

This commit is contained in:
pkufool 2023-08-02 14:43:23 +08:00
parent af8907e1ec
commit a25af9d61d
16 changed files with 4866 additions and 0 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,834 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = int(params.chunk_size)
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../zipformer/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../zipformer/joiner.py

View File

@ -0,0 +1,484 @@
#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corp. (author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from subformer import (
BypassModule,
CompactRelPositionalEncoding,
LearnedDownsamplingModule,
SubformerEncoder,
SubformerEncoderLayer,
)
from zipformer import (
DownsampledZipformer2Encoder,
SimpleDownsample,
SimpleUpsample,
Zipformer2Encoder,
Zipformer2EncoderLayer,
)
from torch import Tensor, nn
class Mixformer(EncoderInterface):
def __init__(
self,
structure: str = "ZZS(S(S)S)SZ",
output_downsampling_factor: int = 2,
downsampling_factor: Tuple[int] = (1, 1, 2, 2, 1),
encoder_dim: Union[int, Tuple[int]] = (
192,
192,
256,
384,
512,
384,
256,
192,
),
num_encoder_layers: Union[int, Tuple[int]] = (
2,
2,
2,
2,
2,
2,
2,
2,
),
encoder_unmasked_dim: Union[int, Tuple[int]] = (192, 192, 192),
query_head_dim: Union[int, Tuple[int]] = (32,),
value_head_dim: Union[int, Tuple[int]] = (12,),
pos_head_dim: Union[int, Tuple[int]] = (4,),
pos_dim: int = (48,),
num_heads: Union[int, Tuple[int]] = (4,),
feedforward_dim: Union[int, Tuple[int]] = (
512,
768,
1024,
1536,
2048,
1536,
1024,
768,
),
cnn_module_kernel: Union[int, Tuple[int]] = (15, 31, 31),
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128, 1024),),
memory_dim: int = -1,
dropout: Optional[FloatLike] = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
) -> None:
super(Mixformer, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
num_zip_encoders = len([s for s in structure if s == 'Z'])
num_sub_encoders = len([s for s in structure if s == 'S'])
num_encoders = num_zip_encoders + num_sub_encoders
num_downsamplers = len([s for s in structure if s == '('])
def _to_tuple(x, length):
"""Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
assert isinstance(x, tuple)
if len(x) == 1:
x = x * length
else:
assert len(x) == length and isinstance(
x[0], int
)
return x
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = (
downsampling_factor
) = _to_tuple(downsampling_factor, num_zip_encoders + num_downsamplers) # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim, num_encoders) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
encoder_unmasked_dim, num_zip_encoders
) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers, num_encoders)
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim, num_encoders)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim, num_encoders)
pos_head_dim = _to_tuple(pos_head_dim, num_encoders)
pos_dim = _to_tuple(pos_dim, num_encoders)
self.num_heads = num_heads = _to_tuple(num_heads, num_encoders)
feedforward_dim = _to_tuple(feedforward_dim, num_encoders)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(
cnn_module_kernel, num_zip_encoders
)
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes, num_sub_encoders)
self.causal = causal
# for u, d in zip(encoder_unmasked_dim, encoder_dim):
# assert u <= d
# each one will be Zipformer2Encoder, DownsampledZipformer2Encoder,
# SubformerEncoder or DownsampledSubformerEncoder
zip_encoders = []
sub_encoders = []
downsamplers = []
bypasses = []
layer_indexes = []
cur_max_dim = 0
downsampling_factors_list = []
def cur_downsampling_factor():
c = 1
for d in downsampling_factors_list: c *= d
return c
zip_encoder_dim = []
zip_downsampling_factor = []
for s in structure:
if s == "Z":
i = len(zip_encoders) + len(sub_encoders)
j = len(zip_encoders)
k = len(downsamplers) + len(zip_encoders)
assert encoder_unmasked_dim[j] <= encoder_dim[i]
zip_encoder_dim.append(encoder_dim[i])
encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim[i],
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[j],
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = Zipformer2Encoder(
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim[i],
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[k] ** 0.5),
)
if downsampling_factor[k] != 1:
encoder = DownsampledZipformer2Encoder(
encoder,
dim=encoder_dim[i],
downsample=downsampling_factor[k],
dropout=dropout,
)
zip_downsampling_factor.append(downsampling_factor[k])
layer_indexes.append(len(zip_encoders))
zip_encoders.append(encoder)
elif s == 'S':
i = len(zip_encoders) + len(sub_encoders)
j = len(sub_encoders)
if len(sub_encoders) == 0:
cur_max_dim = encoder_dim[i]
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_head_dim[i],
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
dropout=dropout,
causal=causal,
)
cur_max_dim = max(cur_max_dim, encoder_dim[i])
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
embed_dim=cur_max_dim,
dropout=dropout,
chunk_sizes=encoder_chunk_sizes[j],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
)
layer_indexes.append(len(sub_encoders))
sub_encoders.append(encoder)
elif s =='(':
i = len(zip_encoders) + len(downsamplers)
downsampler = LearnedDownsamplingModule(cur_max_dim,
downsampling_factor[i])
downsampling_factors_list.append(downsampling_factor[i])
layer_indexes.append(len(downsamplers))
downsamplers.append(downsampler)
else:
assert s == ')'
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
layer_indexes.append(len(bypasses))
bypasses.append(bypass)
downsampling_factors_list.pop()
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
self.zip_encoder_dim = zip_encoder_dim
self.zip_downsampling_factor = zip_downsampling_factor
self.layer_indexes = layer_indexes
self.structure = structure
self.zip_encoders = nn.ModuleList(zip_encoders)
self.sub_encoders = nn.ModuleList(sub_encoders)
self.downsamplers = nn.ModuleList(downsamplers)
self.bypasses = nn.ModuleList(bypasses)
self.encoder_pos = CompactRelPositionalEncoding(64, pos_head_dim[0],
dropout_rate=0.15,
length_factor=1.0)
self.downsample_output = SimpleDownsample(
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout,
)
def _get_full_dim_output(self, outputs: List[Tensor]):
num_encoders = len(self.zip_encoders) + 1
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_pieces = [outputs[-1]]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = list(outputs[i].shape)[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim, (cur_dim, output_dim)
return torch.cat(output_pieces, dim=-1)
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.zip_encoders)
if not self.training:
return [1.0] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0
feature_mask_dropout_prob = 0.125
# mask1 shape: (1, batch_size, 1)
mask1 = (
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(
mask1,
(
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype),
)
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
channels = self.zip_encoder_dim[i]
feature_mask = torch.ones(
1, batch_size, channels, dtype=x.dtype, device=x.device
)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
"""
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
"""
seq_len, batch_size, _num_channels = x.shape
ans = torch.zeros(batch_size, seq_len, seq_len, device=x.device)
if self.causal:
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_t = t
tgt_t = t.unsqueeze(-1)
attn_mask = (src_t > tgt_t)
ans.masked_fill_(attn_mask, float('-inf'))
if src_key_padding_mask is not None:
ans.masked_fill_(src_key_padding_mask.unsqueeze(1), float('-inf'))
# now ans: (batch_size, seq_len, seq_len).
return ans
def forward(
self,
x: Tensor,
x_lens: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x:
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
outputs = []
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
pos_embs = [ self.encoder_pos(x) ]
downsample_info = []
if torch.jit.is_scripting():
feature_masks = [1.0] * len(self.zip_encoders)
else:
feature_masks = self.get_feature_masks(x)
for s, i in zip(self.structure, self.layer_indexes):
if s == 'Z':
encoder = self.zip_encoders[i]
ds = self.zip_downsampling_factor[i]
x = convert_num_channels(x, self.zip_encoder_dim[i])
x = encoder(
x,
feature_mask=feature_masks[i],
src_key_padding_mask=(
None
if src_key_padding_mask is None
else src_key_padding_mask[..., ::ds]
),
)
outputs.append(x)
elif s == 'S':
encoder = self.sub_encoders[i] # one encoder stack
x = encoder(x,
pos_embs[-1],
attn_offset=attn_offsets[-1])
# only the last output of subformer will be used to combine the
# final output.
if i == len(self.sub_encoders) - 1:
outputs.append(x)
# x will have the maximum dimension up till now, even if
# `encoder` uses lower dim in its layers.
elif s == '(':
downsampler = self.downsamplers[i]
indexes, weights, x_new = downsampler(x)
downsample_info.append((indexes, weights, x))
x = x_new
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
indexes,
weights))
else:
assert s == ')' # upsample and bypass
indexes, weights, x_orig = downsample_info.pop()
_attn_offset = attn_offsets.pop()
_pos_emb = pos_embs.pop()
x_orig = convert_num_channels(x_orig, x.shape[-1])
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
bypass = self.bypasses[i]
x = bypass(x_orig, x)
# Only "balanced" structure is supported now
assert len(downsample_info) == 0, len(downsample_info)
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths

View File

@ -0,0 +1,217 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/zipformer.py

View File

@ -274,6 +274,24 @@ def softmax(x: Tensor, dim: int):
return SoftmaxFunction.apply(x, dim)
class ClipGradFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
limit: float):
ctx.limit = limit
return x
@staticmethod
def backward(ctx, x_grad, *args):
return x_grad.clamp(-ctx.limit, ctx.limit), None
def clip_grad(x: Tensor, limit: float):
return ClipGradFunction.apply(x, limit)
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
def forward(
@ -875,6 +893,40 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
return x
class AbsValuePenalizer(nn.Module):
"""
This module adds a penalty to the loss function when ever the absolute value of
any element of the input tensor exceeds a certain limit.
"""
def __init__(self,
limit: float,
prob: float = 0.1,
penalty: float = 1.0e-04):
super().__init__()
self.limit = limit
self.penalty = penalty
self.prob = prob
self.name = None # will be set in training loop
# 20% of the time we will return and do nothing because memory usage is
# too high.
self.mem_cutoff = CutoffEstimator(0.2)
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad
or not self.training
or random.random() > self.prob):
# or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
return _no_op(x) # the _no_op op is to make our diagnostics code work.
x = penalize_abs_values_gt(x,
limit=self.limit,
penalty=self.penalty,
name=self.name)
return x
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()