mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
update lipnet-ctc-vsr for grid
This commit is contained in:
parent
f033379805
commit
149ccd1b85
@ -440,7 +440,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
model = LipNet(num_classes=max_token_id+1)
|
model = LipNet(num_classes=max_token_id + 1)
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
|
@ -509,7 +509,7 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||||
model = LipNet(num_classes=max_token_id+1)
|
model = LipNet(num_classes=max_token_id + 1)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user