From 3feef0a7d05586b2971521e6b3392d9fc535ffe6 Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 13:05:46 +0800 Subject: [PATCH] update tokenizer comments --- egs/himia/wuw/ctc_tdnn/tokenizer.py | 17 +++++++++++------ egs/himia/wuw/ctc_tdnn/train.py | 1 - 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index bb988da6d..bc207ec04 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -64,18 +64,23 @@ class WakeupWordTokenizer(object): self.negative_word_tokens = [1] self.negative_number_tokens = 1 - def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, int]: + def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]: """Convert a list of texts to a list of k2.Fsa based texts. Args: texts: - It is a list of strings. + It is a list of strings, + each element is a reference text for an audio. Returns: - Return a list of k2.Fsa, one for an element in texts. - If the element is `wakeup_word`, a graph for positive samples is appneded - into resulting graph_vec, otherwise, a graph for negative samples is appended. + Return a tuple of 3 elements. + The first one is torch.Tensor(List[List[int]]), + each List[int] is tokens sequence for each a reference text. - Number of positive samples is also returned to track its proportion. + The second one is number of tokens for each sample, + mainly used by CTC loss. + + The last one is number_positive_samples, + used to track proportion of positive samples in each batch. """ batch_token_ids = [] target_lengths = [] diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 0b140020e..249821c29 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -37,7 +37,6 @@ import torch.nn as nn from asr_datamodule import HiMiaWuwDataModule from tdnn import Tdnn -from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP