From f03337980571f1d14ca4489b1c3153bd694f16b8 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 24 Dec 2021 14:07:57 +0800 Subject: [PATCH] update lipnet-ctc-vsr --- egs/grid/AVSR/lipnet_ctc_vsr/decode.py | 3 ++- egs/grid/AVSR/lipnet_ctc_vsr/model.py | 5 +++-- egs/grid/AVSR/lipnet_ctc_vsr/train.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py index d4f3910a3..6b535f156 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/decode.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/decode.py @@ -386,6 +386,7 @@ def main(): logging.info(params) lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): @@ -439,7 +440,7 @@ def main(): else: G = None - model = LipNet() + model = LipNet(num_classes=max_token_id+1) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/model.py b/egs/grid/AVSR/lipnet_ctc_vsr/model.py index ce246899c..4fb70b269 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/model.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/model.py @@ -4,8 +4,9 @@ import torch.nn as nn class LipNet(torch.nn.Module): - def __init__(self, dropout_p=0.1): + def __init__(self, num_classes, dropout_p=0.1): super(LipNet, self).__init__() + self.num_classes = num_classes self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2)) self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2)) @@ -18,7 +19,7 @@ class LipNet(torch.nn.Module): self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True) self.gru2 = nn.GRU(512, 256, 1, bidirectional=True) - self.FC = nn.Linear(512, 28) + self.FC = nn.Linear(512, self.num_classes) self.dropout_p = dropout_p self.relu = nn.ReLU(inplace=True) diff --git a/egs/grid/AVSR/lipnet_ctc_vsr/train.py b/egs/grid/AVSR/lipnet_ctc_vsr/train.py index 9caa9b55a..b0aab49ee 100644 --- a/egs/grid/AVSR/lipnet_ctc_vsr/train.py +++ b/egs/grid/AVSR/lipnet_ctc_vsr/train.py @@ -180,7 +180,6 @@ def get_params() -> AttributeDict: "anno_path": Path("download/GRID/GRID_align_txt"), "train_list": Path("download/GRID/unseen_train.txt"), "vid_padding": 75, - "aud_padding": 200, "num_workers": 16, "batch_size": 120, } @@ -503,13 +502,14 @@ def run(rank, world_size, args): tb_writer = None lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - model = LipNet() + model = LipNet(num_classes=max_token_id+1) checkpoints = load_checkpoint_if_available(params=params, model=model)