update aishell-1 recipe with k2.rnnt_loss (#215)

* update aishell-1 recipe with k2.rnnt_loss

* fix flak8 style

* typo

* add pretrained model link to result.md
This commit is contained in:
PF Luo 2022-02-19 15:56:39 +08:00 committed by GitHub
parent 827b9df51a
commit 277cc3f9bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 20 deletions

View File

@ -113,7 +113,7 @@ The best CER we currently have is:
| | test |
|-----|------|
| CER | 5.7 |
| CER | 5.4 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)

View File

@ -45,12 +45,13 @@ You can use the following commands to reproduce our results:
### Aishell training results (Transducer-stateless)
#### 2021-12-29
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
#### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/SG1KV62hRzO5YZswwMQnoQ/>
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
||test|
|--|--|
|CER| 5.7% |
|CER| 5.4% |
You can use the following commands to reproduce our results:

View File

@ -31,7 +31,6 @@ from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
@ -403,12 +402,9 @@ def main():
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)

View File

@ -108,18 +108,18 @@ class Transducer(nn.Module):
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
return loss
return torch.sum(loss)

View File

@ -558,7 +558,8 @@ def run(rank, world_size, args):
oov="<unk>",
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)