From 05a79c2cbb752516a066ace825e9e43eb3180166 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:22:16 +0800 Subject: [PATCH 1/5] Delete utils.py --- icefall/utils.py | 494 ----------------------------------------------- 1 file changed, 494 deletions(-) delete mode 100644 icefall/utils.py diff --git a/icefall/utils.py b/icefall/utils.py deleted file mode 100644 index 876c926d9..000000000 --- a/icefall/utils.py +++ /dev/null @@ -1,494 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import collections -import os -import subprocess -from collections import defaultdict -from contextlib import contextmanager -from datetime import datetime -from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union - -import k2 -import kaldialign -import torch -import torch.distributed as dist -from torch.utils.tensorboard import SummaryWriter - -Pathlike = Union[str, Path] - - -@contextmanager -def get_executor(): - # We'll either return a process pool or a distributed worker pool. - # Note that this has to be a context manager because we might use multiple - # context manager ("with" clauses) inside, and this way everything will - # free up the resources at the right time. - try: - # If this is executed on the CLSP grid, we will try to use the - # Grid Engine to distribute the tasks. - # Other clusters can also benefit from that, provided a - # cluster-specific wrapper. - # (see https://github.com/pzelasko/plz for reference) - # - # The following must be installed: - # $ pip install dask distributed - # $ pip install git+https://github.com/pzelasko/plz - name = subprocess.check_output("hostname -f", shell=True, text=True) - if name.strip().endswith(".clsp.jhu.edu"): - import plz - from distributed import Client - - with plz.setup_cluster() as cluster: - cluster.scale(80) - yield Client(cluster) - return - except Exception: - pass - # No need to return anything - compute_and_store_features - # will just instantiate the pool itself. - yield None - - -def str2bool(v): - """Used in argparse.ArgumentParser.add_argument to indicate - that a type is a bool type and user can enter - - - yes, true, t, y, 1, to represent True - - no, false, f, n, 0, to represent False - - See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa - """ - if isinstance(v, bool): - return v - if v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - -def setup_logger( - log_filename: Pathlike, log_level: str = "info", use_console: bool = True -) -> None: - """Setup log level. - - Args: - log_filename: - The filename to save the log. - log_level: - The log level to use, e.g., "debug", "info", "warning", "error", - "critical" - """ - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa - log_filename = f"{log_filename}-{date_time}-{rank}" - else: - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - log_filename = f"{log_filename}-{date_time}" - - os.makedirs(os.path.dirname(log_filename), exist_ok=True) - - level = logging.ERROR - if log_level == "debug": - level = logging.DEBUG - elif log_level == "info": - level = logging.INFO - elif log_level == "warning": - level = logging.WARNING - elif log_level == "critical": - level = logging.CRITICAL - - logging.basicConfig( - filename=log_filename, format=formatter, level=level, filemode="w" - ) - if use_console: - console = logging.StreamHandler() - console.setLevel(level) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) - - -def get_env_info(): - """ - TODO: - """ - return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, - } - - -class AttributeDict(dict): - def __getattr__(self, key): - if key in self: - return self[key] - raise AttributeError(f"No such attribute '{key}'") - - def __setattr__(self, key, value): - self[key] = value - - def __delattr__(self, key): - if key in self: - del self[key] - return - raise AttributeError(f"No such attribute '{key}'") - - -def encode_supervisions( - supervisions: dict, subsampling_factor: int -) -> Tuple[torch.Tensor, List[str]]: - """Encodes Lhotse's ``batch["supervisions"]`` dict into - a pair of torch Tensor, and a list of transcription strings. - - The supervision tensor has shape ``(batch_size, 3)``. - Its second dimension contains information about sequence index [0], - start frames [1] and num frames [2]. - - The batch items might become re-ordered during this operation -- the - returned tensor and list of strings are guaranteed to be consistent with - each other. - """ - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // subsampling_factor, - supervisions["num_frames"] // subsampling_factor, - ), - 1, - ).to(torch.int32) - - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - texts = supervisions["text"] - texts = [texts[idx] for idx in indices] - - return supervision_segments, texts - - -def get_texts( - best_paths: k2.Fsa, return_ragged: bool = False -) -> Union[List[List[int]], k2.RaggedTensor]: - """Extract the texts (as word IDs) from the best-path FSAs. - Args: - best_paths: - A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. - containing multiple FSAs, which is expected to be the result - of k2.shortest_path (otherwise the returned values won't - be meaningful). - return_ragged: - True to return a ragged tensor with two axes [utt][word_id]. - False to return a list-of-list word IDs. - Returns: - Returns a list of lists of int, containing the label sequences we - decoded. - """ - if isinstance(best_paths.aux_labels, k2.RaggedTensor): - # remove 0's and -1's. - aux_labels = best_paths.aux_labels.remove_values_leq(0) - # TODO: change arcs.shape() to arcs.shape - aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) - - # remove the states and arcs axes. - aux_shape = aux_shape.remove_axis(1) - aux_shape = aux_shape.remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) - else: - # remove axis corresponding to states. - aux_shape = best_paths.arcs.shape().remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) - # remove 0's and -1's. - aux_labels = aux_labels.remove_values_leq(0) - - assert aux_labels.num_axes == 2 - if return_ragged: - return aux_labels - else: - return aux_labels.tolist() - - -def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str]] -) -> None: - """Save predicted results and reference transcripts to a file. - - Args: - filename: - File to save the results to. - texts: - An iterable of tuples. The first element is the reference transcript - while the second element is the predicted result. - Returns: - Return None. - """ - with open(filename, "w") as f: - for ref, hyp in texts: - print(f"ref={ref}", file=f) - print(f"hyp={hyp}", file=f) - - -def write_error_stats( - f: TextIO, - test_set_name: str, - results: List[Tuple[str, str]], - enable_log: bool = True, -) -> float: - """Write statistics based on predicted results and reference transcripts. - - It will write the following to the given file: - - - WER - - number of insertions, deletions, substitutions, corrects and total - reference words. For example:: - - Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 - reference words (2337 correct) - - - The difference between the reference transcript and predicted result. - An instance is given below:: - - THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - - The above example shows that the reference word is `EDISON`, - but it is predicted to `ADDISON` (a substitution error). - - Another example is:: - - FOR THE FIRST DAY (SIR->*) I THINK - - The reference word `SIR` is missing in the predicted - results (a deletion error). - results: - An iterable of tuples. The first element is the reference transcript - while the second element is the predicted result. - enable_log: - If True, also print detailed WER to the console. - Otherwise, it is written only to the given file. - Returns: - Return None. - """ - subs: Dict[Tuple[str, str], int] = defaultdict(int) - ins: Dict[str, int] = defaultdict(int) - dels: Dict[str, int] = defaultdict(int) - - # `words` stores counts per word, as follows: - # corr, ref_sub, hyp_sub, ins, dels - words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) - num_corr = 0 - ERR = "*" - for ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR) - for ref_word, hyp_word in ali: - if ref_word == ERR: - ins[hyp_word] += 1 - words[hyp_word][3] += 1 - elif hyp_word == ERR: - dels[ref_word] += 1 - words[ref_word][4] += 1 - elif hyp_word != ref_word: - subs[(ref_word, hyp_word)] += 1 - words[ref_word][1] += 1 - words[hyp_word][2] += 1 - else: - words[ref_word][0] += 1 - num_corr += 1 - ref_len = sum([len(r) for r, _ in results]) - sub_errs = sum(subs.values()) - ins_errs = sum(ins.values()) - del_errs = sum(dels.values()) - tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) - - if enable_log: - logging.info( - f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " - f"[{tot_errs} / {ref_len}, {ins_errs} ins, " - f"{del_errs} del, {sub_errs} sub ]" - ) - - print(f"%WER = {tot_err_rate}", file=f) - print( - f"Errors: {ins_errs} insertions, {del_errs} deletions, " - f"{sub_errs} substitutions, over {ref_len} reference " - f"words ({num_corr} correct)", - file=f, - ) - print( - "Search below for sections starting with PER-UTT DETAILS:, " - "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", - file=f, - ) - - print("", file=f) - print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) - for ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR) - combine_successive_errors = True - if combine_successive_errors: - ali = [[[x], [y]] for x, y in ali] - for i in range(len(ali) - 1): - if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: - ali[i + 1][0] = ali[i][0] + ali[i + 1][0] - ali[i + 1][1] = ali[i][1] + ali[i + 1][1] - ali[i] = [[], []] - ali = [ - [ - list(filter(lambda a: a != ERR, x)), - list(filter(lambda a: a != ERR, y)), - ] - for x, y in ali - ] - ali = list(filter(lambda x: x != [[], []], ali)) - ali = [ - [ - ERR if x == [] else " ".join(x), - ERR if y == [] else " ".join(y), - ] - for x, y in ali - ] - - print( - " ".join( - ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" - for ref_word, hyp_word in ali - ) - ), - file=f, - ) - - print("", file=f) - print("SUBSTITUTIONS: count ref -> hyp", file=f) - - for count, (ref, hyp) in sorted( - [(v, k) for k, v in subs.items()], reverse=True - ): - print(f"{count} {ref} -> {hyp}", file=f) - - print("", file=f) - print("DELETIONS: count ref", file=f) - for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): - print(f"{count} {ref}", file=f) - - print("", file=f) - print("INSERTIONS: count hyp", file=f) - for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): - print(f"{count} {hyp}", file=f) - - print("", file=f) - print( - "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f - ) - for _, word, counts in sorted( - [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True - ): - (corr, ref_sub, hyp_sub, ins, dels) = counts - tot_errs = ref_sub + hyp_sub + ins + dels - ref_count = corr + ref_sub + dels - hyp_count = corr + hyp_sub + ins - - print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate) - - -class LossRecord(collections.defaultdict): - def __init__(self): - # Passing the type 'int' to the base-class constructor - # makes undefined items default to int() which is zero. - super(LossRecord, self).__init__(int) - - def __add__(self, other: 'LossRecord') -> 'LossRecord': - ans = LossRecord() - for k, v in self.items(): - ans[k] = v - for k, v in other.items(): - ans[k] = ans[k] + v - return ans - - def __mul__(self, alpha: float) -> 'LossRecord': - ans = LossRecord() - for k, v in self.items(): - ans[k] = v * alpha - return ans - - def __str__(self) -> str: - ans = '' - for k, v in self.norm_items(): - norm_value = '%.4g' % v - ans += (str(k) + '=' + str(norm_value) + ', ') - frames = str(self['frames']) - ans += 'over ' + frames + ' frames.' - return ans - - def norm_items(self) -> List[Tuple[str, float]]: - """ - Returns a list of pairs, like: - [('ctc_loss', 0.1), ('att_loss', 0.07)] - """ - num_frames = self['frames'] if 'frames' in self else 1 - ans = [] - for k, v in self.items(): - if k != 'frames': - norm_value = float(v) / num_frames - ans.append((k, norm_value)) - return ans - - def reduce(self, device): - """ - Reduce using torch.distributed, which I believe ensures that - all processes get the total. - """ - keys = sorted(self.keys()) - s = torch.tensor([float(self[k]) for k in keys], - device=device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - for k, v in zip(keys, s.cpu().tolist()): - self[k] = v - - def write_summary( - self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int, - ) -> None: - """Add logging information to a TensorBoard writer. - - Args: - tb_writer: a TensorBoard writer - prefix: a prefix for the name of the loss, e.g. "train/valid_", - or "train/current_" - batch_idx: The current batch index, used as the x-axis of the plot. - """ - for k, v in self.norm_items(): - tb_writer.add_scalar(prefix + k, v, batch_idx) From 0dfe0e66801bed997bc6a491770088c0f2d538de Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:22:36 +0800 Subject: [PATCH 2/5] Add files via upload --- icefall/utils.py | 477 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 477 insertions(+) create mode 100644 icefall/utils.py diff --git a/icefall/utils.py b/icefall/utils.py new file mode 100644 index 000000000..2c551d884 --- /dev/null +++ b/icefall/utils.py @@ -0,0 +1,477 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import collections +import os +import subprocess +from collections import defaultdict +from contextlib import contextmanager +from datetime import datetime +from pathlib import Path +from typing import Dict, Iterable, List, TextIO, Tuple, Union + +import k2 +import kaldialign +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +Pathlike = Union[str, Path] + + +@contextmanager +def get_executor(): + # We'll either return a process pool or a distributed worker pool. + # Note that this has to be a context manager because we might use multiple + # context manager ("with" clauses) inside, and this way everything will + # free up the resources at the right time. + try: + # If this is executed on the CLSP grid, we will try to use the + # Grid Engine to distribute the tasks. + # Other clusters can also benefit from that, provided a + # cluster-specific wrapper. + # (see https://github.com/pzelasko/plz for reference) + # + # The following must be installed: + # $ pip install dask distributed + # $ pip install git+https://github.com/pzelasko/plz + name = subprocess.check_output("hostname -f", shell=True, text=True) + if name.strip().endswith(".clsp.jhu.edu"): + import plz + from distributed import Client + + with plz.setup_cluster() as cluster: + cluster.scale(80) + yield Client(cluster) + return + except Exception: + pass + # No need to return anything - compute_and_store_features + # will just instantiate the pool itself. + yield None + + +def str2bool(v): + """Used in argparse.ArgumentParser.add_argument to indicate + that a type is a bool type and user can enter + + - yes, true, t, y, 1, to represent True + - no, false, f, n, 0, to represent False + + See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def setup_logger( + log_filename: Pathlike, log_level: str = "info", use_console: bool = True +) -> None: + """Setup log level. + + Args: + log_filename: + The filename to save the log. + log_level: + The log level to use, e.g., "debug", "info", "warning", "error", + "critical" + """ + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa + log_filename = f"{log_filename}-{date_time}-{rank}" + else: + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}" + + os.makedirs(os.path.dirname(log_filename), exist_ok=True) + + level = logging.ERROR + if log_level == "debug": + level = logging.DEBUG + elif log_level == "info": + level = logging.INFO + elif log_level == "warning": + level = logging.WARNING + elif log_level == "critical": + level = logging.CRITICAL + + logging.basicConfig( + filename=log_filename, format=formatter, level=level, filemode="w" + ) + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + +def get_env_info(): + """ + TODO: + """ + return { + "k2-git-sha1": None, + "k2-version": None, + "lhotse-version": None, + "torch-version": None, + "icefall-sha1": None, + "icefall-version": None, + } + + +class AttributeDict(dict): + def __getattr__(self, key): + if key in self: + return self[key] + raise AttributeError(f"No such attribute '{key}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + if key in self: + del self[key] + return + raise AttributeError(f"No such attribute '{key}'") + + +def encode_supervisions( + supervisions: dict, subsampling_factor: int +) -> Tuple[torch.Tensor, List[str]]: + """Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings. + + The supervision tensor has shape ``(batch_size, 3)``. + Its second dimension contains information about sequence index [0], + start frames [1] and num frames [2]. + + The batch items might become re-ordered during this operation -- the + returned tensor and list of strings are guaranteed to be consistent with + each other. + """ + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // subsampling_factor, + supervisions["num_frames"] // subsampling_factor, + ), + 1, + ).to(torch.int32) + + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + texts = supervisions["text"] + texts = [texts[idx] for idx in indices] + + return supervision_segments, texts + + +def get_texts( + best_paths: k2.Fsa, return_ragged: bool = False +) -> Union[List[List[int]], k2.RaggedTensor]: + """Extract the texts (as word IDs) from the best-path FSAs. + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + return_ragged: + True to return a ragged tensor with two axes [utt][word_id]. + False to return a list-of-list word IDs. + Returns: + Returns a list of lists of int, containing the label sequences we + decoded. + """ + if isinstance(best_paths.aux_labels, k2.RaggedTensor): + # remove 0's and -1's. + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) + + # remove the states and arcs axes. + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) + else: + # remove axis corresponding to states. + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) + # remove 0's and -1's. + aux_labels = aux_labels.remove_values_leq(0) + + assert aux_labels.num_axes == 2 + if return_ragged: + return aux_labels + else: + return aux_labels.tolist() + + +def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str]]) -> None: + """Save predicted results and reference transcripts to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the reference transcript + while the second element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for ref, hyp in texts: + print(f"ref={ref}", file=f) + print(f"hyp={hyp}", file=f) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, +) -> float: + """Write statistics based on predicted results and reference transcripts. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the reference transcript + while the second element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + for ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ERR if x == [] else " ".join(x), ERR if y == [] else " ".join(y),] + for x, y in ali + ] + + print( + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) + + +class LossRecord(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + super(LossRecord, self).__init__(int) + + def __add__(self, other: "LossRecord") -> "LossRecord": + ans = LossRecord() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "LossRecord": + ans = LossRecord() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + frames = str(self["frames"]) + ans += "over " + frames + " frames." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + ans = [] + for k, v in self.items(): + if k != "frames": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, tb_writer: SummaryWriter, prefix: str, batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) From bdd890bab9dea20b448a4d01743303f886d8262c Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:23:11 +0800 Subject: [PATCH 3/5] Add files via upload --- egs/librispeech/ASR/conformer_ctc/train.py | 63 ++++++---------------- 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 34b99cd2d..98bd47bc1 100644 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -59,10 +59,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -80,10 +77,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", + "--num-epochs", type=int, default=35, help="Number of epochs to train.", ) parser.add_argument( @@ -230,10 +224,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -335,9 +326,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(token_ids) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) ctc_loss = k2.ctc_loss( @@ -374,12 +363,12 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['ctc_loss'] = ctc_loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: - info['att_loss'] = att_loss.detach().cpu().item() + info["att_loss"] = att_loss.detach().cpu().item() - info['loss'] = loss.detach().cpu().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -410,7 +399,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value @@ -489,15 +478,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -509,17 +492,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train + tb_writer, "train/valid_", params.batch_idx_train ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -563,10 +542,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", + params.lang_dir, device=device, sos_token="", eos_token="", ) logging.info("About to create model") @@ -607,9 +583,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -629,10 +603,7 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, + params=params, model=model, optimizer=optimizer, rank=rank, ) logging.info("Done!") From 279dc74b4e49e0c6d940afb4ea0329e9976f26ef Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:23:54 +0800 Subject: [PATCH 4/5] Add files via upload --- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 53 ++++++---------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 016d51e2c..2b22e4e0f 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -58,10 +58,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -79,10 +76,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", + "--num-epochs", type=int, default=20, help="Number of epochs to train.", ) parser.add_argument( @@ -209,10 +203,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -312,9 +303,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(texts) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) loss = k2.ctc_loss( @@ -328,8 +317,8 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['loss'] = loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -363,7 +352,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -439,15 +428,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( @@ -458,17 +441,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, + tb_writer, "train/valid_", params.batch_idx_train, ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -526,9 +505,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, + model.parameters(), lr=params.lr, weight_decay=params.weight_decay, ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) @@ -548,9 +525,7 @@ def run(rank, world_size, args): if tb_writer is not None: tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, + "train/lr", scheduler.get_last_lr()[0], params.batch_idx_train, ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) From 79fd09e3e5c6e8b9d248a874bb68aa7a3129a22f Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:24:40 +0800 Subject: [PATCH 5/5] Add files via upload --- egs/yesno/ASR/tdnn/train.py | 59 ++++++++++--------------------------- 1 file changed, 15 insertions(+), 44 deletions(-) diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 582f3e822..f8e8538ca 100644 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -33,10 +33,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -54,10 +51,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=15, - help="Number of epochs to train.", + "--num-epochs", type=int, default=15, help="Number of epochs to train.", ) parser.add_argument( @@ -187,10 +181,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -287,16 +278,12 @@ def compute_loss( batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], dtype=torch.int32, ) decoding_graph = graph_compiler.compile(texts) - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - ) + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments,) loss = k2.ctc_loss( decoding_graph=decoding_graph, @@ -309,8 +296,8 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['loss'] = loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -344,7 +331,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -420,15 +407,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( @@ -439,17 +420,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, + tb_writer, "train/valid_", params.batch_idx_train, ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -506,9 +483,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) optimizer = optim.SGD( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, + model.parameters(), lr=params.lr, weight_decay=params.weight_decay, ) if checkpoints: @@ -542,11 +517,7 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, + params=params, model=model, optimizer=optimizer, scheduler=None, rank=rank, ) logging.info("Done!")