mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
remove k2 dependency
This commit is contained in:
parent
e41c1cabd5
commit
37db65984c
@ -48,7 +48,7 @@ from lhotse.utils import fix_random_seed
|
||||
from speech_dataset import K2SpeechRecognitionDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
|
||||
@ -70,10 +70,10 @@ from transformers import (
|
||||
)
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import ( # filter_uneven_sized_batch,
|
||||
# from icefall import diagnostics
|
||||
from utils import get_rank, get_world_size
|
||||
# from icefall.env import get_env_info
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
@ -270,7 +270,7 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
# "env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
224
egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
Normal file
224
egs/speech_llm/SPEECH2SPEECH/qwen_omni/utils.py
Normal file
@ -0,0 +1,224 @@
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
# from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
# from shutil import copyfile
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
def get_world_size():
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank():
|
||||
if "RANK" in os.environ:
|
||||
return int(os.environ["RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
||||
- yes, true, t, y, 1, to represent True
|
||||
- no, false, f, n, 0, to represent False
|
||||
|
||||
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self:
|
||||
del self[key]
|
||||
return
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __str__(self, indent: int = 2):
|
||||
tmp = {}
|
||||
for k, v in self.items():
|
||||
# PosixPath is ont JSON serializable
|
||||
if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
|
||||
v = str(v)
|
||||
tmp[k] = v
|
||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike,
|
||||
log_level: str = "info",
|
||||
use_console: bool = True,
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
Args:
|
||||
log_filename:
|
||||
The filename to save the log.
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
use_console:
|
||||
True to also print logs to console.
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||
else:
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||
|
||||
level = logging.ERROR
|
||||
if log_level == "debug":
|
||||
level = logging.DEBUG
|
||||
elif log_level == "info":
|
||||
level = logging.INFO
|
||||
elif log_level == "warning":
|
||||
level = logging.WARNING
|
||||
elif log_level == "critical":
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=level,
|
||||
filemode="w",
|
||||
force=True,
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
# This class will play a role as metrics tracker.
|
||||
# It can record many metrics, including but not limited to loss.
|
||||
super(MetricsTracker, self).__init__(int)
|
||||
|
||||
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
if v - v == 0:
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> "MetricsTracker":
|
||||
ans = MetricsTracker()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans_frames = ""
|
||||
ans_utterances = ""
|
||||
for k, v in self.norm_items():
|
||||
norm_value = "%.4g" % v
|
||||
if "utt_" not in k:
|
||||
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
||||
else:
|
||||
ans_utterances += str(k) + "=" + str(norm_value)
|
||||
if k == "utt_duration":
|
||||
ans_utterances += " frames, "
|
||||
elif k == "utt_pad_proportion":
|
||||
ans_utterances += ", "
|
||||
else:
|
||||
raise ValueError(f"Unexpected key: {k}")
|
||||
frames = "%.2f" % self["frames"]
|
||||
ans_frames += "over " + str(frames) + " frames. "
|
||||
if ans_utterances != "":
|
||||
utterances = "%.2f" % self["utterances"]
|
||||
ans_utterances += "over " + str(utterances) + " utterances."
|
||||
|
||||
return ans_frames + ans_utterances
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self["frames"] if "frames" in self else 1
|
||||
num_utterances = self["utterances"] if "utterances" in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k == "frames" or k == "utterances":
|
||||
continue
|
||||
norm_value = (
|
||||
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
||||
)
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
tb_writer: SummaryWriter,
|
||||
prefix: str,
|
||||
batch_idx: int,
|
||||
) -> None:
|
||||
"""Add logging information to a TensorBoard writer.
|
||||
|
||||
Args:
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
Loading…
x
Reference in New Issue
Block a user