diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py index 6b535f156..3ed36f339 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py @@ -440,7 +440,7 @@ def main(): else: G = None - model = LipNet(num_classes=max_token_id+1) + model = LipNet(num_classes=max_token_id + 1) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/train.py b/egs/grid/AVSR/lipnet_ctc_vsr/train.py index b0aab49ee..d8b7b3315 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/train.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/train.py @@ -509,7 +509,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - model = LipNet(num_classes=max_token_id+1) + model = LipNet(num_classes=max_token_id + 1) checkpoints = load_checkpoint_if_available(params=params, model=model)