mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor beam search and update results. (#177)
This commit is contained in:
parent
273e5fb2f3
commit
f94ff19bfe
@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.76 | 6.97 |
|
||||
| WER | 2.68 | 6.72 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
@ -13,8 +13,8 @@ The WERs are
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.77 | 7.07 | --epoch 30, --avg 13, --max-duration 100 |
|
||||
| beam search (beam size 4) | 2.76 | 6.97 | |
|
||||
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--num-epochs 76 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp-full \
|
||||
--full-libri 1 \
|
||||
@ -32,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/6fnVojoUQTmEJVq1yG34Vw/>
|
||||
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0>
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=36
|
||||
avg=13
|
||||
epoch=71
|
||||
avg=15
|
||||
|
||||
## greedy search
|
||||
./transducer_stateless/decode.py \
|
||||
@ -58,6 +58,9 @@ avg=13
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
You can find a pretrained model by visiting
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10>
|
||||
|
||||
|
||||
#### Conformer encoder + LSTM decoder
|
||||
Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
|
||||
|
@ -118,7 +118,7 @@ class Hypothesis:
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
@ -130,11 +130,10 @@ class HypothesisList(object):
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
# def add(self, ys: List[int], log_prob: float):
|
||||
def add(self, hyp: Hypothesis):
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
@ -159,7 +158,8 @@ class HypothesisList(object):
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
@ -171,6 +171,9 @@ class HypothesisList(object):
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
@ -189,10 +192,10 @@ class HypothesisList(object):
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
that have `log_prob` being greater than the given `threshold`.
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for key, hyp in self._data.items():
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
@ -222,6 +225,93 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def run_decoder(
|
||||
ys: List[int],
|
||||
model: Transducer,
|
||||
decoder_cache: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Run the neural decoder model for a given hypothesis.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
The current hypothesis.
|
||||
model:
|
||||
The transducer model.
|
||||
decoder_cache:
|
||||
Cache to save computations.
|
||||
Returns:
|
||||
Return a 1-D tensor of shape (decoder_out_dim,) containing
|
||||
output of `model.decoder`.
|
||||
"""
|
||||
context_size = model.decoder.context_size
|
||||
key = "_".join(map(str, ys[-context_size:]))
|
||||
if key in decoder_cache:
|
||||
return decoder_cache[key]
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_cache[key] = decoder_out
|
||||
|
||||
return decoder_out
|
||||
|
||||
|
||||
def run_joiner(
|
||||
key: str,
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
encoder_out_len: torch.Tensor,
|
||||
decoder_out_len: torch.Tensor,
|
||||
joint_cache: Dict[str, torch.Tensor],
|
||||
):
|
||||
"""Run the joint network given outputs from the encoder and decoder.
|
||||
|
||||
Args:
|
||||
key:
|
||||
A key into the `joint_cache`.
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (1, 1, encoder_out_dim).
|
||||
decoder_out:
|
||||
A tensor of shape (1, 1, decoder_out_dim).
|
||||
encoder_out_len:
|
||||
A tensor with value [1].
|
||||
decoder_out_len:
|
||||
A tensor with value [1].
|
||||
joint_cache:
|
||||
A dict to save computations.
|
||||
Returns:
|
||||
Return a tensor from the output of log-softmax.
|
||||
Its shape is (vocab_size,).
|
||||
"""
|
||||
if key in joint_cache:
|
||||
return joint_cache[key]
|
||||
|
||||
logits = model.joiner(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
|
||||
joint_cache[key] = log_prob
|
||||
|
||||
return log_prob
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -288,36 +378,21 @@ def beam_search(
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
|
||||
cached_key = y_star.key
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_cache[cached_key] = decoder_out
|
||||
else:
|
||||
decoder_out = decoder_cache[cached_key]
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
decoder_out = run_decoder(
|
||||
ys=y_star.ys, model=model, decoder_cache=decoder_cache
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
joint_cache[cached_key] = log_prob
|
||||
else:
|
||||
log_prob = joint_cache[cached_key]
|
||||
key = "_".join(map(str, y_star.ys[-context_size:]))
|
||||
key += f"-t-{t}"
|
||||
log_prob = run_joiner(
|
||||
key=key,
|
||||
model=model,
|
||||
encoder_out=current_encoder_out,
|
||||
decoder_out=decoder_out,
|
||||
encoder_out_len=encoder_out_len,
|
||||
decoder_out_len=decoder_out_len,
|
||||
joint_cache=joint_cache,
|
||||
)
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
|
Loading…
x
Reference in New Issue
Block a user