mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
black
This commit is contained in:
parent
2f0d3d7ae3
commit
d5252a4157
@ -56,11 +56,15 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--transcript", type=str, help="Training transcript.",
|
"--transcript",
|
||||||
|
type=str,
|
||||||
|
help="Training transcript.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size", type=int, help="Vocabulary size for BPE training",
|
"--vocab-size",
|
||||||
|
type=int,
|
||||||
|
help="Vocabulary size for BPE training",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -215,7 +215,9 @@ class LibriHeavyAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -357,10 +359,13 @@ class LibriHeavyAsrDataModule:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms, return_cuts=self.args.return_cuts,
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
@ -382,11 +387,16 @@ class LibriHeavyAsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False,
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
|
@ -174,7 +174,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir", type=str, default="zipformer/exp", help="The experiment dir",
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -349,7 +352,9 @@ def decode_one_batch(
|
|||||||
pad_len = 30
|
pad_len = 30
|
||||||
feature_lens += pad_len
|
feature_lens += pad_len
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature, pad=(0, 0, 0, pad_len), value=LOG_EPS,
|
feature,
|
||||||
|
pad=(0, 0, 0, pad_len),
|
||||||
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||||
@ -399,7 +404,9 @@ def decode_one_batch(
|
|||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -427,7 +434,9 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -773,7 +782,9 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params, test_set_name=test_set, results_dict=results_dict,
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -255,7 +255,10 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ctc", type=str2bool, default=False, help="If True, use CTC head.",
|
"--use-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use CTC head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -265,7 +268,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -283,7 +289,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs", type=int, default=30, help="Number of epochs to train.",
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -391,7 +400,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ctc-loss-scale", type=float, default=0.2, help="Scale for CTC loss.",
|
"--ctc-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -853,7 +865,11 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params, model=model, sp=sp, batch=batch, is_training=False,
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
batch=batch,
|
||||||
|
is_training=False,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -943,7 +959,11 @@ def train_one_epoch(
|
|||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params, model=model, sp=sp, batch=batch, is_training=True,
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -974,7 +994,9 @@ def train_one_epoch(
|
|||||||
and params.batch_idx_train % params.average_period == 0
|
and params.batch_idx_train % params.average_period == 0
|
||||||
):
|
):
|
||||||
update_averaged_model(
|
update_averaged_model(
|
||||||
params=params, model_cur=model, model_avg=model_avg,
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -994,7 +1016,9 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_k,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_fp16:
|
||||||
@ -1156,7 +1180,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1297,7 +1321,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor,
|
batch: dict,
|
||||||
|
params: AttributeDict,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
@ -1344,7 +1370,11 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params, model=model, sp=sp, batch=batch, is_training=True,
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user