mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge remote-tracking branch 'dan/master' into dataset-parallel-augmentation-example
This commit is contained in:
commit
516696f3e4
@ -383,3 +383,7 @@ Please see: [:
|
|||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""
|
help="""
|
||||||
Whether to adapt. If true, we will mix 5% of the new data
|
Whether to adapt. If true, we will mix 5%% of the new data
|
||||||
with 95% of the original data to fine-tune. This is useful
|
with 95%% of the original data to fine-tune. This is useful
|
||||||
if you want to maintain the performance on the original domain
|
if you want to maintain the performance on the original domain
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
@ -210,10 +210,10 @@ class AsrModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compute consistency regularization loss
|
# Compute consistency regularization loss
|
||||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
batch_size = ctc_output.shape[0]
|
||||||
exchanged_targets = torch.cat(
|
assert batch_size % 2 == 0, batch_size
|
||||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
# exchange: [x1, x2] -> [x2, x1]
|
||||||
) # exchange: [x1, x2] -> [x2, x1]
|
exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
|
||||||
cr_loss = nn.functional.kl_div(
|
cr_loss = nn.functional.kl_div(
|
||||||
input=ctc_output,
|
input=ctc_output,
|
||||||
target=exchanged_targets,
|
target=exchanged_targets,
|
||||||
|
@ -186,7 +186,7 @@ class AttributeDict(dict):
|
|||||||
tmp = {}
|
tmp = {}
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
# PosixPath is ont JSON serializable
|
# PosixPath is ont JSON serializable
|
||||||
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
if isinstance(v, (pathlib.Path, torch.device, torch.dtype)):
|
||||||
v = str(v)
|
v = str(v)
|
||||||
tmp[k] = v
|
tmp[k] = v
|
||||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user