mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
add fast beam search for decoding
This commit is contained in:
parent
ff855ff821
commit
cc68bc256b
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#### 2022-03-21
|
#### 2022-03-21
|
||||||
|
|
||||||
Using the codes from this PR.
|
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/261.
|
||||||
|
|
||||||
The WERs are
|
The WERs are
|
||||||
|
|
||||||
@ -62,6 +62,18 @@ avg=13
|
|||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
## fast beam search
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
```
|
```
|
||||||
|
|
||||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tedlium3_pruned_transducer_stateless>
|
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tedlium3_pruned_transducer_stateless>
|
||||||
@ -85,6 +97,7 @@ The WERs are
|
|||||||
| greedy search | 7.19 | 6.70 | --epoch 29, --avg 11, --max-duration 100 |
|
| greedy search | 7.19 | 6.70 | --epoch 29, --avg 11, --max-duration 100 |
|
||||||
| beam search (beam size 4) | 7.02 | 6.36 | --epoch 29, --avg 11, --max-duration 100 |
|
| beam search (beam size 4) | 7.02 | 6.36 | --epoch 29, --avg 11, --max-duration 100 |
|
||||||
| modified beam search (beam size 4) | 6.91 | 6.33 | --epoch 29, --avg 11, --max-duration 100 |
|
| modified beam search (beam size 4) | 6.91 | 6.33 | --epoch 29, --avg 11, --max-duration 100 |
|
||||||
|
| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 11, --max-duration 1500|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -18,14 +18,100 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
from icefall.decode import one_best_decoding
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
|
|
||||||
|
B, T, C = encoder_out.shape
|
||||||
|
|
||||||
|
config = k2.RnntDecodingConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
decoder_history_len=context_size,
|
||||||
|
beam=beam,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
max_states=max_states,
|
||||||
|
)
|
||||||
|
individual_streams = []
|
||||||
|
for i in range(B):
|
||||||
|
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||||
|
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# shape is a RaggedShape of shape (B, context)
|
||||||
|
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||||
|
shape, contexts = decoding_streams.get_contexts()
|
||||||
|
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||||
|
contexts = contexts.to(torch.int64)
|
||||||
|
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||||
|
decoder_out = model.decoder(contexts, need_pad=False)
|
||||||
|
# current_encoder_out is of shape
|
||||||
|
# (shape.NumElements(), 1, encoder_out_dim)
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||||
|
)
|
||||||
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
decoding_streams.advance(log_probs)
|
||||||
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
best_path = one_best_decoding(lattice)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
new_hyps = []
|
||||||
|
for hyp in hyps:
|
||||||
|
hyp = [idx for idx in hyp if idx != unk_id]
|
||||||
|
new_hyps.append(hyp)
|
||||||
|
return new_hyps
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""Greedy search for a single utterance.
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
@ -98,6 +184,65 @@ def greedy_search(
|
|||||||
return hyp
|
return hyp
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search_batch(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list integers containing the decoded results.
|
||||||
|
len(ans) equals to encoder_out.size(0).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (batch_size, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
||||||
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v != blank_id and v != unk_id:
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps]
|
||||||
|
decoder_input = torch.tensor(decoder_input, device=device)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
|
ans = [h[context_size:] for h in hyps]
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hypothesis:
|
class Hypothesis:
|
||||||
# The predicted tokens so far.
|
# The predicted tokens so far.
|
||||||
@ -132,8 +277,10 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
def add(self, hyp: Hypothesis) -> None:
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
If `hyp` already exists in `self`, its probability is updated using
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
`log-sum-exp` with the existed one.
|
`log-sum-exp` with the existed one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hyp:
|
hyp:
|
||||||
The hypothesis to be added.
|
The hypothesis to be added.
|
||||||
@ -150,6 +297,7 @@ class HypothesisList(object):
|
|||||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||||
"""Get the most probable hypothesis, i.e., the one with
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
the largest `log_prob`.
|
the largest `log_prob`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
length_norm:
|
length_norm:
|
||||||
If True, the `log_prob` of a hypothesis is normalized by the
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
@ -166,8 +314,10 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
def remove(self, hyp: Hypothesis) -> None:
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
"""Remove a given hypothesis.
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
`self` is modified **in-place**.
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hyp:
|
hyp:
|
||||||
The hypothesis to be removed from `self`.
|
The hypothesis to be removed from `self`.
|
||||||
@ -180,8 +330,10 @@ class HypothesisList(object):
|
|||||||
|
|
||||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a new HypothesisList containing all hypotheses from `self`
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
with `log_prob` being greater than the given `threshold`.
|
with `log_prob` being greater than the given `threshold`.
|
||||||
@ -223,6 +375,7 @@ def modified_beam_search(
|
|||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
@ -324,7 +477,9 @@ def beam_search(
|
|||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
|
|
||||||
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
|
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
@ -346,7 +501,9 @@ def beam_search(
|
|||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device
|
[blank_id] * context_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
@ -383,7 +540,9 @@ def beam_search(
|
|||||||
|
|
||||||
if cached_key not in decoder_cache:
|
if cached_key not in decoder_cache:
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[y_star.ys[-context_size:]], device=device
|
[y_star.ys[-context_size:]],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
@ -397,7 +556,7 @@ def beam_search(
|
|||||||
current_encoder_out, decoder_out.unsqueeze(1)
|
current_encoder_out, decoder_out.unsqueeze(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Cache the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
@ -409,7 +568,7 @@ def beam_search(
|
|||||||
|
|
||||||
# First, process the blank symbol
|
# First, process the blank symbol
|
||||||
skip_log_prob = log_prob[blank_id]
|
skip_log_prob = log_prob[blank_id]
|
||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||||
@ -421,9 +580,8 @@ def beam_search(
|
|||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
A.add(
|
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||||
Hypothesis(ys=new_ys, log_prob=torch.tensor(new_log_prob))
|
|
||||||
)
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
# than the most probable in A
|
# than the most probable in A
|
||||||
A_most_probable = A.get_most_probable()
|
A_most_probable = A.get_most_probable()
|
||||||
|
@ -35,7 +35,7 @@ Usage:
|
|||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless/decode.py \
|
||||||
--epoch 29 \
|
--epoch 29 \
|
||||||
--avg 13 \
|
--avg 13 \
|
||||||
@ -43,20 +43,37 @@ Usage:
|
|||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
"""
|
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch 29 \
|
||||||
|
--avg 13 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import TedLiumAsrDataModule
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -84,6 +101,7 @@ def get_parser():
|
|||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -115,6 +133,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,8 +141,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
beam_search""",
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -216,6 +262,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -238,6 +285,9 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -256,36 +306,72 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
if params.decoding_method == "fast_beam_search":
|
||||||
# fmt: off
|
hyp_tokens = fast_beam_search(
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
model=model,
|
||||||
# fmt: on
|
decoding_graph=decoding_graph,
|
||||||
if params.decoding_method == "greedy_search":
|
encoder_out=encoder_out,
|
||||||
hyp = greedy_search(
|
encoder_out_lens=encoder_out_lens,
|
||||||
model=model,
|
beam=params.beam,
|
||||||
encoder_out=encoder_out_i,
|
max_contexts=params.max_contexts,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyp = beam_search(
|
hyps.append(hyp.split())
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
elif (
|
||||||
)
|
params.decoding_method == "greedy_search"
|
||||||
elif params.decoding_method == "modified_beam_search":
|
and params.max_sym_per_frame == 1
|
||||||
hyp = modified_beam_search(
|
):
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
hyp_tokens = greedy_search_batch(
|
||||||
)
|
model=model,
|
||||||
else:
|
encoder_out=encoder_out,
|
||||||
raise ValueError(
|
)
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
for hyp in sp.decode(hyp_tokens):
|
||||||
)
|
hyps.append(hyp.split())
|
||||||
hyps.append(sp.decode(hyp).split())
|
else:
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -293,6 +379,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -305,6 +392,9 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -333,6 +423,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -412,12 +503,17 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
@ -461,6 +557,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -480,6 +581,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -65,6 +65,7 @@ class Decoder(nn.Module):
|
|||||||
self.unk_id = unk_id
|
self.unk_id = unk_id
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = nn.Conv1d(
|
self.conv = nn.Conv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=embedding_dim,
|
||||||
|
@ -130,6 +130,7 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user