do some changes

This commit is contained in:
luomingshuang 2022-06-07 22:30:50 +08:00
parent b4b3a848ed
commit 4c2cbff501
3 changed files with 18 additions and 12 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -88,7 +88,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
storage_type=ChunkedLilcomHdf5Writer,
)
logging.info("About splitting cuts into smaller chunks")
@ -97,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
min_duration=None,
)
cut_set.to_json(output_dir / cuts_filename)
cut_set.to_file(output_dir / cuts_filename)
def get_args():

View File

@ -230,7 +230,7 @@ class Aishell4AsrDataModule:
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "cuts_musan.jsonl.gz"
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
@ -429,7 +429,7 @@ class Aishell4AsrDataModule:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir
/ "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
)
@lru_cache()

View File

@ -64,6 +64,7 @@ from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from local.text_normalize import text_normalize
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -388,7 +389,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"log_interval": 1,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
@ -612,13 +613,11 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else:
y = y.to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
@ -642,7 +641,6 @@ def compute_loss(
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -752,6 +750,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
# print(batch["supervisions"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -869,8 +868,6 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -959,7 +956,15 @@ def run(rank, world_size, args):
# the threshold
return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -972,7 +977,8 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = aishell4.dev_cuts()
valid_cuts = aishell4.valid_cuts()
valid_cuts = valid_cuts.map(text_normalize_for_cut)
valid_dl = aishell4.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: