mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
update aishell-1 recipe with k2.rnnt_loss (#215)
* update aishell-1 recipe with k2.rnnt_loss * fix flak8 style * typo * add pretrained model link to result.md
This commit is contained in:
parent
827b9df51a
commit
277cc3f9bf
@ -113,7 +113,7 @@ The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 5.7 |
|
||||
| CER | 5.4 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||
|
@ -45,12 +45,13 @@ You can use the following commands to reproduce our results:
|
||||
|
||||
|
||||
### Aishell training results (Transducer-stateless)
|
||||
#### 2021-12-29
|
||||
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
|
||||
#### 2022-02-18
|
||||
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/SG1KV62hRzO5YZswwMQnoQ/>
|
||||
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
|
||||
|
||||
||test|
|
||||
|--|--|
|
||||
|CER| 5.7% |
|
||||
|CER| 5.4% |
|
||||
|
||||
You can use the following commands to reproduce our results:
|
||||
|
||||
|
@ -31,7 +31,6 @@ from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
@ -403,12 +402,9 @@ def main():
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
@ -108,18 +108,18 @@ class Transducer(nn.Module):
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
)
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
)
|
||||
loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
|
||||
|
||||
return loss
|
||||
return torch.sum(loss)
|
||||
|
@ -558,7 +558,8 @@ def run(rank, world_size, args):
|
||||
oov="<unk>",
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user