update lipnet-ctc-vsr for grid

This commit is contained in:
Mingshuang Luo 2021-12-24 14:16:12 +08:00
parent f033379805
commit 149ccd1b85
2 changed files with 2 additions and 2 deletions

View File

@ -440,7 +440,7 @@ def main():
else:
G = None
model = LipNet(num_classes=max_token_id+1)
model = LipNet(num_classes=max_token_id + 1)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:

View File

@ -509,7 +509,7 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = LipNet(num_classes=max_token_id+1)
model = LipNet(num_classes=max_token_id + 1)
checkpoints = load_checkpoint_if_available(params=params, model=model)