mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
update comments
This commit is contained in:
parent
9f94984dbb
commit
17768da017
@ -197,7 +197,7 @@ class HiMiaWuwDataModule:
|
|||||||
default="_7_01",
|
default="_7_01",
|
||||||
help="""channel of HI_MIA train dataset.
|
help="""channel of HI_MIA train dataset.
|
||||||
All channels are used if it is set "all".
|
All channels are used if it is set "all".
|
||||||
Please refer to state 6 in prepare.sh for its meaning and other
|
Please refer to stage 6 in prepare.sh for its meaning and other
|
||||||
potential values. Currently, Only "_7_01" is verified.
|
potential values. Currently, Only "_7_01" is verified.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -207,7 +207,7 @@ class HiMiaWuwDataModule:
|
|||||||
default="_7_01",
|
default="_7_01",
|
||||||
help="""channel of HI_MIA dev dataset.
|
help="""channel of HI_MIA dev dataset.
|
||||||
All channels are used if it is set "all".
|
All channels are used if it is set "all".
|
||||||
Please refer to state 6 in prepare.sh for its meaning and other
|
Please refer to stage 6 in prepare.sh for its meaning and other
|
||||||
potential values. Currently, Only "_7_01" is verified.
|
potential values. Currently, Only "_7_01" is verified.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str:
|
||||||
"""
|
"""
|
||||||
A graph starts with blank/unknown and following by wakeup word.
|
A graph starts with blank/unknown and following by wakeup word.
|
||||||
|
|
||||||
@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
|||||||
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
||||||
It should not contain 0 and 1.
|
It should not contain 0 and 1.
|
||||||
We assume 0 is for blank and 1 is for unknown.
|
We assume 0 is for blank and 1 is for unknown.
|
||||||
|
Returns:
|
||||||
|
Returns a finite-state transducer in string format,
|
||||||
|
used as a decoding graph.
|
||||||
|
Arcs are separated with "\n".
|
||||||
"""
|
"""
|
||||||
assert 0 not in wakeup_word_tokens
|
assert 0 not in wakeup_word_tokens
|
||||||
assert 1 not in wakeup_word_tokens
|
assert 1 not in wakeup_word_tokens
|
||||||
|
@ -23,15 +23,15 @@ from typing import List, Tuple
|
|||||||
class WakeupWordTokenizer(object):
|
class WakeupWordTokenizer(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
wakeup_word: str = "",
|
wakeup_word: str,
|
||||||
wakeup_word_tokens: List[int] = None,
|
wakeup_word_tokens: List[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
wakeup_word: content of positive samples.
|
wakeup_word: content of positive samples.
|
||||||
A sample will be treated as a negative sample unless its context
|
A sample will be treated as a negative sample unless its content
|
||||||
is exactly the same to key_words.
|
is exactly the same to key_words.
|
||||||
wakeup_word_tokens: A list if int represents token ids of wakeup_word.
|
wakeup_word_tokens: A list of int representing token ids of wakeup_word.
|
||||||
For example: the pronunciation of "你好米雅" is
|
For example: the pronunciation of "你好米雅" is
|
||||||
"n i h ao m i y a".
|
"n i h ao m i y a".
|
||||||
Suppose we are using following lexicon:
|
Suppose we are using following lexicon:
|
||||||
@ -67,7 +67,7 @@ class WakeupWordTokenizer(object):
|
|||||||
def texts_to_token_ids(
|
def texts_to_token_ids(
|
||||||
self, texts: List[str]
|
self, texts: List[str]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
"""Convert a list of texts to a list of k2.Fsa based texts.
|
"""Convert a list of texts to parameters needed by CTC loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts:
|
texts:
|
||||||
@ -76,7 +76,7 @@ class WakeupWordTokenizer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tuple of 3 elements.
|
Return a tuple of 3 elements.
|
||||||
The first one is torch.Tensor(List[List[int]]),
|
The first one is torch.Tensor(List[List[int]]),
|
||||||
each List[int] is tokens sequence for each a reference text.
|
each List[int] is tokens sequence for each reference text.
|
||||||
|
|
||||||
The second one is number of tokens for each sample,
|
The second one is number of tokens for each sample,
|
||||||
mainly used by CTC loss.
|
mainly used by CTC loss.
|
||||||
@ -89,13 +89,13 @@ class WakeupWordTokenizer(object):
|
|||||||
number_positive_samples = 0
|
number_positive_samples = 0
|
||||||
for utt_text in texts:
|
for utt_text in texts:
|
||||||
if utt_text == self.wakeup_word:
|
if utt_text == self.wakeup_word:
|
||||||
batch_token_ids.append(self.wakeup_word_tokens)
|
batch_token_ids.extend(self.wakeup_word_tokens)
|
||||||
target_lengths.append(self.positive_number_tokens)
|
target_lengths.append(self.positive_number_tokens)
|
||||||
number_positive_samples += 1
|
number_positive_samples += 1
|
||||||
else:
|
else:
|
||||||
batch_token_ids.append(self.negative_word_tokens)
|
batch_token_ids.extend(self.negative_word_tokens)
|
||||||
target_lengths.append(self.negative_number_tokens)
|
target_lengths.append(self.negative_number_tokens)
|
||||||
|
|
||||||
target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids)))
|
target = torch.tensor(batch_token_ids)
|
||||||
target_lengths = torch.tensor(target_lengths)
|
target_lengths = torch.tensor(target_lengths)
|
||||||
return target, target_lengths, number_positive_samples
|
return target, target_lengths, number_positive_samples
|
||||||
|
@ -531,6 +531,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model = Tdnn(params.feature_dim, params.num_class)
|
model = Tdnn(params.feature_dim, params.num_class)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user