mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
de10e1d5fe
commit
f4fb761129
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
|
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
|
||||||
for how to run models in this recipe.
|
for how to run models in this recipe.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ stop_stage=10
|
|||||||
# - noise
|
# - noise
|
||||||
# - speech
|
# - speech
|
||||||
|
|
||||||
dl_dir=/home/work/workspace/aishell
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
|||||||
@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
|
|||||||
mean of the output is taken. (3) "sum": the output will be summed.
|
mean of the output is taken. (3) "sum": the output will be summed.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert 0.0 <= label_smoothing < 1.0
|
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
|
||||||
|
assert reduction in ("none", "sum", "mean"), reduction
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|||||||
@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
|
|||||||
|
|
||||||
Caution: We use a lexicon that contains disambiguation symbols
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||||
|
|
||||||
The generated HLG is saved in $lang_dir/HLG.pt
|
The generated HLG is saved in $lang_dir/HLG.pt
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -41,6 +41,10 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cut_long(c: MonoCut) -> bool:
|
||||||
|
return c.duration > 5
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_musan():
|
def compute_fbank_musan():
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
@ -86,7 +90,7 @@ def compute_fbank_musan():
|
|||||||
recordings=combine(part["recordings"] for part in manifests.values())
|
recordings=combine(part["recordings"] for part in manifests.values())
|
||||||
)
|
)
|
||||||
.cut_into_windows(10.0)
|
.cut_into_windows(10.0)
|
||||||
.filter(lambda c: c.duration > 5)
|
.filter(is_cut_long)
|
||||||
.compute_and_store_features(
|
.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/musan_feats",
|
storage_path=f"{output_dir}/musan_feats",
|
||||||
|
|||||||
@ -26,7 +26,9 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall import NgramLm, NgramLmStateCost
|
from icefall import NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.rnn_lm.model import RnnLmModel
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
|
from icefall.transformer_lm.model import TransformerLM
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
DecodingResults,
|
DecodingResults,
|
||||||
add_eos,
|
add_eos,
|
||||||
@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search_rnnlm_shallow_fusion(
|
def modified_beam_search_LODR(
|
||||||
model: Transducer,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
rnnlm: RnnLmModel,
|
|
||||||
rnnlm_scale: float,
|
|
||||||
beam: int = 4,
|
|
||||||
return_timestamps: bool = False,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Modified_beam_search + RNNLM shallow fusion
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (Transducer):
|
|
||||||
The transducer model
|
|
||||||
encoder_out (torch.Tensor):
|
|
||||||
Encoder output in (N,T,C)
|
|
||||||
encoder_out_lens (torch.Tensor):
|
|
||||||
A 1-D tensor of shape (N,), containing the number of
|
|
||||||
valid frames in encoder_out before padding.
|
|
||||||
sp:
|
|
||||||
Sentence piece generator.
|
|
||||||
rnnlm (RnnLmModel):
|
|
||||||
RNNLM
|
|
||||||
rnnlm_scale (float):
|
|
||||||
scale of RNNLM in shallow fusion
|
|
||||||
beam (int, optional):
|
|
||||||
Beam size. Defaults to 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
|
||||||
for the i-th utterance.
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
|
||||||
assert rnnlm is not None
|
|
||||||
lm_scale = rnnlm_scale
|
|
||||||
vocab_size = rnnlm.vocab_size
|
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
|
||||||
input=encoder_out,
|
|
||||||
lengths=encoder_out_lens.cpu(),
|
|
||||||
batch_first=True,
|
|
||||||
enforce_sorted=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
|
||||||
sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
|
||||||
context_size = model.decoder.context_size
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
|
|
||||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
|
||||||
N = encoder_out.size(0)
|
|
||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
|
||||||
|
|
||||||
# get initial lm score and lm state by scoring the "sos" token
|
|
||||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
|
||||||
init_score, init_states = rnnlm.score_token(sos_token)
|
|
||||||
|
|
||||||
B = [HypothesisList() for _ in range(N)]
|
|
||||||
for i in range(N):
|
|
||||||
B[i].add(
|
|
||||||
Hypothesis(
|
|
||||||
ys=[blank_id] * context_size,
|
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
|
||||||
state=init_states,
|
|
||||||
lm_score=init_score.reshape(-1),
|
|
||||||
timestamp=[],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
rnnlm.clean_cache()
|
|
||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
finalized_B = []
|
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
|
||||||
start = offset
|
|
||||||
end = offset + batch_size
|
|
||||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
|
||||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
|
||||||
offset = end
|
|
||||||
|
|
||||||
finalized_B = B[batch_size:] + finalized_B
|
|
||||||
B = B[:batch_size]
|
|
||||||
|
|
||||||
hyps_shape = get_hyps_shape(B).to(device)
|
|
||||||
|
|
||||||
A = [list(b) for b in B]
|
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
|
|
||||||
ys_log_probs = torch.cat(
|
|
||||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
) # (num_hyps, context_size)
|
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
|
|
||||||
current_encoder_out = torch.index_select(
|
|
||||||
current_encoder_out,
|
|
||||||
dim=0,
|
|
||||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
|
||||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
|
||||||
|
|
||||||
logits = model.joiner(
|
|
||||||
current_encoder_out,
|
|
||||||
decoder_out,
|
|
||||||
project_input=False,
|
|
||||||
) # (num_hyps, 1, 1, vocab_size)
|
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
|
||||||
|
|
||||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
|
||||||
|
|
||||||
vocab_size = log_probs.size(-1)
|
|
||||||
|
|
||||||
log_probs = log_probs.reshape(-1)
|
|
||||||
|
|
||||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
|
||||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
|
||||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
|
||||||
)
|
|
||||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
|
||||||
"""
|
|
||||||
for all hyps with a non-blank new token, score this token.
|
|
||||||
It is a little confusing here because this for-loop
|
|
||||||
looks very similar to the one below. Here, we go through all
|
|
||||||
top-k tokens and only add the non-blanks ones to the token_list.
|
|
||||||
The RNNLM will score those tokens given the LM states. Note that
|
|
||||||
the variable `scores` is the LM score after seeing the new
|
|
||||||
non-blank token.
|
|
||||||
"""
|
|
||||||
token_list = []
|
|
||||||
hs = []
|
|
||||||
cs = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
|
||||||
hyp = A[i][hyp_idx]
|
|
||||||
|
|
||||||
new_token = topk_token_indexes[k]
|
|
||||||
if new_token not in (blank_id, unk_id):
|
|
||||||
assert new_token != 0, new_token
|
|
||||||
token_list.append([new_token])
|
|
||||||
# store the LSTM states
|
|
||||||
hs.append(hyp.state[0])
|
|
||||||
cs.append(hyp.state[1])
|
|
||||||
|
|
||||||
# forward RNNLM to get new states and scores
|
|
||||||
if len(token_list) != 0:
|
|
||||||
tokens_to_score = (
|
|
||||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
hs = torch.cat(hs, dim=1).to(device)
|
|
||||||
cs = torch.cat(cs, dim=1).to(device)
|
|
||||||
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
|
|
||||||
|
|
||||||
count = 0 # index, used to locate score and lm states
|
|
||||||
for i in range(batch_size):
|
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
|
||||||
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
|
||||||
hyp = A[i][hyp_idx]
|
|
||||||
|
|
||||||
ys = hyp.ys[:]
|
|
||||||
|
|
||||||
lm_score = hyp.lm_score
|
|
||||||
state = hyp.state
|
|
||||||
|
|
||||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
|
||||||
new_token = topk_token_indexes[k]
|
|
||||||
new_timestamp = hyp.timestamp[:]
|
|
||||||
if new_token not in (blank_id, unk_id):
|
|
||||||
|
|
||||||
ys.append(new_token)
|
|
||||||
new_timestamp.append(t)
|
|
||||||
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
|
|
||||||
|
|
||||||
lm_score = scores[count]
|
|
||||||
state = (
|
|
||||||
lm_states[0][:, count, :].unsqueeze(1),
|
|
||||||
lm_states[1][:, count, :].unsqueeze(1),
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
new_hyp = Hypothesis(
|
|
||||||
ys=ys,
|
|
||||||
log_prob=hyp_log_prob,
|
|
||||||
state=state,
|
|
||||||
lm_score=lm_score,
|
|
||||||
timestamp=new_timestamp,
|
|
||||||
)
|
|
||||||
B[i].add(new_hyp)
|
|
||||||
|
|
||||||
B = B + finalized_B
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
|
||||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
|
||||||
ans = []
|
|
||||||
ans_timestamps = []
|
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
|
||||||
for i in range(N):
|
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
|
||||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
|
||||||
|
|
||||||
if not return_timestamps:
|
|
||||||
return ans
|
|
||||||
else:
|
|
||||||
return DecodingResults(
|
|
||||||
tokens=ans,
|
|
||||||
timestamps=ans_timestamps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search_rnnlm_LODR(
|
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
LODR_lm: NgramLm,
|
LODR_lm: NgramLm,
|
||||||
LODR_lm_scale: float,
|
LODR_lm_scale: float,
|
||||||
rnnlm: RnnLmModel,
|
LM: LmScorer,
|
||||||
rnnlm_scale: float,
|
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||||
@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
sp:
|
sp:
|
||||||
Sentence piece generator.
|
Sentence piece generator.
|
||||||
LODR_lm:
|
LODR_lm:
|
||||||
A low order n-gram LM
|
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||||
LODR_lm_scale:
|
LODR_lm_scale:
|
||||||
The scale of the LODR_lm
|
The scale of the LODR_lm
|
||||||
rnnlm (RnnLmModel):
|
LM:
|
||||||
RNNLM, the external language model
|
A neural net LM, e.g an RNNLM or transformer LM
|
||||||
rnnlm_scale (float):
|
|
||||||
scale of RNNLM in shallow fusion
|
|
||||||
beam (int, optional):
|
beam (int, optional):
|
||||||
Beam size. Defaults to 4.
|
Beam size. Defaults to 4.
|
||||||
|
|
||||||
@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
assert rnnlm is not None
|
assert LM is not None
|
||||||
lm_scale = rnnlm_scale
|
lm_scale = LM.lm_scale
|
||||||
vocab_size = rnnlm.vocab_size
|
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
input=encoder_out,
|
input=encoder_out,
|
||||||
@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
|
|
||||||
# get initial lm score and lm state by scoring the "sos" token
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
init_score, init_states = rnnlm.score_token(sos_token)
|
lens = torch.tensor([1]).to(device)
|
||||||
|
init_score, init_states = LM.score_token(sos_token, lens)
|
||||||
|
|
||||||
B = [HypothesisList() for _ in range(N)]
|
B = [HypothesisList() for _ in range(N)]
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
state=init_states, # state of the RNNLM
|
state=init_states, # state of the NN LM
|
||||||
lm_score=init_score.reshape(-1),
|
lm_score=init_score.reshape(-1),
|
||||||
state_cost=NgramLmStateCost(
|
state_cost=NgramLmStateCost(
|
||||||
LODR_lm
|
LODR_lm
|
||||||
@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
rnnlm.clean_cache()
|
|
||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
It is a little confusing here because this for-loop
|
It is a little confusing here because this for-loop
|
||||||
looks very similar to the one below. Here, we go through all
|
looks very similar to the one below. Here, we go through all
|
||||||
top-k tokens and only add the non-blanks ones to the token_list.
|
top-k tokens and only add the non-blanks ones to the token_list.
|
||||||
The RNNLM will score those tokens given the LM states. Note that
|
LM will score those tokens given the LM states. Note that
|
||||||
the variable `scores` is the LM score after seeing the new
|
the variable `scores` is the LM score after seeing the new
|
||||||
non-blank token.
|
non-blank token.
|
||||||
"""
|
"""
|
||||||
@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
|
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
assert new_token != 0, new_token
|
if LM.lm_type == "rnn":
|
||||||
token_list.append([new_token])
|
token_list.append([new_token])
|
||||||
# store the LSTM states
|
# store the LSTM states
|
||||||
hs.append(hyp.state[0])
|
hs.append(hyp.state[0])
|
||||||
cs.append(hyp.state[1])
|
cs.append(hyp.state[1])
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
token_list.append(
|
||||||
|
[sos_id] + hyp.ys[context_size:] + [new_token]
|
||||||
|
)
|
||||||
|
|
||||||
# forward RNNLM to get new states and scores
|
# forward NN LM to get new states and scores
|
||||||
if len(token_list) != 0:
|
if len(token_list) != 0:
|
||||||
|
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
tokens_to_score = (
|
tokens_to_score = (
|
||||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
hs = torch.cat(hs, dim=1).to(device)
|
hs = torch.cat(hs, dim=1).to(device)
|
||||||
cs = torch.cat(cs, dim=1).to(device)
|
cs = torch.cat(cs, dim=1).to(device)
|
||||||
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
|
state = (hs, cs)
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.nn.utils.rnn.pad_sequence(
|
||||||
|
tokens_list, batch_first=True, padding_value=0.0
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.to(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
|
||||||
|
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||||
|
|
||||||
count = 0 # index, used to locate score and lm states
|
count = 0 # index, used to locate score and lm states
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@ -2305,14 +2084,15 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
state_cost.lm_score,
|
state_cost.lm_score,
|
||||||
hyp.state_cost.lm_score,
|
hyp.state_cost.lm_score,
|
||||||
)
|
)
|
||||||
# score = score + RNNLM_score - LODR_score
|
# score = score + TDLM_score - LODR_score
|
||||||
# LODR_LM_scale is a negative number here
|
# LODR_LM_scale should be a negative number here
|
||||||
hyp_log_prob += (
|
hyp_log_prob += (
|
||||||
lm_score[new_token] * lm_scale
|
lm_score[new_token] * lm_scale
|
||||||
+ LODR_lm_scale * current_ngram_score
|
+ LODR_lm_scale * current_ngram_score
|
||||||
) # add the lm score
|
) # add the lm score
|
||||||
|
|
||||||
lm_score = scores[count]
|
lm_score = scores[count]
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
state = (
|
state = (
|
||||||
lm_states[0][:, count, :].unsqueeze(1),
|
lm_states[0][:, count, :].unsqueeze(1),
|
||||||
lm_states[1][:, count, :].unsqueeze(1),
|
lm_states[1][:, count, :].unsqueeze(1),
|
||||||
@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR(
|
|||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_lm_shallow_fusion(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
LM: LmScorer,
|
||||||
|
beam: int = 4,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Modified_beam_search + NN LM shallow fusion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Transducer):
|
||||||
|
The transducer model
|
||||||
|
encoder_out (torch.Tensor):
|
||||||
|
Encoder output in (N,T,C)
|
||||||
|
encoder_out_lens (torch.Tensor):
|
||||||
|
A 1-D tensor of shape (N,), containing the number of
|
||||||
|
valid frames in encoder_out before padding.
|
||||||
|
sp:
|
||||||
|
Sentence piece generator.
|
||||||
|
LM (LmScorer):
|
||||||
|
A neural net LM, e.g RNN or Transformer
|
||||||
|
beam (int, optional):
|
||||||
|
Beam size. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
assert LM is not None
|
||||||
|
lm_scale = LM.lm_scale
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
sos_id = sp.piece_to_id("<sos/eos>")
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
# get initial lm score and lm state by scoring the "sos" token
|
||||||
|
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||||
|
lens = torch.tensor([1]).to(device)
|
||||||
|
init_score, init_states = LM.score_token(sos_token, lens)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
state=init_states,
|
||||||
|
lm_score=init_score.reshape(-1),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
)
|
||||||
|
|
||||||
|
lm_scores = torch.cat(
|
||||||
|
[hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps]
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||||
|
"""
|
||||||
|
for all hyps with a non-blank new token, score this token.
|
||||||
|
It is a little confusing here because this for-loop
|
||||||
|
looks very similar to the one below. Here, we go through all
|
||||||
|
top-k tokens and only add the non-blanks ones to the token_list.
|
||||||
|
`LM` will score those tokens given the LM states. Note that
|
||||||
|
the variable `scores` is the LM score after seeing the new
|
||||||
|
non-blank token.
|
||||||
|
"""
|
||||||
|
token_list = [] # a list of list
|
||||||
|
hs = []
|
||||||
|
cs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
token_list.append([new_token])
|
||||||
|
# store the LSTM states
|
||||||
|
hs.append(hyp.state[0])
|
||||||
|
cs.append(hyp.state[1])
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
token_list.append(
|
||||||
|
[sos_id] + hyp.ys[context_size:] + [new_token]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(token_list) != 0:
|
||||||
|
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||||
|
)
|
||||||
|
hs = torch.cat(hs, dim=1).to(device)
|
||||||
|
cs = torch.cat(cs, dim=1).to(device)
|
||||||
|
state = (hs, cs)
|
||||||
|
else:
|
||||||
|
# for transformer LM
|
||||||
|
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||||
|
tokens_to_score = (
|
||||||
|
torch.nn.utils.rnn.pad_sequence(
|
||||||
|
tokens_list, batch_first=True, padding_value=0.0
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.to(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
state = None
|
||||||
|
|
||||||
|
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||||
|
|
||||||
|
count = 0 # index, used to locate score and lm states
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
ys = hyp.ys[:]
|
||||||
|
|
||||||
|
lm_score = hyp.lm_score
|
||||||
|
state = hyp.state
|
||||||
|
|
||||||
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
|
ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
|
||||||
|
|
||||||
|
lm_score = scores[count]
|
||||||
|
if LM.lm_type == "rnn":
|
||||||
|
state = (
|
||||||
|
lm_states[0][:, count, :].unsqueeze(1),
|
||||||
|
lm_states[1][:, count, :].unsqueeze(1),
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=ys,
|
||||||
|
log_prob=hyp_log_prob,
|
||||||
|
state=state,
|
||||||
|
lm_score=lm_score,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
tokens=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
||||||
|
|||||||
@ -33,7 +33,6 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
class Conformer(EncoderInterface):
|
class Conformer(EncoderInterface):
|
||||||
@ -693,9 +692,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
residual = None
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
@ -705,33 +701,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
'''
|
if i in self.aux_layers:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
output = self.combiner(outputs)
|
||||||
if i == 0:
|
|
||||||
residual = output
|
|
||||||
elif i in [2,5,8,11,14,17]:
|
|
||||||
output = mod(
|
|
||||||
output,
|
|
||||||
pos_emb,
|
|
||||||
src_mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
warmup=warmup,
|
|
||||||
)
|
|
||||||
output += residual
|
|
||||||
residual = output
|
|
||||||
else:
|
|
||||||
output = mod(
|
|
||||||
output,
|
|
||||||
pos_emb,
|
|
||||||
src_mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
warmup=warmup,
|
|
||||||
)
|
|
||||||
#if i in self.aux_layers:
|
|
||||||
# outputs.append(output)
|
|
||||||
|
|
||||||
#output = self.combiner(outputs)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
102
icefall/shared/convert-k2-to-openfst.py
Executable file
102
icefall/shared/convert-k2-to-openfst.py
Executable file
@ -0,0 +1,102 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input an FST in k2 format and convert it
|
||||||
|
to an FST in OpenFST format.
|
||||||
|
|
||||||
|
The generated FST is saved into a binary file and its type is
|
||||||
|
StdVectorFst.
|
||||||
|
|
||||||
|
Usage examples:
|
||||||
|
(1) Convert an acceptor
|
||||||
|
|
||||||
|
./convert-k2-to-openfst.py in.pt binary.fst
|
||||||
|
|
||||||
|
(2) Convert a transducer
|
||||||
|
|
||||||
|
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifst.utils
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--olabels",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""If not empty, the input FST is assumed to be a transducer
|
||||||
|
and we use its attribute specified by "olabels" as the output labels.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"input_filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to the input FST in k2 format",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"output_filename",
|
||||||
|
type=str,
|
||||||
|
help="Path to the output FST in OpenFst format",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.info(f"{vars(args)}")
|
||||||
|
|
||||||
|
input_filename = args.input_filename
|
||||||
|
output_filename = args.output_filename
|
||||||
|
olabels = args.olabels
|
||||||
|
|
||||||
|
if Path(output_filename).is_file():
|
||||||
|
logging.info(f"{output_filename} already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
|
||||||
|
logging.info(f"Loading {input_filename}")
|
||||||
|
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
|
||||||
|
if olabels:
|
||||||
|
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
|
||||||
|
|
||||||
|
p = Path(output_filename).parent
|
||||||
|
if not p.is_dir():
|
||||||
|
logging.info(f"Creating {p}")
|
||||||
|
p.mkdir(parents=True)
|
||||||
|
|
||||||
|
logging.info("Converting (May take some time if the input FST is large)")
|
||||||
|
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
|
||||||
|
logging.info(f"Saving to {output_filename}")
|
||||||
|
fst.write(output_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user