diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index 1a513fe40..7bd0a174a 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -48,7 +48,7 @@ from lhotse.utils import fix_random_seed from speech_dataset import K2SpeechRecognitionDataset from torch.utils.data import DataLoader -from icefall.utils import str2bool +from utils import str2bool class _SeedWorkers: diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 7665a7680..0b2642bf0 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -70,10 +70,10 @@ from transformers import ( ) from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall import diagnostics -from icefall.dist import get_rank, get_world_size -from icefall.env import get_env_info -from icefall.utils import ( # filter_uneven_sized_batch, +# from icefall import diagnostics +from utils import get_rank, get_world_size +# from icefall.env import get_env_info +from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, setup_logger, @@ -270,7 +270,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 5000, - "env_info": get_env_info(), + # "env_info": get_env_info(), } ) diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py new file mode 100644 index 000000000..fe65a8042 --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py @@ -0,0 +1,224 @@ +import argparse +import collections +import json +import logging +import os +import pathlib +import random +import re +import subprocess +from collections import defaultdict +# from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +# from shutil import copyfile +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +Pathlike = Union[str, Path] + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.get_rank() + else: + return 0 + +def str2bool(v): + """Used in argparse.ArgumentParser.add_argument to indicate + that a type is a bool type and user can enter + + - yes, true, t, y, 1, to represent True + - no, false, f, n, 0, to represent False + + See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +class AttributeDict(dict): + def __getattr__(self, key): + if key in self: + return self[key] + raise AttributeError(f"No such attribute '{key}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + if key in self: + del self[key] + return + raise AttributeError(f"No such attribute '{key}'") + + def __str__(self, indent: int = 2): + tmp = {} + for k, v in self.items(): + # PosixPath is ont JSON serializable + if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + v = str(v) + tmp[k] = v + return json.dumps(tmp, indent=indent, sort_keys=True) + +def setup_logger( + log_filename: Pathlike, + log_level: str = "info", + use_console: bool = True, +) -> None: + """Setup log level. + + Args: + log_filename: + The filename to save the log. + log_level: + The log level to use, e.g., "debug", "info", "warning", "error", + "critical" + use_console: + True to also print logs to console. + """ + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa + log_filename = f"{log_filename}-{date_time}-{rank}" + else: + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}" + + os.makedirs(os.path.dirname(log_filename), exist_ok=True) + + level = logging.ERROR + if log_level == "debug": + level = logging.DEBUG + elif log_level == "info": + level = logging.INFO + elif log_level == "warning": + level = logging.WARNING + elif log_level == "critical": + level = logging.CRITICAL + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=level, + filemode="w", + force=True, + ) + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + if v - v == 0: + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans_frames = "" + ans_utterances = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + if "utt_" not in k: + ans_frames += str(k) + "=" + str(norm_value) + ", " + else: + ans_utterances += str(k) + "=" + str(norm_value) + if k == "utt_duration": + ans_utterances += " frames, " + elif k == "utt_pad_proportion": + ans_utterances += ", " + else: + raise ValueError(f"Unexpected key: {k}") + frames = "%.2f" % self["frames"] + ans_frames += "over " + str(frames) + " frames. " + if ans_utterances != "": + utterances = "%.2f" % self["utterances"] + ans_utterances += "over " + str(utterances) + " utterances." + + return ans_frames + ans_utterances + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + num_utterances = self["utterances"] if "utterances" in self else 1 + ans = [] + for k, v in self.items(): + if k == "frames" or k == "utterances": + continue + norm_value = ( + float(v) / num_frames if "utt_" not in k else float(v) / num_utterances + ) + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) \ No newline at end of file