refactor branch exchange in cr-ctc (#1954)

This commit is contained in:
Zengwei Yao 2025-05-27 12:09:59 +08:00 committed by GitHub
parent 021e1a8846
commit ffb7d05635
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -210,10 +210,10 @@ class AsrModel(nn.Module):
)
# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
batch_size = ctc_output.shape[0]
assert batch_size % 2 == 0, batch_size
# exchange: [x1, x2] -> [x2, x1]
exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,