mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
update code for merging
This commit is contained in:
parent
5e507f3514
commit
676c257ba8
2
.flake8
2
.flake8
@ -4,6 +4,7 @@ statistics=true
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
icefall/diagnostics.py: E501
|
||||
egs/*/ASR/*/conformer.py: E501,
|
||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||
egs/*/ASR/*/optim.py: E501,
|
||||
@ -11,6 +12,7 @@ per-file-ignores =
|
||||
|
||||
# invalid escape sequence (cause by tex formular), W605
|
||||
icefall/utils.py: E501, W605
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
**/data/**,
|
||||
|
55
README.md
55
README.md
@ -12,7 +12,7 @@ for installation.
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
|
||||
for more information.
|
||||
|
||||
We provide 6 recipes at present:
|
||||
We provide some recipes:
|
||||
|
||||
- [yesno][yesno]
|
||||
- [LibriSpeech][librispeech]
|
||||
@ -20,6 +20,9 @@ We provide 6 recipes at present:
|
||||
- [TIMIT][timit]
|
||||
- [TED-LIUM3][tedlium3]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [Aidatatang_200zh][aidatatang_200zh]
|
||||
- [WenetSpeech][wenetspeech]
|
||||
- [Alimeeting][alimeeting]
|
||||
|
||||
### yesno
|
||||
|
||||
@ -124,7 +127,7 @@ The best CER we currently have is:
|
||||
| CER | 4.26 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
@ -217,6 +220,47 @@ and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned R
|
||||
| fast beam search | 10.50 | 10.69 |
|
||||
| modified beam search | 10.40 | 10.51 |
|
||||
|
||||
### Aidatatang_200zh
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
| | Dev | Test |
|
||||
|----------------------|-------|-------|
|
||||
| greedy search | 5.53 | 6.59 |
|
||||
| fast beam search | 5.30 | 6.34 |
|
||||
| modified beam search | 5.27 | 6.33 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
|
||||
|
||||
### WenetSpeech
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy search | 7.80 | 8.75 | 13.49 |
|
||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||
| modified beam search | 7.76 | 8.71 | 13.41 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
|
||||
|
||||
### Alimeeting
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset)
|
||||
|
||||
| | Eval | Test-Net |
|
||||
|----------------------|--------|----------|
|
||||
| greedy search | 31.77 | 34.66 |
|
||||
| fast beam search | 31.39 | 33.02 |
|
||||
| modified beam search | 30.38 | 34.25 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
||||
|
||||
## Deployment with C++
|
||||
|
||||
@ -243,10 +287,17 @@ Please see: [
|
||||
|
19
egs/alimeeting/ASR/README.md
Normal file
19
egs/alimeeting/ASR/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different ASR models trained with Alimeeting (far).
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Transducers
|
||||
|
||||
There are various folders containing the name `transducer` in this folder.
|
||||
The following table lists the differences among them.
|
||||
|
||||
| | Encoder | Decoder | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|-----------------------------|
|
||||
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
We place an additional Conv1d layer right after the input embedding layer.
|
71
egs/alimeeting/ASR/RESULTS.md
Normal file
71
egs/alimeeting/ASR/RESULTS.md
Normal file
@ -0,0 +1,71 @@
|
||||
## Results
|
||||
|
||||
### Alimeeting Char training results (Pruned Transducer Stateless2)
|
||||
|
||||
#### 2022-06-01
|
||||
|
||||
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/378.
|
||||
|
||||
The WERs are
|
||||
| | eval | test | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 31.77 | 34.66 | --epoch 29, --avg 18, --max-duration 100 |
|
||||
| modified beam search (beam size 4) | 30.38 | 33.02 | --epoch 29, --avg 18, --max-duration 100 |
|
||||
| fast beam search (set as default) | 31.39 | 34.25 | --epoch 29, --avg 18, --max-duration 1500|
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 220 \
|
||||
--save-every-n 1000
|
||||
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/AoqgSvZKTZCJhJbOuG3W6g/#scalars
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=29
|
||||
avg=18
|
||||
|
||||
## greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 100
|
||||
|
||||
## modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
## fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
```
|
||||
|
||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_alimeeting_pruned_transducer_stateless2>
|
@ -31,7 +31,6 @@ from lhotse import (
|
||||
set_caching_enabled,
|
||||
)
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
@ -290,13 +289,13 @@ class AlimeetingAsrDataModule:
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
buffer_size=30000,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
|
@ -16,11 +16,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
When training with the L subset, usage:
|
||||
When training with the far data, usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 6 \
|
||||
--avg 3 \
|
||||
--epoch 29 \
|
||||
--avg 18 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 100 \
|
||||
@ -28,8 +28,8 @@ When training with the L subset, usage:
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 6 \
|
||||
--avg 3 \
|
||||
--epoch 29 \
|
||||
--avg 18 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 100 \
|
||||
@ -38,8 +38,8 @@ When training with the L subset, usage:
|
||||
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 6 \
|
||||
--avg 3 \
|
||||
--epoch 29 \
|
||||
--avg 18 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 1500 \
|
||||
@ -59,7 +59,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||
from asr_datamodule import AlimeetingAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -67,6 +67,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -248,7 +249,6 @@ def decode_one_batch(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
@ -442,7 +442,7 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
Aidatatang_200zhAsrDataModule.add_arguments(parser)
|
||||
AlimeetingAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -508,6 +508,13 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
average = average_checkpoints(filenames, device=device)
|
||||
checkpoint = {"model": average}
|
||||
torch.save(
|
||||
checkpoint,
|
||||
"pruned_transducer_stateless2/exp/pretrained_epoch_29_avg_18.pt",
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
@ -528,14 +535,14 @@ def main():
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
|
||||
alimeeting = AlimeetingAsrDataModule(args)
|
||||
|
||||
dev = "dev"
|
||||
dev = "eval"
|
||||
test = "test"
|
||||
|
||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||
os.makedirs(dev)
|
||||
dev_cuts = aidatatang_200zh.valid_cuts()
|
||||
dev_cuts = alimeeting.valid_cuts()
|
||||
export_to_webdataset(
|
||||
dev_cuts,
|
||||
output_path=f"{dev}/shared-%d.tar",
|
||||
@ -544,7 +551,7 @@ def main():
|
||||
|
||||
if not os.path.exists(f"{test}/shared-0.tar"):
|
||||
os.makedirs(test)
|
||||
test_cuts = aidatatang_200zh.test_cuts()
|
||||
test_cuts = alimeeting.test_cuts()
|
||||
export_to_webdataset(
|
||||
test_cuts,
|
||||
output_path=f"{test}/shared-%d.tar",
|
||||
@ -573,8 +580,16 @@ def main():
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
return 1.0 <= c.duration
|
||||
|
||||
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
|
||||
cuts_test_webdataset = cuts_test_webdataset.filter(
|
||||
remove_short_and_long_utt
|
||||
)
|
||||
|
||||
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
|
||||
|
||||
test_sets = ["dev", "test"]
|
||||
test_dl = [dev_dl, test_dl]
|
||||
|
@ -22,7 +22,7 @@ Usage:
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 29 \
|
||||
--avg 19
|
||||
--avg 18
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
@ -32,7 +32,7 @@ you can do:
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aidatatang_200zh/ASR
|
||||
cd /path/to/egs/alimeeting/ASR
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--epoch 9999 \
|
||||
|
@ -16,13 +16,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
Here, the far data is used for training, usage:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method greedy_search \
|
||||
--decoding-method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
@ -31,7 +31,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method modified_beam_search \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
@ -40,7 +40,7 @@ Usage:
|
||||
./pruned_transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method fast_beam_search \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
|
@ -19,26 +19,26 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 2 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 250 \
|
||||
--max-duration 220 \
|
||||
--save-every-n 1000
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 2 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 250 \
|
||||
--max-duration 220 \
|
||||
--save-every-n 1000
|
||||
--use-fp16 True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user