mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
de10e1d5fe
commit
f4fb761129
@ -1,7 +1,7 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
|
||||
for how to run models in this recipe.
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ stop_stage=10
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
dl_dir=/home/work/workspace/aishell
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
|
||||
@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
|
||||
mean of the output is taken. (3) "sum": the output will be summed.
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0.0 <= label_smoothing < 1.0
|
||||
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
|
||||
assert reduction in ("none", "sum", "mean"), reduction
|
||||
self.ignore_index = ignore_index
|
||||
self.label_smoothing = label_smoothing
|
||||
self.reduction = reduction
|
||||
|
||||
@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
|
||||
@ -28,7 +28,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -41,6 +41,10 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def is_cut_long(c: MonoCut) -> bool:
|
||||
return c.duration > 5
|
||||
|
||||
|
||||
def compute_fbank_musan():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
@ -86,7 +90,7 @@ def compute_fbank_musan():
|
||||
recordings=combine(part["recordings"] for part in manifests.values())
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
.filter(is_cut_long)
|
||||
.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/musan_feats",
|
||||
|
||||
@ -26,7 +26,9 @@ from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
from icefall.transformer_lm.model import TransformerLM
|
||||
from icefall.utils import (
|
||||
DecodingResults,
|
||||
add_eos,
|
||||
@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring(
|
||||
return ans
|
||||
|
||||
|
||||
def modified_beam_search_rnnlm_shallow_fusion(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Modified_beam_search + RNNLM shallow fusion
|
||||
|
||||
Args:
|
||||
model (Transducer):
|
||||
The transducer model
|
||||
encoder_out (torch.Tensor):
|
||||
Encoder output in (N,T,C)
|
||||
encoder_out_lens (torch.Tensor):
|
||||
A 1-D tensor of shape (N,), containing the number of
|
||||
valid frames in encoder_out before padding.
|
||||
sp:
|
||||
Sentence piece generator.
|
||||
rnnlm (RnnLmModel):
|
||||
RNNLM
|
||||
rnnlm_scale (float):
|
||||
scale of RNNLM in shallow fusion
|
||||
beam (int, optional):
|
||||
Beam size. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
assert rnnlm is not None
|
||||
lm_scale = rnnlm_scale
|
||||
vocab_size = rnnlm.vocab_size
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
init_score, init_states = rnnlm.score_token(sos_token)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
rnnlm.clean_cache()
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
"""
|
||||
for all hyps with a non-blank new token, score this token.
|
||||
It is a little confusing here because this for-loop
|
||||
looks very similar to the one below. Here, we go through all
|
||||
top-k tokens and only add the non-blanks ones to the token_list.
|
||||
The RNNLM will score those tokens given the LM states. Note that
|
||||
the variable `scores` is the LM score after seeing the new
|
||||
non-blank token.
|
||||
"""
|
||||
token_list = []
|
||||
hs = []
|
||||
cs = []
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
assert new_token != 0, new_token
|
||||
token_list.append([new_token])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
|
||||
# forward RNNLM to get new states and scores
|
||||
if len(token_list) != 0:
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
|
||||
|
||||
count = 0 # index, used to locate score and lm states
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
ys = hyp.ys[:]
|
||||
|
||||
lm_score = hyp.lm_score
|
||||
state = hyp.state
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
state = (
|
||||
lm_states[0][:, count, :].unsqueeze(1),
|
||||
lm_states[1][:, count, :].unsqueeze(1),
|
||||
)
|
||||
count += 1
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=ys,
|
||||
log_prob=hyp_log_prob,
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search_rnnlm_LODR(
|
||||
def modified_beam_search_LODR(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
LODR_lm: NgramLm,
|
||||
LODR_lm_scale: float,
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
LM: LmScorer,
|
||||
beam: int = 4,
|
||||
) -> List[List[int]]:
|
||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||
@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR(
|
||||
sp:
|
||||
Sentence piece generator.
|
||||
LODR_lm:
|
||||
A low order n-gram LM
|
||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||
LODR_lm_scale:
|
||||
The scale of the LODR_lm
|
||||
rnnlm (RnnLmModel):
|
||||
RNNLM, the external language model
|
||||
rnnlm_scale (float):
|
||||
scale of RNNLM in shallow fusion
|
||||
LM:
|
||||
A neural net LM, e.g an RNNLM or transformer LM
|
||||
beam (int, optional):
|
||||
Beam size. Defaults to 4.
|
||||
|
||||
@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR(
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
assert rnnlm is not None
|
||||
lm_scale = rnnlm_scale
|
||||
vocab_size = rnnlm.vocab_size
|
||||
assert LM is not None
|
||||
lm_scale = LM.lm_scale
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR(
|
||||
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
init_score, init_states = rnnlm.score_token(sos_token)
|
||||
lens = torch.tensor([1]).to(device)
|
||||
init_score, init_states = LM.score_token(sos_token, lens)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states, # state of the RNNLM
|
||||
state=init_states, # state of the NN LM
|
||||
lm_score=init_score.reshape(-1),
|
||||
state_cost=NgramLmStateCost(
|
||||
LODR_lm
|
||||
@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR(
|
||||
)
|
||||
)
|
||||
|
||||
rnnlm.clean_cache()
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR(
|
||||
It is a little confusing here because this for-loop
|
||||
looks very similar to the one below. Here, we go through all
|
||||
top-k tokens and only add the non-blanks ones to the token_list.
|
||||
The RNNLM will score those tokens given the LM states. Note that
|
||||
LM will score those tokens given the LM states. Note that
|
||||
the variable `scores` is the LM score after seeing the new
|
||||
non-blank token.
|
||||
"""
|
||||
@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR(
|
||||
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
assert new_token != 0, new_token
|
||||
token_list.append([new_token])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
if LM.lm_type == "rnn":
|
||||
token_list.append([new_token])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
else:
|
||||
# for transformer LM
|
||||
token_list.append(
|
||||
[sos_id] + hyp.ys[context_size:] + [new_token]
|
||||
)
|
||||
|
||||
# forward RNNLM to get new states and scores
|
||||
# forward NN LM to get new states and scores
|
||||
if len(token_list) != 0:
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||
if LM.lm_type == "rnn":
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
state = (hs, cs)
|
||||
else:
|
||||
# for transformer LM
|
||||
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||
tokens_to_score = (
|
||||
torch.nn.utils.rnn.pad_sequence(
|
||||
tokens_list, batch_first=True, padding_value=0.0
|
||||
)
|
||||
.to(device)
|
||||
.to(torch.int64)
|
||||
)
|
||||
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
|
||||
state = None
|
||||
|
||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||
|
||||
count = 0 # index, used to locate score and lm states
|
||||
for i in range(batch_size):
|
||||
@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR(
|
||||
state_cost.lm_score,
|
||||
hyp.state_cost.lm_score,
|
||||
)
|
||||
# score = score + RNNLM_score - LODR_score
|
||||
# LODR_LM_scale is a negative number here
|
||||
# score = score + TDLM_score - LODR_score
|
||||
# LODR_LM_scale should be a negative number here
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
+ LODR_lm_scale * current_ngram_score
|
||||
) # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
state = (
|
||||
lm_states[0][:, count, :].unsqueeze(1),
|
||||
lm_states[1][:, count, :].unsqueeze(1),
|
||||
)
|
||||
if LM.lm_type == "rnn":
|
||||
state = (
|
||||
lm_states[0][:, count, :].unsqueeze(1),
|
||||
lm_states[1][:, count, :].unsqueeze(1),
|
||||
)
|
||||
count += 1
|
||||
else:
|
||||
state_cost = hyp.state_cost
|
||||
@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR(
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def modified_beam_search_lm_shallow_fusion(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
LM: LmScorer,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Modified_beam_search + NN LM shallow fusion
|
||||
|
||||
Args:
|
||||
model (Transducer):
|
||||
The transducer model
|
||||
encoder_out (torch.Tensor):
|
||||
Encoder output in (N,T,C)
|
||||
encoder_out_lens (torch.Tensor):
|
||||
A 1-D tensor of shape (N,), containing the number of
|
||||
valid frames in encoder_out before padding.
|
||||
sp:
|
||||
Sentence piece generator.
|
||||
LM (LmScorer):
|
||||
A neural net LM, e.g RNN or Transformer
|
||||
beam (int, optional):
|
||||
Beam size. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
assert LM is not None
|
||||
lm_scale = LM.lm_scale
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
# get initial lm score and lm state by scoring the "sos" token
|
||||
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
|
||||
lens = torch.tensor([1]).to(device)
|
||||
init_score, init_states = LM.score_token(sos_token, lens)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
)
|
||||
|
||||
lm_scores = torch.cat(
|
||||
[hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps]
|
||||
)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
"""
|
||||
for all hyps with a non-blank new token, score this token.
|
||||
It is a little confusing here because this for-loop
|
||||
looks very similar to the one below. Here, we go through all
|
||||
top-k tokens and only add the non-blanks ones to the token_list.
|
||||
`LM` will score those tokens given the LM states. Note that
|
||||
the variable `scores` is the LM score after seeing the new
|
||||
non-blank token.
|
||||
"""
|
||||
token_list = [] # a list of list
|
||||
hs = []
|
||||
cs = []
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
if LM.lm_type == "rnn":
|
||||
token_list.append([new_token])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
else:
|
||||
# for transformer LM
|
||||
token_list.append(
|
||||
[sos_id] + hyp.ys[context_size:] + [new_token]
|
||||
)
|
||||
|
||||
if len(token_list) != 0:
|
||||
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
|
||||
if LM.lm_type == "rnn":
|
||||
tokens_to_score = (
|
||||
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
|
||||
)
|
||||
hs = torch.cat(hs, dim=1).to(device)
|
||||
cs = torch.cat(cs, dim=1).to(device)
|
||||
state = (hs, cs)
|
||||
else:
|
||||
# for transformer LM
|
||||
tokens_list = [torch.tensor(tokens) for tokens in token_list]
|
||||
tokens_to_score = (
|
||||
torch.nn.utils.rnn.pad_sequence(
|
||||
tokens_list, batch_first=True, padding_value=0.0
|
||||
)
|
||||
.to(device)
|
||||
.to(torch.int64)
|
||||
)
|
||||
|
||||
state = None
|
||||
|
||||
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
|
||||
|
||||
count = 0 # index, used to locate score and lm states
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
ys = hyp.ys[:]
|
||||
|
||||
lm_score = hyp.lm_score
|
||||
state = hyp.state
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
if LM.lm_type == "rnn":
|
||||
state = (
|
||||
lm_states[0][:, count, :].unsqueeze(1),
|
||||
lm_states[1][:, count, :].unsqueeze(1),
|
||||
)
|
||||
count += 1
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=ys,
|
||||
log_prob=hyp_log_prob,
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
tokens=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
@ -33,7 +33,6 @@ from scaling import (
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
import random
|
||||
|
||||
|
||||
class Conformer(EncoderInterface):
|
||||
@ -693,10 +692,7 @@ class ConformerEncoder(nn.Module):
|
||||
output = src
|
||||
|
||||
outputs = []
|
||||
residual = None
|
||||
|
||||
'''
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
@ -705,33 +701,10 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
'''
|
||||
if i in self.aux_layers:
|
||||
outputs.append(output)
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
if i == 0:
|
||||
residual = output
|
||||
elif i in [2,5,8,11,14,17]:
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
output += residual
|
||||
residual = output
|
||||
else:
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
#if i in self.aux_layers:
|
||||
# outputs.append(output)
|
||||
|
||||
#output = self.combiner(outputs)
|
||||
output = self.combiner(outputs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
102
icefall/shared/convert-k2-to-openfst.py
Executable file
102
icefall/shared/convert-k2-to-openfst.py
Executable file
@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script takes as input an FST in k2 format and convert it
|
||||
to an FST in OpenFST format.
|
||||
|
||||
The generated FST is saved into a binary file and its type is
|
||||
StdVectorFst.
|
||||
|
||||
Usage examples:
|
||||
(1) Convert an acceptor
|
||||
|
||||
./convert-k2-to-openfst.py in.pt binary.fst
|
||||
|
||||
(2) Convert a transducer
|
||||
|
||||
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import kaldifst.utils
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--olabels",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""If not empty, the input FST is assumed to be a transducer
|
||||
and we use its attribute specified by "olabels" as the output labels.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_filename",
|
||||
type=str,
|
||||
help="Path to the input FST in k2 format",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"output_filename",
|
||||
type=str,
|
||||
help="Path to the output FST in OpenFst format",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.info(f"{vars(args)}")
|
||||
|
||||
input_filename = args.input_filename
|
||||
output_filename = args.output_filename
|
||||
olabels = args.olabels
|
||||
|
||||
if Path(output_filename).is_file():
|
||||
logging.info(f"{output_filename} already exists - skipping")
|
||||
return
|
||||
|
||||
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
|
||||
logging.info(f"Loading {input_filename}")
|
||||
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
|
||||
if olabels:
|
||||
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
|
||||
|
||||
p = Path(output_filename).parent
|
||||
if not p.is_dir():
|
||||
logging.info(f"Creating {p}")
|
||||
p.mkdir(parents=True)
|
||||
|
||||
logging.info("Converting (May take some time if the input FST is large)")
|
||||
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
|
||||
logging.info(f"Saving to {output_filename}")
|
||||
fst.write(output_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user