mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Filter the training data of T < S for Wenet train recipe (#753)
* filter the case of T < S for training data * fix style issues * fix style issues * fix style issues Co-authored-by: 张云斌 <zhangyunbin@MacBook-Air.local>
This commit is contained in:
parent
02c18ba4b2
commit
e83409cbe5
@ -861,15 +861,41 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts = wenetspeech.valid_cuts()
|
valid_cuts = wenetspeech.valid_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 15.0 seconds
|
# Keep only utterances with duration between 1 second and 10 seconds
|
||||||
#
|
#
|
||||||
# Caution: There is a reason to select 15.0 here. Please see
|
# Caution: There is a reason to select 10.0 here. Please see
|
||||||
# ../local/display_manifest_statistics.py
|
# ../local/display_manifest_statistics.py
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 15.0
|
if c.duration < 1.0 or c.duration > 10.0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# In pruned RNN-T, we require that T >= S
|
||||||
|
# where T is the number of feature frames after subsampling
|
||||||
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
|
# In ./conformer.py, the conv module uses the following expression
|
||||||
|
# for subsampling
|
||||||
|
T = ((c.num_frames - 1) // 2 - 1) // 2
|
||||||
|
tokens = c.supervisions[0].text.replace(" ", "")
|
||||||
|
|
||||||
|
if T < len(tokens):
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
|
f"Number of frames (after subsampling): {T}. "
|
||||||
|
f"Text: {c.supervisions[0].text}. "
|
||||||
|
f"Tokens: {tokens}. "
|
||||||
|
f"Number of tokens: {len(tokens)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
@ -1006,15 +1006,41 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts = wenetspeech.valid_cuts()
|
valid_cuts = wenetspeech.valid_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 15.0 seconds
|
# Keep only utterances with duration between 1 second and 10 seconds
|
||||||
#
|
#
|
||||||
# Caution: There is a reason to select 15.0 here. Please see
|
# Caution: There is a reason to select 10.0 here. Please see
|
||||||
# ../local/display_manifest_statistics.py
|
# ../local/display_manifest_statistics.py
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 15.0
|
if c.duration < 1.0 or c.duration > 10.0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# In pruned RNN-T, we require that T >= S
|
||||||
|
# where T is the number of feature frames after subsampling
|
||||||
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
|
# In ./conformer.py, the conv module uses the following expression
|
||||||
|
# for subsampling
|
||||||
|
T = ((c.num_frames - 1) // 2 - 1) // 2
|
||||||
|
tokens = c.supervisions[0].text.replace(" ", "")
|
||||||
|
|
||||||
|
if T < len(tokens):
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
|
f"Number of frames (after subsampling): {T}. "
|
||||||
|
f"Text: {c.supervisions[0].text}. "
|
||||||
|
f"Tokens: {tokens}. "
|
||||||
|
f"Number of tokens: {len(tokens)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user