Fix librispeech train.py (#211)

* fix librispeech train.py

* remove note
This commit is contained in:
Wang, Guanbo 2022-02-09 03:42:28 -05:00 committed by GitHub
parent be1c86b06c
commit 70a3c56a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -601,14 +601,14 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
if "lang_bpe" in params.lang_dir: if "lang_bpe" in str(params.lang_dir):
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
device=device, device=device,
sos_token="<sos/eos>", sos_token="<sos/eos>",
eos_token="<sos/eos>", eos_token="<sos/eos>",
) )
elif "lang_phone" in params.lang_dir: elif "lang_phone" in str(params.lang_dir):
assert params.att_rate == 0, ( assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs " "Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 " "at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
@ -650,9 +650,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
# Note: find_unused_parameters=True is needed in case we model = DDP(model, device_ids=[rank])
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),