mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update lipnet-ctc-vsr
This commit is contained in:
parent
1abf255bdd
commit
f033379805
@ -386,6 +386,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -439,7 +440,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
model = LipNet()
|
model = LipNet(num_classes=max_token_id+1)
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
|
@ -4,8 +4,9 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
class LipNet(torch.nn.Module):
|
class LipNet(torch.nn.Module):
|
||||||
def __init__(self, dropout_p=0.1):
|
def __init__(self, num_classes, dropout_p=0.1):
|
||||||
super(LipNet, self).__init__()
|
super(LipNet, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
|
self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
|
||||||
self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ class LipNet(torch.nn.Module):
|
|||||||
self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
|
self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
|
||||||
self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
|
self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
|
||||||
|
|
||||||
self.FC = nn.Linear(512, 28)
|
self.FC = nn.Linear(512, self.num_classes)
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
@ -180,7 +180,6 @@ def get_params() -> AttributeDict:
|
|||||||
"anno_path": Path("download/GRID/GRID_align_txt"),
|
"anno_path": Path("download/GRID/GRID_align_txt"),
|
||||||
"train_list": Path("download/GRID/unseen_train.txt"),
|
"train_list": Path("download/GRID/unseen_train.txt"),
|
||||||
"vid_padding": 75,
|
"vid_padding": 75,
|
||||||
"aud_padding": 200,
|
|
||||||
"num_workers": 16,
|
"num_workers": 16,
|
||||||
"batch_size": 120,
|
"batch_size": 120,
|
||||||
}
|
}
|
||||||
@ -503,13 +502,14 @@ def run(rank, world_size, args):
|
|||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||||
model = LipNet()
|
model = LipNet(num_classes=max_token_id+1)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user