add remove long short

This commit is contained in:
Yuekai Zhang 2024-01-26 10:18:10 +08:00
parent 1600f7db95
commit bb07b65e45

View File

@ -262,7 +262,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
"valid_interval": 10000,
"env_info": get_env_info(),
}
)
@ -441,9 +441,6 @@ def compute_loss(
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
# make sure feature T no more than 3000, otherwise cut it
if feature.shape[2] > 3000:
feature = feature[:, :, :3000]
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
@ -822,7 +819,23 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts())
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15 seconds
#
# Caution: There is a reason to select 15.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 15.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts(remove_short_and_long_utt))
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)