From 4c2cbff50139bfe1834b20765d6822fa8e981973 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Tue, 7 Jun 2022 22:30:50 +0800 Subject: [PATCH] do some changes --- .../ASR/local/compute_fbank_aishell4.py | 6 +++--- .../asr_datamodule.py | 4 ++-- .../ASR/pruned_transducer_stateless5/train.py | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 590b56984..09f885636 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -29,7 +29,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -88,7 +88,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomChunkyWriter, + storage_type=ChunkedLilcomHdf5Writer, ) logging.info("About splitting cuts into smaller chunks") @@ -97,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): min_duration=None, ) - cut_set.to_json(output_dir / cuts_filename) + cut_set.to_file(output_dir / cuts_filename) def get_args(): diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index e8079b63f..4e3cb87d1 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -230,7 +230,7 @@ class Aishell4AsrDataModule: """ logging.info("About to get Musan cuts") cuts_musan = load_manifest_lazy( - self.args.manifest_dir / "cuts_musan.jsonl.gz" + self.args.manifest_dir / "musan_cuts.jsonl.gz" ) transforms = [] @@ -429,7 +429,7 @@ class Aishell4AsrDataModule: logging.info("About to get train cuts") return load_manifest_lazy( self.args.manifest_dir - / "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz" + / f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz" ) @lru_cache() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index a8fdb5c05..e306ce68b 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -64,6 +64,7 @@ from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from local.text_normalize import text_normalize from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -388,7 +389,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 50, + "log_interval": 1, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for conformer @@ -612,13 +613,11 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) if type(y) == list: y = k2.RaggedTensor(y).to(device) else: y = y.to(device) - with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, @@ -642,7 +641,6 @@ def compute_loss( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss ) - assert loss.requires_grad == is_training info = MetricsTracker() @@ -752,6 +750,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + # print(batch["supervisions"]) with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -869,8 +868,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -959,7 +956,15 @@ def run(rank, world_size, args): # the threshold return 1.0 <= c.duration <= 20.0 + def text_normalize_for_cut(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = text.strip("\n").strip("\t") + c.supervisions[0].text = text_normalize(text) + return c + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(text_normalize_for_cut) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -972,7 +977,8 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = aishell4.dev_cuts() + valid_cuts = aishell4.valid_cuts() + valid_cuts = valid_cuts.map(text_normalize_for_cut) valid_dl = aishell4.valid_dataloaders(valid_cuts) if not params.print_diagnostics: