mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
added text norm for other decoding scripts
This commit is contained in:
parent
5492a6a5e2
commit
cd96f635c3
@ -122,7 +122,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriTTSAsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from train import add_model_arguments, get_model, get_params, normalize_text
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -949,13 +949,13 @@ def main():
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriTTSAsrDataModule(args)
|
||||
libritts = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
|
||||
test_other_cuts = libritts.test_other_cuts().map(normalize_text)
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = libritts.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
@ -80,6 +80,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriTTSAsrDataModule
|
||||
from k2 import SymbolTable
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
from train import normalize_text
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
@ -290,13 +291,13 @@ def main():
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriTTSAsrDataModule(args)
|
||||
libritts = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
|
||||
test_other_cuts = libritts.test_other_cuts().map(normalize_text)
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = libritts.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
@ -52,7 +52,7 @@ from streaming_beam_search import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from train import add_model_arguments, get_model, get_params, normalize_text
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -866,8 +866,8 @@ def main():
|
||||
|
||||
libritts = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = libritts.test_clean_cuts()
|
||||
test_other_cuts = libritts.test_other_cuts()
|
||||
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
|
||||
test_other_cuts = libritts.test_other_cuts().map(normalize_text)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
|
Loading…
x
Reference in New Issue
Block a user