remove find_unused_parameters=True and use bypass module

This commit is contained in:
Daniel Povey 2023-05-29 11:54:21 +08:00
parent 38246c8690
commit 16e51a7deb
2 changed files with 4 additions and 3 deletions

View File

@ -275,7 +275,7 @@ class Subformer(EncoderInterface):
weights))
else:
assert s == ')' # upsample
assert s == ')' # upsample and bypass
indexes, weights, x_orig = downsample_info.pop()
_attn_offset = attn_offsets.pop()
_pos_emb = pos_embs.pop()
@ -283,6 +283,8 @@ class Subformer(EncoderInterface):
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
bypass = self.bypasses[i]
x = bypass(x_orig, x)
# d = self.output_downsampling_factor
# lengths = (x_lens + d - 1) // d

View File

@ -953,8 +953,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank])
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(