remove unwanted changes in utils

This commit is contained in:
Desh Raj 2023-06-13 08:42:38 -04:00
parent 2d3063becd
commit d6adf25c06

View File

@ -903,35 +903,13 @@ def write_surt_error_stats(
is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
toolkit.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Args:
f: File to write the statistics to.
test_set_name: Name of the test set.
results: List of tuples containing the utterance ID and the predicted
transcript.
enable_log: Whether to enable logging.
num_channels: Number of output channels/branches. Defaults to 2.
Returns:
Return None.
"""
@ -1282,10 +1260,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)
return expaned_lengths >= lengths.unsqueeze(-1)
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
@ -1648,7 +1626,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
List of timestamp of each word.
"""
start_token = b"\xe2\x96\x81".decode() # '_'
assert len(tokens) == len(timestamp)
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
ans = []
for i in range(len(tokens)):
flag = False