mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
do some changes
This commit is contained in:
parent
b4b3a848ed
commit
4c2cbff501
@ -29,7 +29,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -88,7 +88,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
# when an executor is specified, make more partitions
|
# when an executor is specified, make more partitions
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=LilcomChunkyWriter,
|
storage_type=ChunkedLilcomHdf5Writer,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About splitting cuts into smaller chunks")
|
logging.info("About splitting cuts into smaller chunks")
|
||||||
@ -97,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
min_duration=None,
|
min_duration=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
cut_set.to_json(output_dir / cuts_filename)
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -230,7 +230,7 @@ class Aishell4AsrDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest_lazy(
|
cuts_musan = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_musan.jsonl.gz"
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
@ -429,7 +429,7 @@ class Aishell4AsrDataModule:
|
|||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
/ f"aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
@ -64,6 +64,7 @@ from joiner import Joiner
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from local.text_normalize import text_normalize
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -388,7 +389,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 1,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
@ -612,13 +613,11 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
if type(y) == list:
|
if type(y) == list:
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -642,7 +641,6 @@ def compute_loss(
|
|||||||
params.simple_loss_scale * simple_loss
|
params.simple_loss_scale * simple_loss
|
||||||
+ pruned_loss_scale * pruned_loss
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -752,6 +750,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
# print(batch["supervisions"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -869,8 +868,6 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
if params.full_libri is False:
|
|
||||||
params.valid_interval = 1600
|
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -959,7 +956,15 @@ def run(rank, world_size, args):
|
|||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
|
def text_normalize_for_cut(c: Cut):
|
||||||
|
# Text normalize for each sample
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
text = text.strip("\n").strip("\t")
|
||||||
|
c.supervisions[0].text = text_normalize(text)
|
||||||
|
return c
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
train_cuts = train_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -972,7 +977,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = aishell4.dev_cuts()
|
valid_cuts = aishell4.valid_cuts()
|
||||||
|
valid_cuts = valid_cuts.map(text_normalize_for_cut)
|
||||||
valid_dl = aishell4.valid_dataloaders(valid_cuts)
|
valid_dl = aishell4.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user