mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
Merge 6e0133902f633dad5c9d28e040da660ed2c184e3 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
d5c3ac833c
@ -96,12 +96,15 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
DecodingResults,
|
||||||
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts_and_timestamps,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats_with_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -165,6 +168,13 @@ def get_parser():
|
|||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -237,7 +247,7 @@ def decode_one_batch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
@ -284,10 +294,12 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
if isinstance(encoder_out, list):
|
||||||
|
encoder_out = encoder_out[-1] # the last item is final output
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
res = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -295,63 +307,72 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
res = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
res = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
|
||||||
hyps.append(hyp.split())
|
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
tokens = []
|
||||||
|
timestamps = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
res = greedy_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
res = beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
tokens.extend(res.tokens)
|
||||||
|
timestamps.extend(res.timestamps)
|
||||||
|
res = DecodingResults(hyps=tokens, timestamps=timestamps)
|
||||||
|
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
res=res,
|
||||||
|
sp=sp,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, timestamps)}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
(
|
(
|
||||||
f"beam_{params.beam}_"
|
f"beam_{params.beam}_"
|
||||||
f"max_contexts_{params.max_contexts}_"
|
f"max_contexts_{params.max_contexts}_"
|
||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): (hyps, timestamps)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, timestamps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -360,7 +381,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -378,9 +399,12 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
The first is the reference transcript, and the second is the
|
- cut_id
|
||||||
predicted result.
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamp of reference transcript
|
||||||
|
- timestamp of predicted result
|
||||||
"""
|
"""
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
@ -390,15 +414,27 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
timestamps_ref = []
|
||||||
|
for cut in batch["supervisions"]["cut"]:
|
||||||
|
for s in cut.supervisions:
|
||||||
|
time = []
|
||||||
|
if s.alignment is not None and "word" in s.alignment:
|
||||||
|
time = [
|
||||||
|
aliword.start
|
||||||
|
for aliword in s.alignment["word"]
|
||||||
|
if aliword.symbol != ""
|
||||||
|
]
|
||||||
|
timestamps_ref.append(time)
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -407,12 +443,16 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
timestamps_ref
|
||||||
|
)
|
||||||
|
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
|
||||||
|
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
|
||||||
|
):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -428,23 +468,28 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[
|
||||||
|
str,
|
||||||
|
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
|
||||||
|
],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
test_set_delays[key] = (mean_delay, var_delay)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
@ -455,6 +500,19 @@ def save_results(
|
|||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
|
delays_info = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(delays_info, "w") as f:
|
||||||
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
print(
|
||||||
|
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -462,6 +520,13 @@ def save_results(
|
|||||||
note = ""
|
note = ""
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_delays:
|
||||||
|
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
@ -511,7 +576,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
@ -580,9 +645,9 @@ def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0
|
assert params.avg > 0, params.avg
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
assert start >= 1
|
assert start >= 1, start
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -606,6 +671,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module):
|
|||||||
tanh_on_mem (bool, optional):
|
tanh_on_mem (bool, optional):
|
||||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||||
negative_inf (float, optional):
|
negative_inf (float, optional):
|
||||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
Value to use for negative infinity in attention weights. (default: -1e8),
|
||||||
|
output_layers:
|
||||||
|
A list of integers containing the id of emformer layers whose activations
|
||||||
|
will be returned
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
output_layers: List[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
self.memory_size = memory_size
|
self.memory_size = memory_size
|
||||||
self.cnn_module_kernel = cnn_module_kernel
|
self.cnn_module_kernel = cnn_module_kernel
|
||||||
|
self.output_layers = output_layers
|
||||||
|
|
||||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Hard copy each chunk's right context and concat them."""
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
@ -1361,7 +1366,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
padding_mask = make_pad_mask(attention_mask.shape[1] - U + output_lengths)
|
padding_mask = make_pad_mask(attention_mask.shape[1] - U + output_lengths)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
for layer in self.emformer_layers:
|
layer_results = []
|
||||||
|
for layer_index, layer in enumerate(self.emformer_layers):
|
||||||
output, right_context = layer(
|
output, right_context = layer(
|
||||||
output,
|
output,
|
||||||
right_context,
|
right_context,
|
||||||
@ -1369,8 +1375,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
if layer_index in self.output_layers:
|
||||||
|
# (T, N, C) --> (N, T, C)
|
||||||
|
layer_results.append(output.permute(1, 0, 2))
|
||||||
|
|
||||||
return output, output_lengths
|
return layer_results, output_lengths
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def infer(
|
def infer(
|
||||||
@ -1540,6 +1549,7 @@ class Emformer(EncoderInterface):
|
|||||||
memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
middle_output_layer: int = None, # 0-based layer index
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1568,6 +1578,17 @@ class Emformer(EncoderInterface):
|
|||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
output_layers = []
|
||||||
|
if middle_output_layer is not None:
|
||||||
|
assert (
|
||||||
|
middle_output_layer >= 0
|
||||||
|
and middle_output_layer < num_encoder_layers
|
||||||
|
), f"Invalid middle output layer"
|
||||||
|
output_layers.append(middle_output_layer)
|
||||||
|
|
||||||
|
# The last layer is always needed.
|
||||||
|
output_layers.append(num_encoder_layers - 1)
|
||||||
|
|
||||||
self.encoder = EmformerEncoder(
|
self.encoder = EmformerEncoder(
|
||||||
chunk_length=chunk_length // subsampling_factor,
|
chunk_length=chunk_length // subsampling_factor,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
@ -1582,6 +1603,7 @@ class Emformer(EncoderInterface):
|
|||||||
memory_size=memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
output_layers=output_layers, # for distillation
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1619,9 +1641,7 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = (((x_lens - 1) >> 1) - 1) >> 1
|
x_lens = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C)
|
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C)
|
||||||
|
|
||||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
|
||||||
|
|
||||||
return output, output_lengths
|
return output, output_lengths
|
||||||
|
|
||||||
|
@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut, MonoCut
|
||||||
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
@ -165,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Number of entries in the memory for the Emformer",
|
help="Number of entries in the memory for the Emformer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-distillation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to eanble distillation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distillation-layer",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="On which encoder layer to perform KD"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-codebooks",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of codebooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# distillation related args
|
||||||
|
parser.add_argument(
|
||||||
|
"--distil-delta",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Offset when doing KD"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--codebook-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The scale of codebook loss.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -408,6 +444,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
|
"frame_shift_ms": 10.0,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
@ -446,6 +483,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
left_context_length=params.left_context_length,
|
left_context_length=params.left_context_length,
|
||||||
right_context_length=params.right_context_length,
|
right_context_length=params.right_context_length,
|
||||||
memory_size=params.memory_size,
|
memory_size=params.memory_size,
|
||||||
|
middle_output_layer=params.distillation_layer
|
||||||
|
if params.enable_distillation
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -483,6 +523,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
|
||||||
|
distil_delta=params.distil_delta if params.enable_distillation else 0,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -602,6 +644,19 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
def extract_codebook_indexes(batch):
|
||||||
|
cuts = batch["supervisions"]["cut"]
|
||||||
|
# -100 is identical to ignore_value in CE loss computation.
|
||||||
|
cuts_pre_mixed = [
|
||||||
|
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||||
|
]
|
||||||
|
for cut in cuts_pre_mixed:
|
||||||
|
cb = cut.codebook_indexes
|
||||||
|
print(f"All cuts have codebook indexes")
|
||||||
|
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||||
|
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||||
|
)
|
||||||
|
return codebook_indexes, codebook_indexes_lens
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -642,8 +697,14 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||||
|
codebook_indexes = codebook_indexes.to(device)
|
||||||
|
else:
|
||||||
|
codebook_indexes = None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, codebook_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -651,6 +712,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
codebook_indexes=codebook_indexes,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -661,6 +723,10 @@ def compute_loss(
|
|||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
assert codebook_loss is not None
|
||||||
|
loss += params.codebook_loss_scale * codebook_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -681,6 +747,8 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -894,6 +962,11 @@ def run(rank, world_size, args):
|
|||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
|
|
||||||
|
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||||
|
# when doing distillation with vq.
|
||||||
|
if params.enable_distillation:
|
||||||
|
assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD"
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
else:
|
else:
|
||||||
@ -959,10 +1032,10 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
else:
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -992,14 +1065,14 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
sp=sp,
|
# sp=sp,
|
||||||
params=params,
|
# params=params,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||||
|
# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -40,6 +41,8 @@ class Transducer(nn.Module):
|
|||||||
decoder_dim: int,
|
decoder_dim: int,
|
||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
|
num_codebooks: int = 0,
|
||||||
|
distil_delta: int=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -69,6 +72,16 @@ class Transducer(nn.Module):
|
|||||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
self.distil_delta = distil_delta
|
||||||
|
|
||||||
|
if num_codebooks > 0:
|
||||||
|
self.codebook_loss_net = JointCodebookLoss(
|
||||||
|
predictor_channels=encoder_dim,
|
||||||
|
num_codebooks=num_codebooks,
|
||||||
|
is_joint=False,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -80,6 +93,7 @@ class Transducer(nn.Module):
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
delay_penalty: float = 0.0,
|
delay_penalty: float = 0.0,
|
||||||
|
codebook_indexes: torch.Tensor = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -112,6 +126,8 @@ class Transducer(nn.Module):
|
|||||||
streaming models to emit symbols earlier.
|
streaming models to emit symbols earlier.
|
||||||
See https://github.com/k2-fsa/k2/issues/955 and
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
|
codebook_indexes:
|
||||||
|
codebook_indexes extracted from a teacher model.
|
||||||
Returns:
|
Returns:
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
@ -129,7 +145,35 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||||
|
encoder_out = layer_results[-1] # the last item is the final output
|
||||||
|
|
||||||
|
middle_layer_output = layer_results[0]
|
||||||
|
if self.training and codebook_indexes is not None:
|
||||||
|
assert hasattr(self, "codebook_loss_net")
|
||||||
|
# due to different subsampling ratio between hubert teacher and emformer
|
||||||
|
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
|
||||||
|
codebook_indexes = self.concat_successive_codebook_indexes(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
)
|
||||||
|
if self.distil_delta is not None:
|
||||||
|
N = codebook_indexes.shape[0]
|
||||||
|
T = codebook_indexes.shape[1]
|
||||||
|
cur_distil_delta = self.distil_delta
|
||||||
|
# align (teacher) with (student + self.distill_delta)
|
||||||
|
# suppose self.distil_delta == 2
|
||||||
|
unvalid_teacher_mask = codebook_indexes == -100
|
||||||
|
# 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8
|
||||||
|
codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :]
|
||||||
|
unvalid_teacher_mask[:, :cur_distil_delta] = True
|
||||||
|
codebook_indexes.masked_fill_(unvalid_teacher_mask, -100)
|
||||||
|
# --> -100, -100, 1,2,3,4,5,6,-100,-100
|
||||||
|
codebook_loss = self.codebook_loss_net(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# when codebook index is not available.
|
||||||
|
codebook_loss = None
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
@ -204,4 +248,32 @@ class Transducer(nn.Module):
|
|||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, codebook_loss)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def concat_successive_codebook_indexes(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
):
|
||||||
|
# Output rate of hubert is 50 frames per second,
|
||||||
|
# while that of current encoder is 25.
|
||||||
|
# Following code handling two issues:
|
||||||
|
# 1.
|
||||||
|
# Roughly speaking, to generate another frame output,
|
||||||
|
# hubert needes extra two frames,
|
||||||
|
# while current encoder needs extra four frames.
|
||||||
|
# Suppose there are only extra three frames provided,
|
||||||
|
# hubert will generate another frame while current encoder does nothing.
|
||||||
|
# 2.
|
||||||
|
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
||||||
|
# learns from 50 frames teacher output, two successive frames of teacher model
|
||||||
|
# output is concatenated together.
|
||||||
|
t_expected = middle_layer_output.shape[1]
|
||||||
|
N, T, C = codebook_indexes.shape
|
||||||
|
assert T >= t_expected, (T, t_expected)
|
||||||
|
# Handling issue 1.
|
||||||
|
if T >= t_expected * 2:
|
||||||
|
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
||||||
|
# Handling issue 2.
|
||||||
|
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
|
||||||
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||||
|
return codebook_indexes
|
Loading…
x
Reference in New Issue
Block a user