Update tedlium3 transducer stateless

This commit is contained in:
Mingshuang Luo 2022-02-24 19:57:36 +08:00
parent 47e49a6663
commit 536ad2252e
16 changed files with 590 additions and 99 deletions

View File

@ -0,0 +1,18 @@
# Introduction
This recipe includes some different ASR models trained with TedLium3.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder |
|------------------------|-----------|--------------------|
| `transducer_stateless` | Conformer | Embedding + Conv1d |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,68 @@
## Results
### TedLium3 BPE training results (Transducer)
#### Conformer encoder + embedding decoder
Using the codes from this commit .
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 7.31 | 6.73 | --epoch 71, --avg 15, --max-duration 100 |
| beam search (beam size 4) | 7.12 | 6.58 | --epoch 71, --avg 15, --max-duration 100 |
| modified beam search (beam size 4) | 7.20 | 6.65 | --epoch 71, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--max-duration 180 \
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/DnRwoZF8RRyod4kkfG5q5Q/#scalars
The decoding command is:
```
epoch=29
avg=15
## greedy search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100
## beam search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
## modified beam search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```

View File

@ -17,7 +17,7 @@
"""
This file computes fbank features of the LibriSpeech dataset.
This file computes fbank features of the TedLium3 dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_librispeech():
def compute_fbank_tedlium():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -96,4 +96,4 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech()
compute_fbank_tedlium()

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
"""
Convert a transcript based on words to a list of BPE ids with the related BPE model.
For example, if we use 2 as the encoding id of <unk>, there are four examples:
texts = ['this is a <unk> day and in the room there are three <unk> laying in the bed']
spm_ids = [[38, 33, 6, 2, 316, 8, 16, 5, 257, 193, 103, 61, 331, 2, 196, 21, 14, 16, 5, 47, 12]]
texts = ['<unk> this is a sunny day and in the room there are three people in the <unk>']
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316, 8, 16, 5, 257, 193, 103, 61, 331, 107, 16, 5, 2]]
texts = ['<unk>']
spm_ids = [[2]]
"""
import argparse
import logging
import sentencepiece as spm
from typing import List
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--texts", type=List[str], help="The input transcripts list."
)
parser.add_argument(
"--unk-id",
type=int,
default=2,
help="The number id for the token '<unk>'.",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
return parser.parse_args()
def convert_texts_into_ids(
texts: List[str],
unk_id: int,
sp: spm.SentencePieceProcessor,
) -> List[int]:
"""
Args:
texts:
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
unk_id:
A number id for the token '<unk>'.
Returns:
Return a integer list of bpe ids.
"""
y = []
for text in texts:
y_ids = []
if "<unk>" in text:
text_segments = text.split("<unk>")
id_segments = sp.encode(text_segments, out_type=int)
for i in range(len(id_segments)):
if i != len(id_segments) - 1:
y_ids.extend(id_segments[i] + [unk_id])
else:
y_ids.extend(id_segments[i])
else:
y_ids = sp.encode([text], out_type=int)[0]
y.append(y_ids)
return y
def main():
args = get_args()
texts = args.texts
bpe_model = args.bpe_model
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
unk_id = sp.piece_to_id("<unk>")
y = convert_texts_into_ids(
texts=texts,
unk_id=unk_id,
sp=sp,
)
logging.info(f"The input texts: {texts}")
logging.info(f"The encoding ids: {y}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer/train.py
for usage.
"""
import numpy as np
from lhotse import load_manifest
def describe(cuts) -> None:
"""
Print a message describing details about the ``CutSet`` - the number of cuts and the
duration statistics, including the total duration and the percentage of speech segments.
Example output:
Cuts count: 804789
Total duration (hours): 1370.6
Speech duration (hours): 1370.6 (100.0%)
***
Duration statistics (seconds):
mean 6.1
std 3.1
min 0.5
25% 3.7
50% 6.0
75% 8.3
99.5% 14.9
99.9% 16.6
max 33.3
In the above example, we set 15(>14.9) as the maximum duration of training samples.
"""
durations = np.array([c.duration for c in cuts])
speech_durations = np.array(
[s.duration for c in cuts for s in c.trimmed_supervisions]
)
total_sum = durations.sum()
speech_sum = speech_durations.sum()
print("Cuts count:", len(cuts))
print(f"Total duration (hours): {total_sum / 3600:.1f}")
print(
f"Speech duration (hours): {speech_sum / 3600:.1f} ({speech_sum / total_sum:.1%})"
)
print("***")
print("Duration statistics (seconds):")
print(f"mean\t{np.mean(durations):.1f}")
print(f"std\t{np.std(durations):.1f}")
print(f"min\t{np.min(durations):.1f}")
print(f"25%\t{np.percentile(durations, 25):.1f}")
print(f"50%\t{np.median(durations):.1f}")
print(f"75%\t{np.percentile(durations, 75):.1f}")
print(f"99.5%\t{np.percentile(durations, 99.5):.1f}")
print(f"99.9%\t{np.percentile(durations, 99.9):.1f}")
print(f"max\t{np.max(durations):.1f}")
def main():
path = "./data/fbank/cuts_train.json.gz"
# path = "./data/fbank/cuts_dev.json.gz"
# path = "./data/fbank/cuts_test.json.gz"
cuts = load_manifest(path)
describe(cuts)
if __name__ == "__main__":
main()

View File

@ -151,6 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Generate data for BPE training"
cat data/lang_phone/train.text | cut -d " " -f 2-
> $lang_dir/transcript_words.txt
# remove the <unk> for transcript_words.txt
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt

View File

