mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add remove long short
This commit is contained in:
parent
1600f7db95
commit
bb07b65e45
@ -262,7 +262,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 5000,
|
||||
"valid_interval": 10000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -441,9 +441,6 @@ def compute_loss(
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
# make sure feature T no more than 3000, otherwise cut it
|
||||
if feature.shape[2] > 3000:
|
||||
feature = feature[:, :, :3000]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
@ -822,7 +819,23 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts())
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 15 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 15.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 15.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts(remove_short_and_long_utt))
|
||||
valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user