diff --git a/README.md b/README.md
index 28c9b6ce4..214e85ad0 100644
--- a/README.md
+++ b/README.md
@@ -113,7 +113,7 @@ The best CER we currently have is:
| | test |
|-----|------|
-| CER | 5.7 |
+| CER | 5.4 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 61f7e500e..688e0f60c 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -45,12 +45,13 @@ You can use the following commands to reproduce our results:
### Aishell training results (Transducer-stateless)
-#### 2021-12-29
-(Pingfeng Luo) : The tensorboard log for training is available at
+#### 2022-02-18
+(Pingfeng Luo) : The tensorboard log for training is available at
+And pretrained model is available at
||test|
|--|--|
-|CER| 5.7% |
+|CER| 5.4% |
You can use the following commands to reproduce our results:
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index f27e4cdcf..a7b030fa5 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -31,7 +31,6 @@ from decoder import Decoder
from joiner import Joiner
from model import Transducer
-from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
@@ -403,12 +402,9 @@ def main():
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
- graph_compiler = CharCtcTrainingGraphCompiler(
- lexicon=lexicon,
- device=device,
- )
- params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ # params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 2f0f9a183..0322edeed 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -108,18 +108,18 @@ class Transducer(nn.Module):
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (x.size(0), 4), dtype=torch.int64, device=x.device
+ )
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = x_lens
+
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
- loss = torchaudio.functional.rnnt_loss(
- logits=logits,
- targets=y_padded,
- logit_lengths=x_lens,
- target_lengths=y_lens,
- blank=blank_id,
- reduction="sum",
- )
+ loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
- return loss
+ return torch.sum(loss)
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 0c180b260..b562f9dd4 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -558,7 +558,8 @@ def run(rank, world_size, args):
oov="",
)
- params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ # params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)