mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
update aishell-1 recipe with k2.rnnt_loss (#215)
* update aishell-1 recipe with k2.rnnt_loss * fix flak8 style * typo * add pretrained model link to result.md
This commit is contained in:
parent
827b9df51a
commit
277cc3f9bf
@ -113,7 +113,7 @@ The best CER we currently have is:
|
|||||||
|
|
||||||
| | test |
|
| | test |
|
||||||
|-----|------|
|
|-----|------|
|
||||||
| CER | 5.7 |
|
| CER | 5.4 |
|
||||||
|
|
||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||||
|
@ -45,12 +45,13 @@ You can use the following commands to reproduce our results:
|
|||||||
|
|
||||||
|
|
||||||
### Aishell training results (Transducer-stateless)
|
### Aishell training results (Transducer-stateless)
|
||||||
#### 2021-12-29
|
#### 2022-02-18
|
||||||
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
|
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/SG1KV62hRzO5YZswwMQnoQ/>
|
||||||
|
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
|
||||||
|
|
||||||
||test|
|
||test|
|
||||||
|--|--|
|
|--|--|
|
||||||
|CER| 5.7% |
|
|CER| 5.4% |
|
||||||
|
|
||||||
You can use the following commands to reproduce our results:
|
You can use the following commands to reproduce our results:
|
||||||
|
|
||||||
|
@ -31,7 +31,6 @@ from decoder import Decoder
|
|||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
@ -403,12 +402,9 @@ def main():
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||||
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -108,18 +108,18 @@ class Transducer(nn.Module):
|
|||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
|
y_padded = y_padded.to(torch.int64)
|
||||||
|
boundary = torch.zeros(
|
||||||
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
|
)
|
||||||
|
boundary[:, 2] = y_lens
|
||||||
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||||
"Please install a version >= 0.10.0"
|
"Please install a version >= 0.10.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torchaudio.functional.rnnt_loss(
|
loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
|
||||||
logits=logits,
|
|
||||||
targets=y_padded,
|
|
||||||
logit_lengths=x_lens,
|
|
||||||
target_lengths=y_lens,
|
|
||||||
blank=blank_id,
|
|
||||||
reduction="sum",
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
return torch.sum(loss)
|
||||||
|
@ -558,7 +558,8 @@ def run(rank, world_size, args):
|
|||||||
oov="<unk>",
|
oov="<unk>",
|
||||||
)
|
)
|
||||||
|
|
||||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
# params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||||
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user