mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
refactor branch exchange in cr-ctc (#1954)
This commit is contained in:
parent
021e1a8846
commit
ffb7d05635
@ -210,10 +210,10 @@ class AsrModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compute consistency regularization loss
|
# Compute consistency regularization loss
|
||||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
batch_size = ctc_output.shape[0]
|
||||||
exchanged_targets = torch.cat(
|
assert batch_size % 2 == 0, batch_size
|
||||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
# exchange: [x1, x2] -> [x2, x1]
|
||||||
) # exchange: [x1, x2] -> [x2, x1]
|
exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
|
||||||
cr_loss = nn.functional.kl_div(
|
cr_loss = nn.functional.kl_div(
|
||||||
input=ctc_output,
|
input=ctc_output,
|
||||||
target=exchanged_targets,
|
target=exchanged_targets,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user