add decode for wenetspeech

This commit is contained in:
Yuekai Zhang 2024-01-19 17:41:53 +08:00
parent 046e071ca3
commit 72c9d01724
3 changed files with 923 additions and 0 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1,491 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import whisper
from whisper.normalizers import BasicTextNormalizer
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from model import load_model
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
from zhconv import convert
from tn.chinese.normalizer import Normalizer
import re
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?'
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
result_text.append(t)
return result_text
else:
raise Exception(f'Not support type {type(text)}')
def to_simple(text: str or List[str]):
if isinstance(text, str):
text = convert(text, 'zh-cn')
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, 'zh-cn')
result_text.append(t)
return result_text
else:
raise Exception(f'Not support type{type(text)}')
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- beam-search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper/exp",
help="The experiment dir",
)
parser.add_argument(
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if decoding method is 1best, the key is the string `no_rescore`.
If attention rescoring is used, the key is the string
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
value of `lm_scale` and `attention_scale`. An example key is
`ngram_lm_scale_0.7_attention_scale_0.5`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "attention-decoder", it uses attention rescoring.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
dtype = torch.float16
device = torch.device("cuda")
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
results = model.decode(feature, params.decoding_options)
hyps = [result.text for result in results]
hyps = remove_punctuation(hyps)
hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
key = "beam-search"
return {key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
Return a dict, whose key may be "no-rescore" if the decoding method is
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
rescoring is used. Its value is a list of tuples. Each tuple contains two
elements: The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
batch=batch,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info(f"device: {device}")
model = load_model(params.model_name)
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
if 'model' not in checkpoint:
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
if 'model' not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
# dev_cuts = wenetspeech.valid_cuts()
# dev_cuts = dev_cuts.filter(remove_short_utt)
# dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
# test_net_cuts = wenetspeech.test_net_cuts()
# test_net_cuts = test_net_cuts.filter(remove_short_utt)
# test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
# test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
# test_dls = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["TEST_MEETING"]
test_dls = [test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,431 @@
import torch
import torch.nn as nn
import base64
import gzip
import warnings
from tqdm import tqdm
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Union
import os
import urllib
import hashlib
import numpy as np
import torch.nn.functional as F
from torch import Tensor
from whisper.decoding import decode as decode_function
from whisper.transcribe import transcribe as transcribe_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# change whisper to process audio with any length
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
transcribe = transcribe_function
decode = decode_function
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = 'cpu',
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
# alignment_heads = _ALIGNMENT_HEADS[name]
alignment_heads = None
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)