do some changes

This commit is contained in:
luomingshuang 2022-06-07 22:30:50 +08:00
parent b4b3a848ed
commit 4c2cbff501
3 changed files with 18 additions and 12 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -88,7 +88,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=LilcomChunkyWriter, storage_type=ChunkedLilcomHdf5Writer,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
@ -97,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
min_duration=None, min_duration=None,
) )
cut_set.to_json(output_dir / cuts_filename) cut_set.to_file(output_dir / cuts_filename)
def get_args(): def get_args():

View File

@ -230,7 +230,7 @@ class Aishell4AsrDataModule:
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest_lazy( cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "cuts_musan.jsonl.gz" self.args.manifest_dir / "musan_cuts.jsonl.gz"
) )
transforms = [] transforms = []
@ -429,7 +429,7 @@ class Aishell4AsrDataModule:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir self.args.manifest_dir
/ "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz" / f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
) )
@lru_cache() @lru_cache()

View File

@ -64,6 +64,7 @@ from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from local.text_normalize import text_normalize
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -388,7 +389,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 1,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer # parameters for conformer
@ -612,13 +613,11 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
if type(y) == list: if type(y) == list:
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else: else:
y = y.to(device) y = y.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
@ -642,7 +641,6 @@ def compute_loss(
params.simple_loss_scale * simple_loss params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss + pruned_loss_scale * pruned_loss
) )
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -752,6 +750,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -869,8 +868,6 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -959,7 +956,15 @@ def run(rank, world_size, args):
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -972,7 +977,8 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_cuts = aishell4.dev_cuts() valid_cuts = aishell4.valid_cuts()
valid_cuts = valid_cuts.map(text_normalize_for_cut)
valid_dl = aishell4.valid_dataloaders(valid_cuts) valid_dl = aishell4.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics: