mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update lipnet-ctc-vsr
This commit is contained in:
parent
1abf255bdd
commit
f033379805
@ -386,6 +386,7 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
@ -439,7 +440,7 @@ def main():
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = LipNet()
|
||||
model = LipNet(num_classes=max_token_id+1)
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
|
@ -4,8 +4,9 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class LipNet(torch.nn.Module):
|
||||
def __init__(self, dropout_p=0.1):
|
||||
def __init__(self, num_classes, dropout_p=0.1):
|
||||
super(LipNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
|
||||
self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
||||
|
||||
@ -18,7 +19,7 @@ class LipNet(torch.nn.Module):
|
||||
self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
|
||||
self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
|
||||
|
||||
self.FC = nn.Linear(512, 28)
|
||||
self.FC = nn.Linear(512, self.num_classes)
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
@ -180,7 +180,6 @@ def get_params() -> AttributeDict:
|
||||
"anno_path": Path("download/GRID/GRID_align_txt"),
|
||||
"train_list": Path("download/GRID/unseen_train.txt"),
|
||||
"vid_padding": 75,
|
||||
"aud_padding": 200,
|
||||
"num_workers": 16,
|
||||
"batch_size": 120,
|
||||
}
|
||||
@ -503,13 +502,14 @@ def run(rank, world_size, args):
|
||||
tb_writer = None
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||
model = LipNet()
|
||||
model = LipNet(num_classes=max_token_id+1)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user