mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update codes
This commit is contained in:
parent
d4e0baf14d
commit
6772a3cf84
@ -18,7 +18,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -141,7 +140,7 @@ class HypothesisList(object):
|
|||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key]
|
old_hyp = self._data[key]
|
||||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
old_hyp.log_prob = torch.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -211,6 +210,106 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
beam:
|
||||||
|
Beam size.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
# support only batch_size == 1 for now
|
||||||
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
B = HypothesisList()
|
||||||
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
|
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||||
|
# fmt: on
|
||||||
|
A = list(B)
|
||||||
|
B = HypothesisList()
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||||
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyp in A],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
current_encoder_out = current_encoder_out.expand(
|
||||||
|
decoder_out.size(0), 1, 1, -1
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
)
|
||||||
|
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||||
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
|
|
||||||
|
# now logits is of shape (num_hyps, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||||
|
|
||||||
|
# topk_hyp_indexes are indexes into `A`
|
||||||
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
for i in range(len(topk_hyp_indexes)):
|
||||||
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[i]
|
||||||
|
if new_token != blank_id and new_token != unk_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_log_prob = topk_log_probs[i]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
@ -47,7 +47,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import TedLiumAsrDataModule
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -105,6 +105,7 @@ def get_parser():
|
|||||||
help="""Possible values are:
|
help="""Possible values are:
|
||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -262,6 +263,10 @@ def decode_one_batch(
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -398,6 +403,7 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user