mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
refactor branch exchange in cr-ctc (#1954)
This commit is contained in:
parent
021e1a8846
commit
ffb7d05635
@ -210,10 +210,10 @@ class AsrModel(nn.Module):
|
||||
)
|
||||
|
||||
# Compute consistency regularization loss
|
||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
||||
exchanged_targets = torch.cat(
|
||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||
) # exchange: [x1, x2] -> [x2, x1]
|
||||
batch_size = ctc_output.shape[0]
|
||||
assert batch_size % 2 == 0, batch_size
|
||||
# exchange: [x1, x2] -> [x2, x1]
|
||||
exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
|
||||
cr_loss = nn.functional.kl_div(
|
||||
input=ctc_output,
|
||||
target=exchanged_targets,
|
||||
|
Loading…
x
Reference in New Issue
Block a user