mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add multi round nbest rescoer
This commit is contained in:
parent
0669aa8ab9
commit
27c46b66ee
@ -6,6 +6,7 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -23,10 +24,12 @@ from icefall.decode import (
|
||||
nbest_decoding,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_attention_decoder_v2,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.score_estimator import ScoreEstimator
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -62,6 +65,7 @@ def get_parser():
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# "exp_dir": Path("exp/conformer_ctc"),
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
@ -86,10 +90,17 @@ def get_params() -> AttributeDict:
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
"method": "attention-decoder-v2",
|
||||
# "method": "nbest-rescoring",
|
||||
# "method": "attention-decoder",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 100,
|
||||
# top_k is used when method is "attention-decoder-v2"
|
||||
"top_k" : 10,
|
||||
# dump_best_matching_feature is used when method is
|
||||
# "attention-decoder-v2" to dump feature to train a special model
|
||||
"dump_best_matching_feature": False,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -104,6 +115,7 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
rescore_est_model: nn.Module,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -135,12 +147,16 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
batch_idx:
|
||||
The batch index of current batch.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
rescore_est_model:
|
||||
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -242,15 +258,24 @@ def decode_one_batch(
|
||||
best_path_dict = rescore_with_attention_decoder_v2(
|
||||
lattice=rescored_lattice,
|
||||
batch_idx=batch_idx,
|
||||
dump_best_matching_feature=params.dump_feature,
|
||||
dump_best_matching_feature=params.dump_best_matching_feature,
|
||||
num_paths=params.num_paths,
|
||||
top_k=params.top_k,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
rescore_est_model=rescore_est_model,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
if params.dump_best_matching_feature:
|
||||
if best_path_dict.size()[0] > 0:
|
||||
save_dir = params.exp_dir / f"rescore/feat"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
file_name = save_dir / f"feats-epoch-{batch_idx}.pt"
|
||||
torch.save(best_path_dict, file_name)
|
||||
return dict()
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
@ -270,6 +295,7 @@ def decode_dataset(
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
rescore_est_model: nn.Module,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
@ -289,6 +315,8 @@ def decode_dataset(
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
rescore_est_model:
|
||||
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -303,7 +331,7 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
tot_batches = len(dl)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -314,11 +342,12 @@ def decode_dataset(
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
batch_idx,
|
||||
batch_idx=batch_idx,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
rescore_est_model=rescore_est_model,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
@ -334,9 +363,8 @@ def decode_dataset(
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"batch {batch_idx}/{tot_batches}, cuts processed until now is "
|
||||
f"{num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
@ -430,6 +458,7 @@ def main():
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
"attention-decoder-v2",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
@ -453,7 +482,7 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder", "attention-decoder-v2"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
@ -465,6 +494,15 @@ def main():
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
if params.method == "attention-decoder-v2":
|
||||
rescore_est_model = ScoreEstimator()
|
||||
rescore_est_model.load_state_dict(
|
||||
torch.load(f"{params.exp_dir}/rescore/epoch-19.pt",
|
||||
map_location="cpu")
|
||||
)
|
||||
rescore_est_model.to(device)
|
||||
else:
|
||||
rescore_est_model = None
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -504,6 +542,7 @@ def main():
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
if test_set == "test-other": continue
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -513,6 +552,7 @@ def main():
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
rescore_est_model=rescore_est_model,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .nbest import Nbest
|
||||
from .utils import get_best_matching_stats
|
||||
|
||||
from .score_estimator import ScoreEstimator
|
||||
|
||||
def _intersect_device(
|
||||
a_fsas: k2.Fsa,
|
||||
@ -752,20 +757,25 @@ def rescore_nbest_with_attention_decoder(
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each sequence in the lattice.
|
||||
A Nbest with all of the scores on fsa arcs updated with attention scores.
|
||||
"""
|
||||
num_seqs = nbest.shape.Dim0()
|
||||
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous())
|
||||
num_paths = nbest.shape.num_elements()
|
||||
# token shape [utt][path][state][arc]
|
||||
token_shape = k2.ragged.compose_ragged_shapes(nbest.shape, nbest.fsa.arcs.shape())
|
||||
|
||||
token_seq = k2.RaggedInt(token_shape, nbest.fsa.labels.contiguous())
|
||||
|
||||
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
|
||||
# removed it when generating nbest list
|
||||
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
||||
# token seq shape [utt][path][token]
|
||||
token_seq = k2.ragged.remove_axis(token_seq, 2)
|
||||
# token seq shape [utt][token]
|
||||
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||
|
||||
token_ids = k2.ragged.to_list(token_seq)
|
||||
|
||||
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long)
|
||||
path_to_seq_map_long = nbest.shape.row_ids(1).to(torch.long)
|
||||
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
@ -780,25 +790,27 @@ def rescore_nbest_with_attention_decoder(
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == num_seqs
|
||||
assert nll.shape[0] == num_paths
|
||||
|
||||
attention_scores = torch.zeros(
|
||||
nbest.fsas.labels().size()[0],
|
||||
nbest.fsa.scores.size()[0],
|
||||
dtype=torch.float32,
|
||||
device=nbest.device
|
||||
device=nbest.fsa.device
|
||||
)
|
||||
|
||||
start_index = 0
|
||||
for i in range(num_seqs):
|
||||
for i in range(num_paths):
|
||||
# Plus 1 to fill the score of final arc
|
||||
tokens_num = len(tokens_ids[i]) + 1
|
||||
attention_scores[start_index: start_index + tokens_num] =
|
||||
tokens_num = 0 if len(token_ids[i]) == 0 else len(token_ids[i]) + 1
|
||||
attention_scores[start_index: start_index + tokens_num] =\
|
||||
nll[i][0: tokens_num]
|
||||
start_index += tokens_num
|
||||
|
||||
fsas = nbest.fsas.clone()
|
||||
fsas.score = attention_scores
|
||||
return Nbest(fsas, nbest.shape.clone())
|
||||
fsas = nbest.fsa.clone()
|
||||
fsas.scores = attention_scores
|
||||
return Nbest(fsas, nbest.shape)
|
||||
|
||||
|
||||
def rescore_with_attention_decoder_v2(
|
||||
@ -810,9 +822,10 @@ def rescore_with_attention_decoder_v2(
|
||||
model: nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
rescore_est_model: nn.Module,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
score is used as the decoding output.
|
||||
@ -820,6 +833,11 @@ def rescore_with_attention_decoder_v2(
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
batch_idx:
|
||||
The batch index currently processed.
|
||||
dump_best_matching_feature:
|
||||
Whether to dump best matching feature, only for preparing training
|
||||
data for attention-decoder-v2
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
@ -831,6 +849,8 @@ def rescore_with_attention_decoder_v2(
|
||||
Its shape is `[T, N, C]`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape [N, T].
|
||||
rescore_est_model:
|
||||
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
@ -841,23 +861,24 @@ def rescore_with_attention_decoder_v2(
|
||||
best decoding path for each sequence in the lattice.
|
||||
"""
|
||||
nbest = generate_nbest_list(lattice, num_paths)
|
||||
# Now we have nbest with scores
|
||||
nbest = nbest.intersect(lattice)
|
||||
|
||||
if dump_best_matching_feature:
|
||||
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
|
||||
return torch.empty(0)
|
||||
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
|
||||
|
||||
rescored_nbest_k = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_k,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
eos_id=eos_id
|
||||
)
|
||||
stats_tensor = get_best_matching_stats(
|
||||
rescored_nbest_k,
|
||||
nbest_q,
|
||||
max_order=3
|
||||
max_order=5
|
||||
)
|
||||
rescored_nbest_q = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_q,
|
||||
@ -865,41 +886,132 @@ def rescore_with_attention_decoder_v2(
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id
|
||||
)
|
||||
merge_tensor = torch.cat(
|
||||
(stats_tensor, rescored_nbest_q.fsa.scores.clone().view(-1, 1)),
|
||||
dim=1
|
||||
)
|
||||
return merge_tensor
|
||||
|
||||
if nbest.fsa.arcs.dim0() >= 2 * top_k and nbest.fsa.arcs.num_elements() != 0:
|
||||
nbest_topk, nbest_remain = nbest.split(k=top_k)
|
||||
|
||||
am_scores = nbest_topk.fsa.scores - nbest_topk.fsa.lm_scores
|
||||
lm_scores = nbest_topk.fsa.lm_scores
|
||||
|
||||
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_topk,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
# return feature & label or dump to file
|
||||
)
|
||||
|
||||
nbest_topk, nbest_remain = nbest.split(k=top_k)
|
||||
stats_tensor = get_best_matching_stats(
|
||||
rescored_nbest_topk,
|
||||
nbest_remain,
|
||||
max_order=5
|
||||
)
|
||||
|
||||
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_topk,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
stats_tensor = get_best_matching_stats(
|
||||
rescored_nbest_topk,
|
||||
nbest_remain,
|
||||
max_order=3
|
||||
)
|
||||
# run rescore estimation model to get the mean and var of each token
|
||||
mean, var = rescore_est_model(stats_tensor)
|
||||
# calculate nbest_remain estimated score and select topk
|
||||
nbest_remain_topk = nbest_remain.top_k(k=top_k)
|
||||
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_remain_topk,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
best_path_dict=get_best_path_from_nbests(
|
||||
rescored_nbest_topk,
|
||||
rescored_nbest_remain_topk,
|
||||
)
|
||||
# run rescore estimation model to get the mean and var of each token
|
||||
mean, var = rescore_est_model(stats_tensor)
|
||||
|
||||
# mean_shape [utt][path][state][arcs]
|
||||
mean_shape = k2.ragged.compose_ragged_shapes(
|
||||
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
|
||||
# mean_shape [utt][path][arcs]
|
||||
mean_shape = k2.ragged.remove_axis(mean_shape, 2)
|
||||
ragged_mean = k2.RaggedFloat(mean_shape, mean.contiguous())
|
||||
# path mean shape [utt][path]
|
||||
path_mean = k2.ragged.sum_per_sublist(ragged_mean)
|
||||
|
||||
# var_shape [utt][path][state][arcs]
|
||||
var_shape = k2.ragged.compose_ragged_shapes(
|
||||
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
|
||||
# var_shape [utt][path][arcs]
|
||||
var_shape = k2.ragged.remove_axis(var_shape, 2)
|
||||
ragged_var = k2.RaggedFloat(var_shape, var.contiguous())
|
||||
# path var shape [utt][path]
|
||||
path_var = k2.ragged.sum_per_sublist(ragged_var)
|
||||
|
||||
# tot_scores() shape [utt][path]
|
||||
# path_score with elements numbers equals numbers of paths
|
||||
# !!! Note: This is right only when utt equals to 1
|
||||
path_scores = nbest_remain.total_scores().values()
|
||||
best_score = torch.max(rescored_nbest_topk.total_scores().values())
|
||||
est_scores = 1 - 1/2 * (
|
||||
1 + torch.erf(
|
||||
(best_score - path_mean) / torch.sqrt(2 * path_var)
|
||||
)
|
||||
)
|
||||
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
|
||||
|
||||
# calculate nbest_remain estimated score and select topk
|
||||
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
|
||||
remain_am_scores = nbest_remain_topk.fsa.scores - nbest_remain_topk.fsa.lm_scores
|
||||
remain_lm_scores = nbest_remain_topk.fsa.lm_scores
|
||||
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest_remain_topk,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
# !!! Note: This is right only when utt equals to 1
|
||||
merge_fsa = k2.cat([rescored_nbest_topk.fsa, rescored_nbest_remain_topk.fsa])
|
||||
merge_row_ids = torch.zeros(
|
||||
merge_fsa.arcs.dim0(),
|
||||
dtype=torch.int32,
|
||||
device=merge_fsa.device
|
||||
)
|
||||
rescore_nbest = Nbest(
|
||||
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
|
||||
)
|
||||
|
||||
attention_scores = rescore_nbest.fsa.scores
|
||||
am_scores = torch.cat((am_scores, remain_am_scores))
|
||||
lm_scores = torch.cat((lm_scores, remain_lm_scores))
|
||||
else:
|
||||
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
|
||||
lm_scores = nbest.fsa.lm_scores
|
||||
rescore_nbest = rescore_nbest_with_attention_decoder(
|
||||
nbest=nbest,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id
|
||||
)
|
||||
attention_scores = rescore_nbest.fsa.scores
|
||||
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = (
|
||||
am_scores
|
||||
+ n_scale * lm_scores
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
rescore_nbest.fsa.scores = tot_scores
|
||||
# ragged tot scores shape [utt][path]
|
||||
ragged_tot_scores = rescore_nbest.total_scores()
|
||||
|
||||
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||
|
||||
best_fsas = k2.index_fsa(rescore_nbest.fsa, argmax_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_fsas
|
||||
return ans
|
||||
|
||||
|
||||
@ -920,49 +1032,90 @@ def generate_nbest_list(
|
||||
that represent the same word sequences, the number of paths
|
||||
in different sequences may not be equal.
|
||||
Return:
|
||||
Return an Nbest object. Note the returned FSAs don't have epsilon
|
||||
self-loops.
|
||||
Return an Nbest object.
|
||||
'''
|
||||
assert len(lats.shape) == 3
|
||||
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||
|
||||
# Seqs is a k2.RaggedInt sharing the same shape as `paths`.
|
||||
# Note that it also contains 0s and -1s.
|
||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
# Its axes are [seq][path][word_id]
|
||||
if aux_labels:
|
||||
# if aux_labels enable, seqs contains word_id
|
||||
assert hasattr(lats, "aux_labels")
|
||||
seqs = k2.index(lats.aux_labels, paths)
|
||||
else:
|
||||
# CAUTION: We use `phones` instead of `tokens` here because
|
||||
# :func:`compile_HLG` uses `phones`
|
||||
#
|
||||
# Note: compile_HLG is from k2-fsa/snowfall
|
||||
assert hasattr(lats, 'phones')
|
||||
word_seq = k2.index(lats.aux_labels, path)
|
||||
|
||||
assert not hasattr(lats, 'tokens')
|
||||
lats.tokens = lats.phones
|
||||
seqs = k2.index(lats.tokens, paths)
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||
|
||||
# Remove epsilons (0s) and -1 from word_seqs
|
||||
seqs = k2.ragged.remove_values_leq(seqs, 0)
|
||||
# Remove paths that has identical word sequences.
|
||||
#
|
||||
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
||||
# except that there are no repeated paths with the same word_seq
|
||||
# within a sequence.
|
||||
#
|
||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||
# multiplicities of each path.
|
||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||
#
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
# to the input path index.
|
||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||
)
|
||||
|
||||
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id].
|
||||
# But then number of pathsin each sequence may be different.
|
||||
unique_seqs, _, _ = k2.ragged.unique_sequences(
|
||||
seqs, need_num_repeats=False, need_new2old_indexes=False)
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
|
||||
# path_to_seq_map is a 1-D torch.Tensor.
|
||||
# path_to_seq_map[i] is the seq to which the i-th path
|
||||
# belongs.
|
||||
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||
|
||||
# Remove the seq axis.
|
||||
# Now unique_word_seqs has only two axes [path][word_id]
|
||||
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0)
|
||||
# Now unique_word_seq has only two axes [path][word]
|
||||
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||
|
||||
fsas = k2.linear_fsa(unique_seqs)
|
||||
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||
|
||||
return Nbest(fsa=fsas, shape=seq_to_path_shape)
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
# k2.compose() currently does not support b_to_a_map. To void
|
||||
# replicating `lats`, we use k2.intersect_device here.
|
||||
#
|
||||
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
|
||||
# need to invert it here.
|
||||
inv_lattice = k2.invert(lats)
|
||||
|
||||
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
|
||||
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
|
||||
|
||||
# Remove its `aux_labels` since it is not needed in the
|
||||
# following computation
|
||||
# del inv_lattice.aux_labels
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
|
||||
path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_seq_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
# path_lattice now has token IDs as `labels` and word IDS as aux_labels.
|
||||
path_lattice = k2.invert(path_lattice)
|
||||
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
# replace labels with tokens to remove repeat token IDs.
|
||||
path_lattice.labels = path_lattice.tokens
|
||||
|
||||
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
|
||||
|
||||
n_best = k2.remove_epsilon(n_best)
|
||||
|
||||
n_best = k2.top_sort(k2.connect(n_best))
|
||||
|
||||
# now we have nbest lists with am_scores and lm_scores
|
||||
return Nbest(fsa=n_best, shape=seq_to_path_shape)
|
||||
|
||||
|
@ -82,8 +82,11 @@ class Nbest(object):
|
||||
|
||||
one_best = k2.remove_epsilon(one_best)
|
||||
|
||||
one_best = k2.top_sort(k2.connect(one_best))
|
||||
|
||||
return Nbest(fsa=one_best, shape=self.shape)
|
||||
|
||||
|
||||
def total_scores(self) -> k2.RaggedFloat:
|
||||
'''Get total scores of the FSAs in this Nbest.
|
||||
|
||||
@ -100,7 +103,7 @@ class Nbest(object):
|
||||
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
||||
return k2.RaggedFloat(self.shape, scores.float())
|
||||
|
||||
def top_k(self, k: int) -> 'Nbest':
|
||||
def top_k(self, k: int, scores: k2.RaggedFloat = None) -> 'Nbest':
|
||||
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
||||
in that each sequence (i.e., utterance) has the same number of
|
||||
paths (k).
|
||||
@ -113,10 +116,14 @@ class Nbest(object):
|
||||
Args:
|
||||
k:
|
||||
Number of paths in each utterance.
|
||||
scores:
|
||||
The scores using to select top-k.
|
||||
Returns:
|
||||
Return a new Nbest with a regular shape.
|
||||
'''
|
||||
ragged_scores = self.total_scores()
|
||||
ragged_scores = scores
|
||||
if ragged_scores is None:
|
||||
ragged_scores = self.total_scores()
|
||||
|
||||
# indexes contains idx01's for self.shape
|
||||
# ragged_scores.values()[indexes] is sorted
|
||||
@ -140,6 +147,7 @@ class Nbest(object):
|
||||
|
||||
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
|
||||
dim1=k)
|
||||
top_k_shape = top_k_shape.to(top_k_fsas.device)
|
||||
return Nbest(top_k_fsas, top_k_shape)
|
||||
|
||||
|
||||
@ -163,7 +171,7 @@ class Nbest(object):
|
||||
# indexes contains idx01's for self.shape
|
||||
indexes = torch.arange(
|
||||
self.shape.num_elements(), dtype=torch.int32,
|
||||
device=self.shape.device
|
||||
device=self.fsa.device
|
||||
)
|
||||
|
||||
if sort:
|
||||
@ -176,9 +184,12 @@ class Nbest(object):
|
||||
|
||||
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
||||
|
||||
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1)
|
||||
padded_indexes = k2.ragged.pad(ragged_indexes,
|
||||
value=-1)
|
||||
|
||||
# Select the idx01's of top-k paths of each utterance
|
||||
max_num_fsa = padded_indexes.size()[1]
|
||||
|
||||
first_indexes = padded_indexes[:, :k].flatten().contiguous()
|
||||
|
||||
# Remove the padding elements
|
||||
|
188
icefall/score_estimator.py
Normal file
188
icefall/score_estimator.py
Normal file
@ -0,0 +1,188 @@
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.utils import (
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path: Path,
|
||||
model: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
files = sorted(glob.glob(f"{path}/*.pt"))
|
||||
if model == 'train':
|
||||
self.files = files[0: int(len(files) * 0.8)]
|
||||
elif model == 'dev':
|
||||
self.files = files[int(len(files) * 0.8): int(len(files) * 0.9)]
|
||||
elif mode == 'test':
|
||||
self.files = files[int(len(files) * 0.9):]
|
||||
|
||||
def __getitem__(self, index) -> torch.Tensor:
|
||||
return torch.load(self.files[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.files)
|
||||
|
||||
|
||||
class DatasetCollateFunc:
|
||||
def __call__(self, batch: List) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = torch.cat(batch)
|
||||
return (x[:, 0:5], x[:, 5])
|
||||
|
||||
|
||||
class ScoreEstimator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 5,
|
||||
hidden_dim: int = 20,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding = nn.Linear(
|
||||
in_features=input_dim,
|
||||
out_features=hidden_dim
|
||||
)
|
||||
self.output = nn.Linear(
|
||||
in_features=hidden_dim,
|
||||
out_features=2
|
||||
)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self.embedding(x)
|
||||
x = self.sigmod(x)
|
||||
x = self.output(x)
|
||||
mean, var = x[:, 0], x[:, 1]
|
||||
var = torch.exp(var)
|
||||
return mean, var
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-dim",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Dim of input feature.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Neural number of didden layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Batch size of dataloader.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Training epochs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_dir",
|
||||
type=Path,
|
||||
default=Path("conformer_ctc/exp"),
|
||||
help="Directory to store experiment data.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
setup_logger(f"{args.exp_dir}/rescore/log")
|
||||
|
||||
model = ScoreEstimator(
|
||||
input_dim = args.input_dim,
|
||||
hidden_dim = args.hidden_dim
|
||||
)
|
||||
|
||||
model = model.to("cuda")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||
loss_fn = nn.GaussianNLLLoss()
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
Dataset(f"{args.exp_dir}/rescore/feat", "train"),
|
||||
collate_fn=DatasetCollateFunc(),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
dev_dataloader = DataLoader(
|
||||
Dataset(f"{args.exp_dir}/rescore/feat", "dev"),
|
||||
collate_fn=DatasetCollateFunc(),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
for epoch in range(args.epoch):
|
||||
model.train()
|
||||
training_loss = 0.0
|
||||
step = 0
|
||||
for x, y in train_dataloader:
|
||||
mean, var = model(x.cuda())
|
||||
loss = loss_fn(mean, y, var)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_loss += loss.item()
|
||||
step += len(y)
|
||||
training_loss /= step
|
||||
|
||||
dev_loss = 0.0
|
||||
step = 0
|
||||
model.eval()
|
||||
for x, y in dev_dataloader:
|
||||
mean, var = model(x.cuda())
|
||||
loss = loss_fn(mean, y, var)
|
||||
dev_loss += loss.item()
|
||||
step += len(y)
|
||||
dev_loss /= step
|
||||
|
||||
logging.info(f"Epoch {epoch} : training loss : {training_loss}, "
|
||||
f"dev loss : {dev_loss}"
|
||||
)
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
f"{args.exp_dir}/rescore/epoch-{epoch}.pt"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -411,6 +411,10 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||
assert keys.shape.dim0() == queries.shape.dim0(), \
|
||||
f'Utterances number in keys and queries should be equal : \
|
||||
{keys.shape.dim0()} vs {queries.shape.dim0()}'
|
||||
assert keys.fsa.device == queries.fsa.device, \
|
||||
f'Device of keys and queries should be equal : \
|
||||
{keys.fsa.device} vs {queries.fsa.device}'
|
||||
device = keys.fsa.device
|
||||
|
||||
# keys_tokens_shape [utt][path][token]
|
||||
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
|
||||
@ -430,11 +434,13 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||
# counts on key positions are ones
|
||||
keys_counts = k2.RaggedInt(keys_tokens_shape,
|
||||
torch.ones(keys_token_num,
|
||||
dtype=torch.int32))
|
||||
dtype=torch.int32,
|
||||
device=device))
|
||||
# counts on query positions are zeros
|
||||
queries_counts = k2.RaggedInt(queries_tokens_shape,
|
||||
torch.zeros(queries_tokens_num,
|
||||
dtype=torch.int32))
|
||||
dtype=torch.int32,
|
||||
device=device))
|
||||
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
|
||||
|
||||
# scores on key positions are the scores inherted from nbest path
|
||||
@ -442,7 +448,8 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||
# scores on query positions MUST be zeros
|
||||
queries_scores = k2.RaggedFloat(queries_tokens_shape,
|
||||
torch.zeros(queries_tokens_num,
|
||||
dtype=torch.float32))
|
||||
dtype=torch.float32,
|
||||
device=device))
|
||||
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
|
||||
|
||||
# we didn't remove -1 labels before
|
||||
@ -450,8 +457,16 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
||||
eos = -1
|
||||
max_token = torch.max(torch.max(keys.fsa.labels),
|
||||
torch.max(queries.fsa.labels))
|
||||
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
|
||||
counts, eos, min_token, max_token, max_order)
|
||||
mean, var, counts_out, ngram = k2.get_best_matching_stats(
|
||||
tokens.to(torch.device('cpu')), scores.to(torch.device('cpu')),
|
||||
counts.to(torch.device('cpu')),
|
||||
eos, min_token, max_token, max_order
|
||||
)
|
||||
|
||||
mean = mean.to(device)
|
||||
var = var.to(device)
|
||||
counts_out = counts_out.to(device)
|
||||
ngram = ngram.to(device)
|
||||
|
||||
queries_init_scores = queries.fsa.scores.clone()
|
||||
# only return the stats on query positions
|
||||
|
Loading…
x
Reference in New Issue
Block a user