mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
remove unwanted changes in utils
This commit is contained in:
parent
2d3063becd
commit
d6adf25c06
@ -903,35 +903,13 @@ def write_surt_error_stats(
|
||||
is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
|
||||
toolkit.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Args:
|
||||
f: File to write the statistics to.
|
||||
test_set_name: Name of the test set.
|
||||
results: List of tuples containing the utterance ID and the predicted
|
||||
transcript.
|
||||
enable_log: Whether to enable logging.
|
||||
num_channels: Number of output channels/branches. Defaults to 2.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
@ -1282,10 +1260,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
|
||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(1)
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
|
||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||
@ -1648,7 +1626,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
List of timestamp of each word.
|
||||
"""
|
||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||
assert len(tokens) == len(timestamp)
|
||||
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
|
||||
ans = []
|
||||
for i in range(len(tokens)):
|
||||
flag = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user