Merge 448c28b3cc8d1b42179d4ac20989a980133b8f3f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
chanchongleong 2025-07-24 11:00:25 +08:00 committed by GitHub
commit 2d52221fcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 40 deletions

View File

@ -366,13 +366,14 @@ def decode_dataset(
num_cuts = 0 num_cuts = 0
try: # try:
num_batches = len(dl) # num_batches = len(dl)
except TypeError: # except TypeError:
num_batches = "?" # num_batches = "?"
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
batch = batch[0]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
@ -399,9 +400,8 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"]) num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}" # batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results return results
@ -547,20 +547,19 @@ def main():
test_sets = ["test"] test_sets = ["test"]
test_dls = [test_dl] test_dls = [test_dl]
# for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)
for test_set, test_dl in zip(test_sets, test_dls): save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!") logging.info("Done!")

View File

@ -22,9 +22,9 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
import os
import k2 import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from conformer import Conformer from conformer import Conformer
@ -543,13 +543,9 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def run(rank, world_size, args): def run(world_size, args):
""" """
Args: Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size: world_size:
Number of GPUs for DDP training. Number of GPUs for DDP training.
args: args:
@ -560,13 +556,14 @@ def run(rank, world_size, args):
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(use_ddp_launch=True, master_addr=params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
logging.info(params) local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
logging.info(params)
if args.tensorboard and rank == 0: if args.tensorboard and local_rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else: else:
tb_writer = None tb_writer = None
@ -577,7 +574,7 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", local_rank)
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon, lexicon=lexicon,
@ -603,7 +600,8 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) torch.distributed.barrier() # Ensure all processes have the same model parameters
model = DDP(model, device_ids=[local_rank])
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -629,7 +627,7 @@ def run(rank, world_size, args):
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if local_rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch
@ -644,12 +642,14 @@ def run(rank, world_size, args):
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
) )
if world_size > 1:
torch.distributed.barrier()
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
rank=rank, rank=local_rank,
) )
logging.info("Done!") logging.info("Done!")
@ -668,10 +668,7 @@ def main():
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: run(world_size=world_size, args=args)
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1) torch.set_num_threads(1)

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from lhotse.cut import MonoCut
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
@ -180,6 +181,33 @@ class AishellAsrDataModule:
help="When enabled, select noise from MUSAN and mix it" help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ", "with training dataset. ",
) )
def to_dict(self, obj):
"""
Recursively convert an object and its nested objects to dictionaries.
"""
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, list):
return [to_dict(item) for item in obj]
elif isinstance(obj, dict):
return {key: to_dict(value) for key, value in obj.items()}
elif hasattr(obj, '__dict__'):
return {key: to_dict(value) for key, value in obj.__dict__.items()}
else:
raise TypeError(f"Unsupported type: {type(obj)}")
def my_collate_fn(self, batch):
"""
Convert MonoCut to dict.
"""
return_batch = []
for item in batch:
if isinstance(item, MonoCut):
processed_item = self.to_dict(item)
return_batch.append(processed_item)
elif isinstance(item, dict):
return_batch.append(item)
return return_batch
def train_dataloaders( def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
@ -353,9 +381,10 @@ class AishellAsrDataModule:
) )
test_dl = DataLoader( test_dl = DataLoader(
test, test,
batch_size=None, batch_size=100, # specified to some value
sampler=sampler, sampler=sampler,
num_workers=self.args.num_workers, num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
collate_fn=self.my_collate_fn
) )
return test_dl return test_dl