[Ready to merge]stateless6: states4 + hubert distillation. (#387)

* a copy of stateless4 as base

* distillation with hubert

* fix typo

* example usage

* usage

* Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* fix comment

* add results of 100hours

* Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* check fairseq and quantization

* a short intro to distillation framework

* Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* add intro of statless6 in README

* fix type error of dst_manifest_dir

* Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* make export.py call stateless6/train.py instead of stateless2/train.py

* update results by stateless6

* adjust results format

* fix typo

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
LIyong.Guo 2022-05-28 12:37:50 +08:00 committed by GitHub
parent c8c8645081
commit c4ee2bc0af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 4429 additions and 5 deletions

View File

@ -21,6 +21,7 @@ The following table lists the differences among them.
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
| `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
The decoder in `transducer_stateless` is modified from the paper

View File

@ -3,6 +3,31 @@
This page shows the WERs for test-clean/test-other using only
train-clean-100 subset as training data.
## Distillation with hubert
### 2022-05-27
Related models/log/tensorboard:
https://huggingface.co/GuoLiyong/stateless6_baseline_vs_disstillation
Following results are obtained by ./distillation_with_hubert.sh
The only differences is in pruned_transducer_stateless6/train.py.
For baseline: set enable_distillation=False
For distillation: set enable_distillation=True (the default)
Decoding method is modified beam search.
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| baseline no vq distillation | 7.09 | 18.88 | --epoch 20, --avg 10, --max-duration 200 |
| baseline no vq distillation | 6.83 | 18.19 | --epoch 30, --avg 10, --max-duration 200 |
| baseline no vq distillation | 6.73 | 17.79 | --epoch 40, --avg 10, --max-duration 200 |
| baseline no vq distillation | 6.75 | 17.68 | --epoch 50, --avg 10, --max-duration 200 |
| distillation with hubert | 5.82 | 15.98 | --epoch 20, --avg 10, --max-duration 200 |
| distillation with hubert | 5.52 | 15.15 | --epoch 30, --avg 10, --max-duration 200 |
| distillation with hubert | 5.45 | 14.94 | --epoch 40, --avg 10, --max-duration 200 |
| distillation with hubert | 5.50 | 14.77 | --epoch 50, --avg 10, --max-duration 200 |
## Conformer encoder + embedding decoder
### 2022-02-21

View File

@ -0,0 +1,144 @@
# A short introduction about distillation framework.
#
# A typical traditional distillation method is
# Loss(teacher embedding, student embedding).
#
# Comparing to these, the proposed distillation framework contains two mainly steps:
# codebook indexes = quantizer.encode(teacher embedding)
# Loss(codebook indexes, student embedding)
#
# Things worth to meantion:
# 1. The float type teacher embedding is quantized into a sequence of
# 8-bit integer codebook indexes.
# 2. a middle layer 36(1-based) out of total 48 layers is used to extract
# teacher embeddings.
# 3. a middle layer 6(1-based) out of total 6 layers is used to extract
# student embeddings.
# This is an example to do distillation with librispeech clean-100 subset.
# run with command:
# bash distillation_with_hubert.sh [0|1|2|3|4]
#
# For example command
# bash distillation_with_hubert.sh 0
# will download hubert model.
stage=$1
# Set the GPUs available.
# This script requires at least one GPU.
# You MUST set environment variable "CUDA_VISIBLE_DEVICES",
# even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly.
# Suppose only one GPU exists:
# export CUDA_VISIBLE_DEVICES="0"
#
# Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="2,3,4,5"
if [ $stage -eq 0 ]; then
# Preparation stage.
# Install fairseq according to:
# https://github.com/pytorch/fairseq
# when testing this code:
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages"
exit 1
fi
# Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master
# when testing this code:
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages"
exit 1
fi
echo "Download hubert model."
# Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt
mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
if [ -f ${hubert_model} ]; then
echo "hubert model alread exists."
else
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi
fi
if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh"
exit 1
fi
if [ $stage -eq 1 ]; then
# This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal.
# Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py
fi
if [ $stage -eq 2 ]; then
# Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
# Training dataset including clean-100h with speed perturb 0.9 and 1.1 has 300 hours.
# The output frame rates of Hubert is 50 per second.
# Theoretically, 412M = 300 * 3600 * 50 * 8 / 1024 / 1024 is needed.
# The actual size of all "*.h5" files storaging codebook index is 450M.
# I think the extra "48M" usage is some meta information.
# Time consumption analysis:
# For quantizer training data(teacher embedding) extraction, only 1000 utts from clean-100 are used.
# Together with quantizer training, no more than 20 minutes will be used.
#
# For codebook indexes extraction,
# with two pieces of NVIDIA A100 gpus, around three hours needed to process 300 hours training data,
# i.e. clean-100 with speed purteb 0.9 and 1.1.
# GPU usage:
# During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False
fi
if [ $stage -eq 3 ]; then
# Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
--master-port 12359 \
--full-libri False \
--spec-aug-time-warp-factor -1 \
--max-duration 300 \
--world-size ${WORLD_SIZE} \
--num-epochs 20
fi
if [ $stage -eq 4 ]; then
# Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
./pruned_transducer_stateless6/decode.py \
--decoding-method "modified_beam_search" \
--epoch 20 \
--avg 10 \
--max-duration 200 \
--exp-dir ./pruned_transducer_stateless6/exp
fi

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,634 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless6/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
layer_results, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out = layer_results[-1]
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import torch
from vq_utils import CodebookIndexExtractor
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=Path,
default="pruned_transducer_stateless6/exp/",
help="The experiment dir",
)
return parser
def get_world_size():
warn_message = (
"It's better to use GPU to extrac codebook indices"
"Please set with commonds like: export CUDA_VISIBLE_DEVICES=0,1,2,3"
)
assert (
torch.cuda.is_available() and "CUDA_VISIBLE_DEVICES" in os.environ
), warn_message
world_size = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
assert world_size > 0, warn_message
return world_size
def main():
world_size = get_world_size()
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
HubertXlargeFineTuned.add_arguments(parser)
CodebookIndexExtractor.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
# reset some parameters needed by hubert.
params.update(HubertXlargeFineTuned.get_params())
params.device = torch.device("cuda", 0)
params.world_size = world_size
extractor = CodebookIndexExtractor(params=params)
extractor.extract_and_save_embedding()
extractor.train_quantizer()
extractor.extract_codebook_indexes()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,205 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=Path,
default="pruned_transducer_stateless6/exp/",
help="The experiment dir",
)
return parser
def decode_dataset(
dl: torch.utils.data.DataLoader,
hubert_model: HubertXlargeFineTuned,
params: AttributeDict,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
Returns:
Return a dict, whose key is decoding method "ctc_greedy_search".
Its value is a list of tuples.
Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts)
this_batch = []
for hyp_text, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 20 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
HubertXlargeFineTuned.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
# reset some parameters needed by hubert.
params.update(HubertXlargeFineTuned.get_params())
params.res_dir = (
params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}"
)
setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
params.device = device
hubert_model = HubertXlargeFineTuned(params)
librispeech = LibriSpeechAsrDataModule(params)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
hubert_model=hubert_model,
params=params,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from fairseq import (
checkpoint_utils,
tasks,
utils,
)
from fairseq.data.data_utils import post_process
from omegaconf import OmegaConf
from icefall.utils import AttributeDict
def _load_hubert_model(params: AttributeDict):
"""
Load the hubert model.
The model loaded is specified by params.hubert_model_dir
and params.teacher_model_id.
Returned model carries hubert,
while processor is responsible to map model's output to human readable transcripts.
"""
cfg_task = OmegaConf.create(
{
"_name": "hubert_pretraining",
"single_target": True,
"fine_tuning": True,
"data": str(params.hubert_model_dir),
}
)
model_path = Path(params.hubert_model_dir) / (
params.teacher_model_id + ".pt"
)
task = tasks.setup_task(cfg_task)
processor = task.target_dictionary
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(str(model_path), separator="\\"),
arg_overrides={},
strict=True,
suffix="",
num_shards=1,
)
model = models[0]
model.to(params.device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
return model, processor
class HubertXlargeFineTuned:
"""
A wrapper of hubert extra large fine-tuned model.
A teacher model is responsible for:
1. load teacher model
2. extracting embeddings to train quantizer.
3. extract codebook indices
4. verify its performance with ctc_greedy_search method.
"""
def __init__(self, params: AttributeDict):
self.model, self.processor = _load_hubert_model(params)
self.w2v_model = self.model.w2v_encoder.w2v_model
self.params = params
@staticmethod
def get_params() -> AttributeDict:
"""Return a dict containing parameters defined in other modules.
Their default value conflits to hubert's requirements so they are reset as following.
"""
params = AttributeDict(
{
# parameters defined in asr_datamodule.py
"input_strategy": "AudioSamples",
"enable_musan": False,
"enable_spec_aug": False,
"return_cuts": True,
"drop_last": False,
# parameters used by quantizer
"embedding_dim": 1280,
}
)
return params
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
# Options about model loading.
parser.add_argument(
"--hubert-model-dir",
type=Path,
default="./pruned_transducer_stateless6/exp/hubert_models/",
help="path to save downloaded hubert models.",
)
parser.add_argument(
"--teacher-model-id",
type=str,
default="hubert_xtralarge_ll60k_finetune_ls960",
help="""could be one of:
[
"hubert_xtralarge_ll60k_finetune_ls960", # fine-tuned model.
"hubert_xtralarge_ll60k.pt", # pretrained model without fintuing.
]""",
)
parser.add_argument(
"--total-layers",
type=int,
default=48,
)
# Modified from HubertModel.forward to extract all middle layers output
def extract_layers_result(
self,
batch: Dict,
) -> List[torch.Tensor]:
"""
Extract activations from all layers.
"""
features = batch["inputs"]
# corresponding task.normalize in fairseq
features = torch.nn.functional.layer_norm(features, features.shape)
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
B, T = features.shape
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape(
[-1, 1]
)
padding_mask = padding_mask.to(self.params.device)
features = features.to(self.params.device)
features = self.w2v_model.forward_features(features)
features = features.transpose(1, 2)
features = self.w2v_model.layer_norm(features)
padding_mask = self.w2v_model.forward_padding_mask(
features, padding_mask
)
if self.w2v_model.post_extract_proj is not None:
features = self.w2v_model.post_extract_proj(features)
_, layer_results = self.w2v_model.encoder(
features,
padding_mask=padding_mask,
)
return layer_results
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
"""
Eextract embeddings specified by self.params.embedding_layer.
These embeddings could be used to train quantizer
or to extract codebook indexes.
The returned List[int] is valid length of each embedding.
We only want to store codebook indexes related to
these valid embeddings.
"""
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert all(c.start == 0 for c in cut_list)
layer_results = self.extract_layers_result(batch)
embeddings = layer_results[self.params.embedding_layer - 1][0]
encoder_embedding = embeddings.transpose(0, 1) # N, T, C
N = encoder_embedding.shape[0]
assert len(cut_list) == N
# 320 is from: 16,000 / 50 = sample_rate / hbuert output frame rate
num_frames = (supervisions["num_samples"] // 320).tolist()
return encoder_embedding, num_frames
def ctc_greedy_search(self, batch):
"""
Mainly used to verify hubert model is used correctly.
"""
layer_results = self.extract_layers_result(batch=batch)
encoder_out = self.w2v_model.encoder.layer_norm(
layer_results[self.params.total_layers - 1][0]
)
encoder_out = self.model.w2v_encoder.proj(encoder_out.transpose(0, 1))
toks = encoder_out.argmax(dim=-1)
blank = 0
toks = [tok.unique_consecutive() for tok in toks]
hyps = [
self.processor.string(tok[tok != blank].int().cpu()) for tok in toks
]
hyps = [post_process(hyp, "letter") for hyp in hyps]
return hyps

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,249 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
num_codebooks:
Used by distillation loss.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1]
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless6/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.enable_distiallation = False
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import glob
import logging
import os
from functools import cached_property
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
import torch.multiprocessing as mp
import quantization
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import (
AttributeDict,
setup_logger,
)
from lhotse import CutSet, load_manifest
from lhotse.features.io import NumpyHdf5Writer
class CodebookIndexExtractor:
"""
A wrapper of quantiation.Quantizer.
It's responsible for:
1. extract and save activations from a teacher model.
2. train quantizer from previous activations.
3. extract codebook indexes for whole training set.
Normally this step needs multi GPUs.
"""
def __init__(self, params: AttributeDict):
self.params = params
params.subsets = ["clean-100"]
if self.params.full_libri:
self.params.subsets += ["clean-360", "other-500"]
self.init_dirs()
setup_logger(f"{self.vq_dir}/log-vq_extraction")
def init_dirs(self):
# vq_dir is the root dir for quantizer:
# training data/ quantizer / extracted codebook indexes
self.vq_dir = (
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
)
self.vq_dir.mkdir(parents=True, exist_ok=True)
# manifest_dir for :
# splited original manifests,
# extracted codebook indexes and their related manifests
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
self.manifest_dir.mkdir(parents=True, exist_ok=True)
# It's doesn't matter whether ori_manifest_dir is str or Path.
# Set it to Path to be consistent.
self.ori_manifest_dir = Path("./data/fbank/")
self.dst_manifest_dir = Path("./data/vq_fbank/")
self.dst_manifest_dir.mkdir(parents=True, exist_ok=True)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
# Options about teacher embeddings eatraction.
parser.add_argument(
"--embedding-layer",
type=int,
help="layer to extract teacher embeddings, 1-based.",
default=36,
)
parser.add_argument(
"--num-utts",
type=int,
default=1000,
help="num utts to train quantizer",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=8,
help="""number of codebooks,
i.e. number of codebook indexes each teacher embedding is compressed.
""",
)
@property
def embedding_file_path(self):
"""
The saved embedding is used to train quantizer.
"""
embedding_file_id = (
f"num_utts_{self.params.num_utts}"
+ f"-layer_{self.params.embedding_layer}"
+ "-embedding_embeddings.h5"
)
embedding_file_path = self.vq_dir / embedding_file_id
return embedding_file_path
@torch.no_grad()
def extract_and_save_embedding(self):
"""
The extract embedding is used to train quantizer.
"""
if self.embedding_file_path.exists():
warn_message = (
f"{self.embedding_file_path} already exists."
+ " Skip extracting embeddings from teacher model"
)
logging.warn(warn_message)
return
total_cuts = 0
with NumpyHdf5Writer(self.embedding_file_path) as writer:
for batch_idx, batch in enumerate(self.quantizer_train_dl):
cut_list = batch["supervisions"]["cut"]
(
encoder_embedding,
num_frames,
) = self.teacher_model.extract_embedding(batch)
encoder_embedding = encoder_embedding.cpu().numpy()
for idx, cut in enumerate(cut_list):
cut.encoder_embedding = writer.store_array(
key=cut.id,
value=encoder_embedding[idx][: num_frames[idx]],
)
total_cuts += len(cut_list)
logging.info(
f"Processed {total_cuts} output of {self.params.num_utts} cuts."
)
logging.info(f"Processed all {total_cuts} cuts.")
@property
def quantizer_train_dl(self):
# used to train quantizer.
librispeech = LibriSpeechAsrDataModule(self.params)
quantizer_trian_cuts = librispeech.train_clean_100_cuts().subset(
first=self.params.num_utts
)
return librispeech.train_dataloaders(quantizer_trian_cuts)
@cached_property
def quantizer_file_path(self):
quantizer_file_id = (
f"num_utts-{self.params.num_utts}"
+ f"-layer-{self.params.embedding_layer}"
+ f"-num_codebooks_{self.params.num_codebooks}"
+ "-quantizer.pt"
)
quantizer_file_path = Path(self.vq_dir) / quantizer_file_id
return quantizer_file_path
def train_quantizer(self):
if self.quantizer_file_path.exists():
warn_message = (
f"{self.quantizer_file_path} already exists."
+ " Skip trainning quantizer."
)
logging.warn(warn_message)
return
assert self.embedding_file_path.exists()
trainer = quantization.QuantizerTrainer(
dim=self.params.embedding_dim,
bytes_per_frame=self.params.num_codebooks,
device=self.params.device,
)
train, valid = quantization.read_hdf5_data(self.embedding_file_path)
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used
# when we tuned this method.
def minibatch_generator(data: torch.Tensor, repeat: bool):
assert 3 * B < data.shape[0]
cur_offset = 0
while True if repeat else cur_offset + B <= data.shape[0]:
start = cur_offset % (data.shape[0] + 1 - B)
end = start + B
cur_offset += B
yield data[start:end, :].to(self.params.device).to(
dtype=torch.float
)
for x in minibatch_generator(train, repeat=True):
trainer.step(x)
if trainer.done():
break
quantizer = trainer.get_quantizer()
torch.save(quantizer.state_dict(), self.quantizer_file_path)
def split_ori_manifests(self):
"""
When multi gpus are available, split original manifests
and extract codebook indexes in a prallel way.
"""
for subset in self.params.subsets:
logging.info(f"About to split {subset}.")
ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz"
split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}"
os.system(f"{split_cmd}")
def merge_vq_manifests(self):
"""
Merge generated vq included manfiests and storage to self.dst_manifest_dir.
"""
for subset in self.params.subsets:
vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz"
dst_vq_manifest = (
self.dst_manifest_dir / f"cuts_train-{subset}.json.gz"
)
if 1 == self.params.world_size:
merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}"
else:
merge_cmd = f"lhotse combine {vq_manifests} {dst_vq_manifest}"
os.system(f"{merge_cmd}")
def reuse_manifests(self):
"""
Only train-* subsets are extracted codebook indexes from.
The reset subsets are just a link from ./data/fbank.
"""
def is_train(manifest: str) -> bool:
for train_subset in ["clean-100", "clean-360", "other-500"]:
if train_subset in manifest:
return True
return False
# Type of self.ori_nanifest_dir is Path
# and result type of glob.glob is str.
reusable_manifests = [
manifest
for manifest in glob.glob(f"{self.ori_manifest_dir}/*.gz")
if not is_train(manifest)
]
for manifest_path in reusable_manifests:
ori_manifest_path = Path(manifest_path).resolve()
# Path cannot used as a parameter of str.replace.
# Cast them to str.
dst_manifest_path = Path(
manifest_path.replace(
str(self.ori_manifest_dir), str(self.dst_manifest_dir)
)
).resolve()
if not dst_manifest_path.exists():
os.symlink(ori_manifest_path, dst_manifest_path)
def create_vq_fbank(self):
self.reuse_manifests()
self.merge_vq_manifests()
@cached_property
def teacher_model(self):
return HubertXlargeFineTuned(self.params)
@cached_property
def quantizer(self):
assert self.quantizer_file_path.exists()
quantizer = quantization.Quantizer(
dim=self.params.embedding_dim,
num_codebooks=self.params.num_codebooks,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(self.quantizer_file_path))
quantizer.to(self.params.device)
return quantizer
def load_ori_dl(self, subset):
if self.params.world_size == 1:
ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz"
else:
ori_manifest_path = (
self.manifest_dir
/ f"cuts_train-{subset}.{self.params.manifest_index}.json.gz"
)
cuts = load_manifest(ori_manifest_path)
dl = LibriSpeechAsrDataModule(self.params).train_dataloaders(cuts)
return dl
def _release_gpu_memory(self):
self.__dict__.pop("teacher_model", None)
self.__dict__.pop("quantizer", None)
torch.cuda.empty_cache()
def extract_codebook_indexes(self):
if self.params.world_size == 1:
self.extract_codebook_indexes_imp()
else:
# Since a new extractor will be created for each rank in
# compute_codebook_indexes_parallel, it's better to
# release the GPU memory occupied by current extractor.
self._release_gpu_memory()
# Prepare split manifests for each job.
self.split_ori_manifests()
mp.spawn(
compute_codebook_indexes_parallel,
args=(self.params,),
nprocs=self.params.world_size,
join=True,
)
self.create_vq_fbank()
@torch.no_grad()
def extract_codebook_indexes_imp(self):
for subset in self.params.subsets:
num_cuts = 0
cuts = []
if self.params.world_size == 1:
manifest_file_id = f"{subset}"
else:
manifest_file_id = f"{subset}-{self.params.manifest_index}"
manifest_file_path = self.manifest_dir / manifest_file_id
with NumpyHdf5Writer(manifest_file_path) as writer:
for batch_idx, batch in enumerate(self.load_ori_dl(subset)):
(
encoder_embedding,
num_frames,
) = self.teacher_model.extract_embedding(batch)
codebook_indexes = self.quantizer.encode(encoder_embedding)
# [N, T, C]
codebook_indexes = codebook_indexes.to("cpu").numpy()
assert np.min(codebook_indexes) >= 0
assert np.max(codebook_indexes) < 256
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert len(cut_list) == codebook_indexes.shape[0]
assert all(c.start == 0 for c in supervisions["cut"])
for idx, cut in enumerate(cut_list):
cut.codebook_indexes = writer.store_array(
key=cut.id,
value=codebook_indexes[idx][: num_frames[idx]],
frame_shift=0.02,
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(cut_list)
message = f"Processed {num_cuts} cuts from {subset}"
if self.params.world_size > 1:
message += f" by job {self.params.manifest_index}"
logging.info(f"{message}.")
json_file_path = (
self.manifest_dir
/ f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz"
)
CutSet.from_cuts(cuts).to_json(json_file_path)
@torch.no_grad()
def compute_codebook_indexes_parallel(
rank: int,
params,
) -> List[Tuple[str, List[int]]]:
"""Create an extractor for each rank and extract codebook indexes parallelly.
Normally, this function is called by torch.multiprocessing
when multi GPUs are available.
"""
params = copy.deepcopy(params)
device = torch.device("cuda", rank)
params.device = device
# rank is 0-based while split manifests by "lhotse split" is 1-based.
params.manifest_index = rank + 1
extractor = CodebookIndexExtractor(params=params)
extractor.extract_codebook_indexes_imp()

View File

@ -25,7 +25,7 @@ from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
BucketingSampler,
CutConcatenate,
CutMix,
@ -34,7 +34,10 @@ from lhotse.dataset import (
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -150,6 +153,12 @@ class LibriSpeechAsrDataModule:
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
@ -192,6 +201,13 @@ class LibriSpeechAsrDataModule:
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
@ -263,6 +279,7 @@ class LibriSpeechAsrDataModule:
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
@ -296,7 +313,7 @@ class LibriSpeechAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
@ -371,7 +388,7 @@ class LibriSpeechAsrDataModule:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(

View File

@ -127,7 +127,11 @@ def setup_logger(
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w"
filename=log_filename,
format=formatter,
level=level,
filemode="w",
force=True,
)
if use_console:
console = logging.StreamHandler()