Update tedlium3 transducer stateless

This commit is contained in:
Mingshuang Luo 2022-02-24 19:57:36 +08:00
parent 47e49a6663
commit 536ad2252e
16 changed files with 590 additions and 99 deletions

View File

@ -0,0 +1,18 @@
# Introduction
This recipe includes some different ASR models trained with TedLium3.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder |
|------------------------|-----------|--------------------|
| `transducer_stateless` | Conformer | Embedding + Conv1d |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,68 @@
## Results
### TedLium3 BPE training results (Transducer)
#### Conformer encoder + embedding decoder
Using the codes from this commit .
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 7.31 | 6.73 | --epoch 71, --avg 15, --max-duration 100 |
| beam search (beam size 4) | 7.12 | 6.58 | --epoch 71, --avg 15, --max-duration 100 |
| modified beam search (beam size 4) | 7.20 | 6.65 | --epoch 71, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--max-duration 180 \
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/DnRwoZF8RRyod4kkfG5q5Q/#scalars
The decoding command is:
```
epoch=29
avg=15
## greedy search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100
## beam search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
## modified beam search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```

View File

@ -17,7 +17,7 @@
""" """
This file computes fbank features of the LibriSpeech dataset. This file computes fbank features of the TedLium3 dataset.
It looks for manifests in the directory data/manifests. It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
@ -43,7 +43,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_librispeech(): def compute_fbank_tedlium():
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -96,4 +96,4 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech() compute_fbank_tedlium()

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
"""
Convert a transcript based on words to a list of BPE ids with the related BPE model.
For example, if we use 2 as the encoding id of <unk>, there are four examples:
texts = ['this is a <unk> day and in the room there are three <unk> laying in the bed']
spm_ids = [[38, 33, 6, 2, 316, 8, 16, 5, 257, 193, 103, 61, 331, 2, 196, 21, 14, 16, 5, 47, 12]]
texts = ['<unk> this is a sunny day and in the room there are three people in the <unk>']
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316, 8, 16, 5, 257, 193, 103, 61, 331, 107, 16, 5, 2]]
texts = ['<unk>']
spm_ids = [[2]]
"""
import argparse
import logging
import sentencepiece as spm
from typing import List
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--texts", type=List[str], help="The input transcripts list."
)
parser.add_argument(
"--unk-id",
type=int,
default=2,
help="The number id for the token '<unk>'.",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
return parser.parse_args()
def convert_texts_into_ids(
texts: List[str],
unk_id: int,
sp: spm.SentencePieceProcessor,
) -> List[int]:
"""
Args:
texts:
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
unk_id:
A number id for the token '<unk>'.
Returns:
Return a integer list of bpe ids.
"""
y = []
for text in texts:
y_ids = []
if "<unk>" in text:
text_segments = text.split("<unk>")
id_segments = sp.encode(text_segments, out_type=int)
for i in range(len(id_segments)):
if i != len(id_segments) - 1:
y_ids.extend(id_segments[i] + [unk_id])
else:
y_ids.extend(id_segments[i])
else:
y_ids = sp.encode([text], out_type=int)[0]
y.append(y_ids)
return y
def main():
args = get_args()
texts = args.texts
bpe_model = args.bpe_model
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
unk_id = sp.piece_to_id("<unk>")
y = convert_texts_into_ids(
texts=texts,
unk_id=unk_id,
sp=sp,
)
logging.info(f"The input texts: {texts}")
logging.info(f"The encoding ids: {y}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer/train.py
for usage.
"""
import numpy as np
from lhotse import load_manifest
def describe(cuts) -> None:
"""
Print a message describing details about the ``CutSet`` - the number of cuts and the
duration statistics, including the total duration and the percentage of speech segments.
Example output:
Cuts count: 804789
Total duration (hours): 1370.6
Speech duration (hours): 1370.6 (100.0%)
***
Duration statistics (seconds):
mean 6.1
std 3.1
min 0.5
25% 3.7
50% 6.0
75% 8.3
99.5% 14.9
99.9% 16.6
max 33.3
In the above example, we set 15(>14.9) as the maximum duration of training samples.
"""
durations = np.array([c.duration for c in cuts])
speech_durations = np.array(
[s.duration for c in cuts for s in c.trimmed_supervisions]
)
total_sum = durations.sum()
speech_sum = speech_durations.sum()
print("Cuts count:", len(cuts))
print(f"Total duration (hours): {total_sum / 3600:.1f}")
print(
f"Speech duration (hours): {speech_sum / 3600:.1f} ({speech_sum / total_sum:.1%})"
)
print("***")
print("Duration statistics (seconds):")
print(f"mean\t{np.mean(durations):.1f}")
print(f"std\t{np.std(durations):.1f}")
print(f"min\t{np.min(durations):.1f}")
print(f"25%\t{np.percentile(durations, 25):.1f}")
print(f"50%\t{np.median(durations):.1f}")
print(f"75%\t{np.percentile(durations, 75):.1f}")
print(f"99.5%\t{np.percentile(durations, 99.5):.1f}")
print(f"99.9%\t{np.percentile(durations, 99.9):.1f}")
print(f"max\t{np.max(durations):.1f}")
def main():
path = "./data/fbank/cuts_train.json.gz"
# path = "./data/fbank/cuts_dev.json.gz"
# path = "./data/fbank/cuts_test.json.gz"
cuts = load_manifest(path)
describe(cuts)
if __name__ == "__main__":
main()

View File

@ -151,6 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
cat data/lang_phone/train.text | cut -d " " -f 2- cat data/lang_phone/train.text | cut -d " " -f 2-
> $lang_dir/transcript_words.txt > $lang_dir/transcript_words.txt
# remove the <unk> for transcript_words.txt
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt sed -i 's/<unk>//g' $lang_dir/transcript_words.txt

View File

@ -7,7 +7,7 @@ https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
You can use the following command to start the training: You can use the following command to start the training:
```bash ```bash
cd egs/librispeech/ASR cd egs/tedlium3/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
@ -16,7 +16,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless/exp \
--full-libri 1 \ --max-duration 180 \
--max-duration 250 \ --lr-factor 5.0
--lr-factor 2.5
``` ```

View File

@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko # Copyright 2021 Piotr Żelasko
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -17,7 +18,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -43,12 +43,13 @@ def greedy_search(
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -84,7 +85,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id and y != unk_id:
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
@ -108,8 +109,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -118,7 +120,7 @@ class Hypothesis:
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
""" """
Args: Args:
data: data:
@ -130,11 +132,10 @@ class HypothesisList(object):
self._data = data self._data = data
@property @property
def data(self): def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
# def add(self, ys: List[int], log_prob: float): def add(self, hyp: Hypothesis) -> None:
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -146,8 +147,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -159,7 +162,8 @@ class HypothesisList(object):
length_norm: length_norm:
If True, the `log_prob` of a hypothesis is normalized by the If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it. number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
""" """
if length_norm: if length_norm:
return max( return max(
@ -171,6 +175,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args: Args:
hyp: hyp:
The hypothesis to be removed from `self`. The hypothesis to be removed from `self`.
@ -181,7 +188,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist" assert key in self, f"{key} does not exist"
del self._data[key] del self._data[key]
def filter(self, threshold: float) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
@ -189,10 +196,10 @@ class HypothesisList(object):
Returns: Returns:
Return a new HypothesisList containing all hypotheses from `self` Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`. with `log_prob` being greater than the given `threshold`.
""" """
ans = HypothesisList() ans = HypothesisList()
for key, hyp in self._data.items(): for _, hyp in self._data.items():
if hyp.log_prob > threshold: if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy ans.add(hyp) # shallow copy
return ans return ans
@ -222,6 +229,201 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def run_decoder(
ys: List[int],
model: Transducer,
decoder_cache: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Run the neural decoder model for a given hypothesis.
Args:
ys:
The current hypothesis.
model:
The transducer model.
decoder_cache:
Cache to save computations.
Returns:
Return a 1-D tensor of shape (decoder_out_dim,) containing
output of `model.decoder`.
"""
context_size = model.decoder.context_size
key = "_".join(map(str, ys[-context_size:]))
if key in decoder_cache:
return decoder_cache[key]
device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out
return decoder_out
def run_joiner(
key: str,
model: Transducer,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
joint_cache: Dict[str, torch.Tensor],
):
"""Run the joint network given outputs from the encoder and decoder.
Args:
key:
A key into the `joint_cache`.
model:
The transducer model.
encoder_out:
A tensor of shape (1, 1, encoder_out_dim).
decoder_out:
A tensor of shape (1, 1, decoder_out_dim).
encoder_out_len:
A tensor with value [1].
decoder_out_len:
A tensor with value [1].
joint_cache:
A dict to save computations.
Returns:
Return a tensor from the output of log-softmax.
Its shape is (vocab_size,).
"""
if key in joint_cache:
return joint_cache[key]
logits = model.joiner(
encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
)
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[key] = log_prob
return log_prob
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id and new_token != unk_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -247,6 +449,7 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = model.device
@ -261,7 +464,12 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -281,58 +489,43 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {} joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True: while True:
y_star = A.get_most_probable() y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
cached_key = y_star.key decoder_out = run_decoder(
ys=y_star.ys, model=model, decoder_cache=decoder_cache
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len,
decoder_out_len,
) )
# TODO(fangjun): Ccale the blank posterior key = "_".join(map(str, y_star.ys[-context_size:]))
key += f"-t-{t}"
log_prob = logits.log_softmax(dim=-1) log_prob = run_joiner(
# log_prob is (1, 1, 1, vocab_size) key=key,
log_prob = log_prob.squeeze() model=model,
# Now log_prob is (vocab_size,) encoder_out=current_encoder_out,
joint_cache[cached_key] = log_prob decoder_out=decoder_out,
else: encoder_out_len=encoder_out_len,
log_prob = joint_cache[cached_key] decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()): for idx in range(values.size(0)):
if i == blank_id: i = indices[idx].item()
if i == blank_id or i == unk_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable

View File

@ -615,7 +615,7 @@ class RelPositionMultiheadAttention(nn.Module):
E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" # noqa """
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check assert embed_dim == embed_dim_to_check
@ -635,7 +635,7 @@ class RelPositionMultiheadAttention(nn.Module):
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = 0 _start = 0
_end = embed_dim _end = embed_dim
@ -643,7 +643,7 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None: if _b is not None:
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
_end = None _end = None
@ -653,7 +653,7 @@ class RelPositionMultiheadAttention(nn.Module):
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else: else:
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = 0 _start = 0
_end = embed_dim _end = embed_dim
@ -662,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
_end = embed_dim * 2 _end = embed_dim * 2
@ -671,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end] _b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b) k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim * 2 _start = embed_dim * 2
_end = None _end = None
@ -687,12 +687,12 @@ class RelPositionMultiheadAttention(nn.Module):
or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format( # noqa ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype attn_mask.dtype
) )
if attn_mask.dtype == torch.uint8: if attn_mask.dtype == torch.uint8:
warnings.warn( warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead." # noqa "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
) )
attn_mask = attn_mask.to(torch.bool) attn_mask = attn_mask.to(torch.bool)
@ -725,7 +725,7 @@ class RelPositionMultiheadAttention(nn.Module):
and key_padding_mask.dtype == torch.uint8 and key_padding_mask.dtype == torch.uint8
): ):
warnings.warn( warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." # noqa "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
) )
key_padding_mask = key_padding_mask.to(torch.bool) key_padding_mask = key_padding_mask.to(torch.bool)
@ -760,7 +760,7 @@ class RelPositionMultiheadAttention(nn.Module):
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(
q_with_bias_u, k q_with_bias_u, k
@ -832,7 +832,7 @@ class RelPositionMultiheadAttention(nn.Module):
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model. """ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args: Args:
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -19,16 +20,16 @@
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 29 \
--avg 7 \ --avg 15 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 29 \
--avg 7 \ --avg 15 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
@ -45,8 +46,8 @@ from typing import Dict, List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import TedLiumAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -77,7 +78,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=13, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -169,6 +170,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size, context_size=params.context_size,
) )
return decoder return decoder
@ -256,6 +258,10 @@ def decode_one_batch(
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -382,14 +388,18 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) TedLiumAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
@ -413,6 +423,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -439,16 +450,12 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) tedlium = TedLiumAsrDataModule(args)
test_cuts = tedlium.test_cuts()
test_dl = tedlium.test_dataloaders(test_cuts)
test_clean_cuts = librispeech.test_clean_cuts() test_sets = ["test"]
test_other_cuts = librispeech.test_other_cuts() test_dl = [test_dl]
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -37,6 +38,7 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
unk_id: int,
context_size: int, context_size: int,
): ):
""" """
@ -47,6 +49,8 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size: context_size:
Number of previous words to use to predict the next word. Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
@ -58,6 +62,7 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size

View File

@ -120,7 +120,6 @@ class Transducer(nn.Module):
target_lengths=y_lens, target_lengths=y_lens,
blank=blank_id, blank=blank_id,
reduction="sum", reduction="sum",
from_log_softmax=False,
) )
return loss return loss

View File

@ -50,7 +50,7 @@ import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -167,6 +167,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size, context_size=params.context_size,
) )
return decoder return decoder
@ -230,6 +231,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
@ -300,6 +302,10 @@ def main():
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -18,7 +19,7 @@
""" """
To run this file, do: To run this file, do:
cd icefall/egs/librispeech/ASR cd icefall/egs/tedlium3/ASR
python ./transducer_stateless/test_decoder.py python ./transducer_stateless/test_decoder.py
""" """
@ -29,6 +30,7 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
unk_id = 2
embedding_dim = 128 embedding_dim = 128
context_size = 4 context_size = 4
@ -36,6 +38,7 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
unk_id=unk_id,
context_size=context_size, context_size=context_size,
) )
N = 100 N = 100

View File

@ -26,9 +26,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless/exp \
--full-libri 1 \ --max-duration 180 \
--max-duration 250 \ --lr-factor 5.0
--lr-factor 2.5
""" """
@ -56,6 +55,8 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
@ -233,6 +234,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size, context_size=params.context_size,
) )
return decoder return decoder
@ -379,7 +381,9 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)] feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
texts = batch["supervisions"]["text"][: feature.size(0)] texts = batch["supervisions"]["text"][: feature.size(0)]
y = sp.encode(texts, out_type=int)
unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -565,6 +569,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)