diff --git a/egs/tedlium3/ASR/README.md b/egs/tedlium3/ASR/README.md new file mode 100644 index 000000000..57bd9458b --- /dev/null +++ b/egs/tedlium3/ASR/README.md @@ -0,0 +1,18 @@ + +# Introduction + +This recipe includes some different ASR models trained with TedLium3. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | +|------------------------|-----------|--------------------| +| `transducer_stateless` | Conformer | Embedding + Conv1d | + + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md new file mode 100644 index 000000000..0a98e54f1 --- /dev/null +++ b/egs/tedlium3/ASR/RESULTS.md @@ -0,0 +1,68 @@ +## Results + +### TedLium3 BPE training results (Transducer) + +#### Conformer encoder + embedding decoder + +Using the codes from this commit . + +Conformer encoder + non-current decoder. The decoder +contains only an embedding layer and a Conv1d (with kernel size 2). + +The WERs are + +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 7.31 | 6.73 | --epoch 71, --avg 15, --max-duration 100 | +| beam search (beam size 4) | 7.12 | 6.58 | --epoch 71, --avg 15, --max-duration 100 | +| modified beam search (beam size 4) | 7.20 | 6.65 | --epoch 71, --avg 15, --max-duration 100 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_stateless/exp \ + --max-duration 180 \ +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/DnRwoZF8RRyod4kkfG5q5Q/#scalars + +The decoding command is: +``` +epoch=29 +avg=15 + +## greedy search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 + +## beam search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +## modified beam search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 +``` diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index 13660a074..fb5c611f6 100644 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -17,7 +17,7 @@ """ -This file computes fbank features of the LibriSpeech dataset. +This file computes fbank features of the TedLium3 dataset. It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. @@ -43,7 +43,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_librispeech(): +def compute_fbank_tedlium(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -96,4 +96,4 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_librispeech() + compute_fbank_tedlium() diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py new file mode 100644 index 000000000..edbf09abe --- /dev/null +++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo) +""" +Convert a transcript based on words to a list of BPE ids with the related BPE model. + +For example, if we use 2 as the encoding id of , there are four examples: + +texts = ['this is a day and in the room there are three laying in the bed'] +spm_ids = [[38, 33, 6, 2, 316, 8, 16, 5, 257, 193, 103, 61, 331, 2, 196, 21, 14, 16, 5, 47, 12]] + +texts = [' this is a sunny day and in the room there are three people in the '] +spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316, 8, 16, 5, 257, 193, 103, 61, 331, 107, 16, 5, 2]] + +texts = [''] +spm_ids = [[2]] +""" + +import argparse +import logging +import sentencepiece as spm +from typing import List + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--texts", type=List[str], help="The input transcripts list." + ) + parser.add_argument( + "--unk-id", + type=int, + default=2, + help="The number id for the token ''.", + ) + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + return parser.parse_args() + + +def convert_texts_into_ids( + texts: List[str], + unk_id: int, + sp: spm.SentencePieceProcessor, +) -> List[int]: + """ + Args: + texts: + A string list of transcripts, such as ['Today is Monday', 'It's sunny']. + unk_id: + A number id for the token ''. + Returns: + Return a integer list of bpe ids. + """ + y = [] + for text in texts: + y_ids = [] + if "" in text: + text_segments = text.split("") + id_segments = sp.encode(text_segments, out_type=int) + for i in range(len(id_segments)): + if i != len(id_segments) - 1: + y_ids.extend(id_segments[i] + [unk_id]) + else: + y_ids.extend(id_segments[i]) + else: + y_ids = sp.encode([text], out_type=int)[0] + y.append(y_ids) + + return y + + +def main(): + args = get_args() + texts = args.texts + bpe_model = args.bpe_model + + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + unk_id = sp.piece_to_id("") + + y = convert_texts_into_ids( + texts=texts, + unk_id=unk_id, + sp=sp, + ) + logging.info(f"The input texts: {texts}") + logging.info(f"The encoding ids: {y}") + + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/local/display_manifest_statistics.py b/egs/tedlium3/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..1573dbf7e --- /dev/null +++ b/egs/tedlium3/ASR/local/display_manifest_statistics.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + +import numpy as np +from lhotse import load_manifest + + +def describe(cuts) -> None: + """ + Print a message describing details about the ``CutSet`` - the number of cuts and the + duration statistics, including the total duration and the percentage of speech segments. + + Example output: + Cuts count: 804789 + Total duration (hours): 1370.6 + Speech duration (hours): 1370.6 (100.0%) + *** + Duration statistics (seconds): + mean 6.1 + std 3.1 + min 0.5 + 25% 3.7 + 50% 6.0 + 75% 8.3 + 99.5% 14.9 + 99.9% 16.6 + max 33.3 + + In the above example, we set 15(>14.9) as the maximum duration of training samples. + """ + durations = np.array([c.duration for c in cuts]) + speech_durations = np.array( + [s.duration for c in cuts for s in c.trimmed_supervisions] + ) + total_sum = durations.sum() + speech_sum = speech_durations.sum() + print("Cuts count:", len(cuts)) + print(f"Total duration (hours): {total_sum / 3600:.1f}") + print( + f"Speech duration (hours): {speech_sum / 3600:.1f} ({speech_sum / total_sum:.1%})" + ) + print("***") + print("Duration statistics (seconds):") + print(f"mean\t{np.mean(durations):.1f}") + print(f"std\t{np.std(durations):.1f}") + print(f"min\t{np.min(durations):.1f}") + print(f"25%\t{np.percentile(durations, 25):.1f}") + print(f"50%\t{np.median(durations):.1f}") + print(f"75%\t{np.percentile(durations, 75):.1f}") + print(f"99.5%\t{np.percentile(durations, 99.5):.1f}") + print(f"99.9%\t{np.percentile(durations, 99.9):.1f}") + print(f"max\t{np.max(durations):.1f}") + + +def main(): + path = "./data/fbank/cuts_train.json.gz" + # path = "./data/fbank/cuts_dev.json.gz" + # path = "./data/fbank/cuts_test.json.gz" + + cuts = load_manifest(path) + describe(cuts) + + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 053cc3941..e1ec0c162 100644 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -151,6 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Generate data for BPE training" cat data/lang_phone/train.text | cut -d " " -f 2- > $lang_dir/transcript_words.txt + # remove the for transcript_words.txt sed -i 's/ //g' $lang_dir/transcript_words.txt sed -i 's/ //g' $lang_dir/transcript_words.txt sed -i 's///g' $lang_dir/transcript_words.txt diff --git a/egs/tedlium3/ASR/transducer_stateless/README.md b/egs/tedlium3/ASR/transducer_stateless/README.md index 964bddfab..670b2b1e0 100644 --- a/egs/tedlium3/ASR/transducer_stateless/README.md +++ b/egs/tedlium3/ASR/transducer_stateless/README.md @@ -7,7 +7,7 @@ https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 You can use the following command to start the training: ```bash -cd egs/librispeech/ASR +cd egs/tedlium3/ASR export CUDA_VISIBLE_DEVICES="0,1,2,3" @@ -16,7 +16,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 250 \ - --lr-factor 2.5 + --max-duration 180 \ + --lr-factor 5.0 ``` diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index f6c58c8db..379c9d77e 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -1,4 +1,5 @@ # Copyright 2021 Piotr Żelasko +# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py index 989caa802..77caf6460 100644 --- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,7 +18,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -43,12 +43,13 @@ def greedy_search( assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -84,7 +85,7 @@ def greedy_search( # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() - if y != blank_id: + if y != blank_id and y != unk_id: hyp.append(y) decoder_input = torch.tensor( [hyp[-context_size:]], device=device @@ -108,8 +109,9 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor @property def key(self) -> str: @@ -118,7 +120,7 @@ class Hypothesis: class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: """ Args: data: @@ -130,11 +132,10 @@ class HypothesisList(object): self._data = data @property - def data(self): + def data(self) -> Dict[str, Hypothesis]: return self._data - # def add(self, ys: List[int], log_prob: float): - def add(self, hyp: Hypothesis): + def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -146,8 +147,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -159,7 +162,8 @@ class HypothesisList(object): length_norm: If True, the `log_prob` of a hypothesis is normalized by the number of tokens in it. - + Returns: + Return the hypothesis that has the largest `log_prob`. """ if length_norm: return max( @@ -171,6 +175,9 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: hyp: The hypothesis to be removed from `self`. @@ -181,7 +188,7 @@ class HypothesisList(object): assert key in self, f"{key} does not exist" del self._data[key] - def filter(self, threshold: float) -> "HypothesisList": + def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. Caution: @@ -189,10 +196,10 @@ class HypothesisList(object): Returns: Return a new HypothesisList containing all hypotheses from `self` - that have `log_prob` being greater than the given `threshold`. + with `log_prob` being greater than the given `threshold`. """ ans = HypothesisList() - for key, hyp in self._data.items(): + for _, hyp in self._data.items(): if hyp.log_prob > threshold: ans.add(hyp) # shallow copy return ans @@ -222,6 +229,201 @@ class HypothesisList(object): return ", ".join(s) +def run_decoder( + ys: List[int], + model: Transducer, + decoder_cache: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Run the neural decoder model for a given hypothesis. + + Args: + ys: + The current hypothesis. + model: + The transducer model. + decoder_cache: + Cache to save computations. + Returns: + Return a 1-D tensor of shape (decoder_out_dim,) containing + output of `model.decoder`. + """ + context_size = model.decoder.context_size + key = "_".join(map(str, ys[-context_size:])) + if key in decoder_cache: + return decoder_cache[key] + + device = model.device + + decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[key] = decoder_out + + return decoder_out + + +def run_joiner( + key: str, + model: Transducer, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, + joint_cache: Dict[str, torch.Tensor], +): + """Run the joint network given outputs from the encoder and decoder. + + Args: + key: + A key into the `joint_cache`. + model: + The transducer model. + encoder_out: + A tensor of shape (1, 1, encoder_out_dim). + decoder_out: + A tensor of shape (1, 1, decoder_out_dim). + encoder_out_len: + A tensor with value [1]. + decoder_out_len: + A tensor with value [1]. + joint_cache: + A dict to save computations. + Returns: + Return a tensor from the output of log-softmax. + Its shape is (vocab_size,). + """ + if key in joint_cache: + return joint_cache[key] + + logits = model.joiner( + encoder_out, + decoder_out, + encoder_out_len, + decoder_out_len, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + + joint_cache[key] = log_prob + + return log_prob + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id and new_token != unk_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -247,6 +449,7 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device @@ -261,7 +464,12 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) max_sym_per_utt = 20000 @@ -281,58 +489,43 @@ def beam_search( joint_cache: Dict[str, torch.Tensor] = {} - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - while True: y_star = A.get_most_probable() A.remove(y_star) - cached_key = y_star.key + decoder_out = run_decoder( + ys=y_star.ys, model=model, decoder_cache=decoder_cache + ) - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out, - encoder_out_len, - decoder_out_len, - ) - - # TODO(fangjun): Ccale the blank posterior - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] + key = "_".join(map(str, y_star.ys[-context_size:])) + key += f"-t-{t}" + log_prob = run_joiner( + key=key, + model=model, + encoder_out=current_encoder_out, + decoder_out=decoder_out, + encoder_out_len=encoder_out_len, + decoder_out_len=decoder_out_len, + joint_cache=joint_cache, + ) # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id: + for idx in range(values.size(0)): + i = indices[idx].item() + if i == blank_id or i == unk_id: continue + new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v + + new_log_prob = y_star.log_prob + values[idx] A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable diff --git a/egs/tedlium3/ASR/transducer_stateless/conformer.py b/egs/tedlium3/ASR/transducer_stateless/conformer.py index 59a29382f..81d7708f9 100644 --- a/egs/tedlium3/ASR/transducer_stateless/conformer.py +++ b/egs/tedlium3/ASR/transducer_stateless/conformer.py @@ -615,7 +615,7 @@ class RelPositionMultiheadAttention(nn.Module): E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. - """ # noqa + """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check @@ -635,7 +635,7 @@ class RelPositionMultiheadAttention(nn.Module): elif torch.equal(key, value): # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias # noqa + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim @@ -643,7 +643,7 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias # noqa + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = None @@ -653,7 +653,7 @@ class RelPositionMultiheadAttention(nn.Module): k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) else: - # This is inline in_proj function with in_proj_weight and in_proj_bias # noqa + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim @@ -662,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias # noqa + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 @@ -671,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:_end] k = nn.functional.linear(key, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias # noqa + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = None @@ -687,12 +687,12 @@ class RelPositionMultiheadAttention(nn.Module): or attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( # noqa + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( attn_mask.dtype ) if attn_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." # noqa + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." ) attn_mask = attn_mask.to(torch.bool) @@ -725,7 +725,7 @@ class RelPositionMultiheadAttention(nn.Module): and key_padding_mask.dtype == torch.uint8 ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." # noqa + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -760,7 +760,7 @@ class RelPositionMultiheadAttention(nn.Module): # compute attention score # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) matrix_ac = torch.matmul( q_with_bias_u, k @@ -832,7 +832,7 @@ class RelPositionMultiheadAttention(nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index e5987b75e..1c1bf1fd1 100644 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,16 +20,16 @@ Usage: (1) greedy search ./transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 29 \ + --avg 15 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search ./transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 29 \ + --avg 15 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -45,8 +46,8 @@ from typing import Dict, List, Tuple import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search +from asr_datamodule import TedLiumAsrDataModule +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -77,7 +78,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=13, + default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -169,6 +170,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -256,6 +258,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -382,14 +388,18 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + TedLiumAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" @@ -413,6 +423,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -439,16 +450,12 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + tedlium = TedLiumAsrDataModule(args) + test_cuts = tedlium.test_cuts() + test_dl = tedlium.test_dataloaders(test_cuts) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["test"] + test_dl = [test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py index dca084477..d2b532904 100644 --- a/egs/tedlium3/ASR/transducer_stateless/decoder.py +++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -37,6 +38,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + unk_id: int, context_size: int, ): """ @@ -47,6 +49,8 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + unk_id: + The ID of the unk symbol. context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. @@ -58,6 +62,7 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.blank_id = blank_id + self.unk_id = unk_id assert context_size >= 1, context_size self.context_size = context_size diff --git a/egs/tedlium3/ASR/transducer_stateless/model.py b/egs/tedlium3/ASR/transducer_stateless/model.py index 7aac290d9..98a6f0f37 100644 --- a/egs/tedlium3/ASR/transducer_stateless/model.py +++ b/egs/tedlium3/ASR/transducer_stateless/model.py @@ -120,7 +120,6 @@ class Transducer(nn.Module): target_lengths=y_lens, blank=blank_id, reduction="sum", - from_log_softmax=False, ) return loss diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index e5dba8f0e..f070b2e75 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -50,7 +50,7 @@ import kaldifeat import sentencepiece as spm import torch import torchaudio -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -167,6 +167,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -230,6 +231,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(f"{params}") @@ -300,6 +302,10 @@ def main(): hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError(f"Unsupported method: {params.method}") diff --git a/egs/tedlium3/ASR/transducer_stateless/test_decoder.py b/egs/tedlium3/ASR/transducer_stateless/test_decoder.py index 3a653c1b7..906938046 100644 --- a/egs/tedlium3/ASR/transducer_stateless/test_decoder.py +++ b/egs/tedlium3/ASR/transducer_stateless/test_decoder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,7 +19,7 @@ """ To run this file, do: - cd icefall/egs/librispeech/ASR + cd icefall/egs/tedlium3/ASR python ./transducer_stateless/test_decoder.py """ @@ -29,6 +30,7 @@ from decoder import Decoder def test_decoder(): vocab_size = 3 blank_id = 0 + unk_id = 2 embedding_dim = 128 context_size = 4 @@ -36,6 +38,7 @@ def test_decoder(): vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, + unk_id=unk_id, context_size=context_size, ) N = 100 diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index e8f636768..57707f55d 100644 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -26,9 +26,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 250 \ - --lr-factor 2.5 + --max-duration 180 \ + --lr-factor 5.0 """ @@ -56,6 +55,8 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam +from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids + from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist @@ -233,6 +234,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -379,7 +381,9 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)] texts = batch["supervisions"]["text"][: feature.size(0)] - y = sp.encode(texts, out_type=int) + + unk_id = params.unk_id + y = convert_texts_into_ids(texts, unk_id, sp=sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): @@ -565,6 +569,7 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params)