mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update tedlium3 transducer stateless
This commit is contained in:
parent
47e49a6663
commit
536ad2252e
18
egs/tedlium3/ASR/README.md
Normal file
18
egs/tedlium3/ASR/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
This recipe includes some different ASR models trained with TedLium3.
|
||||||
|
|
||||||
|
# Transducers
|
||||||
|
|
||||||
|
There are various folders containing the name `transducer` in this folder.
|
||||||
|
The following table lists the differences among them.
|
||||||
|
|
||||||
|
| | Encoder | Decoder |
|
||||||
|
|------------------------|-----------|--------------------|
|
||||||
|
| `transducer_stateless` | Conformer | Embedding + Conv1d |
|
||||||
|
|
||||||
|
|
||||||
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
We place an additional Conv1d layer right after the input embedding layer.
|
||||||
68
egs/tedlium3/ASR/RESULTS.md
Normal file
68
egs/tedlium3/ASR/RESULTS.md
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### TedLium3 BPE training results (Transducer)
|
||||||
|
|
||||||
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
|
Using the codes from this commit .
|
||||||
|
|
||||||
|
Conformer encoder + non-current decoder. The decoder
|
||||||
|
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||||
|
|
||||||
|
The WERs are
|
||||||
|
|
||||||
|
| | dev | test | comment |
|
||||||
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
|
| greedy search | 7.31 | 6.73 | --epoch 71, --avg 15, --max-duration 100 |
|
||||||
|
| beam search (beam size 4) | 7.12 | 6.58 | --epoch 71, --avg 15, --max-duration 100 |
|
||||||
|
| modified beam search (beam size 4) | 7.20 | 6.65 | --epoch 71, --avg 15, --max-duration 100 |
|
||||||
|
|
||||||
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./transducer_stateless/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir transducer_stateless/exp \
|
||||||
|
--max-duration 180 \
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
https://tensorboard.dev/experiment/DnRwoZF8RRyod4kkfG5q5Q/#scalars
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```
|
||||||
|
epoch=29
|
||||||
|
avg=15
|
||||||
|
|
||||||
|
## greedy search
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir transducer_stateless/exp \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 100
|
||||||
|
|
||||||
|
## beam search
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir transducer_stateless/exp \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
## modified beam search
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir transducer_stateless/exp \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
```
|
||||||
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file computes fbank features of the LibriSpeech dataset.
|
This file computes fbank features of the TedLium3 dataset.
|
||||||
It looks for manifests in the directory data/manifests.
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
The generated fbank features are saved in data/fbank.
|
||||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_librispeech():
|
def compute_fbank_tedlium():
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -96,4 +96,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
compute_fbank_librispeech()
|
compute_fbank_tedlium()
|
||||||
|
|||||||
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
"""
|
||||||
|
Convert a transcript based on words to a list of BPE ids with the related BPE model.
|
||||||
|
|
||||||
|
For example, if we use 2 as the encoding id of <unk>, there are four examples:
|
||||||
|
|
||||||
|
texts = ['this is a <unk> day and in the room there are three <unk> laying in the bed']
|
||||||
|
spm_ids = [[38, 33, 6, 2, 316, 8, 16, 5, 257, 193, 103, 61, 331, 2, 196, 21, 14, 16, 5, 47, 12]]
|
||||||
|
|
||||||
|
texts = ['<unk> this is a sunny day and in the room there are three people in the <unk>']
|
||||||
|
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316, 8, 16, 5, 257, 193, 103, 61, 331, 107, 16, 5, 2]]
|
||||||
|
|
||||||
|
texts = ['<unk>']
|
||||||
|
spm_ids = [[2]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sentencepiece as spm
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--texts", type=List[str], help="The input transcripts list."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--unk-id",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number id for the token '<unk>'.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_texts_into_ids(
|
||||||
|
texts: List[str],
|
||||||
|
unk_id: int,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
|
||||||
|
unk_id:
|
||||||
|
A number id for the token '<unk>'.
|
||||||
|
Returns:
|
||||||
|
Return a integer list of bpe ids.
|
||||||
|
"""
|
||||||
|
y = []
|
||||||
|
for text in texts:
|
||||||
|
y_ids = []
|
||||||
|
if "<unk>" in text:
|
||||||
|
text_segments = text.split("<unk>")
|
||||||
|
id_segments = sp.encode(text_segments, out_type=int)
|
||||||
|
for i in range(len(id_segments)):
|
||||||
|
if i != len(id_segments) - 1:
|
||||||
|
y_ids.extend(id_segments[i] + [unk_id])
|
||||||
|
else:
|
||||||
|
y_ids.extend(id_segments[i])
|
||||||
|
else:
|
||||||
|
y_ids = sp.encode([text], out_type=int)[0]
|
||||||
|
y.append(y_ids)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
texts = args.texts
|
||||||
|
bpe_model = args.bpe_model
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(bpe_model)
|
||||||
|
unk_id = sp.piece_to_id("<unk>")
|
||||||
|
|
||||||
|
y = convert_texts_into_ids(
|
||||||
|
texts=texts,
|
||||||
|
unk_id=unk_id,
|
||||||
|
sp=sp,
|
||||||
|
)
|
||||||
|
logging.info(f"The input texts: {texts}")
|
||||||
|
logging.info(f"The encoding ids: {y}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
89
egs/tedlium3/ASR/local/display_manifest_statistics.py
Normal file
89
egs/tedlium3/ASR/local/display_manifest_statistics.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file displays duration statistics of utterances in a manifest.
|
||||||
|
You can use the displayed value to choose minimum/maximum duration
|
||||||
|
to remove short and long utterances during the training.
|
||||||
|
|
||||||
|
See the function `remove_short_and_long_utt()` in transducer/train.py
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def describe(cuts) -> None:
|
||||||
|
"""
|
||||||
|
Print a message describing details about the ``CutSet`` - the number of cuts and the
|
||||||
|
duration statistics, including the total duration and the percentage of speech segments.
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
Cuts count: 804789
|
||||||
|
Total duration (hours): 1370.6
|
||||||
|
Speech duration (hours): 1370.6 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.1
|
||||||
|
std 3.1
|
||||||
|
min 0.5
|
||||||
|
25% 3.7
|
||||||
|
50% 6.0
|
||||||
|
75% 8.3
|
||||||
|
99.5% 14.9
|
||||||
|
99.9% 16.6
|
||||||
|
max 33.3
|
||||||
|
|
||||||
|
In the above example, we set 15(>14.9) as the maximum duration of training samples.
|
||||||
|
"""
|
||||||
|
durations = np.array([c.duration for c in cuts])
|
||||||
|
speech_durations = np.array(
|
||||||
|
[s.duration for c in cuts for s in c.trimmed_supervisions]
|
||||||
|
)
|
||||||
|
total_sum = durations.sum()
|
||||||
|
speech_sum = speech_durations.sum()
|
||||||
|
print("Cuts count:", len(cuts))
|
||||||
|
print(f"Total duration (hours): {total_sum / 3600:.1f}")
|
||||||
|
print(
|
||||||
|
f"Speech duration (hours): {speech_sum / 3600:.1f} ({speech_sum / total_sum:.1%})"
|
||||||
|
)
|
||||||
|
print("***")
|
||||||
|
print("Duration statistics (seconds):")
|
||||||
|
print(f"mean\t{np.mean(durations):.1f}")
|
||||||
|
print(f"std\t{np.std(durations):.1f}")
|
||||||
|
print(f"min\t{np.min(durations):.1f}")
|
||||||
|
print(f"25%\t{np.percentile(durations, 25):.1f}")
|
||||||
|
print(f"50%\t{np.median(durations):.1f}")
|
||||||
|
print(f"75%\t{np.percentile(durations, 75):.1f}")
|
||||||
|
print(f"99.5%\t{np.percentile(durations, 99.5):.1f}")
|
||||||
|
print(f"99.9%\t{np.percentile(durations, 99.9):.1f}")
|
||||||
|
print(f"max\t{np.max(durations):.1f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
path = "./data/fbank/cuts_train.json.gz"
|
||||||
|
# path = "./data/fbank/cuts_dev.json.gz"
|
||||||
|
# path = "./data/fbank/cuts_test.json.gz"
|
||||||
|
|
||||||
|
cuts = load_manifest(path)
|
||||||
|
describe(cuts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -151,6 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
log "Generate data for BPE training"
|
log "Generate data for BPE training"
|
||||||
cat data/lang_phone/train.text | cut -d " " -f 2-
|
cat data/lang_phone/train.text | cut -d " " -f 2-
|
||||||
> $lang_dir/transcript_words.txt
|
> $lang_dir/transcript_words.txt
|
||||||
|
# remove the <unk> for transcript_words.txt
|
||||||
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
|
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
|
||||||
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
|
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
|
||||||
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt
|
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt
|
||||||
|
|||||||
@ -7,7 +7,7 @@ https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
|||||||
You can use the following command to start the training:
|
You can use the following command to start the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd egs/librispeech/ASR
|
cd egs/tedlium3/ASR
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir transducer_stateless/exp \
|
--exp-dir transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--max-duration 180 \
|
||||||
--max-duration 250 \
|
--lr-factor 5.0
|
||||||
--lr-factor 2.5
|
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -17,7 +18,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -43,12 +43,13 @@ def greedy_search(
|
|||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device
|
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
@ -84,7 +85,7 @@ def greedy_search(
|
|||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y != blank_id:
|
if y != blank_id and y != unk_id:
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[hyp[-context_size:]], device=device
|
[hyp[-context_size:]], device=device
|
||||||
@ -108,8 +109,9 @@ class Hypothesis:
|
|||||||
# Newly predicted tokens are appended to `ys`.
|
# Newly predicted tokens are appended to `ys`.
|
||||||
ys: List[int]
|
ys: List[int]
|
||||||
|
|
||||||
# The log prob of ys
|
# The log prob of ys.
|
||||||
log_prob: float
|
# It contains only one entry.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
@ -118,7 +120,7 @@ class Hypothesis:
|
|||||||
|
|
||||||
|
|
||||||
class HypothesisList(object):
|
class HypothesisList(object):
|
||||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data:
|
data:
|
||||||
@ -130,11 +132,10 @@ class HypothesisList(object):
|
|||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
# def add(self, ys: List[int], log_prob: float):
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
def add(self, hyp: Hypothesis):
|
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
If `hyp` already exists in `self`, its probability is updated using
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
@ -146,8 +147,10 @@ class HypothesisList(object):
|
|||||||
"""
|
"""
|
||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key]
|
old_hyp = self._data[key] # shallow copy
|
||||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -159,7 +162,8 @@ class HypothesisList(object):
|
|||||||
length_norm:
|
length_norm:
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
number of tokens in it.
|
number of tokens in it.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `log_prob`.
|
||||||
"""
|
"""
|
||||||
if length_norm:
|
if length_norm:
|
||||||
return max(
|
return max(
|
||||||
@ -171,6 +175,9 @@ class HypothesisList(object):
|
|||||||
def remove(self, hyp: Hypothesis) -> None:
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
"""Remove a given hypothesis.
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hyp:
|
hyp:
|
||||||
The hypothesis to be removed from `self`.
|
The hypothesis to be removed from `self`.
|
||||||
@ -181,7 +188,7 @@ class HypothesisList(object):
|
|||||||
assert key in self, f"{key} does not exist"
|
assert key in self, f"{key} does not exist"
|
||||||
del self._data[key]
|
del self._data[key]
|
||||||
|
|
||||||
def filter(self, threshold: float) -> "HypothesisList":
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
@ -189,10 +196,10 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a new HypothesisList containing all hypotheses from `self`
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
that have `log_prob` being greater than the given `threshold`.
|
with `log_prob` being greater than the given `threshold`.
|
||||||
"""
|
"""
|
||||||
ans = HypothesisList()
|
ans = HypothesisList()
|
||||||
for key, hyp in self._data.items():
|
for _, hyp in self._data.items():
|
||||||
if hyp.log_prob > threshold:
|
if hyp.log_prob > threshold:
|
||||||
ans.add(hyp) # shallow copy
|
ans.add(hyp) # shallow copy
|
||||||
return ans
|
return ans
|
||||||
@ -222,6 +229,201 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def run_decoder(
|
||||||
|
ys: List[int],
|
||||||
|
model: Transducer,
|
||||||
|
decoder_cache: Dict[str, torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run the neural decoder model for a given hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys:
|
||||||
|
The current hypothesis.
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
decoder_cache:
|
||||||
|
Cache to save computations.
|
||||||
|
Returns:
|
||||||
|
Return a 1-D tensor of shape (decoder_out_dim,) containing
|
||||||
|
output of `model.decoder`.
|
||||||
|
"""
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
key = "_".join(map(str, ys[-context_size:]))
|
||||||
|
if key in decoder_cache:
|
||||||
|
return decoder_cache[key]
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
|
||||||
|
1, context_size
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_cache[key] = decoder_out
|
||||||
|
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
def run_joiner(
|
||||||
|
key: str,
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
encoder_out_len: torch.Tensor,
|
||||||
|
decoder_out_len: torch.Tensor,
|
||||||
|
joint_cache: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
"""Run the joint network given outputs from the encoder and decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key:
|
||||||
|
A key into the `joint_cache`.
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (1, 1, encoder_out_dim).
|
||||||
|
decoder_out:
|
||||||
|
A tensor of shape (1, 1, decoder_out_dim).
|
||||||
|
encoder_out_len:
|
||||||
|
A tensor with value [1].
|
||||||
|
decoder_out_len:
|
||||||
|
A tensor with value [1].
|
||||||
|
joint_cache:
|
||||||
|
A dict to save computations.
|
||||||
|
Returns:
|
||||||
|
Return a tensor from the output of log-softmax.
|
||||||
|
Its shape is (vocab_size,).
|
||||||
|
"""
|
||||||
|
if key in joint_cache:
|
||||||
|
return joint_cache[key]
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len,
|
||||||
|
decoder_out_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
log_prob = log_prob.squeeze()
|
||||||
|
# Now log_prob is (vocab_size,)
|
||||||
|
|
||||||
|
joint_cache[key] = log_prob
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
beam:
|
||||||
|
Beam size.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
# support only batch_size == 1 for now
|
||||||
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[blank_id] * context_size, device=device
|
||||||
|
).reshape(1, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
B = HypothesisList()
|
||||||
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
|
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||||
|
# fmt: on
|
||||||
|
A = list(B)
|
||||||
|
B = HypothesisList()
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||||
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyp in A],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
current_encoder_out = current_encoder_out.expand(
|
||||||
|
decoder_out.size(0), 1, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
)
|
||||||
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||||
|
|
||||||
|
# topk_hyp_indexes are indexes into `A`
|
||||||
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
for i in range(len(topk_hyp_indexes)):
|
||||||
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[i]
|
||||||
|
if new_token != blank_id and new_token != unk_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_log_prob = topk_log_probs[i]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -247,6 +449,7 @@ def beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = model.device
|
||||||
@ -261,7 +464,12 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -281,58 +489,43 @@ def beam_search(
|
|||||||
|
|
||||||
joint_cache: Dict[str, torch.Tensor] = {}
|
joint_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
|
||||||
# of hypotheses in A
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
y_star = A.get_most_probable()
|
y_star = A.get_most_probable()
|
||||||
A.remove(y_star)
|
A.remove(y_star)
|
||||||
|
|
||||||
cached_key = y_star.key
|
decoder_out = run_decoder(
|
||||||
|
ys=y_star.ys, model=model, decoder_cache=decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
if cached_key not in decoder_cache:
|
key = "_".join(map(str, y_star.ys[-context_size:]))
|
||||||
decoder_input = torch.tensor(
|
key += f"-t-{t}"
|
||||||
[y_star.ys[-context_size:]], device=device
|
log_prob = run_joiner(
|
||||||
).reshape(1, context_size)
|
key=key,
|
||||||
|
model=model,
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
encoder_out=current_encoder_out,
|
||||||
decoder_cache[cached_key] = decoder_out
|
decoder_out=decoder_out,
|
||||||
else:
|
encoder_out_len=encoder_out_len,
|
||||||
decoder_out = decoder_cache[cached_key]
|
decoder_out_len=decoder_out_len,
|
||||||
|
joint_cache=joint_cache,
|
||||||
cached_key += f"-t-{t}"
|
)
|
||||||
if cached_key not in joint_cache:
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out,
|
|
||||||
decoder_out,
|
|
||||||
encoder_out_len,
|
|
||||||
decoder_out_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
|
||||||
log_prob = log_prob.squeeze()
|
|
||||||
# Now log_prob is (vocab_size,)
|
|
||||||
joint_cache[cached_key] = log_prob
|
|
||||||
else:
|
|
||||||
log_prob = joint_cache[cached_key]
|
|
||||||
|
|
||||||
# First, process the blank symbol
|
# First, process the blank symbol
|
||||||
skip_log_prob = log_prob[blank_id]
|
skip_log_prob = log_prob[blank_id]
|
||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
for i, v in zip(indices.tolist(), values.tolist()):
|
for idx in range(values.size(0)):
|
||||||
if i == blank_id:
|
i = indices[idx].item()
|
||||||
|
if i == blank_id or i == unk_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
|
||||||
|
new_log_prob = y_star.log_prob + values[idx]
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
|
|||||||
@ -615,7 +615,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
""" # noqa
|
"""
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
assert embed_dim == embed_dim_to_check
|
assert embed_dim == embed_dim_to_check
|
||||||
@ -635,7 +635,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = 0
|
_start = 0
|
||||||
_end = embed_dim
|
_end = embed_dim
|
||||||
@ -643,7 +643,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
_end = None
|
_end = None
|
||||||
@ -653,7 +653,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = 0
|
_start = 0
|
||||||
_end = embed_dim
|
_end = embed_dim
|
||||||
@ -662,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
_end = embed_dim * 2
|
_end = embed_dim * 2
|
||||||
@ -671,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
k = nn.functional.linear(key, _w, _b)
|
k = nn.functional.linear(key, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim * 2
|
_start = embed_dim * 2
|
||||||
_end = None
|
_end = None
|
||||||
@ -687,12 +687,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
or attn_mask.dtype == torch.float16
|
or attn_mask.dtype == torch.float16
|
||||||
or attn_mask.dtype == torch.uint8
|
or attn_mask.dtype == torch.uint8
|
||||||
or attn_mask.dtype == torch.bool
|
or attn_mask.dtype == torch.bool
|
||||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format( # noqa
|
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||||
attn_mask.dtype
|
attn_mask.dtype
|
||||||
)
|
)
|
||||||
if attn_mask.dtype == torch.uint8:
|
if attn_mask.dtype == torch.uint8:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead." # noqa
|
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||||
)
|
)
|
||||||
attn_mask = attn_mask.to(torch.bool)
|
attn_mask = attn_mask.to(torch.bool)
|
||||||
|
|
||||||
@ -725,7 +725,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
and key_padding_mask.dtype == torch.uint8
|
and key_padding_mask.dtype == torch.uint8
|
||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." # noqa
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
)
|
)
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
@ -760,7 +760,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
# compute attention score
|
# compute attention score
|
||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
matrix_ac = torch.matmul(
|
matrix_ac = torch.matmul(
|
||||||
q_with_bias_u, k
|
q_with_bias_u, k
|
||||||
@ -832,7 +832,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -19,16 +20,16 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
--epoch 14 \
|
--epoch 29 \
|
||||||
--avg 7 \
|
--avg 15 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
--epoch 14 \
|
--epoch 29 \
|
||||||
--avg 7 \
|
--avg 15 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
@ -45,8 +46,8 @@ from typing import Dict, List, Tuple
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -77,7 +78,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=13,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -169,6 +170,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
unk_id=params.unk_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
@ -256,6 +258,10 @@ def decode_one_batch(
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -382,14 +388,18 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
TedLiumAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
@ -413,6 +423,7 @@ def main():
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -439,16 +450,12 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
tedlium = TedLiumAsrDataModule(args)
|
||||||
|
test_cuts = tedlium.test_cuts()
|
||||||
|
test_dl = tedlium.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_sets = ["test"]
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_dl = [test_dl]
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -37,6 +38,7 @@ class Decoder(nn.Module):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
|
unk_id: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -47,6 +49,8 @@ class Decoder(nn.Module):
|
|||||||
Dimension of the input embedding.
|
Dimension of the input embedding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
|
unk_id:
|
||||||
|
The ID of the unk symbol.
|
||||||
context_size:
|
context_size:
|
||||||
Number of previous words to use to predict the next word.
|
Number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
@ -58,6 +62,7 @@ class Decoder(nn.Module):
|
|||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
self.unk_id = unk_id
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
|||||||
@ -120,7 +120,6 @@ class Transducer(nn.Module):
|
|||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
from_log_softmax=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@ -50,7 +50,7 @@ import kaldifeat
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -167,6 +167,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
unk_id=params.unk_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
@ -230,6 +231,7 @@ def main():
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
@ -300,6 +302,10 @@ def main():
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -18,7 +19,7 @@
|
|||||||
"""
|
"""
|
||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/librispeech/ASR
|
cd icefall/egs/tedlium3/ASR
|
||||||
python ./transducer_stateless/test_decoder.py
|
python ./transducer_stateless/test_decoder.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ from decoder import Decoder
|
|||||||
def test_decoder():
|
def test_decoder():
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
|
unk_id = 2
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
context_size = 4
|
context_size = 4
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ def test_decoder():
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
blank_id=blank_id,
|
blank_id=blank_id,
|
||||||
|
unk_id=unk_id,
|
||||||
context_size=context_size,
|
context_size=context_size,
|
||||||
)
|
)
|
||||||
N = 100
|
N = 100
|
||||||
|
|||||||
@ -26,9 +26,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir transducer_stateless/exp \
|
--exp-dir transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--max-duration 180 \
|
||||||
--max-duration 250 \
|
--lr-factor 5.0
|
||||||
--lr-factor 2.5
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +55,8 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
|
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
@ -233,6 +234,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
unk_id=params.unk_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
@ -379,7 +381,9 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
|
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"][: feature.size(0)]
|
texts = batch["supervisions"]["text"][: feature.size(0)]
|
||||||
y = sp.encode(texts, out_type=int)
|
|
||||||
|
unk_id = params.unk_id
|
||||||
|
y = convert_texts_into_ids(texts, unk_id, sp=sp)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -565,6 +569,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user