from local

This commit is contained in:
dohe0342 2023-02-02 12:39:09 +09:00
parent de10e1d5fe
commit f4fb761129
8 changed files with 428 additions and 308 deletions

View File

@ -1,7 +1,7 @@
# Introduction # Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html> Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
for how to run models in this recipe. for how to run models in this recipe.

View File

@ -31,7 +31,7 @@ stop_stage=10
# - noise # - noise
# - speech # - speech
dl_dir=/home/work/workspace/aishell dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1

View File

@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
mean of the output is taken. (3) "sum": the output will be summed. mean of the output is taken. (3) "sum": the output will be summed.
""" """
super().__init__() super().__init__()
assert 0.0 <= label_smoothing < 1.0 assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
assert reduction in ("none", "sum", "mean"), reduction
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
self.reduction = reduction self.reduction = reduction

View File

@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
Caution: We use a lexicon that contains disambiguation symbols Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt - G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt The generated HLG is saved in $lang_dir/HLG.pt
""" """

View File

@ -28,7 +28,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -41,6 +41,10 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan(): def compute_fbank_musan():
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
@ -86,7 +90,7 @@ def compute_fbank_musan():
recordings=combine(part["recordings"] for part in manifests.values()) recordings=combine(part["recordings"] for part in manifests.values())
) )
.cut_into_windows(10.0) .cut_into_windows(10.0)
.filter(lambda c: c.duration > 5) .filter(is_cut_long)
.compute_and_store_features( .compute_and_store_features(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/musan_feats", storage_path=f"{output_dir}/musan_feats",

View File

@ -26,7 +26,9 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import ( from icefall.utils import (
DecodingResults, DecodingResults,
add_eos, add_eos,
@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring(
return ans return ans
def modified_beam_search_rnnlm_shallow_fusion( def modified_beam_search_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def modified_beam_search_rnnlm_LODR(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
LODR_lm: NgramLm, LODR_lm: NgramLm,
LODR_lm_scale: float, LODR_lm_scale: float,
rnnlm: RnnLmModel, LM: LmScorer,
rnnlm_scale: float,
beam: int = 4, beam: int = 4,
) -> List[List[int]]: ) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with """This function implements LODR (https://arxiv.org/abs/2203.16776) with
@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR(
sp: sp:
Sentence piece generator. Sentence piece generator.
LODR_lm: LODR_lm:
A low order n-gram LM A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale: LODR_lm_scale:
The scale of the LODR_lm The scale of the LODR_lm
rnnlm (RnnLmModel): LM:
RNNLM, the external language model A neural net LM, e.g an RNNLM or transformer LM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional): beam (int, optional):
Beam size. Defaults to 4. Beam size. Defaults to 4.
@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR(
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None assert LM is not None
lm_scale = rnnlm_scale lm_scale = LM.lm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out, input=encoder_out,
@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR(
# get initial lm score and lm state by scoring the "sos" token # get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token) lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)] B = [HypothesisList() for _ in range(N)]
for i in range(N): for i in range(N):
@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, # state of the RNNLM state=init_states, # state of the NN LM
lm_score=init_score.reshape(-1), lm_score=init_score.reshape(-1),
state_cost=NgramLmStateCost( state_cost=NgramLmStateCost(
LODR_lm LODR_lm
@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR(
) )
) )
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0 offset = 0
@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR(
It is a little confusing here because this for-loop It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list. top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that LM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new the variable `scores` is the LM score after seeing the new
non-blank token. non-blank token.
""" """
@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR(
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token if LM.lm_type == "rnn":
token_list.append([new_token]) token_list.append([new_token])
# store the LSTM states # store the LSTM states
hs.append(hyp.state[0]) hs.append(hyp.state[0])
cs.append(hyp.state[1]) cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
# forward RNNLM to get new states and scores # forward NN LM to get new states and scores
if len(token_list) != 0: if len(token_list) != 0:
tokens_to_score = ( x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) if LM.lm_type == "rnn":
) tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
hs = torch.cat(hs, dim=1).to(device) state = None
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states count = 0 # index, used to locate score and lm states
for i in range(batch_size): for i in range(batch_size):
@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR(
state_cost.lm_score, state_cost.lm_score,
hyp.state_cost.lm_score, hyp.state_cost.lm_score,
) )
# score = score + RNNLM_score - LODR_score # score = score + TDLM_score - LODR_score
# LODR_LM_scale is a negative number here # LODR_LM_scale should be a negative number here
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score + LODR_lm_scale * current_ngram_score
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
state = ( if LM.lm_type == "rnn":
lm_states[0][:, count, :].unsqueeze(1), state = (
lm_states[1][:, count, :].unsqueeze(1), lm_states[0][:, count, :].unsqueeze(1),
) lm_states[1][:, count, :].unsqueeze(1),
)
count += 1 count += 1
else: else:
state_cost = hyp.state_cost state_cost = hyp.state_cost
@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR(
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
def modified_beam_search_lm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LM: LmScorer,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + NN LM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
LM (LmScorer):
A neural net LM, e.g RNN or Transformer
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert LM is not None
lm_scale = LM.lm_scale
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
lm_scores = torch.cat(
[hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
`LM` will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = [] # a list of list
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
if LM.lm_type == "rnn":
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
if LM.lm_type == "rnn":
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)

View File

@ -33,7 +33,6 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask
import random
class Conformer(EncoderInterface): class Conformer(EncoderInterface):
@ -693,10 +692,7 @@ class ConformerEncoder(nn.Module):
output = src output = src
outputs = [] outputs = []
residual = None
'''
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
@ -705,33 +701,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
) )
''' if i in self.aux_layers:
outputs.append(output)
for i, mod in enumerate(self.layers): output = self.combiner(outputs)
if i == 0:
residual = output
elif i in [2,5,8,11,14,17]:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
output += residual
residual = output
else:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
#if i in self.aux_layers:
# outputs.append(output)
#output = self.combiner(outputs)
return output return output

View File

@ -0,0 +1,102 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input an FST in k2 format and convert it
to an FST in OpenFST format.
The generated FST is saved into a binary file and its type is
StdVectorFst.
Usage examples:
(1) Convert an acceptor
./convert-k2-to-openfst.py in.pt binary.fst
(2) Convert a transducer
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import kaldifst.utils
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--olabels",
type=str,
default=None,
help="""If not empty, the input FST is assumed to be a transducer
and we use its attribute specified by "olabels" as the output labels.
""",
)
parser.add_argument(
"input_filename",
type=str,
help="Path to the input FST in k2 format",
)
parser.add_argument(
"output_filename",
type=str,
help="Path to the output FST in OpenFst format",
)
return parser.parse_args()
def main():
args = get_args()
logging.info(f"{vars(args)}")
input_filename = args.input_filename
output_filename = args.output_filename
olabels = args.olabels
if Path(output_filename).is_file():
logging.info(f"{output_filename} already exists - skipping")
return
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
logging.info(f"Loading {input_filename}")
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
if olabels:
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
p = Path(output_filename).parent
if not p.is_dir():
logging.info(f"Creating {p}")
p.mkdir(parents=True)
logging.info("Converting (May take some time if the input FST is large)")
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
logging.info(f"Saving to {output_filename}")
fst.write(output_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()