Merge remote-tracking branch 'dan/master' into dataset-parallel-augmentation-example

This commit is contained in:
Fangjun Kuang 2025-05-29 17:04:50 +08:00
commit 516696f3e4
4 changed files with 12 additions and 8 deletions

View File

@ -383,3 +383,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[vctk]: egs/vctk/TTS [vctk]: egs/vctk/TTS
[ljspeech]: egs/ljspeech/TTS [ljspeech]: egs/ljspeech/TTS
[libritts_tts]: egs/libritts/TTS [libritts_tts]: egs/libritts/TTS
## Acknowledgements
Some contributors to this project were supported by Xiaomi Corporation. Others were supported by National Science Foundation CCRI award 2120435. This is not an exhaustive list of sources of support.

View File

@ -140,8 +140,8 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
type=str2bool, type=str2bool,
default=False, default=False,
help=""" help="""
Whether to adapt. If true, we will mix 5% of the new data Whether to adapt. If true, we will mix 5%% of the new data
with 95% of the original data to fine-tune. This is useful with 95%% of the original data to fine-tune. This is useful
if you want to maintain the performance on the original domain if you want to maintain the performance on the original domain
""", """,
) )
@ -1134,7 +1134,7 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr: .2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )

View File

@ -210,10 +210,10 @@ class AsrModel(nn.Module):
) )
# Compute consistency regularization loss # Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0) batch_size = ctc_output.shape[0]
exchanged_targets = torch.cat( assert batch_size % 2 == 0, batch_size
[exchanged_targets[1], exchanged_targets[0]], dim=0 # exchange: [x1, x2] -> [x2, x1]
) # exchange: [x1, x2] -> [x2, x1] exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
cr_loss = nn.functional.kl_div( cr_loss = nn.functional.kl_div(
input=ctc_output, input=ctc_output,
target=exchanged_targets, target=exchanged_targets,

View File

@ -186,7 +186,7 @@ class AttributeDict(dict):
tmp = {} tmp = {}
for k, v in self.items(): for k, v in self.items():
# PosixPath is ont JSON serializable # PosixPath is ont JSON serializable
if isinstance(v, pathlib.Path) or isinstance(v, torch.device): if isinstance(v, (pathlib.Path, torch.device, torch.dtype)):
v = str(v) v = str(v)
tmp[k] = v tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True) return json.dumps(tmp, indent=indent, sort_keys=True)