remove unwanted changes in utils

This commit is contained in:
Desh Raj 2023-06-13 08:42:38 -04:00
parent 2d3063becd
commit d6adf25c06

View File

@ -903,35 +903,13 @@ def write_surt_error_stats(
is that this function finds the optimal speaker-agnostic WER using the ``meeteval`` is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
toolkit. toolkit.
It will write the following to the given file: Args:
f: File to write the statistics to.
- WER test_set_name: Name of the test set.
- number of insertions, deletions, substitutions, corrects and total results: List of tuples containing the utterance ID and the predicted
reference words. For example:: transcript.
enable_log: Whether to enable logging.
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 num_channels: Number of output channels/branches. Defaults to 2.
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns: Returns:
Return None. Return None.
""" """
@ -1282,10 +1260,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
assert lengths.ndim == 1, lengths.ndim assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max()) max_len = max(max_len, lengths.max())
n = lengths.size(0) n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) return expaned_lengths >= lengths.unsqueeze(-1)
return expaned_lengths >= lengths.unsqueeze(1)
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
@ -1648,7 +1626,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
List of timestamp of each word. List of timestamp of each word.
""" """
start_token = b"\xe2\x96\x81".decode() # '_' start_token = b"\xe2\x96\x81".decode() # '_'
assert len(tokens) == len(timestamp) assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
ans = [] ans = []
for i in range(len(tokens)): for i in range(len(tokens)):
flag = False flag = False