diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index c8df3c5d0..1a2c934fe 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -141,7 +141,9 @@ def get_parser(): help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt' + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. """, ) @@ -836,7 +838,7 @@ def run(rank, world_size, args): diagnostic.print_diagnostics() break - if epoch % params.save_every_n == 0: + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" save_checkpoint( filename=filename,