from local

This commit is contained in:
dohe0342 2023-02-02 12:39:09 +09:00
parent de10e1d5fe
commit f4fb761129
8 changed files with 428 additions and 308 deletions

View File

@ -1,7 +1,7 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
for how to run models in this recipe.

View File

@ -31,7 +31,7 @@ stop_stage=10
# - noise
# - speech
dl_dir=/home/work/workspace/aishell
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1

View File

@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
assert reduction in ("none", "sum", "mean"), reduction
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction

View File

@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""

View File

@ -28,7 +28,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -41,6 +41,10 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
@ -86,7 +90,7 @@ def compute_fbank_musan():
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",

View File

@ -26,7 +26,9 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import (
DecodingResults,
add_eos,
@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring(
return ans
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)
def modified_beam_search_rnnlm_LODR(
def modified_beam_search_LODR(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LODR_lm: NgramLm,
LODR_lm_scale: float,
rnnlm: RnnLmModel,
rnnlm_scale: float,
LM: LmScorer,
beam: int = 4,
) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR(
sp:
Sentence piece generator.
LODR_lm:
A low order n-gram LM
A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale:
The scale of the LODR_lm
rnnlm (RnnLmModel):
RNNLM, the external language model
rnnlm_scale (float):
scale of RNNLM in shallow fusion
LM:
A neural net LM, e.g an RNNLM or transformer LM
beam (int, optional):
Beam size. Defaults to 4.
@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR(
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
assert LM is not None
lm_scale = LM.lm_scale
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR(
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, # state of the RNNLM
state=init_states, # state of the NN LM
lm_score=init_score.reshape(-1),
state_cost=NgramLmStateCost(
LODR_lm
@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR(
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR(
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
LM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR(
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
if LM.lm_type == "rnn":
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
# forward RNNLM to get new states and scores
# forward NN LM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR(
state_cost.lm_score,
hyp.state_cost.lm_score,
)
# score = score + RNNLM_score - LODR_score
# LODR_LM_scale is a negative number here
# score = score + TDLM_score - LODR_score
# LODR_LM_scale should be a negative number here
hyp_log_prob += (
lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score
) # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
if LM.lm_type == "rnn":
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
else:
state_cost = hyp.state_cost
@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR(
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search_lm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
LM: LmScorer,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + NN LM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
LM (LmScorer):
A neural net LM, e.g RNN or Transformer
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert LM is not None
lm_scale = LM.lm_scale
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
lens = torch.tensor([1]).to(device)
init_score, init_states = LM.score_token(sos_token, lens)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
lm_scores = torch.cat(
[hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
`LM` will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = [] # a list of list
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
if LM.lm_type == "rnn":
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
else:
# for transformer LM
token_list.append(
[sos_id] + hyp.ys[context_size:] + [new_token]
)
if len(token_list) != 0:
x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device)
if LM.lm_type == "rnn":
tokens_to_score = (
torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
state = (hs, cs)
else:
# for transformer LM
tokens_list = [torch.tensor(tokens) for tokens in token_list]
tokens_to_score = (
torch.nn.utils.rnn.pad_sequence(
tokens_list, batch_first=True, padding_value=0.0
)
.to(device)
.to(torch.int64)
)
state = None
scores, lm_states = LM.score_token(tokens_to_score, x_lens, state)
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score
lm_score = scores[count]
if LM.lm_type == "rnn":
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestamp=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)

View File

@ -33,7 +33,6 @@ from scaling import (
from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
import random
class Conformer(EncoderInterface):
@ -693,9 +692,6 @@ class ConformerEncoder(nn.Module):
output = src
outputs = []
residual = None
'''
for i, mod in enumerate(self.layers):
output = mod(
@ -705,33 +701,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
'''
if i in self.aux_layers:
outputs.append(output)
for i, mod in enumerate(self.layers):
if i == 0:
residual = output
elif i in [2,5,8,11,14,17]:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
output += residual
residual = output
else:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
#if i in self.aux_layers:
# outputs.append(output)
#output = self.combiner(outputs)
output = self.combiner(outputs)
return output

View File

@ -0,0 +1,102 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input an FST in k2 format and convert it
to an FST in OpenFST format.
The generated FST is saved into a binary file and its type is
StdVectorFst.
Usage examples:
(1) Convert an acceptor
./convert-k2-to-openfst.py in.pt binary.fst
(2) Convert a transducer
./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import kaldifst.utils
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--olabels",
type=str,
default=None,
help="""If not empty, the input FST is assumed to be a transducer
and we use its attribute specified by "olabels" as the output labels.
""",
)
parser.add_argument(
"input_filename",
type=str,
help="Path to the input FST in k2 format",
)
parser.add_argument(
"output_filename",
type=str,
help="Path to the output FST in OpenFst format",
)
return parser.parse_args()
def main():
args = get_args()
logging.info(f"{vars(args)}")
input_filename = args.input_filename
output_filename = args.output_filename
olabels = args.olabels
if Path(output_filename).is_file():
logging.info(f"{output_filename} already exists - skipping")
return
assert Path(input_filename).is_file(), f"{input_filename} does not exist"
logging.info(f"Loading {input_filename}")
k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
if olabels:
assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
p = Path(output_filename).parent
if not p.is_dir():
logging.info(f"Creating {p}")
p.mkdir(parents=True)
logging.info("Converting (May take some time if the input FST is large)")
fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
logging.info(f"Saving to {output_filename}")
fst.write(output_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()