diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index cb0bd5c2d..e77327146 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -601,14 +601,14 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - if "lang_bpe" in params.lang_dir: + if "lang_bpe" in str(params.lang_dir): graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, device=device, sos_token="", eos_token="", ) - elif "lang_phone" in params.lang_dir: + elif "lang_phone" in str(params.lang_dir): assert params.att_rate == 0, ( "Attention decoder training does not support phone lang dirs " "at this time due to a missing symbol. Set --att-rate=0 " @@ -652,7 +652,7 @@ def run(rank, world_size, args): if world_size > 1: # Note: find_unused_parameters=True is needed in case we # want to set params.att_rate = 0 (i.e. att decoder is not trained) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP(model, device_ids=[rank]) optimizer = Noam( model.parameters(),