From 396aaefbaa19e137baf5dbc15e57e0b59b75a129 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Fri, 11 Mar 2022 13:44:46 +0800 Subject: [PATCH] update codes --- egs/librispeech/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/transducer_stateless/joiner.py | 50 +++------ .../ASR/transducer_stateless/model.py | 102 +++++++++++------- .../ASR/local/compute_fbank_tedlium.py | 4 +- egs/tedlium3/ASR/prepare.sh | 6 +- .../asr_datamodule.py | 3 +- .../ASR/pruned_transducer_stateless/train.py | 6 +- 7 files changed, 88 insertions(+), 84 deletions(-) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index cf32f308d..c907cdccb 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -90,6 +90,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] + print(pieces, token2id) pieces = [token2id[i] for i in pieces] for i in range(len(pieces) - 1): diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 9fd9da4f1..7c5a93a86 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,57 +16,35 @@ import torch import torch.nn as nn +import torch.nn.functional as F class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, input_dim: int, inner_dim: int, output_dim: int): super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.output_linear = nn.Linear(input_dim, output_dim) + self.inner_linear = nn.Linear(input_dim, inner_dim) + self.output_linear = nn.Linear(inner_dim, output_dim) def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - encoder_out_len: torch.Tensor, - decoder_out_len: torch.Tensor, + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor ) -> torch.Tensor: """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, self.input_dim). + Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: - Output from the decoder. Its shape is (N, U, self.input_dim). + Output from the decoder. Its shape is (N, T, s_range, C). Returns: - Return a tensor of shape (sum_all_TU, self.output_dim). + Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 3 - assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == self.input_dim - assert decoder_out.size(2) == self.input_dim + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape - N = encoder_out.size(0) + logit = encoder_out + decoder_out - encoder_out_list = [ - encoder_out[i, : encoder_out_len[i], :] for i in range(N) - ] + logit = self.inner_linear(torch.tanh(logit)) - decoder_out_list = [ - decoder_out[i, : decoder_out_len[i], :] for i in range(N) - ] + output = self.output_linear(F.relu(logit)) - x = [ - e.unsqueeze(1) + d.unsqueeze(0) - for e, d in zip(encoder_out_list, decoder_out_list) - ] - - x = [p.reshape(-1, self.input_dim) for p in x] - x = torch.cat(x) - - activations = torch.tanh(x) - - logits = self.output_linear(activations) - - return logits + return output diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 8281e1fb5..2f019bcdb 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import k2 import torch @@ -64,7 +63,9 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, - modified_transducer_prob: float = 0.0, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, ) -> torch.Tensor: """ Args: @@ -76,10 +77,23 @@ class Transducer(nn.Module): y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. - modified_transducer_prob: - The probability to use modified transducer loss. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part Returns: Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -97,47 +111,59 @@ class Transducer(nn.Module): blank_id = self.decoder.blank_id sos_y = add_sos(y, sos_id=blank_id) + # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - sos_y_padded = sos_y_padded.to(torch.int64) + # decoder_out: [B, S + 1, C] decoder_out = self.decoder(sos_y_padded) - # +1 here since a blank is prepended to each utterance. - logits = self.joiner( - encoder_out=encoder_out, - decoder_out=decoder_out, - encoder_out_len=x_lens, - decoder_out_len=y_lens + 1, - ) - - # rnnt_loss requires 0 padded targets # Note: y does not start with SOS + # y_padded : [B, S] y_padded = y.pad(mode="constant", padding_value=0) - # We don't put this `import` at the beginning of the file - # as it is required only in the training, not during the - # reference stage - import optimized_transducer + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens - assert 0 <= modified_transducer_prob <= 1 - - if modified_transducer_prob == 0: - one_sym_per_frame = False - elif random.random() < modified_transducer_prob: - # random.random() returns a float in the range [0, 1) - one_sym_per_frame = True - else: - one_sym_per_frame = False - - loss = optimized_transducer.transducer_loss( - logits=logits, - targets=y_padded, - logit_lengths=x_lens, - target_lengths=y_lens, - blank=blank_id, + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=decoder_out, + am=encoder_out, + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, reduction="sum", - one_sym_per_frame=one_sym_per_frame, - from_log_softmax=False, + return_grad=True, ) - return loss + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, C] + # lm_pruned : [B, T, prune_range, C] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=encoder_out, lm=decoder_out, ranges=ranges + ) + + # logits : [B, T, prune_range, C] + logits = self.joiner(am_pruned, lm_pruned) + + pruned_loss = k2.rnnt_loss_pruned( + logits=logits, + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index 915197594..2e1830626 100644 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -71,6 +71,8 @@ def compute_fbank_tedlium(): recordings=m["recordings"], supervisions=m["supervisions"], ) + # Split long cuts into many short and un-overlapping cuts + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) if "train" in partition: cut_set = ( cut_set @@ -85,8 +87,6 @@ def compute_fbank_tedlium(): executor=ex, storage_type=ChunkedLilcomHdf5Writer, ) - # Split long cuts into many short and un-overlapping cuts - cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 4f2269430..7f9a9b25d 100644 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -29,9 +29,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + #5000 + #2000 + #1000 500 ) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py index 97dc05775..518eadca1 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py @@ -71,7 +71,7 @@ class TedLiumAsrDataModule: group.add_argument( "--manifest-dir", type=Path, - default=Path("data/fbank"), + default=Path("data/fbank_overlap_false"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( @@ -348,7 +348,6 @@ class TedLiumAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - print(self.args.manifest_dir) return load_manifest(self.args.manifest_dir / "cuts_train.json.gz") @lru_cache() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index 4cb324a89..70ed0691e 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -77,7 +77,7 @@ def get_parser(): parser.add_argument( "--master-port", type=int, - default=12354, + default=12350, help="Master port to use for DDP training.", ) @@ -108,7 +108,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless/exp-4-gpus-300", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -658,7 +658,7 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and max seconds # Here, we set max as 20.0. # If you want to use a big max-duration, you can set it as 17.0. - return 1.0 <= c.duration <= 20.0 + return 1.0 <= c.duration <= 17.0 num_in_total = len(train_cuts)