Add multi round nbest rescoer

This commit is contained in:
pkufool 2021-08-18 15:00:13 +08:00
parent 0669aa8ab9
commit 27c46b66ee
5 changed files with 508 additions and 101 deletions

View File

@ -6,6 +6,7 @@
import argparse
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -23,10 +24,12 @@ from icefall.decode import (
nbest_decoding,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_attention_decoder_v2,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.score_estimator import ScoreEstimator
from icefall.utils import (
AttributeDict,
get_texts,
@ -62,6 +65,7 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
# "exp_dir": Path("exp/conformer_ctc"),
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
@ -86,10 +90,17 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
"method": "attention-decoder-v2",
# "method": "nbest-rescoring",
# "method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
# top_k is used when method is "attention-decoder-v2"
"top_k" : 10,
# dump_best_matching_feature is used when method is
# "attention-decoder-v2" to dump feature to train a special model
"dump_best_matching_feature": False,
}
)
return params
@ -104,6 +115,7 @@ def decode_one_batch(
lexicon: Lexicon,
sos_id: int,
eos_id: int,
rescore_est_model: nn.Module,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -135,12 +147,16 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
batch_idx:
The batch index of current batch.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -242,15 +258,24 @@ def decode_one_batch(
best_path_dict = rescore_with_attention_decoder_v2(
lattice=rescored_lattice,
batch_idx=batch_idx,
dump_best_matching_feature=params.dump_feature,
dump_best_matching_feature=params.dump_best_matching_feature,
num_paths=params.num_paths,
top_k=params.top_k,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
rescore_est_model=rescore_est_model,
sos_id=sos_id,
eos_id=eos_id,
)
if params.dump_best_matching_feature:
if best_path_dict.size()[0] > 0:
save_dir = params.exp_dir / f"rescore/feat"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name = save_dir / f"feats-epoch-{batch_idx}.pt"
torch.save(best_path_dict, file_name)
return dict()
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -270,6 +295,7 @@ def decode_dataset(
lexicon: Lexicon,
sos_id: int,
eos_id: int,
rescore_est_model: nn.Module,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
@ -289,6 +315,8 @@ def decode_dataset(
The token ID for SOS.
eos_id:
The token ID for EOS.
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -303,7 +331,7 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
tot_batches = len(dl)
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -314,11 +342,12 @@ def decode_dataset(
model=model,
HLG=HLG,
batch=batch,
batch_idx,
batch_idx=batch_idx,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
rescore_est_model=rescore_est_model,
)
for lm_scale, hyps in hyps_dict.items():
@ -334,9 +363,8 @@ def decode_dataset(
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"batch {batch_idx}/{tot_batches}, cuts processed until now is "
f"{num_cuts}"
)
return results
@ -430,6 +458,7 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"attention-decoder-v2",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
@ -453,7 +482,7 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
if params.method in ["whole-lattice-rescoring", "attention-decoder", "attention-decoder-v2"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
@ -465,6 +494,15 @@ def main():
G.lm_scores = G.scores.clone()
else:
G = None
if params.method == "attention-decoder-v2":
rescore_est_model = ScoreEstimator()
rescore_est_model.load_state_dict(
torch.load(f"{params.exp_dir}/rescore/epoch-19.pt",
map_location="cpu")
)
rescore_est_model.to(device)
else:
rescore_est_model = None
model = Conformer(
num_features=params.feature_dim,
@ -504,6 +542,7 @@ def main():
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
if test_set == "test-other": continue
results_dict = decode_dataset(
dl=test_dl,
params=params,
@ -513,6 +552,7 @@ def main():
G=G,
sos_id=sos_id,
eos_id=eos_id,
rescore_est_model=rescore_est_model,
)
save_results(

View File

@ -1,10 +1,15 @@
import logging
import os
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
import torch.nn as nn
from .nbest import Nbest
from .utils import get_best_matching_stats
from .score_estimator import ScoreEstimator
def _intersect_device(
a_fsas: k2.Fsa,
@ -752,20 +757,25 @@ def rescore_nbest_with_attention_decoder(
eos_id:
The token ID for EOS.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
A Nbest with all of the scores on fsa arcs updated with attention scores.
"""
num_seqs = nbest.shape.Dim0()
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous())
num_paths = nbest.shape.num_elements()
# token shape [utt][path][state][arc]
token_shape = k2.ragged.compose_ragged_shapes(nbest.shape, nbest.fsa.arcs.shape())
token_seq = k2.RaggedInt(token_shape, nbest.fsa.labels.contiguous())
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
# removed it when generating nbest list
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
# token seq shape [utt][path][token]
token_seq = k2.ragged.remove_axis(token_seq, 2)
# token seq shape [utt][token]
token_seq = k2.ragged.remove_axis(token_seq, 0)
token_ids = k2.ragged.to_list(token_seq)
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long)
path_to_seq_map_long = nbest.shape.row_ids(1).to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
@ -780,25 +790,27 @@ def rescore_nbest_with_attention_decoder(
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == num_seqs
assert nll.shape[0] == num_paths
attention_scores = torch.zeros(
nbest.fsas.labels().size()[0],
nbest.fsa.scores.size()[0],
dtype=torch.float32,
device=nbest.device
device=nbest.fsa.device
)
start_index = 0
for i in range(num_seqs):
for i in range(num_paths):
# Plus 1 to fill the score of final arc
tokens_num = len(tokens_ids[i]) + 1
attention_scores[start_index: start_index + tokens_num] =
tokens_num = 0 if len(token_ids[i]) == 0 else len(token_ids[i]) + 1
attention_scores[start_index: start_index + tokens_num] =\
nll[i][0: tokens_num]
start_index += tokens_num
fsas = nbest.fsas.clone()
fsas.score = attention_scores
return Nbest(fsas, nbest.shape.clone())
fsas = nbest.fsa.clone()
fsas.scores = attention_scores
return Nbest(fsas, nbest.shape)
def rescore_with_attention_decoder_v2(
@ -810,9 +822,10 @@ def rescore_with_attention_decoder_v2(
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
rescore_est_model: nn.Module,
sos_id: int,
eos_id: int,
) -> Dict[str, k2.Fsa]:
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
@ -820,6 +833,11 @@ def rescore_with_attention_decoder_v2(
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
batch_idx:
The batch index currently processed.
dump_best_matching_feature:
Whether to dump best matching feature, only for preparing training
data for attention-decoder-v2
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
@ -831,6 +849,8 @@ def rescore_with_attention_decoder_v2(
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
rescore_est_model:
The model to estimate rescore mean and variance, only for attention-decoder-v2
sos_id:
The token ID for SOS.
eos_id:
@ -841,23 +861,24 @@ def rescore_with_attention_decoder_v2(
best decoding path for each sequence in the lattice.
"""
nbest = generate_nbest_list(lattice, num_paths)
# Now we have nbest with scores
nbest = nbest.intersect(lattice)
if dump_best_matching_feature:
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
return torch.empty(0)
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
rescored_nbest_k = rescore_nbest_with_attention_decoder(
nbest=nbest_k,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
eos_id=eos_id
)
stats_tensor = get_best_matching_stats(
rescored_nbest_k,
nbest_q,
max_order=3
max_order=5
)
rescored_nbest_q = rescore_nbest_with_attention_decoder(
nbest=nbest_q,
@ -865,41 +886,132 @@ def rescore_with_attention_decoder_v2(
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id
)
merge_tensor = torch.cat(
(stats_tensor, rescored_nbest_q.fsa.scores.clone().view(-1, 1)),
dim=1
)
return merge_tensor
if nbest.fsa.arcs.dim0() >= 2 * top_k and nbest.fsa.arcs.num_elements() != 0:
nbest_topk, nbest_remain = nbest.split(k=top_k)
am_scores = nbest_topk.fsa.scores - nbest_topk.fsa.lm_scores
lm_scores = nbest_topk.fsa.lm_scores
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
# return feature & label or dump to file
)
nbest_topk, nbest_remain = nbest.split(k=top_k)
stats_tensor = get_best_matching_stats(
rescored_nbest_topk,
nbest_remain,
max_order=5
)
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
stats_tensor = get_best_matching_stats(
rescored_nbest_topk,
nbest_remain,
max_order=3
)
# run rescore estimation model to get the mean and var of each token
mean, var = rescore_est_model(stats_tensor)
# calculate nbest_remain estimated score and select topk
nbest_remain_topk = nbest_remain.top_k(k=top_k)
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_remain_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
best_path_dict=get_best_path_from_nbests(
rescored_nbest_topk,
rescored_nbest_remain_topk,
)
# run rescore estimation model to get the mean and var of each token
mean, var = rescore_est_model(stats_tensor)
# mean_shape [utt][path][state][arcs]
mean_shape = k2.ragged.compose_ragged_shapes(
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
# mean_shape [utt][path][arcs]
mean_shape = k2.ragged.remove_axis(mean_shape, 2)
ragged_mean = k2.RaggedFloat(mean_shape, mean.contiguous())
# path mean shape [utt][path]
path_mean = k2.ragged.sum_per_sublist(ragged_mean)
# var_shape [utt][path][state][arcs]
var_shape = k2.ragged.compose_ragged_shapes(
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
# var_shape [utt][path][arcs]
var_shape = k2.ragged.remove_axis(var_shape, 2)
ragged_var = k2.RaggedFloat(var_shape, var.contiguous())
# path var shape [utt][path]
path_var = k2.ragged.sum_per_sublist(ragged_var)
# tot_scores() shape [utt][path]
# path_score with elements numbers equals numbers of paths
# !!! Note: This is right only when utt equals to 1
path_scores = nbest_remain.total_scores().values()
best_score = torch.max(rescored_nbest_topk.total_scores().values())
est_scores = 1 - 1/2 * (
1 + torch.erf(
(best_score - path_mean) / torch.sqrt(2 * path_var)
)
)
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
# calculate nbest_remain estimated score and select topk
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
remain_am_scores = nbest_remain_topk.fsa.scores - nbest_remain_topk.fsa.lm_scores
remain_lm_scores = nbest_remain_topk.fsa.lm_scores
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_remain_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
# !!! Note: This is right only when utt equals to 1
merge_fsa = k2.cat([rescored_nbest_topk.fsa, rescored_nbest_remain_topk.fsa])
merge_row_ids = torch.zeros(
merge_fsa.arcs.dim0(),
dtype=torch.int32,
device=merge_fsa.device
)
rescore_nbest = Nbest(
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
)
attention_scores = rescore_nbest.fsa.scores
am_scores = torch.cat((am_scores, remain_am_scores))
lm_scores = torch.cat((lm_scores, remain_lm_scores))
else:
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
lm_scores = nbest.fsa.lm_scores
rescore_nbest = rescore_nbest_with_attention_decoder(
nbest=nbest,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id
)
attention_scores = rescore_nbest.fsa.scores
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores
+ n_scale * lm_scores
+ a_scale * attention_scores
)
rescore_nbest.fsa.scores = tot_scores
# ragged tot scores shape [utt][path]
ragged_tot_scores = rescore_nbest.total_scores()
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
best_fsas = k2.index_fsa(rescore_nbest.fsa, argmax_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_fsas
return ans
@ -920,49 +1032,90 @@ def generate_nbest_list(
that represent the same word sequences, the number of paths
in different sequences may not be equal.
Return:
Return an Nbest object. Note the returned FSAs don't have epsilon
self-loops.
Return an Nbest object.
'''
assert len(lats.shape) == 3
# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
# Seqs is a k2.RaggedInt sharing the same shape as `paths`.
# Note that it also contains 0s and -1s.
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# Its axes are [seq][path][word_id]
if aux_labels:
# if aux_labels enable, seqs contains word_id
assert hasattr(lats, "aux_labels")
seqs = k2.index(lats.aux_labels, paths)
else:
# CAUTION: We use `phones` instead of `tokens` here because
# :func:`compile_HLG` uses `phones`
#
# Note: compile_HLG is from k2-fsa/snowfall
assert hasattr(lats, 'phones')
word_seq = k2.index(lats.aux_labels, path)
assert not hasattr(lats, 'tokens')
lats.tokens = lats.phones
seqs = k2.index(lats.tokens, paths)
# Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
# Remove epsilons (0s) and -1 from word_seqs
seqs = k2.ragged.remove_values_leq(seqs, 0)
# Remove paths that has identical word sequences.
#
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True
)
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id].
# But then number of pathsin each sequence may be different.
unique_seqs, _, _ = k2.ragged.unique_sequences(
seqs, need_num_repeats=False, need_new2old_indexes=False)
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seqs has only two axes [path][word_id]
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0)
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
fsas = k2.linear_fsa(unique_seqs)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
return Nbest(fsa=fsas, shape=seq_to_path_shape)
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
# k2.compose() currently does not support b_to_a_map. To void
# replicating `lats`, we use k2.intersect_device here.
#
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
# need to invert it here.
inv_lattice = k2.invert(lats)
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
# Remove its `aux_labels` since it is not needed in the
# following computation
# del inv_lattice.aux_labels
inv_lattice = k2.arc_sort(inv_lattice)
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
# path_lattice now has token IDs as `labels` and word IDS as aux_labels.
path_lattice = k2.invert(path_lattice)
path_lattice = k2.top_sort(k2.connect(path_lattice))
# replace labels with tokens to remove repeat token IDs.
path_lattice.labels = path_lattice.tokens
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
n_best = k2.remove_epsilon(n_best)
n_best = k2.top_sort(k2.connect(n_best))
# now we have nbest lists with am_scores and lm_scores
return Nbest(fsa=n_best, shape=seq_to_path_shape)

View File

@ -82,8 +82,11 @@ class Nbest(object):
one_best = k2.remove_epsilon(one_best)
one_best = k2.top_sort(k2.connect(one_best))
return Nbest(fsa=one_best, shape=self.shape)
def total_scores(self) -> k2.RaggedFloat:
'''Get total scores of the FSAs in this Nbest.
@ -100,7 +103,7 @@ class Nbest(object):
# If k2.RaggedDouble is wrapped, we can use double precision here.
return k2.RaggedFloat(self.shape, scores.float())
def top_k(self, k: int) -> 'Nbest':
def top_k(self, k: int, scores: k2.RaggedFloat = None) -> 'Nbest':
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
in that each sequence (i.e., utterance) has the same number of
paths (k).
@ -113,10 +116,14 @@ class Nbest(object):
Args:
k:
Number of paths in each utterance.
scores:
The scores using to select top-k.
Returns:
Return a new Nbest with a regular shape.
'''
ragged_scores = self.total_scores()
ragged_scores = scores
if ragged_scores is None:
ragged_scores = self.total_scores()
# indexes contains idx01's for self.shape
# ragged_scores.values()[indexes] is sorted
@ -140,6 +147,7 @@ class Nbest(object):
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
dim1=k)
top_k_shape = top_k_shape.to(top_k_fsas.device)
return Nbest(top_k_fsas, top_k_shape)
@ -163,7 +171,7 @@ class Nbest(object):
# indexes contains idx01's for self.shape
indexes = torch.arange(
self.shape.num_elements(), dtype=torch.int32,
device=self.shape.device
device=self.fsa.device
)
if sort:
@ -176,9 +184,12 @@ class Nbest(object):
ragged_indexes = k2.RaggedInt(self.shape, indexes)
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1)
padded_indexes = k2.ragged.pad(ragged_indexes,
value=-1)
# Select the idx01's of top-k paths of each utterance
max_num_fsa = padded_indexes.size()[1]
first_indexes = padded_indexes[:, :k].flatten().contiguous()
# Remove the padding elements

188
icefall/score_estimator.py Normal file
View File

@ -0,0 +1,188 @@
import argparse
import glob
import logging
from pathlib import Path
from typing import Tuple, List
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from icefall.utils import (
setup_logger,
str2bool,
)
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
path: Path,
model: str,
) -> None:
super().__init__()
files = sorted(glob.glob(f"{path}/*.pt"))
if model == 'train':
self.files = files[0: int(len(files) * 0.8)]
elif model == 'dev':
self.files = files[int(len(files) * 0.8): int(len(files) * 0.9)]
elif mode == 'test':
self.files = files[int(len(files) * 0.9):]
def __getitem__(self, index) -> torch.Tensor:
return torch.load(self.files[index])
def __len__(self) -> int:
return len(self.files)
class DatasetCollateFunc:
def __call__(self, batch: List) -> Tuple[torch.Tensor, torch.Tensor]:
x = torch.cat(batch)
return (x[:, 0:5], x[:, 5])
class ScoreEstimator(nn.Module):
def __init__(
self,
input_dim: int = 5,
hidden_dim: int = 20,
) -> None:
super().__init__()
self.embedding = nn.Linear(
in_features=input_dim,
out_features=hidden_dim
)
self.output = nn.Linear(
in_features=hidden_dim,
out_features=2
)
self.sigmod = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.embedding(x)
x = self.sigmod(x)
x = self.output(x)
mean, var = x[:, 0], x[:, 1]
var = torch.exp(var)
return mean, var
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-dim",
type=int,
default=5,
help="Dim of input feature.",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=20,
help="Neural number of didden layer.",
)
parser.add_argument(
"--batch-size",
type=int,
default=10,
help="Batch size of dataloader.",
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="Training epochs",
)
parser.add_argument(
"--learning-rate",
type=float,
default=1e-4,
help="Learning rate.",
)
parser.add_argument(
"--exp_dir",
type=Path,
default=Path("conformer_ctc/exp"),
help="Directory to store experiment data.",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
setup_logger(f"{args.exp_dir}/rescore/log")
model = ScoreEstimator(
input_dim = args.input_dim,
hidden_dim = args.hidden_dim
)
model = model.to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_fn = nn.GaussianNLLLoss()
train_dataloader = DataLoader(
Dataset(f"{args.exp_dir}/rescore/feat", "train"),
collate_fn=DatasetCollateFunc(),
batch_size=args.batch_size,
shuffle=True
)
dev_dataloader = DataLoader(
Dataset(f"{args.exp_dir}/rescore/feat", "dev"),
collate_fn=DatasetCollateFunc(),
batch_size=args.batch_size,
shuffle=True
)
for epoch in range(args.epoch):
model.train()
training_loss = 0.0
step = 0
for x, y in train_dataloader:
mean, var = model(x.cuda())
loss = loss_fn(mean, y, var)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss += loss.item()
step += len(y)
training_loss /= step
dev_loss = 0.0
step = 0
model.eval()
for x, y in dev_dataloader:
mean, var = model(x.cuda())
loss = loss_fn(mean, y, var)
dev_loss += loss.item()
step += len(y)
dev_loss /= step
logging.info(f"Epoch {epoch} : training loss : {training_loss}, "
f"dev loss : {dev_loss}"
)
torch.save(
model.state_dict(),
f"{args.exp_dir}/rescore/epoch-{epoch}.pt"
)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -411,6 +411,10 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
assert keys.shape.dim0() == queries.shape.dim0(), \
f'Utterances number in keys and queries should be equal : \
{keys.shape.dim0()} vs {queries.shape.dim0()}'
assert keys.fsa.device == queries.fsa.device, \
f'Device of keys and queries should be equal : \
{keys.fsa.device} vs {queries.fsa.device}'
device = keys.fsa.device
# keys_tokens_shape [utt][path][token]
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
@ -430,11 +434,13 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
# counts on key positions are ones
keys_counts = k2.RaggedInt(keys_tokens_shape,
torch.ones(keys_token_num,
dtype=torch.int32))
dtype=torch.int32,
device=device))
# counts on query positions are zeros
queries_counts = k2.RaggedInt(queries_tokens_shape,
torch.zeros(queries_tokens_num,
dtype=torch.int32))
dtype=torch.int32,
device=device))
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
# scores on key positions are the scores inherted from nbest path
@ -442,7 +448,8 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
# scores on query positions MUST be zeros
queries_scores = k2.RaggedFloat(queries_tokens_shape,
torch.zeros(queries_tokens_num,
dtype=torch.float32))
dtype=torch.float32,
device=device))
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
# we didn't remove -1 labels before
@ -450,8 +457,16 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
eos = -1
max_token = torch.max(torch.max(keys.fsa.labels),
torch.max(queries.fsa.labels))
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
counts, eos, min_token, max_token, max_order)
mean, var, counts_out, ngram = k2.get_best_matching_stats(
tokens.to(torch.device('cpu')), scores.to(torch.device('cpu')),
counts.to(torch.device('cpu')),
eos, min_token, max_token, max_order
)
mean = mean.to(device)
var = var.to(device)
counts_out = counts_out.to(device)
ngram = ngram.to(device)
queries_init_scores = queries.fsa.scores.clone()
# only return the stats on query positions