update comments

This commit is contained in:
glynpu 2023-03-17 14:26:50 +08:00
parent 9f94984dbb
commit 17768da017
4 changed files with 19 additions and 12 deletions

View File

@ -197,7 +197,7 @@ class HiMiaWuwDataModule:
default="_7_01", default="_7_01",
help="""channel of HI_MIA train dataset. help="""channel of HI_MIA train dataset.
All channels are used if it is set "all". All channels are used if it is set "all".
Please refer to state 6 in prepare.sh for its meaning and other Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified. potential values. Currently, Only "_7_01" is verified.
""", """,
) )
@ -207,7 +207,7 @@ class HiMiaWuwDataModule:
default="_7_01", default="_7_01",
help="""channel of HI_MIA dev dataset. help="""channel of HI_MIA dev dataset.
All channels are used if it is set "all". All channels are used if it is set "all".
Please refer to state 6 in prepare.sh for its meaning and other Please refer to stage 6 in prepare.sh for its meaning and other
potential values. Currently, Only "_7_01" is verified. potential values. Currently, Only "_7_01" is verified.
""", """,
) )

View File

@ -19,7 +19,7 @@
from typing import List from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str:
""" """
A graph starts with blank/unknown and following by wakeup word. A graph starts with blank/unknown and following by wakeup word.
@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
It should not contain 0 and 1. It should not contain 0 and 1.
We assume 0 is for blank and 1 is for unknown. We assume 0 is for blank and 1 is for unknown.
Returns:
Returns a finite-state transducer in string format,
used as a decoding graph.
Arcs are separated with "\n".
""" """
assert 0 not in wakeup_word_tokens assert 0 not in wakeup_word_tokens
assert 1 not in wakeup_word_tokens assert 1 not in wakeup_word_tokens

View File

@ -23,15 +23,15 @@ from typing import List, Tuple
class WakeupWordTokenizer(object): class WakeupWordTokenizer(object):
def __init__( def __init__(
self, self,
wakeup_word: str = "", wakeup_word: str,
wakeup_word_tokens: List[int] = None, wakeup_word_tokens: List[int],
) -> None: ) -> None:
""" """
Args: Args:
wakeup_word: content of positive samples. wakeup_word: content of positive samples.
A sample will be treated as a negative sample unless its context A sample will be treated as a negative sample unless its content
is exactly the same to key_words. is exactly the same to key_words.
wakeup_word_tokens: A list if int represents token ids of wakeup_word. wakeup_word_tokens: A list of int representing token ids of wakeup_word.
For example: the pronunciation of "你好米雅" is For example: the pronunciation of "你好米雅" is
"n i h ao m i y a". "n i h ao m i y a".
Suppose we are using following lexicon: Suppose we are using following lexicon:
@ -67,7 +67,7 @@ class WakeupWordTokenizer(object):
def texts_to_token_ids( def texts_to_token_ids(
self, texts: List[str] self, texts: List[str]
) -> Tuple[torch.Tensor, torch.Tensor, int]: ) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a list of texts to a list of k2.Fsa based texts. """Convert a list of texts to parameters needed by CTC loss.
Args: Args:
texts: texts:
@ -76,7 +76,7 @@ class WakeupWordTokenizer(object):
Returns: Returns:
Return a tuple of 3 elements. Return a tuple of 3 elements.
The first one is torch.Tensor(List[List[int]]), The first one is torch.Tensor(List[List[int]]),
each List[int] is tokens sequence for each a reference text. each List[int] is tokens sequence for each reference text.
The second one is number of tokens for each sample, The second one is number of tokens for each sample,
mainly used by CTC loss. mainly used by CTC loss.
@ -89,13 +89,13 @@ class WakeupWordTokenizer(object):
number_positive_samples = 0 number_positive_samples = 0
for utt_text in texts: for utt_text in texts:
if utt_text == self.wakeup_word: if utt_text == self.wakeup_word:
batch_token_ids.append(self.wakeup_word_tokens) batch_token_ids.extend(self.wakeup_word_tokens)
target_lengths.append(self.positive_number_tokens) target_lengths.append(self.positive_number_tokens)
number_positive_samples += 1 number_positive_samples += 1
else: else:
batch_token_ids.append(self.negative_word_tokens) batch_token_ids.extend(self.negative_word_tokens)
target_lengths.append(self.negative_number_tokens) target_lengths.append(self.negative_number_tokens)
target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids))) target = torch.tensor(batch_token_ids)
target_lengths = torch.tensor(target_lengths) target_lengths = torch.tensor(target_lengths)
return target, target_lengths, number_positive_samples return target, target_lengths, number_positive_samples

View File

@ -531,6 +531,9 @@ def run(rank, world_size, args):
model = Tdnn(params.feature_dim, params.num_class) model = Tdnn(params.feature_dim, params.num_class)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)