From ab08201f6c6ff1447cfd0ccc9f37cb4a55e62efc Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 22 Jan 2024 16:15:56 +0800 Subject: [PATCH] remove model file --- egs/aishell/ASR/whisper/model.py | 431 ------------------------------- 1 file changed, 431 deletions(-) delete mode 100755 egs/aishell/ASR/whisper/model.py diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py deleted file mode 100755 index 9ec412513..000000000 --- a/egs/aishell/ASR/whisper/model.py +++ /dev/null @@ -1,431 +0,0 @@ -import torch -import torch.nn as nn -import base64 -import gzip -import warnings -from tqdm import tqdm -from dataclasses import dataclass -from typing import Dict, Iterable, Optional, Union -import os -import urllib -import hashlib -import numpy as np - -import torch.nn.functional as F -from torch import Tensor - -from whisper.decoding import decode as decode_function -from whisper.transcribe import transcribe as transcribe_function - - -@dataclass -class ModelDimensions: - n_mels: int - n_audio_ctx: int - n_audio_state: int - n_audio_head: int - n_audio_layer: int - n_vocab: int - n_text_ctx: int - n_text_state: int - n_text_head: int - n_text_layer: int - - -class LayerNorm(nn.LayerNorm): - def forward(self, x: Tensor) -> Tensor: - return super().forward(x.float()).type(x.dtype) - - -class Linear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - return F.linear( - x, - self.weight.to(x.dtype), - None if self.bias is None else self.bias.to(x.dtype), - ) - - -class Conv1d(nn.Conv1d): - def _conv_forward( - self, x: Tensor, weight: Tensor, bias: Optional[Tensor] - ) -> Tensor: - return super()._conv_forward( - x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) - ) - - -def sinusoids(length, channels, max_timescale=10000): - """Returns sinusoids for positional embedding""" - assert channels % 2 == 0 - log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) - scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) - - -class MultiHeadAttention(nn.Module): - def __init__(self, n_state: int, n_head: int): - super().__init__() - self.n_head = n_head - self.query = Linear(n_state, n_state) - self.key = Linear(n_state, n_state, bias=False) - self.value = Linear(n_state, n_state) - self.out = Linear(n_state, n_state) - - def forward( - self, - x: Tensor, - xa: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - kv_cache: Optional[dict] = None, - ): - q = self.query(x) - - if kv_cache is None or xa is None or self.key not in kv_cache: - # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; - # otherwise, perform key/value projections for self- or cross-attention as usual. - k = self.key(x if xa is None else xa) - v = self.value(x if xa is None else xa) - else: - # for cross-attention, calculate keys and values once and reuse in subsequent calls. - k = kv_cache[self.key] - v = kv_cache[self.value] - - wv, qk = self.qkv_attention(q, k, v, mask) - return self.out(wv), qk - - def qkv_attention( - self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None - ): - n_batch, n_ctx, n_state = q.shape - scale = (n_state // self.n_head) ** -0.25 - q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale - v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) - - qk = q @ k - if mask is not None: - qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.float() - - w = F.softmax(qk, dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): - super().__init__() - - self.attn = MultiHeadAttention(n_state, n_head) - self.attn_ln = LayerNorm(n_state) - - self.cross_attn = ( - MultiHeadAttention(n_state, n_head) if cross_attention else None - ) - self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None - - n_mlp = n_state * 4 - self.mlp = nn.Sequential( - Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) - ) - self.mlp_ln = LayerNorm(n_state) - - def forward( - self, - x: Tensor, - xa: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - kv_cache: Optional[dict] = None, - ): - x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] - if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] - x = x + self.mlp(self.mlp_ln(x)) - return x - - -class AudioEncoder(nn.Module): - def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int - ): - super().__init__() - self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) - self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) - self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) - - self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( - [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] - ) - self.ln_post = LayerNorm(n_state) - - def forward(self, x: Tensor): - """ - x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) - the mel spectrogram of the audio - """ - x = F.gelu(self.conv1(x)) - x = F.gelu(self.conv2(x)) - x = x.permute(0, 2, 1) - - # change whisper to process audio with any length - x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) - - for block in self.blocks: - x = block(x) - - x = self.ln_post(x) - return x - - -class TextDecoder(nn.Module): - def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int - ): - super().__init__() - - self.token_embedding = nn.Embedding(n_vocab, n_state) - self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) - - self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( - [ - ResidualAttentionBlock(n_state, n_head, cross_attention=True) - for _ in range(n_layer) - ] - ) - self.ln = LayerNorm(n_state) - - mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) - self.register_buffer("mask", mask, persistent=False) - - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): - """ - x : torch.LongTensor, shape = (batch_size, <= n_ctx) - the text tokens - xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) - the encoded audio features to be attended on - """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) - x = x.to(xa.dtype) - - for block in self.blocks: - x = block(x, xa, mask=self.mask, kv_cache=kv_cache) - - x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() - - return logits - -class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): - super().__init__() - self.dims = dims - self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, - ) - self.decoder = TextDecoder( - self.dims.n_vocab, - self.dims.n_text_ctx, - self.dims.n_text_state, - self.dims.n_text_head, - self.dims.n_text_layer, - ) - # use the last half layers for alignment by default; see `set_alignment_heads()` below - all_heads = torch.zeros( - self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool - ) - all_heads[self.dims.n_text_layer // 2 :] = True - self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) - - def set_alignment_heads(self, dump: bytes): - array = np.frombuffer( - gzip.decompress(base64.b85decode(dump)), dtype=bool - ).copy() - mask = torch.from_numpy(array).reshape( - self.dims.n_text_layer, self.dims.n_text_head - ) - self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) - - def embed_audio(self, mel: torch.Tensor): - return self.encoder(mel) - - def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): - return self.decoder(tokens, audio_features) - - def forward( - self, mel: torch.Tensor, tokens: torch.Tensor - ) -> Dict[str, torch.Tensor]: - return self.decoder(tokens, self.encoder(mel)) - - @property - def device(self): - return next(self.parameters()).device - - @property - def is_multilingual(self): - return self.dims.n_vocab >= 51865 - - @property - def num_languages(self): - return self.dims.n_vocab - 51765 - int(self.is_multilingual) - - def install_kv_cache_hooks(self, cache: Optional[dict] = None): - """ - The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value - tensors calculated for the previous positions. This method returns a dictionary that stores - all caches, and the necessary hooks for the key and value projection modules that save the - intermediate tensors to be reused during later calculations. - - Returns - ------- - cache : Dict[nn.Module, torch.Tensor] - A dictionary object mapping the key/value projection modules to its cache - hooks : List[RemovableHandle] - List of PyTorch RemovableHandle objects to stop the hooks to be called - """ - cache = {**cache} if cache is not None else {} - hooks = [] - - def save_to_cache(module, _, output): - if module not in cache or output.shape[1] > self.dims.n_text_ctx: - # save as-is, for the first token or cross attention - cache[module] = output - else: - cache[module] = torch.cat([cache[module], output], dim=1).detach() - return cache[module] - - def install_hooks(layer: nn.Module): - if isinstance(layer, MultiHeadAttention): - hooks.append(layer.key.register_forward_hook(save_to_cache)) - hooks.append(layer.value.register_forward_hook(save_to_cache)) - - self.decoder.apply(install_hooks) - return cache, hooks - - transcribe = transcribe_function - decode = decode_function - -_MODELS = { - "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", - "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", - "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", - "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", - "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", - "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", - "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", - "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", - "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", - "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", -} - -def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: - os.makedirs(root, exist_ok=True) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, os.path.basename(url)) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes if in_memory else download_target - else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError( - "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." - ) - - return model_bytes if in_memory else download_target - -def load_model( - name: str, - device: Optional[Union[str, torch.device]] = 'cpu', - download_root: str = None, - in_memory: bool = False, -) -> Whisper: - """ - Load a Whisper ASR model - - Parameters - ---------- - name : str - one of the official model names listed by `whisper.available_models()`, or - path to a model checkpoint containing the model dimensions and the model state_dict. - device : Union[str, torch.device] - the PyTorch device to put the model into - download_root: str - path to download the model files; by default, it uses "~/.cache/whisper" - in_memory: bool - whether to preload the model weights into host memory - - Returns - ------- - model : Whisper - The Whisper ASR model instance - """ - - # if device is None: - # device = "cuda" if torch.cuda.is_available() else "cpu" - if download_root is None: - default = os.path.join(os.path.expanduser("~"), ".cache") - download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") - - if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root, in_memory) - # alignment_heads = _ALIGNMENT_HEADS[name] - alignment_heads = None - elif os.path.isfile(name): - checkpoint_file = open(name, "rb").read() if in_memory else name - alignment_heads = None - else: - raise RuntimeError( - f"Model {name} not found; available models = {available_models()}" - ) - - with ( - io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") - ) as fp: - checkpoint = torch.load(fp, map_location=device) - del checkpoint_file - - dims = ModelDimensions(**checkpoint["dims"]) - model = Whisper(dims) - model.load_state_dict(checkpoint["model_state_dict"]) - - return model.to(device) \ No newline at end of file