Fix dataloader in decode.py

This commit is contained in:
pkufool 2023-05-16 11:33:11 +08:00
parent 47565959d9
commit 42513d2e98
2 changed files with 37 additions and 135 deletions

View File

@ -273,8 +273,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -353,9 +352,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
@ -415,10 +412,7 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -543,9 +537,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -578,8 +570,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -630,9 +621,7 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -660,9 +649,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -689,9 +678,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -750,9 +739,7 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
# word_table = None # word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
# word_table = None # word_table = None
@ -760,89 +747,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev, exist_ok=True)
dev_cuts = wenetspeech.valid_cuts() dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset( dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net, exist_ok=True)
test_net_cuts = wenetspeech.test_net_cuts() test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset( test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting, exist_ok=True)
test_meeting_cuts = wenetspeech.test_meeting_cuts() test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset( test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
print("done")
dev_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(
glob.glob(os.path.join(test_meeting, "shared-*.tar"))
)
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]

View File

@ -83,9 +83,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
@ -271,8 +269,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -295,8 +292,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -645,11 +641,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -697,9 +689,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -862,9 +852,7 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or ( if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
cur_grad_scale < 8.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
@ -882,11 +870,7 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ ( + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -897,9 +881,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", "train/grad_scale",
@ -907,10 +889,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
if ( if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -1001,8 +980,15 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam( optimizer = ScaledAdam(
model.parameters(), lr=params.base_lr, clipping_scale=2.0 model.parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=parameters_names,
) )
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)