@ -7,7 +7,7 @@ https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
You can use the following command to start the training:
```bash
cd egs/librispeech/ASR
cd egs/tedlium3/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3"
@ -16,7 +16,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
--max-duration 180 \
--lr-factor 5.0
```

View File

@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -17,7 +18,6 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
from model import Transducer
@ -43,12 +43,13 @@ def greedy_search(
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
@ -84,7 +85,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y != blank_id:
if y != blank_id and y != unk_id:
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
@ -108,8 +109,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys
log_prob: float
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
@property
def key(self) -> str:
@ -118,7 +120,7 @@ class Hypothesis:
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
@ -130,11 +132,10 @@ class HypothesisList(object):
self._data = data
@property
def data(self):
def data(self) -> Dict[str, Hypothesis]:
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -146,8 +147,10 @@ class HypothesisList(object):
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -159,7 +162,8 @@ class HypothesisList(object):
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
@ -171,6 +175,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
@ -181,7 +188,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
@ -189,10 +196,10 @@ class HypothesisList(object):
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
@ -222,6 +229,201 @@ class HypothesisList(object):
return ", ".join(s)
def run_decoder(
ys: List[int],
model: Transducer,
decoder_cache: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Run the neural decoder model for a given hypothesis.
Args:
ys:
The current hypothesis.
model:
The transducer model.
decoder_cache:
Cache to save computations.
Returns:
Return a 1-D tensor of shape (decoder_out_dim,) containing
output of `model.decoder`.
"""
context_size = model.decoder.context_size
key = "_".join(map(str, ys[-context_size:]))
if key in decoder_cache:
return decoder_cache[key]
device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out
return decoder_out
def run_joiner(
key: str,
model: Transducer,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
joint_cache: Dict[str, torch.Tensor],
):
"""Run the joint network given outputs from the encoder and decoder.
Args:
key:
A key into the `joint_cache`.
model:
The transducer model.
encoder_out:
A tensor of shape (1, 1, encoder_out_dim).
decoder_out:
A tensor of shape (1, 1, decoder_out_dim).
encoder_out_len:
A tensor with value [1].
decoder_out_len:
A tensor with value [1].
joint_cache:
A dict to save computations.
Returns:
Return a tensor from the output of log-softmax.
Its shape is (vocab_size,).
"""
if key in joint_cache:
return joint_cache[key]
logits = model.joiner(
encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[key] = log_prob
return log_prob
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id and new_token != unk_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
@ -247,6 +449,7 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
@ -261,7 +464,12 @@ def beam_search(
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000
@ -281,58 +489,43 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
cached_key = y_star.key
decoder_out = run_decoder(
ys=y_star.ys, model=model, decoder_cache=decoder_cache
)
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Ccale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
key = "_".join(map(str, y_star.ys[-context_size:]))
key += f"-t-{t}"
log_prob = run_joiner(
key=key,
model=model,
encoder_out=current_encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
for idx in range(values.size(0)):
i = indices[idx].item()
if i == blank_id or i == unk_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable

View File

@ -615,7 +615,7 @@ class RelPositionMultiheadAttention(nn.Module):
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
""" # noqa
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
@ -635,7 +635,7 @@ class RelPositionMultiheadAttention(nn.Module):
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
@ -643,7 +643,7 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
@ -653,7 +653,7 @@ class RelPositionMultiheadAttention(nn.Module):
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
@ -662,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
@ -671,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
@ -687,12 +687,12 @@ class RelPositionMultiheadAttention(nn.Module):
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format( # noqa
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead." # noqa
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
@ -725,7 +725,7 @@ class RelPositionMultiheadAttention(nn.Module):
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." # noqa
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
@ -760,7 +760,7 @@ class RelPositionMultiheadAttention(nn.Module):
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
@ -832,7 +832,7 @@ class RelPositionMultiheadAttention(nn.Module):
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,16 +20,16 @@
Usage:
(1) greedy search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--epoch 29 \
--avg 15 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--epoch 29 \
--avg 15 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
@ -45,8 +46,8 @@ from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from asr_datamodule import TedLiumAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -77,7 +78,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=13,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -169,6 +170,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@ -256,6 +258,10 @@ def decode_one_batch(
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -382,14 +388,18 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
TedLiumAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -413,6 +423,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -439,16 +450,12 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
tedlium = TedLiumAsrDataModule(args)
test_cuts = tedlium.test_cuts()
test_dl = tedlium.test_dataloaders(test_cuts)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_sets = ["test"]
test_dl = [test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -37,6 +38,7 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
unk_id: int,
context_size: int,
):
"""
@ -47,6 +49,8 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
@ -58,6 +62,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size
self.context_size = context_size

View File

@ -120,7 +120,6 @@ class Transducer(nn.Module):
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
from_log_softmax=False,
)
return loss

View File

@ -50,7 +50,7 @@ import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -167,6 +167,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@ -230,6 +231,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
@ -300,6 +302,10 @@ def main():
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,7 +19,7 @@
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
cd icefall/egs/tedlium3/ASR
python ./transducer_stateless/test_decoder.py
"""
@ -29,6 +30,7 @@ from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
unk_id = 2
embedding_dim = 128
context_size = 4
@ -36,6 +38,7 @@ def test_decoder():
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
unk_id=unk_id,
context_size=context_size,
)
N = 100

View File

@ -26,9 +26,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
--max-duration 180 \
--lr-factor 5.0
"""
@ -56,6 +55,8 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
@ -233,6 +234,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@ -379,7 +381,9 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
texts = batch["supervisions"]["text"][: feature.size(0)]
y = sp.encode(texts, out_type=int)
unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
@ -565,6 +569,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)