mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge remote-tracking branch 'dan/master' into rnnt-modified
This commit is contained in:
commit
4159d6fbf6
@ -82,17 +82,17 @@ class Decoder(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embeding_out = F.pad(
|
||||
embeding_out, pad=(self.context_size - 1, 0)
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
return embeding_out
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
return embedding_out
|
||||
|
@ -48,6 +48,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from conformer import Conformer
|
||||
@ -57,10 +58,10 @@ from joiner import Joiner
|
||||
from model import Transducer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -123,6 +124,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-factor",
|
||||
type=float,
|
||||
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
# parameters for loss
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
@ -357,9 +366,17 @@ def compute_loss(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
||||
# Works with a BPE model
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
||||
# Works with a phone lexicon
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
@ -584,12 +601,38 @@ def run(rank, world_size, args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
if "lang_bpe" in params.lang_dir:
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
elif "lang_phone" in params.lang_dir:
|
||||
assert params.att_rate == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||
"for pure CTC training when using a phone-based lang dir."
|
||||
)
|
||||
assert params.num_decoder_layers == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. "
|
||||
"Set --num-decoder-layers=0 for pure CTC training when using "
|
||||
"a phone-based lang dir."
|
||||
)
|
||||
graph_compiler = CtcTrainingGraphCompiler(
|
||||
lexicon,
|
||||
device=device,
|
||||
)
|
||||
# Manually add the sos/eos ID with their default values
|
||||
# from the BPE recipe which we're adapting here.
|
||||
graph_compiler.sos_id = 1
|
||||
graph_compiler.eos_id = 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of lang dir (we expected it to have "
|
||||
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = Conformer(
|
||||
@ -607,7 +650,9 @@ def run(rank, world_size, args):
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
# Note: find_unused_parameters=True is needed in case we
|
||||
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
|
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
@ -99,6 +99,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
@ -101,6 +101,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
@ -47,7 +47,7 @@ def greedy_search(
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
@ -48,6 +48,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -98,6 +98,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
|
@ -59,6 +59,7 @@ from typing import List
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
@ -159,7 +160,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -173,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -183,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -191,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -224,7 +224,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -239,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -249,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -257,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
|
||||
|
||||
return decoding_graph
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of word IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of word IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
return word_ids_list
|
||||
|
||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert a list of transcript texts to an FsaVec.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user