diff --git a/README.md b/README.md
index 7dee1c1d6..38c25900f 100644
--- a/README.md
+++ b/README.md
@@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
| | test-clean | test-other |
|-----|------------|------------|
-| WER | 2.76 | 6.97 |
+| WER | 2.68 | 6.72 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 9f65f56bd..ffeaaae68 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -13,8 +13,8 @@ The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
-| greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
-| beam search (beam size 4) | 2.76 | 6.97 | |
+| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 |
+| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
@@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
- --num-epochs 30 \
+ --num-epochs 76 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp-full \
--full-libri 1 \
@@ -32,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
```
The tensorboard training log can be found at
-
+
The decoding command is:
```
-epoch=36
-avg=13
+epoch=71
+avg=15
## greedy search
./transducer_stateless/decode.py \
@@ -58,6 +58,9 @@ avg=13
--beam-size 4
```
+You can find a pretrained model by visiting
+
+
#### Conformer encoder + LSTM decoder
Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index 989caa802..341c74fab 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -118,7 +118,7 @@ class Hypothesis:
class HypothesisList(object):
- def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
+ def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
@@ -130,11 +130,10 @@ class HypothesisList(object):
self._data = data
@property
- def data(self):
+ def data(self) -> Dict[str, Hypothesis]:
return self._data
- # def add(self, ys: List[int], log_prob: float):
- def add(self, hyp: Hypothesis):
+ def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@@ -159,7 +158,8 @@ class HypothesisList(object):
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
-
+ Returns:
+ Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(
@@ -171,6 +171,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
+ Caution:
+ `self` is modified **in-place**.
+
Args:
hyp:
The hypothesis to be removed from `self`.
@@ -189,10 +192,10 @@ class HypothesisList(object):
Returns:
Return a new HypothesisList containing all hypotheses from `self`
- that have `log_prob` being greater than the given `threshold`.
+ with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
- for key, hyp in self._data.items():
+ for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
@@ -222,6 +225,93 @@ class HypothesisList(object):
return ", ".join(s)
+def run_decoder(
+ ys: List[int],
+ model: Transducer,
+ decoder_cache: Dict[str, torch.Tensor],
+) -> torch.Tensor:
+ """Run the neural decoder model for a given hypothesis.
+
+ Args:
+ ys:
+ The current hypothesis.
+ model:
+ The transducer model.
+ decoder_cache:
+ Cache to save computations.
+ Returns:
+ Return a 1-D tensor of shape (decoder_out_dim,) containing
+ output of `model.decoder`.
+ """
+ context_size = model.decoder.context_size
+ key = "_".join(map(str, ys[-context_size:]))
+ if key in decoder_cache:
+ return decoder_cache[key]
+
+ device = model.device
+
+ decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
+ 1, context_size
+ )
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_cache[key] = decoder_out
+
+ return decoder_out
+
+
+def run_joiner(
+ key: str,
+ model: Transducer,
+ encoder_out: torch.Tensor,
+ decoder_out: torch.Tensor,
+ encoder_out_len: torch.Tensor,
+ decoder_out_len: torch.Tensor,
+ joint_cache: Dict[str, torch.Tensor],
+):
+ """Run the joint network given outputs from the encoder and decoder.
+
+ Args:
+ key:
+ A key into the `joint_cache`.
+ model:
+ The transducer model.
+ encoder_out:
+ A tensor of shape (1, 1, encoder_out_dim).
+ decoder_out:
+ A tensor of shape (1, 1, decoder_out_dim).
+ encoder_out_len:
+ A tensor with value [1].
+ decoder_out_len:
+ A tensor with value [1].
+ joint_cache:
+ A dict to save computations.
+ Returns:
+ Return a tensor from the output of log-softmax.
+ Its shape is (vocab_size,).
+ """
+ if key in joint_cache:
+ return joint_cache[key]
+
+ logits = model.joiner(
+ encoder_out,
+ decoder_out,
+ encoder_out_len,
+ decoder_out_len,
+ )
+
+ # TODO(fangjun): Scale the blank posterior
+ log_prob = logits.log_softmax(dim=-1)
+ # log_prob is (1, 1, 1, vocab_size)
+
+ log_prob = log_prob.squeeze()
+ # Now log_prob is (vocab_size,)
+
+ joint_cache[key] = log_prob
+
+ return log_prob
+
+
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
@@ -288,36 +378,21 @@ def beam_search(
y_star = A.get_most_probable()
A.remove(y_star)
- cached_key = y_star.key
+ decoder_out = run_decoder(
+ ys=y_star.ys, model=model, decoder_cache=decoder_cache
+ )
- if cached_key not in decoder_cache:
- decoder_input = torch.tensor(
- [y_star.ys[-context_size:]], device=device
- ).reshape(1, context_size)
-
- decoder_out = model.decoder(decoder_input, need_pad=False)
- decoder_cache[cached_key] = decoder_out
- else:
- decoder_out = decoder_cache[cached_key]
-
- cached_key += f"-t-{t}"
- if cached_key not in joint_cache:
- logits = model.joiner(
- current_encoder_out,
- decoder_out,
- encoder_out_len,
- decoder_out_len,
- )
-
- # TODO(fangjun): Ccale the blank posterior
-
- log_prob = logits.log_softmax(dim=-1)
- # log_prob is (1, 1, 1, vocab_size)
- log_prob = log_prob.squeeze()
- # Now log_prob is (vocab_size,)
- joint_cache[cached_key] = log_prob
- else:
- log_prob = joint_cache[cached_key]
+ key = "_".join(map(str, y_star.ys[-context_size:]))
+ key += f"-t-{t}"
+ log_prob = run_joiner(
+ key=key,
+ model=model,
+ encoder_out=current_encoder_out,
+ decoder_out=decoder_out,
+ encoder_out_len=encoder_out_len,
+ decoder_out_len=decoder_out_len,
+ joint_cache=joint_cache,
+ )
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]