mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
updated
This commit is contained in:
parent
755430c29e
commit
c6334ae45d
@ -719,7 +719,7 @@ def greedy_search_batch(
|
|||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -779,6 +779,74 @@ def greedy_search_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_greedy_search_batch(
|
||||||
|
model: nn.Module, encoder_out: torch.Tensor
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
|
len(ans) equals to encoder_out.size(0).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (batch_size, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
y = logits.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v not in (blank_id, unk_id):
|
||||||
|
hyps[i].append(v)
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
ans = [h[context_size:] for h in hyps]
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Hypothesis:
|
class Hypothesis:
|
||||||
# The predicted tokens so far.
|
# The predicted tokens so far.
|
||||||
@ -1019,7 +1087,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -1227,7 +1295,7 @@ def modified_beam_search_lm_rescore(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -1427,7 +1495,7 @@ def modified_beam_search_lm_rescore_LODR(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -2599,7 +2667,6 @@ def modified_beam_search_LODR(
|
|||||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
ys.append(new_token)
|
ys.append(new_token)
|
||||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||||
|
|
||||||
@ -2721,7 +2788,7 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
@ -2863,7 +2930,6 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
new_timestamp = hyp.timestamp[:]
|
new_timestamp = hyp.timestamp[:]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
ys.append(new_token)
|
ys.append(new_token)
|
||||||
new_timestamp.append(t)
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
@ -108,6 +108,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
deprecated_greedy_search_batch,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG,
|
fast_beam_search_nbest_LG,
|
||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
@ -116,7 +117,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -273,8 +274,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -425,14 +425,15 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
params.decoding_method == "greedy_search"
|
# hyp_tokens = greedy_search_batch(
|
||||||
and params.max_sym_per_frame == 1
|
# model=model,
|
||||||
):
|
# encoder_out=encoder_out,
|
||||||
hyp_tokens = greedy_search_batch(
|
# encoder_out_lens=encoder_out_lens,
|
||||||
|
# )
|
||||||
|
hyp_tokens = deprecated_greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -559,9 +560,7 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -594,8 +593,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -656,9 +654,7 @@ def main():
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -690,9 +686,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -719,9 +715,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -780,9 +776,7 @@ def main():
|
|||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
word_table = None
|
word_table = None
|
||||||
decoding_graph = k2.trivial_graph(
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
params.vocab_size - 1, device=device
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user