update lipnet-ctc-vsr

This commit is contained in:
Mingshuang Luo 2021-12-24 14:07:57 +08:00
parent 1abf255bdd
commit f033379805
3 changed files with 7 additions and 5 deletions

View File

@ -386,6 +386,7 @@ def main():
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
@ -439,7 +440,7 @@ def main():
else:
G = None
model = LipNet()
model = LipNet(num_classes=max_token_id+1)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:

View File

@ -4,8 +4,9 @@ import torch.nn as nn
class LipNet(torch.nn.Module):
def __init__(self, dropout_p=0.1):
def __init__(self, num_classes, dropout_p=0.1):
super(LipNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
@ -18,7 +19,7 @@ class LipNet(torch.nn.Module):
self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
self.FC = nn.Linear(512, 28)
self.FC = nn.Linear(512, self.num_classes)
self.dropout_p = dropout_p
self.relu = nn.ReLU(inplace=True)

View File

@ -180,7 +180,6 @@ def get_params() -> AttributeDict:
"anno_path": Path("download/GRID/GRID_align_txt"),
"train_list": Path("download/GRID/unseen_train.txt"),
"vid_padding": 75,
"aud_padding": 200,
"num_workers": 16,
"batch_size": 120,
}
@ -503,13 +502,14 @@ def run(rank, world_size, args):
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = LipNet()
model = LipNet(num_classes=max_token_id+1)
checkpoints = load_checkpoint_if_available(params=params, model=model)