mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove find_unused_parameters=True and use bypass module
This commit is contained in:
parent
38246c8690
commit
16e51a7deb
@ -275,7 +275,7 @@ class Subformer(EncoderInterface):
|
|||||||
weights))
|
weights))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert s == ')' # upsample
|
assert s == ')' # upsample and bypass
|
||||||
indexes, weights, x_orig = downsample_info.pop()
|
indexes, weights, x_orig = downsample_info.pop()
|
||||||
_attn_offset = attn_offsets.pop()
|
_attn_offset = attn_offsets.pop()
|
||||||
_pos_emb = pos_embs.pop()
|
_pos_emb = pos_embs.pop()
|
||||||
@ -283,6 +283,8 @@ class Subformer(EncoderInterface):
|
|||||||
|
|
||||||
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
|
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
|
||||||
|
|
||||||
|
bypass = self.bypasses[i]
|
||||||
|
x = bypass(x_orig, x)
|
||||||
|
|
||||||
# d = self.output_downsampling_factor
|
# d = self.output_downsampling_factor
|
||||||
# lengths = (x_lens + d - 1) // d
|
# lengths = (x_lens + d - 1) // d
|
||||||
|
|||||||
@ -953,8 +953,7 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank],
|
model = DDP(model, device_ids=[rank])
|
||||||
find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
get_parameter_groups_with_lrs(
|
get_parameter_groups_with_lrs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user