mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix dataloader in decode.py
This commit is contained in:
parent
47565959d9
commit
42513d2e98
@ -273,8 +273,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -353,9 +352,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
@ -415,10 +412,7 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -543,9 +537,7 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -578,8 +570,7 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -630,9 +621,7 @@ def main():
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -660,9 +649,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -689,9 +678,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -750,9 +739,7 @@ def main():
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
# word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
# word_table = None
|
||||
@ -760,89 +747,18 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Note: Please use "pip install webdataset==0.1.103"
|
||||
# for installing the webdataset.
|
||||
import glob
|
||||
import os
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||
|
||||
dev = "dev"
|
||||
test_net = "test_net"
|
||||
test_meeting = "test_meeting"
|
||||
dev_cuts = wenetspeech.valid_cuts()
|
||||
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||
|
||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||
os.makedirs(dev, exist_ok=True)
|
||||
dev_cuts = wenetspeech.valid_cuts()
|
||||
export_to_webdataset(
|
||||
dev_cuts,
|
||||
output_path=f"{dev}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||
|
||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||
os.makedirs(test_net, exist_ok=True)
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
export_to_webdataset(
|
||||
test_net_cuts,
|
||||
output_path=f"{test_net}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
|
||||
os.makedirs(test_meeting, exist_ok=True)
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
export_to_webdataset(
|
||||
test_meeting_cuts,
|
||||
output_path=f"{test_meeting}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
print("done")
|
||||
|
||||
dev_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||
]
|
||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||
dev_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
test_net_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_net_webdataset = CutSet.from_webdataset(
|
||||
test_net_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
test_meeting_shards = [
|
||||
str(path)
|
||||
for path in sorted(
|
||||
glob.glob(os.path.join(test_meeting, "shared-*.tar"))
|
||||
)
|
||||
]
|
||||
cuts_test_meeting_webdataset = CutSet.from_webdataset(
|
||||
test_meeting_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
@ -83,9 +83,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
@ -271,8 +269,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -295,8 +292,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -645,11 +641,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -697,9 +689,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -862,9 +852,7 @@ def train_one_epoch(
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
if cur_grad_scale < 1.0 or (
|
||||
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
||||
):
|
||||
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
@ -882,11 +870,7 @@ def train_one_epoch(
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if params.use_fp16
|
||||
else ""
|
||||
)
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -897,9 +881,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
@ -907,10 +889,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if (
|
||||
batch_idx % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1001,8 +980,15 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(), lr=params.base_lr, clipping_scale=2.0
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
@ -1021,7 +1007,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user