Merge remote-tracking branch 'dan/master' into rnnt-modified

This commit is contained in:
Fangjun Kuang 2022-02-07 16:43:14 +08:00
commit 4159d6fbf6
15 changed files with 129 additions and 50 deletions

View File

@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
embedding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
import argparse
import logging
import math
from typing import List
from pathlib import Path
from typing import List
import kaldifeat
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
@ -57,10 +58,10 @@ from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser():
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -123,6 +124,15 @@ def get_parser():
""",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
parser.add_argument(
"--lr-factor",
type=float,
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
@ -357,9 +366,17 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
# Works with a BPE model
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
# Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts)
else:
raise ValueError(
f"Unsupported type of graph compiler: {type(graph_compiler)}"
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
@ -584,12 +601,38 @@ def run(rank, world_size, args):
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
if "lang_bpe" in params.lang_dir:
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
elif "lang_phone" in params.lang_dir:
assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
graph_compiler.sos_id = 1
graph_compiler.eos_id = 1
else:
raise ValueError(
f"Unsupported type of lang dir (we expected it to have "
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
)
logging.info("About to create model")
model = Conformer(
@ -607,7 +650,9 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
# Note: find_unused_parameters=True is needed in case we
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam(
model.parameters(),

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0

View File

@ -99,6 +99,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0

View File

@ -101,6 +101,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -47,7 +47,7 @@ def greedy_search(
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
[blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -98,6 +98,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded)

View File

@ -59,6 +59,7 @@ from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
@ -159,7 +160,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -173,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -183,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -191,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -224,7 +224,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
@ -239,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -249,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -257,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
return decoding_graph
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of word IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
return word_ids_list
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec.