mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix librispeech train.py (#211)
* fix librispeech train.py * remove note
This commit is contained in:
parent
be1c86b06c
commit
70a3c56a18
@ -601,14 +601,14 @@ def run(rank, world_size, args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
if "lang_bpe" in params.lang_dir:
|
if "lang_bpe" in str(params.lang_dir):
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
device=device,
|
device=device,
|
||||||
sos_token="<sos/eos>",
|
sos_token="<sos/eos>",
|
||||||
eos_token="<sos/eos>",
|
eos_token="<sos/eos>",
|
||||||
)
|
)
|
||||||
elif "lang_phone" in params.lang_dir:
|
elif "lang_phone" in str(params.lang_dir):
|
||||||
assert params.att_rate == 0, (
|
assert params.att_rate == 0, (
|
||||||
"Attention decoder training does not support phone lang dirs "
|
"Attention decoder training does not support phone lang dirs "
|
||||||
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||||
@ -650,9 +650,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# Note: find_unused_parameters=True is needed in case we
|
model = DDP(model, device_ids=[rank])
|
||||||
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user