Use k2 pruned transducer loss to train conformer-transducer model (#194)

* Using k2 pruned version transducer loss to train model

* Fix style

* Minor fixes
This commit is contained in:
Wei Kang 2022-02-17 13:33:54 +08:00 committed by GitHub
parent e8eb408760
commit b702281e90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 2750 additions and 1 deletions

View File

@ -6,6 +6,8 @@ per-file-ignores =
# line too long # line too long
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
exclude = exclude =
.git, .git,

View File

@ -1,5 +1,53 @@
## Results ## Results
### LibriSpeech BPE training results (Pruned Transducer)
#### Conformer encoder + embedding decoder
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
The decoding command is:
```
epoch=28
avg=15
## greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 100
```
### LibriSpeech BPE training results (Transducer) ### LibriSpeech BPE training results (Transducer)
#### Conformer encoder + embedding decoder #### Conformer encoder + embedding decoder

View File

@ -0,0 +1 @@
../transducer/asr_datamodule.py

View File

@ -0,0 +1,340 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
from model import Transducer
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y != blank_id:
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
sym_per_utt += 1
sym_per_frame += 1
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
return hyp
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys
log_prob: float
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self):
return self._data
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
cached_key = y_star.key
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Cache the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = A.get_most_probable()
kept_B = B.filter(A_most_probable.log_prob)
if len(kept_B) >= beam:
B = kept_B.topk(beam)
break
t += 1
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys

View File

@ -0,0 +1 @@
../transducer_stateless/conformer.py

View File

@ -0,0 +1,476 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,252 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless/export.py \
--exp-dir ./pruned_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless/decode.py \
--exp-dir ./pruned_transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,50 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
super().__init__()
self.inner_linear = nn.Linear(input_dim, inner_dim)
self.output_linear = nn.Linear(inner_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out
logit = self.inner_linear(torch.tanh(logit))
output = self.output_linear(F.relu(logit))
return output

View File

@ -0,0 +1,169 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1,326 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
(1) beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
./pruned_transducer_stateless/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../transducer/subsampling.py

View File

@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless/test_decoder.py
"""
import torch
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
context_size = 4
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
N = 100
U = 20
x = torch.randint(low=0, high=vocab_size, size=(N, U))
y = decoder(x)
assert y.shape == (N, U, vocab_size)
# for inference
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x, need_pad=False)
assert y.shape == (N, 1, vocab_size)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,830 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 30000
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_gradient_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_weights(tag: str):
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
tb_writer.add_scalars(
tag,
measure_weight_norms(model, norm="l2"),
global_step=params.batch_idx_train,
)
def maybe_log_param_relative_changes():
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
loss.backward()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
maybe_log_param_relative_changes()
optimizer.zero_grad()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 30000
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../transducer_stateless/transformer.py

View File

@ -25,13 +25,15 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union
import k2 import k2
import k2.version import k2.version
import kaldialign import kaldialign
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -690,3 +692,94 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1) return expaned_lengths >= lengths.unsqueeze(1)
def l1_norm(x):
return torch.sum(torch.abs(x))
def l2_norm(x):
return torch.sum(torch.pow(x, 2))
def linf_norm(x):
return torch.max(torch.abs(x))
def measure_weight_norms(
model: nn.Module, norm: str = "l2"
) -> Dict[str, float]:
"""
Compute the norms of the model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param)
elif norm == "l2":
val = l2_norm(param)
elif norm == "linf":
val = linf_norm(param)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def measure_gradient_norms(
model: nn.Module, norm: str = "l1"
) -> Dict[str, float]:
"""
Compute the norms of the gradients for each of model's parameters.
:param model: a torch.nn.Module instance
:param norm: how to compute the norm. Available values: 'l1', 'l2', 'linf'
:return: a dict mapping from parameter's name to its gradient's norm.
"""
with torch.no_grad():
norms = {}
for name, param in model.named_parameters():
if norm == "l1":
val = l1_norm(param.grad)
elif norm == "l2":
val = l2_norm(param.grad)
elif norm == "linf":
val = linf_norm(param.grad)
else:
raise ValueError(f"Unknown norm type: {norm}")
norms[name] = val.item()
return norms
def optim_step_and_measure_param_change(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: Optional[GradScaler] = None,
) -> Dict[str, float]:
"""
Perform model weight update and measure the "relative change in parameters per minibatch."
It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
and the L2 norm of the original parameter. It is given by the formula:
.. math::
\begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned}
"""
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()}
if scaler:
scaler.step(optimizer)
else:
optimizer.step()
relative_change = {}
with torch.no_grad():
for n, p_new in model.named_parameters():
p_orig = param_copy[n]
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change