Add multi round nbest rescoer

This commit is contained in:
pkufool 2021-08-18 15:00:13 +08:00
parent 0669aa8ab9
commit 27c46b66ee
5 changed files with 508 additions and 101 deletions

View File

@ -6,6 +6,7 @@
import argparse import argparse
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -23,10 +24,12 @@ from icefall.decode import (
nbest_decoding, nbest_decoding,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_attention_decoder_v2,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.score_estimator import ScoreEstimator
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -62,6 +65,7 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
# "exp_dir": Path("exp/conformer_ctc"),
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
@ -86,10 +90,17 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring # - whole-lattice-rescoring
# - attention-decoder # - attention-decoder
# "method": "whole-lattice-rescoring", # "method": "whole-lattice-rescoring",
"method": "attention-decoder", "method": "attention-decoder-v2",
# "method": "nbest-rescoring",
# "method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring", # num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder # and attention-decoder
"num_paths": 100, "num_paths": 100,
# top_k is used when method is "attention-decoder-v2"
"top_k" : 10,
# dump_best_matching_feature is used when method is
# "attention-decoder-v2" to dump feature to train a special model
"dump_best_matching_feature": False,
} }
) )
return params return params
@ -104,6 +115,7 @@ def decode_one_batch(
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
rescore_est_model: nn.Module,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -135,12 +147,16 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
batch_idx:
The batch index of current batch.
lexicon: lexicon:
It contains word symbol table. It contains word symbol table.
sos_id: sos_id:
The token ID of the SOS. The token ID of the SOS.
eos_id: eos_id:
The token ID of the EOS. The token ID of the EOS.
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -242,15 +258,24 @@ def decode_one_batch(
best_path_dict = rescore_with_attention_decoder_v2( best_path_dict = rescore_with_attention_decoder_v2(
lattice=rescored_lattice, lattice=rescored_lattice,
batch_idx=batch_idx, batch_idx=batch_idx,
dump_best_matching_feature=params.dump_feature, dump_best_matching_feature=params.dump_best_matching_feature,
num_paths=params.num_paths, num_paths=params.num_paths,
top_k=params.top_k, top_k=params.top_k,
model=model, model=model,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
rescore_est_model=rescore_est_model,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
) )
if params.dump_best_matching_feature:
if best_path_dict.size()[0] > 0:
save_dir = params.exp_dir / f"rescore/feat"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name = save_dir / f"feats-epoch-{batch_idx}.pt"
torch.save(best_path_dict, file_name)
return dict()
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -270,6 +295,7 @@ def decode_dataset(
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
rescore_est_model: nn.Module,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset. """Decode dataset.
@ -289,6 +315,8 @@ def decode_dataset(
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -303,7 +331,7 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts) tot_batches = len(dl)
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -314,11 +342,12 @@ def decode_dataset(
model=model, model=model,
HLG=HLG, HLG=HLG,
batch=batch, batch=batch,
batch_idx, batch_idx=batch_idx,
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
rescore_est_model=rescore_est_model,
) )
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
@ -334,9 +363,8 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_idx}/{tot_batches}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} " f"{num_cuts}"
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
) )
return results return results
@ -430,6 +458,7 @@ def main():
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
"attention-decoder-v2",
): ):
if not (params.lm_dir / "G_4_gram.pt").is_file(): if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
@ -453,7 +482,7 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder", "attention-decoder-v2"]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)
@ -465,6 +494,15 @@ def main():
G.lm_scores = G.scores.clone() G.lm_scores = G.scores.clone()
else: else:
G = None G = None
if params.method == "attention-decoder-v2":
rescore_est_model = ScoreEstimator()
rescore_est_model.load_state_dict(
torch.load(f"{params.exp_dir}/rescore/epoch-19.pt",
map_location="cpu")
)
rescore_est_model.to(device)
else:
rescore_est_model = None
model = Conformer( model = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -504,6 +542,7 @@ def main():
# #
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
if test_set == "test-other": continue
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
@ -513,6 +552,7 @@ def main():
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
rescore_est_model=rescore_est_model,
) )
save_results( save_results(

View File

@ -1,10 +1,15 @@
import logging import logging
import os
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nbest import Nbest
from .utils import get_best_matching_stats
from .score_estimator import ScoreEstimator
def _intersect_device( def _intersect_device(
a_fsas: k2.Fsa, a_fsas: k2.Fsa,
@ -752,20 +757,25 @@ def rescore_nbest_with_attention_decoder(
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
Returns: Returns:
A dict of FsaVec, whose key contains a string A Nbest with all of the scores on fsa arcs updated with attention scores.
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
""" """
num_seqs = nbest.shape.Dim0() num_paths = nbest.shape.num_elements()
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous()) # token shape [utt][path][state][arc]
token_shape = k2.ragged.compose_ragged_shapes(nbest.shape, nbest.fsa.arcs.shape())
token_seq = k2.RaggedInt(token_shape, nbest.fsa.labels.contiguous())
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we # Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
# removed it when generating nbest list # removed it when generating nbest list
token_seq = k2.ragged.remove_values_leq(token_seq, -1) token_seq = k2.ragged.remove_values_leq(token_seq, -1)
# token seq shape [utt][path][token]
token_seq = k2.ragged.remove_axis(token_seq, 2)
# token seq shape [utt][token]
token_seq = k2.ragged.remove_axis(token_seq, 0)
token_ids = k2.ragged.to_list(token_seq) token_ids = k2.ragged.to_list(token_seq)
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long) path_to_seq_map_long = nbest.shape.row_ids(1).to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long) expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
@ -780,25 +790,27 @@ def rescore_nbest_with_attention_decoder(
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
) )
assert nll.ndim == 2 assert nll.ndim == 2
assert nll.shape[0] == num_seqs assert nll.shape[0] == num_paths
attention_scores = torch.zeros( attention_scores = torch.zeros(
nbest.fsas.labels().size()[0], nbest.fsa.scores.size()[0],
dtype=torch.float32, dtype=torch.float32,
device=nbest.device device=nbest.fsa.device
) )
start_index = 0 start_index = 0
for i in range(num_seqs): for i in range(num_paths):
# Plus 1 to fill the score of final arc # Plus 1 to fill the score of final arc
tokens_num = len(tokens_ids[i]) + 1 tokens_num = 0 if len(token_ids[i]) == 0 else len(token_ids[i]) + 1
attention_scores[start_index: start_index + tokens_num] = attention_scores[start_index: start_index + tokens_num] =\
nll[i][0: tokens_num] nll[i][0: tokens_num]
start_index += tokens_num start_index += tokens_num
fsas = nbest.fsas.clone() fsas = nbest.fsa.clone()
fsas.score = attention_scores fsas.scores = attention_scores
return Nbest(fsas, nbest.shape.clone()) return Nbest(fsas, nbest.shape)
def rescore_with_attention_decoder_v2( def rescore_with_attention_decoder_v2(
@ -810,9 +822,10 @@ def rescore_with_attention_decoder_v2(
model: nn.Module, model: nn.Module,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
rescore_est_model: nn.Module,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
) -> Dict[str, k2.Fsa]: ) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
"""This function extracts n paths from the given lattice and uses """This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest an attention decoder to rescore them. The path with the highest
score is used as the decoding output. score is used as the decoding output.
@ -820,6 +833,11 @@ def rescore_with_attention_decoder_v2(
Args: Args:
lattice: lattice:
An FsaVec. It can be the return value of :func:`get_lattice`. An FsaVec. It can be the return value of :func:`get_lattice`.
batch_idx:
The batch index currently processed.
dump_best_matching_feature:
Whether to dump best matching feature, only for preparing training
data for attention-decoder-v2
num_paths: num_paths:
Number of paths to extract from the given lattice for rescoring. Number of paths to extract from the given lattice for rescoring.
model: model:
@ -831,6 +849,8 @@ def rescore_with_attention_decoder_v2(
Its shape is `[T, N, C]`. Its shape is `[T, N, C]`.
memory_key_padding_mask: memory_key_padding_mask:
The padding mask for memory with shape [N, T]. The padding mask for memory with shape [N, T].
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
sos_id: sos_id:
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
@ -841,23 +861,24 @@ def rescore_with_attention_decoder_v2(
best decoding path for each sequence in the lattice. best decoding path for each sequence in the lattice.
""" """
nbest = generate_nbest_list(lattice, num_paths) nbest = generate_nbest_list(lattice, num_paths)
# Now we have nbest with scores
nbest = nbest.intersect(lattice)
if dump_best_matching_feature: if dump_best_matching_feature:
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
return torch.empty(0)
nbest_k, nbest_q = nbest.split(k=top_k, sort=False) nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
rescored_nbest_k = rescore_nbest_with_attention_decoder( rescored_nbest_k = rescore_nbest_with_attention_decoder(
nbest=nbest_k, nbest=nbest_k,
model=model, model=model,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id
) )
stats_tensor = get_best_matching_stats( stats_tensor = get_best_matching_stats(
rescored_nbest_k, rescored_nbest_k,
nbest_q, nbest_q,
max_order=3 max_order=5
) )
rescored_nbest_q = rescore_nbest_with_attention_decoder( rescored_nbest_q = rescore_nbest_with_attention_decoder(
nbest=nbest_q, nbest=nbest_q,
@ -865,11 +886,20 @@ def rescore_with_attention_decoder_v2(
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id
# return feature & label or dump to file )
merge_tensor = torch.cat(
(stats_tensor, rescored_nbest_q.fsa.scores.clone().view(-1, 1)),
dim=1
)
return merge_tensor
if nbest.fsa.arcs.dim0() >= 2 * top_k and nbest.fsa.arcs.num_elements() != 0:
nbest_topk, nbest_remain = nbest.split(k=top_k) nbest_topk, nbest_remain = nbest.split(k=top_k)
am_scores = nbest_topk.fsa.scores - nbest_topk.fsa.lm_scores
lm_scores = nbest_topk.fsa.lm_scores
rescored_nbest_topk = rescore_nbest_with_attention_decoder( rescored_nbest_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_topk, nbest=nbest_topk,
model=model, model=model,
@ -878,15 +908,50 @@ def rescore_with_attention_decoder_v2(
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
) )
stats_tensor = get_best_matching_stats( stats_tensor = get_best_matching_stats(
rescored_nbest_topk, rescored_nbest_topk,
nbest_remain, nbest_remain,
max_order=3 max_order=5
) )
# run rescore estimation model to get the mean and var of each token # run rescore estimation model to get the mean and var of each token
mean, var = rescore_est_model(stats_tensor) mean, var = rescore_est_model(stats_tensor)
# mean_shape [utt][path][state][arcs]
mean_shape = k2.ragged.compose_ragged_shapes(
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
# mean_shape [utt][path][arcs]
mean_shape = k2.ragged.remove_axis(mean_shape, 2)
ragged_mean = k2.RaggedFloat(mean_shape, mean.contiguous())
# path mean shape [utt][path]
path_mean = k2.ragged.sum_per_sublist(ragged_mean)
# var_shape [utt][path][state][arcs]
var_shape = k2.ragged.compose_ragged_shapes(
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
# var_shape [utt][path][arcs]
var_shape = k2.ragged.remove_axis(var_shape, 2)
ragged_var = k2.RaggedFloat(var_shape, var.contiguous())
# path var shape [utt][path]
path_var = k2.ragged.sum_per_sublist(ragged_var)
# tot_scores() shape [utt][path]
# path_score with elements numbers equals numbers of paths
# !!! Note: This is right only when utt equals to 1
path_scores = nbest_remain.total_scores().values()
best_score = torch.max(rescored_nbest_topk.total_scores().values())
est_scores = 1 - 1/2 * (
1 + torch.erf(
(best_score - path_mean) / torch.sqrt(2 * path_var)
)
)
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
# calculate nbest_remain estimated score and select topk # calculate nbest_remain estimated score and select topk
nbest_remain_topk = nbest_remain.top_k(k=top_k) nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
remain_am_scores = nbest_remain_topk.fsa.scores - nbest_remain_topk.fsa.lm_scores
remain_lm_scores = nbest_remain_topk.fsa.lm_scores
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder( rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_remain_topk, nbest=nbest_remain_topk,
model=model, model=model,
@ -895,11 +960,58 @@ def rescore_with_attention_decoder_v2(
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
) )
best_path_dict=get_best_path_from_nbests(
rescored_nbest_topk, # !!! Note: This is right only when utt equals to 1
rescored_nbest_remain_topk, merge_fsa = k2.cat([rescored_nbest_topk.fsa, rescored_nbest_remain_topk.fsa])
merge_row_ids = torch.zeros(
merge_fsa.arcs.dim0(),
dtype=torch.int32,
device=merge_fsa.device
)
rescore_nbest = Nbest(
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
) )
attention_scores = rescore_nbest.fsa.scores
am_scores = torch.cat((am_scores, remain_am_scores))
lm_scores = torch.cat((lm_scores, remain_lm_scores))
else:
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
lm_scores = nbest.fsa.lm_scores
rescore_nbest = rescore_nbest_with_attention_decoder(
nbest=nbest,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id
)
attention_scores = rescore_nbest.fsa.scores
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores
+ n_scale * lm_scores
+ a_scale * attention_scores
)
rescore_nbest.fsa.scores = tot_scores
# ragged tot scores shape [utt][path]
ragged_tot_scores = rescore_nbest.total_scores()
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
best_fsas = k2.index_fsa(rescore_nbest.fsa, argmax_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_fsas
return ans return ans
@ -920,49 +1032,90 @@ def generate_nbest_list(
that represent the same word sequences, the number of paths that represent the same word sequences, the number of paths
in different sequences may not be equal. in different sequences may not be equal.
Return: Return:
Return an Nbest object. Note the returned FSAs don't have epsilon Return an Nbest object.
self-loops.
''' '''
assert len(lats.shape) == 3
# First, extract `num_paths` paths for each sequence. # First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos] # path is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
# Seqs is a k2.RaggedInt sharing the same shape as `paths`. # word_seq is a k2.RaggedInt sharing the same shape as `path`
# Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1. # The last entry in each sublist is -1.
# Its axes are [seq][path][word_id] word_seq = k2.index(lats.aux_labels, path)
if aux_labels:
# if aux_labels enable, seqs contains word_id # Remove epsilons and -1 from word_seq
assert hasattr(lats, "aux_labels") word_seq = k2.ragged.remove_values_leq(word_seq, 0)
seqs = k2.index(lats.aux_labels, paths)
else: # Remove paths that has identical word sequences.
# CAUTION: We use `phones` instead of `tokens` here because
# :func:`compile_HLG` uses `phones`
# #
# Note: compile_HLG is from k2-fsa/snowfall # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
assert hasattr(lats, 'phones') # except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True
)
assert not hasattr(lats, 'tokens') seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
lats.tokens = lats.phones
seqs = k2.index(lats.tokens, paths)
# Remove epsilons (0s) and -1 from word_seqs # path_to_seq_map is a 1-D torch.Tensor.
seqs = k2.ragged.remove_values_leq(seqs, 0) # path_to_seq_map[i] is the seq to which the i-th path
# belongs.
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id]. path_to_seq_map = seq_to_path_shape.row_ids(1)
# But then number of pathsin each sequence may be different.
unique_seqs, _, _ = k2.ragged.unique_sequences(
seqs, need_num_repeats=False, need_new2old_indexes=False)
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
# Remove the seq axis. # Remove the seq axis.
# Now unique_word_seqs has only two axes [path][word_id] # Now unique_word_seq has only two axes [path][word]
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0) unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
fsas = k2.linear_fsa(unique_seqs) # word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
return Nbest(fsa=fsas, shape=seq_to_path_shape) word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
# k2.compose() currently does not support b_to_a_map. To void
# replicating `lats`, we use k2.intersect_device here.
#
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
# need to invert it here.
inv_lattice = k2.invert(lats)
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
# Remove its `aux_labels` since it is not needed in the
# following computation
# del inv_lattice.aux_labels
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
# path_lattice now has token IDs as `labels` and word IDS as aux_labels.
path_lattice = k2.invert(path_lattice)
path_lattice = k2.top_sort(k2.connect(path_lattice))
# replace labels with tokens to remove repeat token IDs.
path_lattice.labels = path_lattice.tokens
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
n_best = k2.remove_epsilon(n_best)
n_best = k2.top_sort(k2.connect(n_best))
# now we have nbest lists with am_scores and lm_scores
return Nbest(fsa=n_best, shape=seq_to_path_shape)

View File

@ -82,8 +82,11 @@ class Nbest(object):
one_best = k2.remove_epsilon(one_best) one_best = k2.remove_epsilon(one_best)
one_best = k2.top_sort(k2.connect(one_best))
return Nbest(fsa=one_best, shape=self.shape) return Nbest(fsa=one_best, shape=self.shape)
def total_scores(self) -> k2.RaggedFloat: def total_scores(self) -> k2.RaggedFloat:
'''Get total scores of the FSAs in this Nbest. '''Get total scores of the FSAs in this Nbest.
@ -100,7 +103,7 @@ class Nbest(object):
# If k2.RaggedDouble is wrapped, we can use double precision here. # If k2.RaggedDouble is wrapped, we can use double precision here.
return k2.RaggedFloat(self.shape, scores.float()) return k2.RaggedFloat(self.shape, scores.float())
def top_k(self, k: int) -> 'Nbest': def top_k(self, k: int, scores: k2.RaggedFloat = None) -> 'Nbest':
'''Get a subset of paths in the Nbest. The resulting Nbest is regular '''Get a subset of paths in the Nbest. The resulting Nbest is regular
in that each sequence (i.e., utterance) has the same number of in that each sequence (i.e., utterance) has the same number of
paths (k). paths (k).
@ -113,9 +116,13 @@ class Nbest(object):
Args: Args:
k: k:
Number of paths in each utterance. Number of paths in each utterance.
scores:
The scores using to select top-k.
Returns: Returns:
Return a new Nbest with a regular shape. Return a new Nbest with a regular shape.
''' '''
ragged_scores = scores
if ragged_scores is None:
ragged_scores = self.total_scores() ragged_scores = self.total_scores()
# indexes contains idx01's for self.shape # indexes contains idx01's for self.shape
@ -140,6 +147,7 @@ class Nbest(object):
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(), top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
dim1=k) dim1=k)
top_k_shape = top_k_shape.to(top_k_fsas.device)
return Nbest(top_k_fsas, top_k_shape) return Nbest(top_k_fsas, top_k_shape)
@ -163,7 +171,7 @@ class Nbest(object):
# indexes contains idx01's for self.shape # indexes contains idx01's for self.shape
indexes = torch.arange( indexes = torch.arange(
self.shape.num_elements(), dtype=torch.int32, self.shape.num_elements(), dtype=torch.int32,
device=self.shape.device device=self.fsa.device
) )
if sort: if sort:
@ -176,9 +184,12 @@ class Nbest(object):
ragged_indexes = k2.RaggedInt(self.shape, indexes) ragged_indexes = k2.RaggedInt(self.shape, indexes)
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1) padded_indexes = k2.ragged.pad(ragged_indexes,
value=-1)
# Select the idx01's of top-k paths of each utterance # Select the idx01's of top-k paths of each utterance
max_num_fsa = padded_indexes.size()[1]
first_indexes = padded_indexes[:, :k].flatten().contiguous() first_indexes = padded_indexes[:, :k].flatten().contiguous()
# Remove the padding elements # Remove the padding elements

188
icefall/score_estimator.py Normal file
View File

@ -0,0 +1,188 @@
import argparse
import glob
import logging
from pathlib import Path
from typing import Tuple, List
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from icefall.utils import (
setup_logger,
str2bool,
)
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
path: Path,
model: str,
) -> None:
super().__init__()
files = sorted(glob.glob(f"{path}/*.pt"))
if model == 'train':
self.files = files[0: int(len(files) * 0.8)]
elif model == 'dev':
self.files = files[int(len(files) * 0.8): int(len(files) * 0.9)]
elif mode == 'test':
self.files = files[int(len(files) * 0.9):]
def __getitem__(self, index) -> torch.Tensor:
return torch.load(self.files[index])
def __len__(self) -> int:
return len(self.files)
class DatasetCollateFunc:
def __call__(self, batch: List) -> Tuple[torch.Tensor, torch.Tensor]:
x = torch.cat(batch)
return (x[:, 0:5], x[:, 5])
class ScoreEstimator(nn.Module):
def __init__(
self,
input_dim: int = 5,
hidden_dim: int = 20,
) -> None:
super().__init__()
self.embedding = nn.Linear(
in_features=input_dim,
out_features=hidden_dim
)
self.output = nn.Linear(
in_features=hidden_dim,
out_features=2
)
self.sigmod = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.embedding(x)
x = self.sigmod(x)
x = self.output(x)
mean, var = x[:, 0], x[:, 1]
var = torch.exp(var)
return mean, var
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-dim",
type=int,
default=5,
help="Dim of input feature.",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=20,
help="Neural number of didden layer.",
)
parser.add_argument(
"--batch-size",
type=int,
default=10,
help="Batch size of dataloader.",
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="Training epochs",
)
parser.add_argument(
"--learning-rate",
type=float,
default=1e-4,
help="Learning rate.",
)
parser.add_argument(
"--exp_dir",
type=Path,
default=Path("conformer_ctc/exp"),
help="Directory to store experiment data.",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
setup_logger(f"{args.exp_dir}/rescore/log")
model = ScoreEstimator(
input_dim = args.input_dim,
hidden_dim = args.hidden_dim
)
model = model.to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_fn = nn.GaussianNLLLoss()
train_dataloader = DataLoader(
Dataset(f"{args.exp_dir}/rescore/feat", "train"),
collate_fn=DatasetCollateFunc(),
batch_size=args.batch_size,
shuffle=True
)
dev_dataloader = DataLoader(
Dataset(f"{args.exp_dir}/rescore/feat", "dev"),
collate_fn=DatasetCollateFunc(),
batch_size=args.batch_size,
shuffle=True
)
for epoch in range(args.epoch):
model.train()
training_loss = 0.0
step = 0
for x, y in train_dataloader:
mean, var = model(x.cuda())
loss = loss_fn(mean, y, var)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss += loss.item()
step += len(y)
training_loss /= step
dev_loss = 0.0
step = 0
model.eval()
for x, y in dev_dataloader:
mean, var = model(x.cuda())
loss = loss_fn(mean, y, var)
dev_loss += loss.item()
step += len(y)
dev_loss /= step
logging.info(f"Epoch {epoch} : training loss : {training_loss}, "
f"dev loss : {dev_loss}"
)
torch.save(
model.state_dict(),
f"{args.exp_dir}/rescore/epoch-{epoch}.pt"
)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -411,6 +411,10 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
assert keys.shape.dim0() == queries.shape.dim0(), \ assert keys.shape.dim0() == queries.shape.dim0(), \
f'Utterances number in keys and queries should be equal : \ f'Utterances number in keys and queries should be equal : \
{keys.shape.dim0()} vs {queries.shape.dim0()}' {keys.shape.dim0()} vs {queries.shape.dim0()}'
assert keys.fsa.device == queries.fsa.device, \
f'Device of keys and queries should be equal : \
{keys.fsa.device} vs {queries.fsa.device}'
device = keys.fsa.device
# keys_tokens_shape [utt][path][token] # keys_tokens_shape [utt][path][token]
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape, keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
@ -430,11 +434,13 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
# counts on key positions are ones # counts on key positions are ones
keys_counts = k2.RaggedInt(keys_tokens_shape, keys_counts = k2.RaggedInt(keys_tokens_shape,
torch.ones(keys_token_num, torch.ones(keys_token_num,
dtype=torch.int32)) dtype=torch.int32,
device=device))
# counts on query positions are zeros # counts on query positions are zeros
queries_counts = k2.RaggedInt(queries_tokens_shape, queries_counts = k2.RaggedInt(queries_tokens_shape,
torch.zeros(queries_tokens_num, torch.zeros(queries_tokens_num,
dtype=torch.int32)) dtype=torch.int32,
device=device))
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values() counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
# scores on key positions are the scores inherted from nbest path # scores on key positions are the scores inherted from nbest path
@ -442,7 +448,8 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
# scores on query positions MUST be zeros # scores on query positions MUST be zeros
queries_scores = k2.RaggedFloat(queries_tokens_shape, queries_scores = k2.RaggedFloat(queries_tokens_shape,
torch.zeros(queries_tokens_num, torch.zeros(queries_tokens_num,
dtype=torch.float32)) dtype=torch.float32,
device=device))
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values() scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
# we didn't remove -1 labels before # we didn't remove -1 labels before
@ -450,8 +457,16 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
eos = -1 eos = -1
max_token = torch.max(torch.max(keys.fsa.labels), max_token = torch.max(torch.max(keys.fsa.labels),
torch.max(queries.fsa.labels)) torch.max(queries.fsa.labels))
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores, mean, var, counts_out, ngram = k2.get_best_matching_stats(
counts, eos, min_token, max_token, max_order) tokens.to(torch.device('cpu')), scores.to(torch.device('cpu')),
counts.to(torch.device('cpu')),
eos, min_token, max_token, max_order
)
mean = mean.to(device)
var = var.to(device)
counts_out = counts_out.to(device)
ngram = ngram.to(device)
queries_init_scores = queries.fsa.scores.clone() queries_init_scores = queries.fsa.scores.clone()
# only return the stats on query positions # only return the stats on query positions