Merge remote-tracking branch 'dan/master' into rnnt-modified

This commit is contained in:
Fangjun Kuang 2022-02-07 16:43:14 +08:00
commit 4159d6fbf6
15 changed files with 129 additions and 50 deletions

View File

@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
import argparse import argparse
import logging import logging
import math import math
from typing import List
from pathlib import Path from pathlib import Path
from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from conformer import Conformer from conformer import Conformer
@ -57,10 +58,10 @@ from joiner import Joiner
from model import Transducer from model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -123,6 +124,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 6,
# parameters for loss # parameters for loss
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -357,9 +366,17 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor supervisions, subsampling_factor=params.subsampling_factor
) )
token_ids = graph_compiler.texts_to_ids(texts) if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
# Works with a BPE model
decoding_graph = graph_compiler.compile(token_ids) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
# Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts)
else:
raise ValueError(
f"Unsupported type of graph compiler: {type(graph_compiler)}"
)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -584,12 +601,38 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler( if "lang_bpe" in params.lang_dir:
params.lang_dir, graph_compiler = BpeCtcTrainingGraphCompiler(
device=device, params.lang_dir,
sos_token="<sos/eos>", device=device,
eos_token="<sos/eos>", sos_token="<sos/eos>",
) eos_token="<sos/eos>",
)
elif "lang_phone" in params.lang_dir:
assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
graph_compiler.sos_id = 1
graph_compiler.eos_id = 1
else:
raise ValueError(
f"Unsupported type of lang dir (we expected it to have "
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
)
logging.info("About to create model") logging.info("About to create model")
model = Conformer( model = Conformer(
@ -607,7 +650,9 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) # Note: find_unused_parameters=True is needed in case we
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos) decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0

View File

@ -99,6 +99,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos) decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0

View File

@ -101,6 +101,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=sos_id) sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -47,7 +47,7 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -98,6 +98,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)

View File

@ -59,6 +59,7 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
@ -159,7 +160,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -173,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -183,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -191,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -224,7 +224,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -239,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -249,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -257,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
return decoding_graph return decoding_graph
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of word IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
return word_ids_list
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec. """Convert a list of transcript texts to an FsaVec.