mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
update comments
This commit is contained in:
parent
9f94984dbb
commit
17768da017
@ -197,7 +197,7 @@ class HiMiaWuwDataModule:
|
||||
default="_7_01",
|
||||
help="""channel of HI_MIA train dataset.
|
||||
All channels are used if it is set "all".
|
||||
Please refer to state 6 in prepare.sh for its meaning and other
|
||||
Please refer to stage 6 in prepare.sh for its meaning and other
|
||||
potential values. Currently, Only "_7_01" is verified.
|
||||
""",
|
||||
)
|
||||
@ -207,7 +207,7 @@ class HiMiaWuwDataModule:
|
||||
default="_7_01",
|
||||
help="""channel of HI_MIA dev dataset.
|
||||
All channels are used if it is set "all".
|
||||
Please refer to state 6 in prepare.sh for its meaning and other
|
||||
Please refer to stage 6 in prepare.sh for its meaning and other
|
||||
potential values. Currently, Only "_7_01" is verified.
|
||||
""",
|
||||
)
|
||||
|
@ -19,7 +19,7 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
||||
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str:
|
||||
"""
|
||||
A graph starts with blank/unknown and following by wakeup word.
|
||||
|
||||
@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
||||
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
||||
It should not contain 0 and 1.
|
||||
We assume 0 is for blank and 1 is for unknown.
|
||||
Returns:
|
||||
Returns a finite-state transducer in string format,
|
||||
used as a decoding graph.
|
||||
Arcs are separated with "\n".
|
||||
"""
|
||||
assert 0 not in wakeup_word_tokens
|
||||
assert 1 not in wakeup_word_tokens
|
||||
|
@ -23,15 +23,15 @@ from typing import List, Tuple
|
||||
class WakeupWordTokenizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
wakeup_word: str = "",
|
||||
wakeup_word_tokens: List[int] = None,
|
||||
wakeup_word: str,
|
||||
wakeup_word_tokens: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
wakeup_word: content of positive samples.
|
||||
A sample will be treated as a negative sample unless its context
|
||||
A sample will be treated as a negative sample unless its content
|
||||
is exactly the same to key_words.
|
||||
wakeup_word_tokens: A list if int represents token ids of wakeup_word.
|
||||
wakeup_word_tokens: A list of int representing token ids of wakeup_word.
|
||||
For example: the pronunciation of "你好米雅" is
|
||||
"n i h ao m i y a".
|
||||
Suppose we are using following lexicon:
|
||||
@ -67,7 +67,7 @@ class WakeupWordTokenizer(object):
|
||||
def texts_to_token_ids(
|
||||
self, texts: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Convert a list of texts to a list of k2.Fsa based texts.
|
||||
"""Convert a list of texts to parameters needed by CTC loss.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
@ -76,7 +76,7 @@ class WakeupWordTokenizer(object):
|
||||
Returns:
|
||||
Return a tuple of 3 elements.
|
||||
The first one is torch.Tensor(List[List[int]]),
|
||||
each List[int] is tokens sequence for each a reference text.
|
||||
each List[int] is tokens sequence for each reference text.
|
||||
|
||||
The second one is number of tokens for each sample,
|
||||
mainly used by CTC loss.
|
||||
@ -89,13 +89,13 @@ class WakeupWordTokenizer(object):
|
||||
number_positive_samples = 0
|
||||
for utt_text in texts:
|
||||
if utt_text == self.wakeup_word:
|
||||
batch_token_ids.append(self.wakeup_word_tokens)
|
||||
batch_token_ids.extend(self.wakeup_word_tokens)
|
||||
target_lengths.append(self.positive_number_tokens)
|
||||
number_positive_samples += 1
|
||||
else:
|
||||
batch_token_ids.append(self.negative_word_tokens)
|
||||
batch_token_ids.extend(self.negative_word_tokens)
|
||||
target_lengths.append(self.negative_number_tokens)
|
||||
|
||||
target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids)))
|
||||
target = torch.tensor(batch_token_ids)
|
||||
target_lengths = torch.tensor(target_lengths)
|
||||
return target, target_lengths, number_positive_samples
|
||||
|
@ -531,6 +531,9 @@ def run(rank, world_size, args):
|
||||
|
||||
model = Tdnn(params.feature_dim, params.num_class)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user