import torch import torch.nn as nn import base64 import gzip import warnings from tqdm import tqdm from dataclasses import dataclass from typing import Dict, Iterable, Optional, Union import os import urllib import hashlib import numpy as np import torch.nn.functional as F from torch import Tensor from whisper.decoding import decode as decode_function from whisper.transcribe import transcribe as transcribe_function @dataclass class ModelDimensions: n_mels: int n_audio_ctx: int n_audio_state: int n_audio_head: int n_audio_layer: int n_vocab: int n_text_ctx: int n_text_state: int n_text_head: int n_text_layer: int class LayerNorm(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: return super().forward(x.float()).type(x.dtype) class Linear(nn.Linear): def forward(self, x: Tensor) -> Tensor: return F.linear( x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype), ) class Conv1d(nn.Conv1d): def _conv_forward( self, x: Tensor, weight: Tensor, bias: Optional[Tensor] ) -> Tensor: return super()._conv_forward( x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) ) def sinusoids(length, channels, max_timescale=10000): """Returns sinusoids for positional embedding""" assert channels % 2 == 0 log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) class MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int): super().__init__() self.n_head = n_head self.query = Linear(n_state, n_state) self.key = Linear(n_state, n_state, bias=False) self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) def forward( self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ): q = self.query(x) if kv_cache is None or xa is None or self.key not in kv_cache: # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; # otherwise, perform key/value projections for self- or cross-attention as usual. k = self.key(x if xa is None else xa) v = self.value(x if xa is None else xa) else: # for cross-attention, calculate keys and values once and reuse in subsequent calls. k = kv_cache[self.key] v = kv_cache[self.value] wv, qk = self.qkv_attention(q, k, v, mask) return self.out(wv), qk def qkv_attention( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None ): n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] qk = qk.float() w = F.softmax(qk, dim=-1).to(q.dtype) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() class ResidualAttentionBlock(nn.Module): def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): super().__init__() self.attn = MultiHeadAttention(n_state, n_head) self.attn_ln = LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head) if cross_attention else None ) self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 self.mlp = nn.Sequential( Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) ) self.mlp_ln = LayerNorm(n_state) def forward( self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ): x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.mlp(self.mlp_ln(x)) return x class AudioEncoder(nn.Module): def __init__( self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] ) self.ln_post = LayerNorm(n_state) def forward(self, x: Tensor): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) the mel spectrogram of the audio """ x = F.gelu(self.conv1(x)) x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) # change whisper to process audio with any length x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) for block in self.blocks: x = block(x) x = self.ln_post(x) return x class TextDecoder(nn.Module): def __init__( self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() self.token_embedding = nn.Embedding(n_vocab, n_state) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer) ] ) self.ln = LayerNorm(n_state) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on """ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 x = ( self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] ) x = x.to(xa.dtype) for block in self.blocks: x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = self.ln(x) logits = ( x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) ).float() return logits class Whisper(nn.Module): def __init__(self, dims: ModelDimensions): super().__init__() self.dims = dims self.encoder = AudioEncoder( self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, ) self.decoder = TextDecoder( self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer, ) # use the last half layers for alignment by default; see `set_alignment_heads()` below all_heads = torch.zeros( self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool ) all_heads[self.dims.n_text_layer // 2 :] = True self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) def set_alignment_heads(self, dump: bytes): array = np.frombuffer( gzip.decompress(base64.b85decode(dump)), dtype=bool ).copy() mask = torch.from_numpy(array).reshape( self.dims.n_text_layer, self.dims.n_text_head ) self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) def embed_audio(self, mel: torch.Tensor): return self.encoder(mel) def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): return self.decoder(tokens, audio_features) def forward( self, mel: torch.Tensor, tokens: torch.Tensor ) -> Dict[str, torch.Tensor]: return self.decoder(tokens, self.encoder(mel)) @property def device(self): return next(self.parameters()).device @property def is_multilingual(self): return self.dims.n_vocab >= 51865 @property def num_languages(self): return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict] = None): """ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value tensors calculated for the previous positions. This method returns a dictionary that stores all caches, and the necessary hooks for the key and value projection modules that save the intermediate tensors to be reused during later calculations. Returns ------- cache : Dict[nn.Module, torch.Tensor] A dictionary object mapping the key/value projection modules to its cache hooks : List[RemovableHandle] List of PyTorch RemovableHandle objects to stop the hooks to be called """ cache = {**cache} if cache is not None else {} hooks = [] def save_to_cache(module, _, output): if module not in cache or output.shape[1] > self.dims.n_text_ctx: # save as-is, for the first token or cross attention cache[module] = output else: cache[module] = torch.cat([cache[module], output], dim=1).detach() return cache[module] def install_hooks(layer: nn.Module): if isinstance(layer, MultiHeadAttention): hooks.append(layer.key.register_forward_hook(save_to_cache)) hooks.append(layer.value.register_forward_hook(save_to_cache)) self.decoder.apply(install_hooks) return cache, hooks transcribe = transcribe_function decode = decode_function _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, os.path.basename(url)) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): with open(download_target, "rb") as f: model_bytes = f.read() if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return model_bytes if in_memory else download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm( total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024, ) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) model_bytes = open(download_target, "rb").read() if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) return model_bytes if in_memory else download_target def load_model( name: str, device: Optional[Union[str, torch.device]] = 'cpu', download_root: str = None, in_memory: bool = False, ) -> Whisper: """ Load a Whisper ASR model Parameters ---------- name : str one of the official model names listed by `whisper.available_models()`, or path to a model checkpoint containing the model dimensions and the model state_dict. device : Union[str, torch.device] the PyTorch device to put the model into download_root: str path to download the model files; by default, it uses "~/.cache/whisper" in_memory: bool whether to preload the model weights into host memory Returns ------- model : Whisper The Whisper ASR model instance """ # if device is None: # device = "cuda" if torch.cuda.is_available() else "cpu" if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") if name in _MODELS: checkpoint_file = _download(_MODELS[name], download_root, in_memory) # alignment_heads = _ALIGNMENT_HEADS[name] alignment_heads = None elif os.path.isfile(name): checkpoint_file = open(name, "rb").read() if in_memory else name alignment_heads = None else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}" ) with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: checkpoint = torch.load(fp, map_location=device) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims) model.load_state_dict(checkpoint["model_state_dict"]) return model.to(device)