update comments

This commit is contained in:
glynpu 2023-03-17 14:26:50 +08:00
parent 9f94984dbb
commit 17768da017
4 changed files with 19 additions and 12 deletions

View File

@ -197,7 +197,7 @@ class HiMiaWuwDataModule:
default="_7_01",
help="""channel of HI_MIA train dataset.
All channels are used if it is set "all".
Please refer to state 6 in prepare.sh for its meaning and other
Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""",
)
@ -207,7 +207,7 @@ class HiMiaWuwDataModule:
default="_7_01",
help="""channel of HI_MIA dev dataset.
All channels are used if it is set "all".
Please refer to state 6 in prepare.sh for its meaning and other
Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified.
""",
)

View File

@ -19,7 +19,7 @@
from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str:
"""
A graph starts with blank/unknown and following by wakeup word.
@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
It should not contain 0 and 1.
We assume 0 is for blank and 1 is for unknown.
Returns:
Returns a finite-state transducer in string format,
used as a decoding graph.
Arcs are separated with "\n".
"""
assert 0 not in wakeup_word_tokens
assert 1 not in wakeup_word_tokens

View File

@ -23,15 +23,15 @@ from typing import List, Tuple
class WakeupWordTokenizer(object):
def __init__(
self,
wakeup_word: str = "",
wakeup_word_tokens: List[int] = None,
wakeup_word: str,
wakeup_word_tokens: List[int],
) -> None:
"""
Args:
wakeup_word: content of positive samples.
A sample will be treated as a negative sample unless its context
A sample will be treated as a negative sample unless its content
is exactly the same to key_words.
wakeup_word_tokens: A list if int represents token ids of wakeup_word.
wakeup_word_tokens: A list of int representing token ids of wakeup_word.
For example: the pronunciation of "你好米雅" is
"n i h ao m i y a".
Suppose we are using following lexicon:
@ -67,7 +67,7 @@ class WakeupWordTokenizer(object):
def texts_to_token_ids(
self, texts: List[str]
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a list of texts to a list of k2.Fsa based texts.
"""Convert a list of texts to parameters needed by CTC loss.
Args:
texts:
@ -76,7 +76,7 @@ class WakeupWordTokenizer(object):
Returns:
Return a tuple of 3 elements.
The first one is torch.Tensor(List[List[int]]),
each List[int] is tokens sequence for each a reference text.
each List[int] is tokens sequence for each reference text.
The second one is number of tokens for each sample,
mainly used by CTC loss.
@ -89,13 +89,13 @@ class WakeupWordTokenizer(object):
number_positive_samples = 0
for utt_text in texts:
if utt_text == self.wakeup_word:
batch_token_ids.append(self.wakeup_word_tokens)
batch_token_ids.extend(self.wakeup_word_tokens)
target_lengths.append(self.positive_number_tokens)
number_positive_samples += 1
else:
batch_token_ids.append(self.negative_word_tokens)
batch_token_ids.extend(self.negative_word_tokens)
target_lengths.append(self.negative_number_tokens)
target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids)))
target = torch.tensor(batch_token_ids)
target_lengths = torch.tensor(target_lengths)
return target, target_lengths, number_positive_samples

View File

@ -531,6 +531,9 @@ def run(rank, world_size, args):
model = Tdnn(params.feature_dim, params.num_class)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)