Merge 6e0133902f633dad5c9d28e040da660ed2c184e3 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Xiaoyu Yang 2025-06-27 11:32:43 +00:00 committed by GitHub
commit d5c3ac833c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 290 additions and 57 deletions

View File

@ -96,12 +96,15 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts, store_transcripts_and_timestamps,
str2bool, str2bool,
write_error_stats, write_error_stats_with_timestamps,
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -165,6 +168,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -237,7 +247,7 @@ def decode_one_batch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -284,10 +294,12 @@ def decode_one_batch(
) )
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if isinstance(encoder_out, list):
encoder_out = encoder_out[-1] # the last item is final output
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( res = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -295,63 +307,72 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( res = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( res = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
tokens = []
timestamps = []
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search( res = greedy_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
return_timestamps=True,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( res = beam_search(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
return_timestamps=True,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) tokens.extend(res.tokens)
timestamps.extend(res.timestamps)
res = DecodingResults(hyps=tokens, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
sp=sp,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": (hyps, timestamps)}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): (hyps, timestamps)
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
def decode_dataset( def decode_dataset(
@ -360,7 +381,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -378,9 +399,12 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements: Its value is a list of tuples. Each tuple contains five elements:
The first is the reference transcript, and the second is the - cut_id
predicted result. - reference transcript
- predicted result
- timestamp of reference transcript
- timestamp of predicted result
""" """
num_cuts = 0 num_cuts = 0
@ -390,15 +414,27 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
@ -407,12 +443,16 @@ def decode_dataset(
batch=batch, batch=batch,
) )
for name, hyps in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -428,23 +468,28 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
): ):
test_set_wers = dict() test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts_and_timestamps(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -455,6 +500,19 @@ def save_results(
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
file=f,
)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -462,6 +520,13 @@ def save_results(
note = "" note = ""
logging.info(s) logging.info(s)
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():
@ -511,7 +576,7 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -580,9 +645,9 @@ def main():
) )
) )
else: else:
assert params.avg > 0 assert params.avg > 0, params.avg
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1 assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
@ -606,6 +671,9 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")

View File

@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module):
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``true``, applies tanh to memory elements. (default: ``false``) If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional): negative_inf (float, optional):
Value to use for negative infinity in attention weights. (default: -1e8) Value to use for negative infinity in attention weights. (default: -1e8),
output_layers:
A list of integers containing the id of emformer layers whose activations
will be returned
""" """
def __init__( def __init__(
@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
output_layers: List[int] = None,
): ):
super().__init__() super().__init__()
@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module):
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.memory_size = memory_size self.memory_size = memory_size
self.cnn_module_kernel = cnn_module_kernel self.cnn_module_kernel = cnn_module_kernel
self.output_layers = output_layers
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them.""" """Hard copy each chunk's right context and concat them."""
@ -1361,7 +1366,8 @@ class EmformerEncoder(nn.Module):
padding_mask = make_pad_mask(attention_mask.shape[1] - U + output_lengths) padding_mask = make_pad_mask(attention_mask.shape[1] - U + output_lengths)
output = utterance output = utterance
for layer in self.emformer_layers: layer_results = []
for layer_index, layer in enumerate(self.emformer_layers):
output, right_context = layer( output, right_context = layer(
output, output,
right_context, right_context,
@ -1369,8 +1375,11 @@ class EmformerEncoder(nn.Module):
padding_mask=padding_mask, padding_mask=padding_mask,
warmup=warmup, warmup=warmup,
) )
if layer_index in self.output_layers:
# (T, N, C) --> (N, T, C)
layer_results.append(output.permute(1, 0, 2))
return output, output_lengths return layer_results, output_lengths
@torch.jit.export @torch.jit.export
def infer( def infer(
@ -1540,6 +1549,7 @@ class Emformer(EncoderInterface):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
middle_output_layer: int = None, # 0-based layer index
): ):
super().__init__() super().__init__()
@ -1568,6 +1578,17 @@ class Emformer(EncoderInterface):
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < num_encoder_layers
), f"Invalid middle output layer"
output_layers.append(middle_output_layer)
# The last layer is always needed.
output_layers.append(num_encoder_layers - 1)
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length=chunk_length // subsampling_factor, chunk_length=chunk_length // subsampling_factor,
d_model=d_model, d_model=d_model,
@ -1582,6 +1603,7 @@ class Emformer(EncoderInterface):
memory_size=memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
output_layers=output_layers, # for distillation
) )
def forward( def forward(
@ -1619,9 +1641,7 @@ class Emformer(EncoderInterface):
x_lens = (((x_lens - 1) >> 1) - 1) >> 1 x_lens = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return output, output_lengths return output, output_lengths

View File

@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from emformer import Emformer from emformer import Emformer
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
@ -165,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of entries in the memory for the Emformer", help="Number of entries in the memory for the Emformer",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=8,
help="On which encoder layer to perform KD"
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks"
)
# distillation related args
parser.add_argument(
"--distil-delta",
type=int,
default=None,
help="Offset when doing KD"
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -408,6 +444,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10.0,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
@ -446,6 +483,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
left_context_length=params.left_context_length, left_context_length=params.left_context_length,
right_context_length=params.right_context_length, right_context_length=params.right_context_length,
memory_size=params.memory_size, memory_size=params.memory_size,
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
) )
return encoder return encoder
@ -483,6 +523,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
distil_delta=params.distil_delta if params.enable_distillation else 0,
) )
return model return model
@ -602,6 +644,19 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch):
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
for cut in cuts_pre_mixed:
cb = cut.codebook_indexes
print(f"All cuts have codebook indexes")
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -642,8 +697,14 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, codebook_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -651,6 +712,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
codebook_indexes=codebook_indexes,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -661,6 +723,10 @@ def compute_loss(
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if is_training and params.enable_distillation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -681,6 +747,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info
@ -894,6 +962,11 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD"
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else: else:
@ -959,10 +1032,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts() train_cuts += librispeech.train_clean_360_cuts()
else: train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -992,14 +1065,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer=optimizer, # optimizer=optimizer,
sp=sp, # sp=sp,
params=params, # params=params,
) # )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -40,6 +41,8 @@ class Transducer(nn.Module):
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int, vocab_size: int,
num_codebooks: int = 0,
distil_delta: int=None,
): ):
""" """
Args: Args:
@ -69,6 +72,16 @@ class Transducer(nn.Module):
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
from multi_quantization.prediction import JointCodebookLoss
self.distil_delta = distil_delta
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -80,6 +93,7 @@ class Transducer(nn.Module):
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -112,6 +126,8 @@ class Transducer(nn.Module):
streaming models to emit symbols earlier. streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details. https://arxiv.org/pdf/2211.00490.pdf for more details.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns: Returns:
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -129,7 +145,35 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1] # the last item is the final output
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
# due to different subsampling ratio between hubert teacher and emformer
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
if self.distil_delta is not None:
N = codebook_indexes.shape[0]
T = codebook_indexes.shape[1]
cur_distil_delta = self.distil_delta
# align (teacher) with (student + self.distill_delta)
# suppose self.distil_delta == 2
unvalid_teacher_mask = codebook_indexes == -100
# 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8
codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :]
unvalid_teacher_mask[:, :cur_distil_delta] = True
codebook_indexes.masked_fill_(unvalid_teacher_mask, -100)
# --> -100, -100, 1,2,3,4,5,6,-100,-100
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
@ -204,4 +248,32 @@ class Transducer(nn.Module):
reduction=reduction, reduction=reduction,
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes