From 149ccd1b85337131a196966e8f23632c1ebe3f55 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 24 Dec 2021 14:16:12 +0800 Subject: [PATCH] update lipnet-ctc-vsr for grid --- egs/grid/AVSR/lipnet_ctc_vsr/decode.py | 2 +- egs/grid/AVSR/lipnet_ctc_vsr/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py index 6b535f156..3ed36f339 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py @@ -440,7 +440,7 @@ def main(): else: G = None - model = LipNet(num_classes=max_token_id+1) + model = LipNet(num_classes=max_token_id + 1) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/train.py b/egs/grid/AVSR/lipnet_ctc_vsr/train.py index b0aab49ee..d8b7b3315 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/train.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/train.py @@ -509,7 +509,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - model = LipNet(num_classes=max_token_id+1) + model = LipNet(num_classes=max_token_id + 1) checkpoints = load_checkpoint_if_available(params=params, model=model)