mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix dataloader in decode.py
This commit is contained in:
parent
47565959d9
commit
42513d2e98
@ -273,8 +273,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -353,9 +352,7 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
x=feature, x_lens=feature_lens
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
@ -415,10 +412,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif (
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
params.decoding_method == "greedy_search"
|
|
||||||
and params.max_sym_per_frame == 1
|
|
||||||
):
|
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -543,9 +537,7 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -578,8 +570,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -630,9 +621,7 @@ def main():
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -660,9 +649,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -689,9 +678,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -750,9 +739,7 @@ def main():
|
|||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
# word_table = None
|
# word_table = None
|
||||||
decoding_graph = k2.trivial_graph(
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
params.vocab_size - 1, device=device
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
# word_table = None
|
# word_table = None
|
||||||
@ -760,89 +747,18 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# Note: Please use "pip install webdataset==0.1.103"
|
|
||||||
# for installing the webdataset.
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
from lhotse import CutSet
|
|
||||||
from lhotse.dataset.webdataset import export_to_webdataset
|
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
dev = "dev"
|
|
||||||
test_net = "test_net"
|
|
||||||
test_meeting = "test_meeting"
|
|
||||||
|
|
||||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
|
||||||
os.makedirs(dev, exist_ok=True)
|
|
||||||
dev_cuts = wenetspeech.valid_cuts()
|
dev_cuts = wenetspeech.valid_cuts()
|
||||||
export_to_webdataset(
|
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
|
||||||
dev_cuts,
|
|
||||||
output_path=f"{dev}/shared-%d.tar",
|
|
||||||
shard_size=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
|
||||||
os.makedirs(test_net, exist_ok=True)
|
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
export_to_webdataset(
|
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
|
||||||
test_net_cuts,
|
|
||||||
output_path=f"{test_net}/shared-%d.tar",
|
|
||||||
shard_size=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
|
|
||||||
os.makedirs(test_meeting, exist_ok=True)
|
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
export_to_webdataset(
|
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
||||||
test_meeting_cuts,
|
|
||||||
output_path=f"{test_meeting}/shared-%d.tar",
|
|
||||||
shard_size=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("done")
|
|
||||||
|
|
||||||
dev_shards = [
|
|
||||||
str(path)
|
|
||||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
|
||||||
]
|
|
||||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
|
||||||
dev_shards,
|
|
||||||
split_by_worker=True,
|
|
||||||
split_by_node=True,
|
|
||||||
shuffle_shards=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_net_shards = [
|
|
||||||
str(path)
|
|
||||||
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
|
||||||
]
|
|
||||||
cuts_test_net_webdataset = CutSet.from_webdataset(
|
|
||||||
test_net_shards,
|
|
||||||
split_by_worker=True,
|
|
||||||
split_by_node=True,
|
|
||||||
shuffle_shards=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_meeting_shards = [
|
|
||||||
str(path)
|
|
||||||
for path in sorted(
|
|
||||||
glob.glob(os.path.join(test_meeting, "shared-*.tar"))
|
|
||||||
)
|
|
||||||
]
|
|
||||||
cuts_test_meeting_webdataset = CutSet.from_webdataset(
|
|
||||||
test_meeting_shards,
|
|
||||||
split_by_worker=True,
|
|
||||||
split_by_node=True,
|
|
||||||
shuffle_shards=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
|
||||||
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
|
|
||||||
|
|
||||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|||||||
@ -83,9 +83,7 @@ from icefall.hooks import register_inf_check_hooks
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
@ -271,8 +269,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -295,8 +292,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||||
"part.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -645,11 +641,7 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = (
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
model.device
|
|
||||||
if isinstance(model, DDP)
|
|
||||||
else next(model.parameters()).device
|
|
||||||
)
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -697,9 +689,7 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -862,9 +852,7 @@ def train_one_epoch(
|
|||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
if cur_grad_scale < 1.0 or (
|
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
|
||||||
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
|
||||||
):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
@ -882,11 +870,7 @@ def train_one_epoch(
|
|||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
f"grad_scale: {scaler._scale.item()}"
|
|
||||||
if params.use_fp16
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -897,9 +881,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale",
|
"train/grad_scale",
|
||||||
@ -907,10 +889,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
batch_idx % params.valid_interval == 0
|
|
||||||
and not params.print_diagnostics
|
|
||||||
):
|
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1001,8 +980,15 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
|
parameters_names = []
|
||||||
|
parameters_names.append(
|
||||||
|
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||||
|
)
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
model.parameters(), lr=params.base_lr, clipping_scale=2.0
|
model.parameters(),
|
||||||
|
lr=params.base_lr,
|
||||||
|
clipping_scale=2.0,
|
||||||
|
parameters_names=parameters_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
@ -1021,7 +1007,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user