pruned-transducer-stateless2-for-wenetspeech

This commit is contained in:
luomingshuang 2022-04-13 14:58:50 +08:00
parent abcb0b31e5
commit 5319429d76

View File

@ -63,10 +63,9 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search,
@ -81,6 +80,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -203,7 +203,7 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
@ -222,8 +222,6 @@ def decode_one_batch(
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -260,8 +258,8 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens])
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -270,16 +268,16 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens])
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens])
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -303,7 +301,7 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -323,7 +321,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -335,8 +333,6 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search.
@ -366,7 +362,7 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -438,7 +434,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -473,12 +469,9 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -514,26 +507,24 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
test_net_cuts = wenetspeech.test_net_cuts()
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_clean_cuts = librispeech.test_clean_cuts() test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts)
test_other_cuts = librispeech.test_other_cuts() test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_sets = ["TEST_NET", "TEST_MEETING"]
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_dl = [test_net_dl, test_meeting_dl]
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(
params=params, params=params,
test_set_name=test_set, test_set_name=test_set,