mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add multi round nbest rescoer
This commit is contained in:
parent
0669aa8ab9
commit
27c46b66ee
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -23,10 +24,12 @@ from icefall.decode import (
|
|||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_attention_decoder_v2,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.score_estimator import ScoreEstimator
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -62,6 +65,7 @@ def get_parser():
|
|||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
# "exp_dir": Path("exp/conformer_ctc"),
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc/exp"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
@ -86,10 +90,17 @@ def get_params() -> AttributeDict:
|
|||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
# - attention-decoder
|
# - attention-decoder
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "attention-decoder",
|
"method": "attention-decoder-v2",
|
||||||
|
# "method": "nbest-rescoring",
|
||||||
|
# "method": "attention-decoder",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# and attention-decoder
|
# and attention-decoder
|
||||||
"num_paths": 100,
|
"num_paths": 100,
|
||||||
|
# top_k is used when method is "attention-decoder-v2"
|
||||||
|
"top_k" : 10,
|
||||||
|
# dump_best_matching_feature is used when method is
|
||||||
|
# "attention-decoder-v2" to dump feature to train a special model
|
||||||
|
"dump_best_matching_feature": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -104,6 +115,7 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
|
rescore_est_model: nn.Module,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -135,12 +147,16 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
batch_idx:
|
||||||
|
The batch index of current batch.
|
||||||
lexicon:
|
lexicon:
|
||||||
It contains word symbol table.
|
It contains word symbol table.
|
||||||
sos_id:
|
sos_id:
|
||||||
The token ID of the SOS.
|
The token ID of the SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
The token ID of the EOS.
|
The token ID of the EOS.
|
||||||
|
rescore_est_model:
|
||||||
|
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -242,15 +258,24 @@ def decode_one_batch(
|
|||||||
best_path_dict = rescore_with_attention_decoder_v2(
|
best_path_dict = rescore_with_attention_decoder_v2(
|
||||||
lattice=rescored_lattice,
|
lattice=rescored_lattice,
|
||||||
batch_idx=batch_idx,
|
batch_idx=batch_idx,
|
||||||
dump_best_matching_feature=params.dump_feature,
|
dump_best_matching_feature=params.dump_best_matching_feature,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
top_k=params.top_k,
|
top_k=params.top_k,
|
||||||
model=model,
|
model=model,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
rescore_est_model=rescore_est_model,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
if params.dump_best_matching_feature:
|
||||||
|
if best_path_dict.size()[0] > 0:
|
||||||
|
save_dir = params.exp_dir / f"rescore/feat"
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
file_name = save_dir / f"feats-epoch-{batch_idx}.pt"
|
||||||
|
torch.save(best_path_dict, file_name)
|
||||||
|
return dict()
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
@ -270,6 +295,7 @@ def decode_dataset(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
|
rescore_est_model: nn.Module,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -289,6 +315,8 @@ def decode_dataset(
|
|||||||
The token ID for SOS.
|
The token ID for SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
The token ID for EOS.
|
The token ID for EOS.
|
||||||
|
rescore_est_model:
|
||||||
|
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -303,7 +331,7 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
tot_batches = len(dl)
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -314,11 +342,12 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
batch_idx,
|
batch_idx=batch_idx,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
rescore_est_model=rescore_est_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
@ -334,9 +363,8 @@ def decode_dataset(
|
|||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_idx}/{tot_batches}, cuts processed until now is "
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
f"{num_cuts}"
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -430,6 +458,7 @@ def main():
|
|||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
|
"attention-decoder-v2",
|
||||||
):
|
):
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
@ -453,7 +482,7 @@ def main():
|
|||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder", "attention-decoder-v2"]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
@ -465,6 +494,15 @@ def main():
|
|||||||
G.lm_scores = G.scores.clone()
|
G.lm_scores = G.scores.clone()
|
||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
if params.method == "attention-decoder-v2":
|
||||||
|
rescore_est_model = ScoreEstimator()
|
||||||
|
rescore_est_model.load_state_dict(
|
||||||
|
torch.load(f"{params.exp_dir}/rescore/epoch-19.pt",
|
||||||
|
map_location="cpu")
|
||||||
|
)
|
||||||
|
rescore_est_model.to(device)
|
||||||
|
else:
|
||||||
|
rescore_est_model = None
|
||||||
|
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -504,6 +542,7 @@ def main():
|
|||||||
#
|
#
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||||
|
if test_set == "test-other": continue
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
@ -513,6 +552,7 @@ def main():
|
|||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
rescore_est_model=rescore_est_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .nbest import Nbest
|
||||||
|
from .utils import get_best_matching_stats
|
||||||
|
|
||||||
|
from .score_estimator import ScoreEstimator
|
||||||
|
|
||||||
def _intersect_device(
|
def _intersect_device(
|
||||||
a_fsas: k2.Fsa,
|
a_fsas: k2.Fsa,
|
||||||
@ -752,20 +757,25 @@ def rescore_nbest_with_attention_decoder(
|
|||||||
eos_id:
|
eos_id:
|
||||||
The token ID for EOS.
|
The token ID for EOS.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key contains a string
|
A Nbest with all of the scores on fsa arcs updated with attention scores.
|
||||||
ngram_lm_scale_attention_scale and the value is the
|
|
||||||
best decoding path for each sequence in the lattice.
|
|
||||||
"""
|
"""
|
||||||
num_seqs = nbest.shape.Dim0()
|
num_paths = nbest.shape.num_elements()
|
||||||
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous())
|
# token shape [utt][path][state][arc]
|
||||||
|
token_shape = k2.ragged.compose_ragged_shapes(nbest.shape, nbest.fsa.arcs.shape())
|
||||||
|
|
||||||
|
token_seq = k2.RaggedInt(token_shape, nbest.fsa.labels.contiguous())
|
||||||
|
|
||||||
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
|
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
|
||||||
# removed it when generating nbest list
|
# removed it when generating nbest list
|
||||||
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
||||||
|
# token seq shape [utt][path][token]
|
||||||
|
token_seq = k2.ragged.remove_axis(token_seq, 2)
|
||||||
|
# token seq shape [utt][token]
|
||||||
|
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||||
|
|
||||||
token_ids = k2.ragged.to_list(token_seq)
|
token_ids = k2.ragged.to_list(token_seq)
|
||||||
|
|
||||||
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long)
|
path_to_seq_map_long = nbest.shape.row_ids(1).to(torch.long)
|
||||||
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||||
|
|
||||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
@ -780,25 +790,27 @@ def rescore_nbest_with_attention_decoder(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert nll.ndim == 2
|
assert nll.ndim == 2
|
||||||
assert nll.shape[0] == num_seqs
|
assert nll.shape[0] == num_paths
|
||||||
|
|
||||||
attention_scores = torch.zeros(
|
attention_scores = torch.zeros(
|
||||||
nbest.fsas.labels().size()[0],
|
nbest.fsa.scores.size()[0],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=nbest.device
|
device=nbest.fsa.device
|
||||||
)
|
)
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i in range(num_seqs):
|
for i in range(num_paths):
|
||||||
# Plus 1 to fill the score of final arc
|
# Plus 1 to fill the score of final arc
|
||||||
tokens_num = len(tokens_ids[i]) + 1
|
tokens_num = 0 if len(token_ids[i]) == 0 else len(token_ids[i]) + 1
|
||||||
attention_scores[start_index: start_index + tokens_num] =
|
attention_scores[start_index: start_index + tokens_num] =\
|
||||||
nll[i][0: tokens_num]
|
nll[i][0: tokens_num]
|
||||||
start_index += tokens_num
|
start_index += tokens_num
|
||||||
|
|
||||||
fsas = nbest.fsas.clone()
|
fsas = nbest.fsa.clone()
|
||||||
fsas.score = attention_scores
|
fsas.scores = attention_scores
|
||||||
return Nbest(fsas, nbest.shape.clone())
|
return Nbest(fsas, nbest.shape)
|
||||||
|
|
||||||
|
|
||||||
def rescore_with_attention_decoder_v2(
|
def rescore_with_attention_decoder_v2(
|
||||||
@ -810,9 +822,10 @@ def rescore_with_attention_decoder_v2(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
rescore_est_model: nn.Module,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
|
||||||
"""This function extracts n paths from the given lattice and uses
|
"""This function extracts n paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest
|
an attention decoder to rescore them. The path with the highest
|
||||||
score is used as the decoding output.
|
score is used as the decoding output.
|
||||||
@ -820,6 +833,11 @@ def rescore_with_attention_decoder_v2(
|
|||||||
Args:
|
Args:
|
||||||
lattice:
|
lattice:
|
||||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||||
|
batch_idx:
|
||||||
|
The batch index currently processed.
|
||||||
|
dump_best_matching_feature:
|
||||||
|
Whether to dump best matching feature, only for preparing training
|
||||||
|
data for attention-decoder-v2
|
||||||
num_paths:
|
num_paths:
|
||||||
Number of paths to extract from the given lattice for rescoring.
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
model:
|
model:
|
||||||
@ -831,6 +849,8 @@ def rescore_with_attention_decoder_v2(
|
|||||||
Its shape is `[T, N, C]`.
|
Its shape is `[T, N, C]`.
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask for memory with shape [N, T].
|
The padding mask for memory with shape [N, T].
|
||||||
|
rescore_est_model:
|
||||||
|
The model to estimate rescore mean and variance, only for attention-decoder-v2
|
||||||
sos_id:
|
sos_id:
|
||||||
The token ID for SOS.
|
The token ID for SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
@ -841,23 +861,24 @@ def rescore_with_attention_decoder_v2(
|
|||||||
best decoding path for each sequence in the lattice.
|
best decoding path for each sequence in the lattice.
|
||||||
"""
|
"""
|
||||||
nbest = generate_nbest_list(lattice, num_paths)
|
nbest = generate_nbest_list(lattice, num_paths)
|
||||||
# Now we have nbest with scores
|
|
||||||
nbest = nbest.intersect(lattice)
|
|
||||||
|
|
||||||
if dump_best_matching_feature:
|
if dump_best_matching_feature:
|
||||||
|
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
|
||||||
|
return torch.empty(0)
|
||||||
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
|
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
|
||||||
|
|
||||||
rescored_nbest_k = rescore_nbest_with_attention_decoder(
|
rescored_nbest_k = rescore_nbest_with_attention_decoder(
|
||||||
nbest=nbest_k,
|
nbest=nbest_k,
|
||||||
model=model,
|
model=model,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id
|
||||||
)
|
)
|
||||||
stats_tensor = get_best_matching_stats(
|
stats_tensor = get_best_matching_stats(
|
||||||
rescored_nbest_k,
|
rescored_nbest_k,
|
||||||
nbest_q,
|
nbest_q,
|
||||||
max_order=3
|
max_order=5
|
||||||
)
|
)
|
||||||
rescored_nbest_q = rescore_nbest_with_attention_decoder(
|
rescored_nbest_q = rescore_nbest_with_attention_decoder(
|
||||||
nbest=nbest_q,
|
nbest=nbest_q,
|
||||||
@ -865,11 +886,20 @@ def rescore_with_attention_decoder_v2(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id
|
||||||
# return feature & label or dump to file
|
)
|
||||||
|
merge_tensor = torch.cat(
|
||||||
|
(stats_tensor, rescored_nbest_q.fsa.scores.clone().view(-1, 1)),
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
return merge_tensor
|
||||||
|
|
||||||
|
if nbest.fsa.arcs.dim0() >= 2 * top_k and nbest.fsa.arcs.num_elements() != 0:
|
||||||
nbest_topk, nbest_remain = nbest.split(k=top_k)
|
nbest_topk, nbest_remain = nbest.split(k=top_k)
|
||||||
|
|
||||||
|
am_scores = nbest_topk.fsa.scores - nbest_topk.fsa.lm_scores
|
||||||
|
lm_scores = nbest_topk.fsa.lm_scores
|
||||||
|
|
||||||
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
|
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
|
||||||
nbest=nbest_topk,
|
nbest=nbest_topk,
|
||||||
model=model,
|
model=model,
|
||||||
@ -878,15 +908,50 @@ def rescore_with_attention_decoder_v2(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
stats_tensor = get_best_matching_stats(
|
stats_tensor = get_best_matching_stats(
|
||||||
rescored_nbest_topk,
|
rescored_nbest_topk,
|
||||||
nbest_remain,
|
nbest_remain,
|
||||||
max_order=3
|
max_order=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# run rescore estimation model to get the mean and var of each token
|
# run rescore estimation model to get the mean and var of each token
|
||||||
mean, var = rescore_est_model(stats_tensor)
|
mean, var = rescore_est_model(stats_tensor)
|
||||||
|
|
||||||
|
# mean_shape [utt][path][state][arcs]
|
||||||
|
mean_shape = k2.ragged.compose_ragged_shapes(
|
||||||
|
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
|
||||||
|
# mean_shape [utt][path][arcs]
|
||||||
|
mean_shape = k2.ragged.remove_axis(mean_shape, 2)
|
||||||
|
ragged_mean = k2.RaggedFloat(mean_shape, mean.contiguous())
|
||||||
|
# path mean shape [utt][path]
|
||||||
|
path_mean = k2.ragged.sum_per_sublist(ragged_mean)
|
||||||
|
|
||||||
|
# var_shape [utt][path][state][arcs]
|
||||||
|
var_shape = k2.ragged.compose_ragged_shapes(
|
||||||
|
nbest_remain.shape, nbest_remain.fsa.arcs.shape())
|
||||||
|
# var_shape [utt][path][arcs]
|
||||||
|
var_shape = k2.ragged.remove_axis(var_shape, 2)
|
||||||
|
ragged_var = k2.RaggedFloat(var_shape, var.contiguous())
|
||||||
|
# path var shape [utt][path]
|
||||||
|
path_var = k2.ragged.sum_per_sublist(ragged_var)
|
||||||
|
|
||||||
|
# tot_scores() shape [utt][path]
|
||||||
|
# path_score with elements numbers equals numbers of paths
|
||||||
|
# !!! Note: This is right only when utt equals to 1
|
||||||
|
path_scores = nbest_remain.total_scores().values()
|
||||||
|
best_score = torch.max(rescored_nbest_topk.total_scores().values())
|
||||||
|
est_scores = 1 - 1/2 * (
|
||||||
|
1 + torch.erf(
|
||||||
|
(best_score - path_mean) / torch.sqrt(2 * path_var)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
|
||||||
|
|
||||||
# calculate nbest_remain estimated score and select topk
|
# calculate nbest_remain estimated score and select topk
|
||||||
nbest_remain_topk = nbest_remain.top_k(k=top_k)
|
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
|
||||||
|
remain_am_scores = nbest_remain_topk.fsa.scores - nbest_remain_topk.fsa.lm_scores
|
||||||
|
remain_lm_scores = nbest_remain_topk.fsa.lm_scores
|
||||||
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
|
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
|
||||||
nbest=nbest_remain_topk,
|
nbest=nbest_remain_topk,
|
||||||
model=model,
|
model=model,
|
||||||
@ -895,11 +960,58 @@ def rescore_with_attention_decoder_v2(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
best_path_dict=get_best_path_from_nbests(
|
|
||||||
rescored_nbest_topk,
|
# !!! Note: This is right only when utt equals to 1
|
||||||
rescored_nbest_remain_topk,
|
merge_fsa = k2.cat([rescored_nbest_topk.fsa, rescored_nbest_remain_topk.fsa])
|
||||||
|
merge_row_ids = torch.zeros(
|
||||||
|
merge_fsa.arcs.dim0(),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=merge_fsa.device
|
||||||
|
)
|
||||||
|
rescore_nbest = Nbest(
|
||||||
|
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attention_scores = rescore_nbest.fsa.scores
|
||||||
|
am_scores = torch.cat((am_scores, remain_am_scores))
|
||||||
|
lm_scores = torch.cat((lm_scores, remain_lm_scores))
|
||||||
|
else:
|
||||||
|
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
|
||||||
|
lm_scores = nbest.fsa.lm_scores
|
||||||
|
rescore_nbest = rescore_nbest_with_attention_decoder(
|
||||||
|
nbest=nbest,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id
|
||||||
|
)
|
||||||
|
attention_scores = rescore_nbest.fsa.scores
|
||||||
|
|
||||||
|
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
|
||||||
|
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
tot_scores = (
|
||||||
|
am_scores
|
||||||
|
+ n_scale * lm_scores
|
||||||
|
+ a_scale * attention_scores
|
||||||
|
)
|
||||||
|
rescore_nbest.fsa.scores = tot_scores
|
||||||
|
# ragged tot scores shape [utt][path]
|
||||||
|
ragged_tot_scores = rescore_nbest.total_scores()
|
||||||
|
|
||||||
|
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||||
|
|
||||||
|
best_fsas = k2.index_fsa(rescore_nbest.fsa, argmax_indexes)
|
||||||
|
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
|
ans[key] = best_fsas
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -920,49 +1032,90 @@ def generate_nbest_list(
|
|||||||
that represent the same word sequences, the number of paths
|
that represent the same word sequences, the number of paths
|
||||||
in different sequences may not be equal.
|
in different sequences may not be equal.
|
||||||
Return:
|
Return:
|
||||||
Return an Nbest object. Note the returned FSAs don't have epsilon
|
Return an Nbest object.
|
||||||
self-loops.
|
|
||||||
'''
|
'''
|
||||||
assert len(lats.shape) == 3
|
|
||||||
|
|
||||||
# First, extract `num_paths` paths for each sequence.
|
# First, extract `num_paths` paths for each sequence.
|
||||||
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
|
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||||
|
|
||||||
# Seqs is a k2.RaggedInt sharing the same shape as `paths`.
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
# Note that it also contains 0s and -1s.
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
# The last entry in each sublist is -1.
|
# The last entry in each sublist is -1.
|
||||||
# Its axes are [seq][path][word_id]
|
word_seq = k2.index(lats.aux_labels, path)
|
||||||
if aux_labels:
|
|
||||||
# if aux_labels enable, seqs contains word_id
|
# Remove epsilons and -1 from word_seq
|
||||||
assert hasattr(lats, "aux_labels")
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||||
seqs = k2.index(lats.aux_labels, paths)
|
|
||||||
else:
|
# Remove paths that has identical word sequences.
|
||||||
# CAUTION: We use `phones` instead of `tokens` here because
|
|
||||||
# :func:`compile_HLG` uses `phones`
|
|
||||||
#
|
#
|
||||||
# Note: compile_HLG is from k2-fsa/snowfall
|
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
||||||
assert hasattr(lats, 'phones')
|
# except that there are no repeated paths with the same word_seq
|
||||||
|
# within a sequence.
|
||||||
|
#
|
||||||
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||||
|
# multiplicities of each path.
|
||||||
|
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||||
|
#
|
||||||
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||||
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||||
|
# to the input path index.
|
||||||
|
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||||
|
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||||
|
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||||
|
)
|
||||||
|
|
||||||
assert not hasattr(lats, 'tokens')
|
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||||
lats.tokens = lats.phones
|
|
||||||
seqs = k2.index(lats.tokens, paths)
|
|
||||||
|
|
||||||
# Remove epsilons (0s) and -1 from word_seqs
|
# path_to_seq_map is a 1-D torch.Tensor.
|
||||||
seqs = k2.ragged.remove_values_leq(seqs, 0)
|
# path_to_seq_map[i] is the seq to which the i-th path
|
||||||
|
# belongs.
|
||||||
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id].
|
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||||
# But then number of pathsin each sequence may be different.
|
|
||||||
unique_seqs, _, _ = k2.ragged.unique_sequences(
|
|
||||||
seqs, need_num_repeats=False, need_new2old_indexes=False)
|
|
||||||
|
|
||||||
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
|
|
||||||
|
|
||||||
# Remove the seq axis.
|
# Remove the seq axis.
|
||||||
# Now unique_word_seqs has only two axes [path][word_id]
|
# Now unique_word_seq has only two axes [path][word]
|
||||||
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0)
|
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||||
|
|
||||||
fsas = k2.linear_fsa(unique_seqs)
|
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||||
|
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||||
|
|
||||||
return Nbest(fsa=fsas, shape=seq_to_path_shape)
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||||
|
|
||||||
|
# k2.compose() currently does not support b_to_a_map. To void
|
||||||
|
# replicating `lats`, we use k2.intersect_device here.
|
||||||
|
#
|
||||||
|
# lattice has token IDs as `labels` and word IDs as aux_labels, so we
|
||||||
|
# need to invert it here.
|
||||||
|
inv_lattice = k2.invert(lats)
|
||||||
|
|
||||||
|
# Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor)
|
||||||
|
# and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes)
|
||||||
|
|
||||||
|
# Remove its `aux_labels` since it is not needed in the
|
||||||
|
# following computation
|
||||||
|
# del inv_lattice.aux_labels
|
||||||
|
inv_lattice = k2.arc_sort(inv_lattice)
|
||||||
|
|
||||||
|
path_lattice = _intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=path_to_seq_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# path_lattice now has token IDs as `labels` and word IDS as aux_labels.
|
||||||
|
path_lattice = k2.invert(path_lattice)
|
||||||
|
|
||||||
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||||
|
|
||||||
|
# replace labels with tokens to remove repeat token IDs.
|
||||||
|
path_lattice.labels = path_lattice.tokens
|
||||||
|
|
||||||
|
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
|
||||||
|
|
||||||
|
n_best = k2.remove_epsilon(n_best)
|
||||||
|
|
||||||
|
n_best = k2.top_sort(k2.connect(n_best))
|
||||||
|
|
||||||
|
# now we have nbest lists with am_scores and lm_scores
|
||||||
|
return Nbest(fsa=n_best, shape=seq_to_path_shape)
|
||||||
|
|
||||||
|
@ -82,8 +82,11 @@ class Nbest(object):
|
|||||||
|
|
||||||
one_best = k2.remove_epsilon(one_best)
|
one_best = k2.remove_epsilon(one_best)
|
||||||
|
|
||||||
|
one_best = k2.top_sort(k2.connect(one_best))
|
||||||
|
|
||||||
return Nbest(fsa=one_best, shape=self.shape)
|
return Nbest(fsa=one_best, shape=self.shape)
|
||||||
|
|
||||||
|
|
||||||
def total_scores(self) -> k2.RaggedFloat:
|
def total_scores(self) -> k2.RaggedFloat:
|
||||||
'''Get total scores of the FSAs in this Nbest.
|
'''Get total scores of the FSAs in this Nbest.
|
||||||
|
|
||||||
@ -100,7 +103,7 @@ class Nbest(object):
|
|||||||
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
# If k2.RaggedDouble is wrapped, we can use double precision here.
|
||||||
return k2.RaggedFloat(self.shape, scores.float())
|
return k2.RaggedFloat(self.shape, scores.float())
|
||||||
|
|
||||||
def top_k(self, k: int) -> 'Nbest':
|
def top_k(self, k: int, scores: k2.RaggedFloat = None) -> 'Nbest':
|
||||||
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
|
||||||
in that each sequence (i.e., utterance) has the same number of
|
in that each sequence (i.e., utterance) has the same number of
|
||||||
paths (k).
|
paths (k).
|
||||||
@ -113,9 +116,13 @@ class Nbest(object):
|
|||||||
Args:
|
Args:
|
||||||
k:
|
k:
|
||||||
Number of paths in each utterance.
|
Number of paths in each utterance.
|
||||||
|
scores:
|
||||||
|
The scores using to select top-k.
|
||||||
Returns:
|
Returns:
|
||||||
Return a new Nbest with a regular shape.
|
Return a new Nbest with a regular shape.
|
||||||
'''
|
'''
|
||||||
|
ragged_scores = scores
|
||||||
|
if ragged_scores is None:
|
||||||
ragged_scores = self.total_scores()
|
ragged_scores = self.total_scores()
|
||||||
|
|
||||||
# indexes contains idx01's for self.shape
|
# indexes contains idx01's for self.shape
|
||||||
@ -140,6 +147,7 @@ class Nbest(object):
|
|||||||
|
|
||||||
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
|
top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
|
||||||
dim1=k)
|
dim1=k)
|
||||||
|
top_k_shape = top_k_shape.to(top_k_fsas.device)
|
||||||
return Nbest(top_k_fsas, top_k_shape)
|
return Nbest(top_k_fsas, top_k_shape)
|
||||||
|
|
||||||
|
|
||||||
@ -163,7 +171,7 @@ class Nbest(object):
|
|||||||
# indexes contains idx01's for self.shape
|
# indexes contains idx01's for self.shape
|
||||||
indexes = torch.arange(
|
indexes = torch.arange(
|
||||||
self.shape.num_elements(), dtype=torch.int32,
|
self.shape.num_elements(), dtype=torch.int32,
|
||||||
device=self.shape.device
|
device=self.fsa.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if sort:
|
if sort:
|
||||||
@ -176,9 +184,12 @@ class Nbest(object):
|
|||||||
|
|
||||||
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
ragged_indexes = k2.RaggedInt(self.shape, indexes)
|
||||||
|
|
||||||
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1)
|
padded_indexes = k2.ragged.pad(ragged_indexes,
|
||||||
|
value=-1)
|
||||||
|
|
||||||
# Select the idx01's of top-k paths of each utterance
|
# Select the idx01's of top-k paths of each utterance
|
||||||
|
max_num_fsa = padded_indexes.size()[1]
|
||||||
|
|
||||||
first_indexes = padded_indexes[:, :k].flatten().contiguous()
|
first_indexes = padded_indexes[:, :k].flatten().contiguous()
|
||||||
|
|
||||||
# Remove the padding elements
|
# Remove the padding elements
|
||||||
|
188
icefall/score_estimator.py
Normal file
188
icefall/score_estimator.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from icefall.utils import (
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
files = sorted(glob.glob(f"{path}/*.pt"))
|
||||||
|
if model == 'train':
|
||||||
|
self.files = files[0: int(len(files) * 0.8)]
|
||||||
|
elif model == 'dev':
|
||||||
|
self.files = files[int(len(files) * 0.8): int(len(files) * 0.9)]
|
||||||
|
elif mode == 'test':
|
||||||
|
self.files = files[int(len(files) * 0.9):]
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> torch.Tensor:
|
||||||
|
return torch.load(self.files[index])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.files)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCollateFunc:
|
||||||
|
def __call__(self, batch: List) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = torch.cat(batch)
|
||||||
|
return (x[:, 0:5], x[:, 5])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreEstimator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int = 5,
|
||||||
|
hidden_dim: int = 20,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Linear(
|
||||||
|
in_features=input_dim,
|
||||||
|
out_features=hidden_dim
|
||||||
|
)
|
||||||
|
self.output = nn.Linear(
|
||||||
|
in_features=hidden_dim,
|
||||||
|
out_features=2
|
||||||
|
)
|
||||||
|
self.sigmod = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = self.embedding(x)
|
||||||
|
x = self.sigmod(x)
|
||||||
|
x = self.output(x)
|
||||||
|
mean, var = x[:, 0], x[:, 1]
|
||||||
|
var = torch.exp(var)
|
||||||
|
return mean, var
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-dim",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Dim of input feature.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden-dim",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Neural number of didden layer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Batch size of dataloader.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Training epochs",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate",
|
||||||
|
type=float,
|
||||||
|
default=1e-4,
|
||||||
|
help="Learning rate.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp_dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("conformer_ctc/exp"),
|
||||||
|
help="Directory to store experiment data.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
torch.manual_seed(42)
|
||||||
|
torch.cuda.manual_seed(42)
|
||||||
|
|
||||||
|
setup_logger(f"{args.exp_dir}/rescore/log")
|
||||||
|
|
||||||
|
model = ScoreEstimator(
|
||||||
|
input_dim = args.input_dim,
|
||||||
|
hidden_dim = args.hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to("cuda")
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||||
|
loss_fn = nn.GaussianNLLLoss()
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
Dataset(f"{args.exp_dir}/rescore/feat", "train"),
|
||||||
|
collate_fn=DatasetCollateFunc(),
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
Dataset(f"{args.exp_dir}/rescore/feat", "dev"),
|
||||||
|
collate_fn=DatasetCollateFunc(),
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(args.epoch):
|
||||||
|
model.train()
|
||||||
|
training_loss = 0.0
|
||||||
|
step = 0
|
||||||
|
for x, y in train_dataloader:
|
||||||
|
mean, var = model(x.cuda())
|
||||||
|
loss = loss_fn(mean, y, var)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
training_loss += loss.item()
|
||||||
|
step += len(y)
|
||||||
|
training_loss /= step
|
||||||
|
|
||||||
|
dev_loss = 0.0
|
||||||
|
step = 0
|
||||||
|
model.eval()
|
||||||
|
for x, y in dev_dataloader:
|
||||||
|
mean, var = model(x.cuda())
|
||||||
|
loss = loss_fn(mean, y, var)
|
||||||
|
dev_loss += loss.item()
|
||||||
|
step += len(y)
|
||||||
|
dev_loss /= step
|
||||||
|
|
||||||
|
logging.info(f"Epoch {epoch} : training loss : {training_loss}, "
|
||||||
|
f"dev loss : {dev_loss}"
|
||||||
|
)
|
||||||
|
torch.save(
|
||||||
|
model.state_dict(),
|
||||||
|
f"{args.exp_dir}/rescore/epoch-{epoch}.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
main()
|
||||||
|
|
@ -411,6 +411,10 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
|||||||
assert keys.shape.dim0() == queries.shape.dim0(), \
|
assert keys.shape.dim0() == queries.shape.dim0(), \
|
||||||
f'Utterances number in keys and queries should be equal : \
|
f'Utterances number in keys and queries should be equal : \
|
||||||
{keys.shape.dim0()} vs {queries.shape.dim0()}'
|
{keys.shape.dim0()} vs {queries.shape.dim0()}'
|
||||||
|
assert keys.fsa.device == queries.fsa.device, \
|
||||||
|
f'Device of keys and queries should be equal : \
|
||||||
|
{keys.fsa.device} vs {queries.fsa.device}'
|
||||||
|
device = keys.fsa.device
|
||||||
|
|
||||||
# keys_tokens_shape [utt][path][token]
|
# keys_tokens_shape [utt][path][token]
|
||||||
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
|
keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape,
|
||||||
@ -430,11 +434,13 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
|||||||
# counts on key positions are ones
|
# counts on key positions are ones
|
||||||
keys_counts = k2.RaggedInt(keys_tokens_shape,
|
keys_counts = k2.RaggedInt(keys_tokens_shape,
|
||||||
torch.ones(keys_token_num,
|
torch.ones(keys_token_num,
|
||||||
dtype=torch.int32))
|
dtype=torch.int32,
|
||||||
|
device=device))
|
||||||
# counts on query positions are zeros
|
# counts on query positions are zeros
|
||||||
queries_counts = k2.RaggedInt(queries_tokens_shape,
|
queries_counts = k2.RaggedInt(queries_tokens_shape,
|
||||||
torch.zeros(queries_tokens_num,
|
torch.zeros(queries_tokens_num,
|
||||||
dtype=torch.int32))
|
dtype=torch.int32,
|
||||||
|
device=device))
|
||||||
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
|
counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values()
|
||||||
|
|
||||||
# scores on key positions are the scores inherted from nbest path
|
# scores on key positions are the scores inherted from nbest path
|
||||||
@ -442,7 +448,8 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
|||||||
# scores on query positions MUST be zeros
|
# scores on query positions MUST be zeros
|
||||||
queries_scores = k2.RaggedFloat(queries_tokens_shape,
|
queries_scores = k2.RaggedFloat(queries_tokens_shape,
|
||||||
torch.zeros(queries_tokens_num,
|
torch.zeros(queries_tokens_num,
|
||||||
dtype=torch.float32))
|
dtype=torch.float32,
|
||||||
|
device=device))
|
||||||
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
|
scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values()
|
||||||
|
|
||||||
# we didn't remove -1 labels before
|
# we didn't remove -1 labels before
|
||||||
@ -450,8 +457,16 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest,
|
|||||||
eos = -1
|
eos = -1
|
||||||
max_token = torch.max(torch.max(keys.fsa.labels),
|
max_token = torch.max(torch.max(keys.fsa.labels),
|
||||||
torch.max(queries.fsa.labels))
|
torch.max(queries.fsa.labels))
|
||||||
mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores,
|
mean, var, counts_out, ngram = k2.get_best_matching_stats(
|
||||||
counts, eos, min_token, max_token, max_order)
|
tokens.to(torch.device('cpu')), scores.to(torch.device('cpu')),
|
||||||
|
counts.to(torch.device('cpu')),
|
||||||
|
eos, min_token, max_token, max_order
|
||||||
|
)
|
||||||
|
|
||||||
|
mean = mean.to(device)
|
||||||
|
var = var.to(device)
|
||||||
|
counts_out = counts_out.to(device)
|
||||||
|
ngram = ngram.to(device)
|
||||||
|
|
||||||
queries_init_scores = queries.fsa.scores.clone()
|
queries_init_scores = queries.fsa.scores.clone()
|
||||||
# only return the stats on query positions
|
# only return the stats on query positions
|
||||||
|
Loading…
x
Reference in New Issue
Block a user