mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 448c28b3cc8d1b42179d4ac20989a980133b8f3f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
2d52221fcb
@ -366,13 +366,14 @@ def decode_dataset(
|
|||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
num_batches = len(dl)
|
# num_batches = len(dl)
|
||||||
except TypeError:
|
# except TypeError:
|
||||||
num_batches = "?"
|
# num_batches = "?"
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
batch = batch[0]
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
@ -399,9 +400,8 @@ def decode_dataset(
|
|||||||
num_cuts += len(batch["supervisions"]["text"])
|
num_cuts += len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
# batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -547,20 +547,19 @@ def main():
|
|||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
test_dls = [test_dl]
|
test_dls = [test_dl]
|
||||||
|
# for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
lexicon=lexicon,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
HLG=HLG,
|
|
||||||
H=H,
|
|
||||||
lexicon=lexicon,
|
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -22,9 +22,9 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import os
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -543,13 +543,9 @@ def train_one_epoch(
|
|||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
def run(world_size, args):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rank:
|
|
||||||
It is a value between 0 and `world_size-1`, which is
|
|
||||||
passed automatically by `mp.spawn()` in :func:`main`.
|
|
||||||
The node with rank 0 is responsible for saving checkpoint.
|
|
||||||
world_size:
|
world_size:
|
||||||
Number of GPUs for DDP training.
|
Number of GPUs for DDP training.
|
||||||
args:
|
args:
|
||||||
@ -560,13 +556,14 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(use_ddp_launch=True, master_addr=params.master_port)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
logging.info(params)
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
if local_rank == 0:
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and local_rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
@ -577,7 +574,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", local_rank)
|
||||||
|
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
@ -603,7 +600,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
torch.distributed.barrier() # Ensure all processes have the same model parameters
|
||||||
|
model = DDP(model, device_ids=[local_rank])
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
@ -629,7 +627,7 @@ def run(rank, world_size, args):
|
|||||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if local_rank == 0:
|
||||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
@ -644,12 +642,14 @@ def run(rank, world_size, args):
|
|||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
rank=rank,
|
rank=local_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
@ -668,10 +668,7 @@ def main():
|
|||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
if world_size > 1:
|
run(world_size=world_size, args=args)
|
||||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
|
||||||
else:
|
|
||||||
run(rank=0, world_size=1, args=args)
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
@ -23,6 +23,7 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from lhotse.cut import MonoCut
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -180,7 +181,34 @@ class AishellAsrDataModule:
|
|||||||
help="When enabled, select noise from MUSAN and mix it"
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
def to_dict(self, obj):
|
||||||
|
"""
|
||||||
|
Recursively convert an object and its nested objects to dictionaries.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [to_dict(item) for item in obj]
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: to_dict(value) for key, value in obj.items()}
|
||||||
|
elif hasattr(obj, '__dict__'):
|
||||||
|
return {key: to_dict(value) for key, value in obj.__dict__.items()}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported type: {type(obj)}")
|
||||||
|
|
||||||
|
def my_collate_fn(self, batch):
|
||||||
|
"""
|
||||||
|
Convert MonoCut to dict.
|
||||||
|
"""
|
||||||
|
return_batch = []
|
||||||
|
for item in batch:
|
||||||
|
if isinstance(item, MonoCut):
|
||||||
|
processed_item = self.to_dict(item)
|
||||||
|
return_batch.append(processed_item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
return_batch.append(item)
|
||||||
|
return return_batch
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
@ -353,9 +381,10 @@ class AishellAsrDataModule:
|
|||||||
)
|
)
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test,
|
test,
|
||||||
batch_size=None,
|
batch_size=100, # specified to some value
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
|
||||||
|
collate_fn=self.my_collate_fn
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user