From d6adf25c0664a7e68733741745eddf7248aa2a19 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 13 Jun 2023 08:42:38 -0400 Subject: [PATCH] remove unwanted changes in utils --- icefall/utils.py | 44 +++++++++++--------------------------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index d72dd1e68..0feff9dc8 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -903,35 +903,13 @@ def write_surt_error_stats( is that this function finds the optimal speaker-agnostic WER using the ``meeteval`` toolkit. - It will write the following to the given file: - - - WER - - number of insertions, deletions, substitutions, corrects and total - reference words. For example:: - - Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 - reference words (2337 correct) - - - The difference between the reference transcript and predicted result. - An instance is given below:: - - THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - - The above example shows that the reference word is `EDISON`, - but it is predicted to `ADDISON` (a substitution error). - - Another example is:: - - FOR THE FIRST DAY (SIR->*) I THINK - - The reference word `SIR` is missing in the predicted - results (a deletion error). - results: - An iterable of tuples. The first element is the cur_id, the second is - the reference transcript and the third element is the predicted result. - enable_log: - If True, also print detailed WER to the console. - Otherwise, it is written only to the given file. + Args: + f: File to write the statistics to. + test_set_name: Name of the test set. + results: List of tuples containing the utterance ID and the predicted + transcript. + enable_log: Whether to enable logging. + num_channels: Number of output channels/branches. Defaults to 2. Returns: Return None. """ @@ -1282,10 +1260,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: assert lengths.ndim == 1, lengths.ndim max_len = max(max_len, lengths.max()) n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) - - return expaned_lengths >= lengths.unsqueeze(1) + return expaned_lengths >= lengths.unsqueeze(-1) # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py @@ -1648,7 +1626,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: List of timestamp of each word. """ start_token = b"\xe2\x96\x81".decode() # '_' - assert len(tokens) == len(timestamp) + assert len(tokens) == len(timestamp), (len(tokens), len(timestamp)) ans = [] for i in range(len(tokens)): flag = False