mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
do some changes
This commit is contained in:
parent
b4b3a848ed
commit
4c2cbff501
@ -29,7 +29,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -88,7 +88,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
)
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
@ -97,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
||||
min_duration=None,
|
||||
)
|
||||
|
||||
cut_set.to_json(output_dir / cuts_filename)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -230,7 +230,7 @@ class Aishell4AsrDataModule:
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_musan.jsonl.gz"
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
|
||||
transforms = []
|
||||
@ -429,7 +429,7 @@ class Aishell4AsrDataModule:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
||||
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
|
@ -64,6 +64,7 @@ from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from local.text_normalize import text_normalize
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
@ -388,7 +389,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"log_interval": 1,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for conformer
|
||||
@ -612,13 +613,11 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
else:
|
||||
y = y.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
@ -642,7 +641,6 @@ def compute_loss(
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -752,6 +750,7 @@ def train_one_epoch(
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# print(batch["supervisions"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -869,8 +868,6 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
@ -959,7 +956,15 @@ def run(rank, world_size, args):
|
||||
# the threshold
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
def text_normalize_for_cut(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip("\n").strip("\t")
|
||||
c.supervisions[0].text = text_normalize(text)
|
||||
return c
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_cuts = train_cuts.map(text_normalize_for_cut)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -972,7 +977,8 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = aishell4.dev_cuts()
|
||||
valid_cuts = aishell4.valid_cuts()
|
||||
valid_cuts = valid_cuts.map(text_normalize_for_cut)
|
||||
valid_dl = aishell4.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
|
Loading…
x
Reference in New Issue
Block a user