mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
adding ILM beam search and decoding
This commit is contained in:
parent
109354b6b8
commit
5c4a7ffd1e
1
egs/librispeech/ASR/zipformer_hat/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../tdnn_lstm_ctc/asr_datamodule.py
|
803
egs/librispeech/ASR/zipformer_hat/beam_search.py
Normal file
803
egs/librispeech/ASR/zipformer_hat/beam_search.py
Normal file
@ -0,0 +1,803 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||||
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
|
from icefall.transformer_lm.model import TransformerLM
|
||||||
|
from icefall.utils import (
|
||||||
|
DecodingResults,
|
||||||
|
add_eos,
|
||||||
|
add_sos,
|
||||||
|
get_texts,
|
||||||
|
get_texts_with_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search_batch(
|
||||||
|
model: nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
blank_penalty: float = 0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||||
|
|
||||||
|
# timestamp[n][i] is the frame index after subsampling
|
||||||
|
# on which hyp[n][i] is decoded
|
||||||
|
timestamps = [[] for _ in range(N)]
|
||||||
|
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
||||||
|
scores = [[] for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
assert logits.ndim == 2, logits.shape
|
||||||
|
|
||||||
|
if blank_penalty != 0:
|
||||||
|
logits[:, 0] -= blank_penalty
|
||||||
|
|
||||||
|
# If logit for blank token is positive, the output should be blank (Bernoulli)
|
||||||
|
y = torch.zeros_like(logits[:, 0], dtype=torch.int64, device=device)
|
||||||
|
# If logit for blank token is negative, the output should be the argmax
|
||||||
|
# of the rest of the logits
|
||||||
|
y += torch.where(logits[:, 0] <= 0, logits[:, 1:].argmax(dim=1) + 1, 0)
|
||||||
|
# Convert y to list
|
||||||
|
y = y.tolist()
|
||||||
|
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v not in (blank_id, unk_id):
|
||||||
|
hyps[i].append(v)
|
||||||
|
timestamps[i].append(t)
|
||||||
|
scores[i].append(logits[i, v].item())
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
ans_scores = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
ans_scores.append(scores[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
hyps=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
scores=ans_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hypothesis:
|
||||||
|
# The predicted tokens so far.
|
||||||
|
# Newly predicted tokens are appended to `ys`.
|
||||||
|
ys: List[int]
|
||||||
|
|
||||||
|
# The log prob of ys.
|
||||||
|
# It contains only one entry.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which ys[i] is decoded
|
||||||
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
# the lm score for next token given the current ys
|
||||||
|
lm_score: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# the RNNLM states (h and c in LSTM)
|
||||||
|
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
|
||||||
|
# N-gram LM state
|
||||||
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
# Context graph state
|
||||||
|
context_state: Optional[ContextState] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self) -> str:
|
||||||
|
"""Return a string representation of self.ys"""
|
||||||
|
return "_".join(map(str, self.ys))
|
||||||
|
|
||||||
|
|
||||||
|
class HypothesisList(object):
|
||||||
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
A dict of Hypotheses. Its key is its `value.key`.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
self._data = {}
|
||||||
|
else:
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
|
`log-sum-exp` with the existed one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be added.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
if key in self:
|
||||||
|
old_hyp = self._data[key] # shallow copy
|
||||||
|
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
|
||||||
|
else:
|
||||||
|
self._data[key] = hyp
|
||||||
|
|
||||||
|
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||||
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
|
the largest `log_prob`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `log_prob`.
|
||||||
|
"""
|
||||||
|
if length_norm:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||||
|
else:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||||
|
|
||||||
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be removed from `self`.
|
||||||
|
Note: It must be contained in `self`. Otherwise,
|
||||||
|
an exception is raised.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
assert key in self, f"{key} does not exist"
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
|
with `log_prob` being greater than the given `threshold`.
|
||||||
|
"""
|
||||||
|
ans = HypothesisList()
|
||||||
|
for _, hyp in self._data.items():
|
||||||
|
if hyp.log_prob > threshold:
|
||||||
|
ans.add(hyp) # shallow copy
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
|
||||||
|
"""Return the top-k hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
"""
|
||||||
|
hyps = list(self._data.items())
|
||||||
|
|
||||||
|
if length_norm:
|
||||||
|
hyps = sorted(
|
||||||
|
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
|
||||||
|
)[:k]
|
||||||
|
else:
|
||||||
|
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||||
|
|
||||||
|
ans = HypothesisList(dict(hyps))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.values())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
s = []
|
||||||
|
for key in self:
|
||||||
|
s.append(key)
|
||||||
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyps:
|
||||||
|
len(hyps) == batch_size. It contains the current hypothesis for
|
||||||
|
each utterance in the batch.
|
||||||
|
Returns:
|
||||||
|
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||||
|
the shape is on CPU.
|
||||||
|
"""
|
||||||
|
num_hyps = [len(h) for h in hyps]
|
||||||
|
|
||||||
|
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||||
|
# to get exclusive sum later.
|
||||||
|
num_hyps.insert(0, 0)
|
||||||
|
|
||||||
|
num_hyps = torch.tensor(num_hyps)
|
||||||
|
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||||
|
ans = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
# For blank symbol, log-prob is log-sigmoid of the score
|
||||||
|
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])
|
||||||
|
# Additionally, to ensure the the probs of blank and non-blank sum to 1, we
|
||||||
|
# need to add the following term to the log-probs of non-blank symbols. This
|
||||||
|
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||||
|
nb_shift = logp_b - logits[..., 0]
|
||||||
|
nb_shift = nb_shift.unsqueeze(-1)
|
||||||
|
log_probs1 = (logits[..., 1:] / temperature).log_softmax(
|
||||||
|
dim=-1
|
||||||
|
) + nb_shift # (num_hyps, vocab_size-1)
|
||||||
|
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
hyps=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_lm_shallow_fusion(
|
||||||
|
model: nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
LM: LmScorer,
|
||||||
|
beam: int = 4,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
subtract_ilm: bool = True,
|
||||||
|
ilm_scale: float = 0.1,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Modified_beam_search + NN LM shallow fusion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Transducer):
|
||||||
|
The transducer model
|
||||||
|
encoder_out (torch.Tensor):
|
||||||
|
Encoder output in (N,T,C)
|
||||||
|
encoder_out_lens (torch.Tensor):
|
||||||
|
A 1-D tensor of shape (N,), containing the number of
|
||||||
|
valid frames in encoder_out before padding.
|
||||||
|
sp:
|
||||||
|
Sentence piece generator.
|
||||||
|
LM (LmScorer):
|
||||||
|
A neural net LM, e.g RNN or Transformer
|
||||||
|
beam (int, optional):
|
||||||
|
Beam size. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
assert LM is not None
|
||||||
|
lm_scale = LM.lm_scale
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
sos_id = getattr(LM, "sos_id", 1)
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
|
lens = torch.tensor([1]).to(device)
|
||||||
|
init_score, init_states = LM.score_token(sos_token, lens)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
state=init_states,
|
||||||
|
lm_score=init_score.reshape(-1),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
# For blank symbol, log-prob is log-sigmoid of the score
|
||||||
|
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])
|
||||||
|
# Additionally, to ensure the the probs of blank and non-blank sum to 1, we
|
||||||
|
# need to add the following term to the log-probs of non-blank symbols. This
|
||||||
|
# is equivalent to log(1 - sigmoid(logits[..., 0])).
|
||||||
|
nb_shift = logp_b - logits[..., 0]
|
||||||
|
nb_shift = nb_shift.unsqueeze(-1)
|
||||||
|
log_probs1 = (logits[..., 1:]).log_softmax(dim=-1) + nb_shift
|
||||||
|
if subtract_ilm:
|
||||||
|
ilm_logits = model.joiner(
|
||||||
|
torch.zeros_like(
|
||||||
|
current_encoder_out, device=current_encoder_out.device
|
||||||
|
),
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
)
|
||||||
|
ilm_logits = ilm_logits.squeeze(1).squeeze(1)
|
||||||
|
ilm_logp_b = torch.nn.functional.logsigmoid(ilm_logits[..., 0])
|
||||||
|
ilm_nb_shift = ilm_logp_b - ilm_logits[..., 0]
|
||||||
|
ilm_nb_shift = ilm_nb_shift.unsqueeze(-1)
|
||||||
|
ilm_log_probs = (ilm_logits[..., 1:]).log_softmax(dim=-1) + ilm_nb_shift
|
||||||
|
log_probs1 -= ilm_scale * ilm_log_probs
|
||||||
|
|
||||||
|
log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1)
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
for all hyps with a non-blank new token, score this token.
|
||||||
|
It is a little confusing here because this for-loop
|
||||||
|
looks very similar to the one below. Here, we go through all
|
||||||
|
top-k tokens and only add the non-blanks ones to the token_list.
|
||||||
|
`LM` will score those tokens given the LM states. Note that
|
||||||
|
the variable `scores` is the LM score after seeing the new
|
||||||
|
non-blank token.
|
||||||
|
"""
|
||||||
|
token_list = [] # a list of list
|
||||||
|
hs = []
|
||||||
|
cs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
token_list.append([new_token])
|
||||||
|
# store the LSTM states
|
||||||
|
hs.append(hyp.state[0])
|
||||||
|
cs.append(hyp.state[1])
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
token_list.append(
|
||||||
|
[sos_id] + hyp.ys[context_size:] + [new_token]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(token_list) != 0:
|
||||||
|
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||||
|
)
|
||||||
|
hs = torch.cat(hs, dim=1).to(device)
|
||||||
|
cs = torch.cat(cs, dim=1).to(device)
|
||||||
|
state = (hs, cs)
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.nn.utils.rnn.pad_sequence(
|
||||||
|
tokens_list, batch_first=True, padding_value=0.0
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.to(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
|
||||||
|
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||||
|
|
||||||
|
count = 0 # index, used to locate score and lm states
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
ys = hyp.ys[:]
|
||||||
|
|
||||||
|
lm_score = hyp.lm_score
|
||||||
|
state = hyp.state
|
||||||
|
|
||||||
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
|
ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
|
||||||
|
|
||||||
|
lm_score = scores[count]
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
state = (
|
||||||
|
lm_states[0][:, count, :].unsqueeze(1),
|
||||||
|
lm_states[1][:, count, :].unsqueeze(1),
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=ys,
|
||||||
|
log_prob=hyp_log_prob,
|
||||||
|
state=state,
|
||||||
|
lm_score=lm_score,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
hyps=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
908
egs/librispeech/ASR/zipformer_hat/decode.py
Executable file
908
egs/librispeech/ASR/zipformer_hat/decode.py
Executable file
@ -0,0 +1,908 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search (one best)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(7) fast beam search (with LG)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
|
)
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- modified_beam_search_LODR
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=20.0,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding-method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding-method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subtract-ilm",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Subtract the ILME LM score from the NN LM score.
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ilm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""The scale of the ILME LM that will be subtracted.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens-ngram",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""The order of the ngram lm.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="ID of the backoff symbol in the ngram LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding-method is modified_beam_search and
|
||||||
|
modified_beam_search_LODR.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding-method is modified_beam_search and
|
||||||
|
modified_beam_search_LODR.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
batch: dict,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
|
ngram_lm=None,
|
||||||
|
ngram_lm_scale: float = 0.0,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
LM:
|
||||||
|
A neural network language model.
|
||||||
|
ngram_lm:
|
||||||
|
A ngram language model
|
||||||
|
ngram_lm_scale:
|
||||||
|
The scale for the ngram language model.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||||
|
pad_len = 30
|
||||||
|
feature_lens += pad_len
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, pad_len),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
# context_graph=context_graph,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
subtract_ilm=params.subtract_ilm,
|
||||||
|
ilm_scale=params.ilm_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search": hyps}
|
||||||
|
elif "modified_beam_search" in params.decoding_method:
|
||||||
|
prefix = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
prefix += f"-context-score-{params.context_score}"
|
||||||
|
return {prefix: hyps}
|
||||||
|
else:
|
||||||
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
|
ngram_lm=None,
|
||||||
|
ngram_lm_scale: float = 0.0,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
|
word_table=word_table,
|
||||||
|
batch=batch,
|
||||||
|
LM=LM,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
|
)
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
assert (
|
||||||
|
"," not in params.chunk_size
|
||||||
|
), "chunk_size should be one value in decoding."
|
||||||
|
assert (
|
||||||
|
"," not in params.left_context_frames
|
||||||
|
), "left_context_frames should be one value in decoding."
|
||||||
|
params.suffix += f"-chunk-{params.chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.decoding_method in (
|
||||||
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_LODR",
|
||||||
|
):
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
if (
|
||||||
|
params.use_shallow_fusion
|
||||||
|
or params.decoding_method == "modified_beam_search_lm_shallow_fusion"
|
||||||
|
):
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||||
|
try:
|
||||||
|
import kenlm
|
||||||
|
except ImportError:
|
||||||
|
print("Please install kenlm first. You can use")
|
||||||
|
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
|
||||||
|
print("to install it")
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(-1)
|
||||||
|
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
|
||||||
|
logging.info(f"lm filename: {ngram_file_name}")
|
||||||
|
ngram_lm = kenlm.Model(ngram_file_name)
|
||||||
|
ngram_lm_scale = None # use a list to search
|
||||||
|
|
||||||
|
elif params.decoding_method == "modified_beam_search_LODR":
|
||||||
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
|
logging.info(f"Loading token level lm: {lm_filename}")
|
||||||
|
ngram_lm = NgramLm(
|
||||||
|
str(params.lang_dir / lm_filename),
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
ngram_lm_scale = params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
word_table = None
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
word_table = None
|
||||||
|
|
||||||
|
if "modified_beam_search" in params.decoding_method:
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(line.strip())
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(sp.encode(contexts))
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
|
LM=LM,
|
||||||
|
ngram_lm=ngram_lm,
|
||||||
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/librispeech/ASR/zipformer_hat/decoder.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/decoder.py
|
1
egs/librispeech/ASR/zipformer_hat/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/encoder_interface.py
|
1
egs/librispeech/ASR/zipformer_hat/export-onnx-streaming.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/export-onnx-streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/export-onnx-streaming.py
|
1
egs/librispeech/ASR/zipformer_hat/export-onnx.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/export-onnx.py
|
1
egs/librispeech/ASR/zipformer_hat/export.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/export.py
|
1
egs/librispeech/ASR/zipformer_hat/generate_averaged_model.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/generate_averaged_model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/generate_averaged_model.py
|
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/jit_pretrained.py
|
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained_ctc.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/jit_pretrained_ctc.py
|
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained_streaming.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/jit_pretrained_streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/jit_pretrained_streaming.py
|
1
egs/librispeech/ASR/zipformer_hat/joiner.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/joiner.py
|
359
egs/librispeech/ASR/zipformer_hat/model.py
Normal file
359
egs/librispeech/ASR/zipformer_hat/model.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
class AsrModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
decoder: Optional[nn.Module] = None,
|
||||||
|
joiner: Optional[nn.Module] = None,
|
||||||
|
encoder_dim: int = 384,
|
||||||
|
decoder_dim: int = 512,
|
||||||
|
vocab_size: int = 500,
|
||||||
|
use_transducer: bool = True,
|
||||||
|
use_ctc: bool = False,
|
||||||
|
):
|
||||||
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
|
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||||
|
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||||
|
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_embed:
|
||||||
|
It is a Convolutional 2D subsampling module. It converts
|
||||||
|
an input of shape (N, T, idim) to an output of of shape
|
||||||
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
decoder:
|
||||||
|
It is the prediction network in the paper. Its input shape
|
||||||
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
|
It should contain one attribute: `blank_id`.
|
||||||
|
It is used when use_transducer is True.
|
||||||
|
joiner:
|
||||||
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||||
|
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||||
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
|
It is used when use_transducer is True.
|
||||||
|
use_transducer:
|
||||||
|
Whether use transducer head. Default: True.
|
||||||
|
use_ctc:
|
||||||
|
Whether use CTC head. Default: False.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
use_transducer or use_ctc
|
||||||
|
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||||
|
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.use_transducer = use_transducer
|
||||||
|
if use_transducer:
|
||||||
|
# Modules for Transducer head
|
||||||
|
assert decoder is not None
|
||||||
|
assert hasattr(decoder, "blank_id")
|
||||||
|
assert joiner is not None
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
self.joiner = joiner
|
||||||
|
|
||||||
|
self.simple_am_proj = ScaledLinear(
|
||||||
|
encoder_dim, vocab_size, initial_scale=0.25
|
||||||
|
)
|
||||||
|
self.simple_lm_proj = ScaledLinear(
|
||||||
|
decoder_dim, vocab_size, initial_scale=0.25
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert decoder is None
|
||||||
|
assert joiner is None
|
||||||
|
|
||||||
|
self.use_ctc = use_ctc
|
||||||
|
if use_ctc:
|
||||||
|
# Modules for CTC head
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(encoder_dim, vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_encoder(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute encoder outputs.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
"""
|
||||||
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def forward_ctc(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute CTC loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
targets:
|
||||||
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
"""
|
||||||
|
# Compute CTC log-prob
|
||||||
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
|
targets=targets,
|
||||||
|
input_lengths=encoder_out_lens,
|
||||||
|
target_lengths=target_lengths,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
return ctc_loss
|
||||||
|
|
||||||
|
def forward_transducer(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
y_lens: torch.Tensor,
|
||||||
|
prune_range: int = 5,
|
||||||
|
am_scale: float = 0.0,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute Transducer loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
y:
|
||||||
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
|
utterance.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
"""
|
||||||
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
blank_id = self.decoder.blank_id
|
||||||
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
# decoder_out: [B, S + 1, decoder_dim]
|
||||||
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
|
# Note: y does not start with SOS
|
||||||
|
# y_padded : [B, S]
|
||||||
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
|
y_padded = y_padded.to(torch.int64)
|
||||||
|
boundary = torch.zeros(
|
||||||
|
(encoder_out.size(0), 4),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=encoder_out.device,
|
||||||
|
)
|
||||||
|
boundary[:, 2] = y_lens
|
||||||
|
boundary[:, 3] = encoder_out_lens
|
||||||
|
|
||||||
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
|
# if self.training and random.random() < 0.25:
|
||||||
|
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||||
|
# if self.training and random.random() < 0.25:
|
||||||
|
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
lm=lm.float(),
|
||||||
|
am=am.float(),
|
||||||
|
symbols=y_padded,
|
||||||
|
termination_symbol=blank_id,
|
||||||
|
lm_only_scale=lm_scale,
|
||||||
|
am_only_scale=am_scale,
|
||||||
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
|
return_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ranges : [B, T, prune_range]
|
||||||
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
|
px_grad=px_grad,
|
||||||
|
py_grad=py_grad,
|
||||||
|
boundary=boundary,
|
||||||
|
s_range=prune_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||||
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||||
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
|
ranges=ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
|
# project_input=False since we applied the decoder's input projections
|
||||||
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
logits=logits.float(),
|
||||||
|
symbols=y_padded,
|
||||||
|
ranges=ranges,
|
||||||
|
termination_symbol=blank_id,
|
||||||
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
|
use_hat_loss=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
prune_range: int = 5,
|
||||||
|
am_scale: float = 0.0,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
y:
|
||||||
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
|
utterance.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
Returns:
|
||||||
|
Return the transducer losses and CTC loss,
|
||||||
|
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
|
the form:
|
||||||
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
|
||||||
|
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||||
|
|
||||||
|
# Compute encoder outputs
|
||||||
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
row_splits = y.shape.row_splits(1)
|
||||||
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
if self.use_transducer:
|
||||||
|
# Compute transducer loss
|
||||||
|
simple_loss, pruned_loss = self.forward_transducer(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
y=y.to(x.device),
|
||||||
|
y_lens=y_lens,
|
||||||
|
prune_range=prune_range,
|
||||||
|
am_scale=am_scale,
|
||||||
|
lm_scale=lm_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
simple_loss = torch.empty(0)
|
||||||
|
pruned_loss = torch.empty(0)
|
||||||
|
|
||||||
|
if self.use_ctc:
|
||||||
|
# Compute CTC loss
|
||||||
|
targets = y.values
|
||||||
|
ctc_loss = self.forward_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=targets,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ctc_loss = torch.empty(0)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss, ctc_loss
|
1
egs/librispeech/ASR/zipformer_hat/onnx_check.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/onnx_check.py
|
1
egs/librispeech/ASR/zipformer_hat/onnx_decode.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/onnx_decode.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/onnx_decode.py
|
1
egs/librispeech/ASR/zipformer_hat/onnx_pretrained-streaming.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/onnx_pretrained-streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/onnx_pretrained-streaming.py
|
1
egs/librispeech/ASR/zipformer_hat/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/onnx_pretrained.py
|
1
egs/librispeech/ASR/zipformer_hat/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/optim.py
|
1
egs/librispeech/ASR/zipformer_hat/pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/pretrained.py
|
1
egs/librispeech/ASR/zipformer_hat/pretrained_ctc.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/pretrained_ctc.py
|
1
egs/librispeech/ASR/zipformer_hat/profile.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/profile.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/profile.py
|
1
egs/librispeech/ASR/zipformer_hat/scaling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/scaling.py
|
1
egs/librispeech/ASR/zipformer_hat/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/scaling_converter.py
|
1
egs/librispeech/ASR/zipformer_hat/subsampling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/subsampling.py
|
1
egs/librispeech/ASR/zipformer_hat/test_scaling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/test_scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/test_scaling.py
|
1
egs/librispeech/ASR/zipformer_hat/test_subsampling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/test_subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/test_subsampling.py
|
1389
egs/librispeech/ASR/zipformer_hat/train.py
Executable file
1389
egs/librispeech/ASR/zipformer_hat/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_hat/zipformer.py
Symbolic link
1
egs/librispeech/ASR/zipformer_hat/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user