remove streaming/greedy_search results folder

This commit is contained in:
root 2024-08-01 16:00:27 +09:00 committed by root
parent eebe6add4c
commit cfef53dcbe
6 changed files with 148 additions and 171 deletions

View File

@ -103,9 +103,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -134,7 +135,6 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
make_pad_mask,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
@ -205,7 +205,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default="data/lang_char", default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )
@ -371,6 +371,7 @@ def get_parser():
modified_beam_search_LODR. modified_beam_search_LODR.
""", """,
) )
<<<<<<< HEAD
parser.add_argument( parser.add_argument(
"--skip-scoring", "--skip-scoring",
@ -398,7 +399,7 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: Tokenizer, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
@ -477,10 +478,9 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG": elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG( hyp_tokens = fast_beam_search_nbest_LG(
model=model, model=model,
@ -492,7 +492,6 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -507,10 +506,9 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle": elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle( hyp_tokens = fast_beam_search_nbest_oracle(
model=model, model=model,
@ -523,19 +521,17 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -543,10 +539,9 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph, context_graph=context_graph,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion( hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model, model=model,
@ -556,7 +551,7 @@ def decode_one_batch(
LM=LM, LM=LM,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR": elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR( hyp_tokens = modified_beam_search_LODR(
model=model, model=model,
@ -569,7 +564,7 @@ def decode_one_batch(
context_graph=context_graph, context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp)) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore": elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)] lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore( ans_dict = modified_beam_search_lm_rescore(
@ -615,7 +610,7 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.text2word(sp.decode(hyp))) hyps.append(sp.decode(hyp).split())
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}" prefix = f"{params.decoding_method}"
@ -658,14 +653,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: Tokenizer, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None, context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -724,7 +719,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = sp.text2word(ref_text) ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -742,7 +737,7 @@ def save_asr_output(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
""" """
Save text produced by ASR. Save text produced by ASR.
""" """
@ -760,7 +755,7 @@ def save_wer_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
): ):
""" """
Save WER and per-utterance word alignments. Save WER and per-utterance word alignments.
""" """
@ -797,8 +792,8 @@ def save_wer_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser) LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -872,8 +867,6 @@ def main():
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
) )
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -886,9 +879,10 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type) sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/prepare_lang_char.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -1060,13 +1054,20 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
for subdir in ["valid"]: test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=reazonspeech_corpus.test_dataloaders( dl=test_dl,
getattr(reazonspeech_corpus, f"{subdir}_cuts")()
),
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -1080,20 +1081,9 @@ def main():
save_asr_output( save_asr_output(
params=params, params=params,
test_set_name=subdir, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
# with (
# params.res_dir
# / (
# f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
# f"_{params.avg}_{params.epoch}.cer"
# )
# ).open("w") as fout:
# if len(tot_err) == 1:
# fout.write(f"{tot_err[0][1]}")
# else:
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
if not params.skip_scoring: if not params.skip_scoring:
save_wer_results( save_wer_results(

View File

@ -1,2 +0,0 @@
2024-07-29 17:52:11,668 INFO [streaming_decode.py:736] Decoding started
2024-07-29 17:52:11,669 INFO [streaming_decode.py:742] Device: cuda:0

View File

@ -1,2 +0,0 @@
2024-07-29 17:54:22,556 INFO [streaming_decode.py:736] Decoding started
2024-07-29 17:54:22,556 INFO [streaming_decode.py:742] Device: cuda:0

View File

@ -1,2 +0,0 @@
2024-07-29 17:55:15,276 INFO [streaming_decode.py:736] Decoding started
2024-07-29 17:55:15,277 INFO [streaming_decode.py:742] Device: cuda:0

View File

@ -1,2 +0,0 @@
2024-07-29 17:59:02,028 INFO [streaming_decode.py:736] Decoding started
2024-07-29 17:59:02,029 INFO [streaming_decode.py:742] Device: cuda:0

View File

@ -1,5 +0,0 @@
2024-07-29 18:01:06,736 INFO [streaming_decode.py:736] Decoding started
2024-07-29 18:01:06,736 INFO [streaming_decode.py:742] Device: cuda:0
2024-07-29 18:01:06,740 INFO [streaming_decode.py:753] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.24.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '8f976a1e1407e330e2a233d68f81b1eb5269fdaa', 'k2-git-date': 'Thu Jun 6 02:13:08 2024', 'lhotse-version': '1.26.0.dev+git.bd12d5d.clean', 'torch-version': '2.3.1+cu121', 'torch-cuda-available': True, 'torch-cuda-version': '12.1', 'python-version': '3.10', 'icefall-git-branch': 'jp-streaming', 'icefall-git-sha1': '4af81af-dirty', 'icefall-git-date': 'Thu Jul 18 22:05:59 2024', 'icefall-path': '/root/tmp/icefall', 'k2-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/k2/__init__.py', 'lhotse-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'KDA00', 'IP address': '192.168.0.1'}, 'epoch': 28, 'iter': 0, 'avg': 15, 'use_averaged_model': True, 'exp_dir': PosixPath('zipformer'), 'bpe_model': 'data/lang_bpe_500/bpe.model', 'lang_dir': PosixPath('data/lang_char'), 'decoding_method': 'greedy_search', 'num_active_paths': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 32, 'context_size': 2, 'num_decode_streams': 2000, 'num_encoder_layers': '2,2,3,4,3,2', 'downsampling_factor': '1,2,4,8,4,2', 'feedforward_dim': '512,768,1024,1536,1024,768', 'num_heads': '4,4,4,8,4,4', 'encoder_dim': '192,256,384,512,384,256', 'query_head_dim': '32', 'value_head_dim': '12', 'pos_head_dim': '4', 'pos_dim': 48, 'encoder_unmasked_dim': '192,192,256,256,256,192', 'cnn_module_kernel': '31,31,15,15,15,31', 'decoder_dim': 512, 'joiner_dim': 512, 'causal': True, 'chunk_size': '32', 'left_context_frames': '256', 'use_transducer': True, 'use_ctc': False, 'manifest_dir': PosixPath('data/manifests'), 'max_duration': 200.0, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'drop_last': True, 'return_cuts': False, 'num_workers': 2, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': False, 'lang': PosixPath('data/lang_char'), 'lang_type': None, 'res_dir': PosixPath('zipformer/streaming/greedy_search'), 'suffix': 'epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model', 'blank_id': 0, 'unk_id': 2990, 'vocab_size': 2992}
2024-07-29 18:01:06,740 INFO [streaming_decode.py:755] About to create model
2024-07-29 18:01:07,118 INFO [streaming_decode.py:822] Calculating the averaged model over epoch range from 13 (excluded) to 28