From 8cfc9753141164bb6873b0012e4810553f631ad4 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 29 Dec 2022 11:32:11 +0900 Subject: [PATCH] from local --- egs/librispeech/ASR/.run_adapter.sh.swp | Bin 12288 -> 12288 bytes .../ASR/distillation_with_hubert.sh | 15 +++++--- .../pruned_transducer_stateless2/scaling.py | 20 +++++------ .../scaling_converter.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 1 - .../ASR/pruned_transducer_stateless6/train.py | 9 ++--- .../pruned_transducer_stateless6/vq_utils.py | 34 +++++++++++++++--- 7 files changed, 53 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/.run_adapter.sh.swp b/egs/librispeech/ASR/.run_adapter.sh.swp index 21f9804ba3990e82fb5e75fa1a682ad02783fa1e..fb3d61dd51c6267cb9f639c9b94a393c0a94233c 100644 GIT binary patch delta 32 mcmZojXh;xGG6?hZRj|}EU;qLE1_tIIYmzyS|KBM7QXc@0+X>bH delta 32 mcmZojXh;xGG6?hZRj|}EU;qLE28L@(Rwv)x@^7Q~OML*D@Cy_G diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index d7d599161..d5d3008aa 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -23,7 +23,7 @@ # To start from scratch, you can # set stage=0, stop_stage=4, use_extracted_codebook=False -stage=2 +stage=0 stop_stage=4 # Set the GPUs available. @@ -35,8 +35,7 @@ stop_stage=4 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -#export CUDA_VISIBLE_DEVICES="0,1,2,3" -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +# export CUDA_VISIBLE_DEVICES="0,1,2,3" exp_dir=./pruned_transducer_stateless6/exp mkdir -p $exp_dir @@ -50,7 +49,7 @@ full_libri=False # "True" -> stage 0 and stage 1 would be skipped, # and directly download the extracted codebook indexes for distillation # "False" -> start from scratch -use_extracted_codebook=False +use_extracted_codebook=True # teacher_model_id can be one of # "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. @@ -156,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi log "Downloading extracted codebook indexes to $codebook_download_dir" # Make sure you have git-lfs installed (https://git-lfs.github.com) + # The codebook indexes are generated using lhotse 1.11.0, to avoid + # potential issues, we recommend you to use lhotse version >= 1.11.0 + lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))") + if [ "$lhotse_version" == "False" ]; then + log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch." + fi git lfs install - git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir mkdir -p data/vq_fbank mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c802ecf89..963ebdc2d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: if random.random() >= self.balance_prob: return x - else: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor / self.balance_prob, - self.min_abs, - self.max_abs, - ) + + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor / self.balance_prob, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index b712eeda0..a6540c584 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled( if not inplace: model = copy.deepcopy(model) - excluded_patterns = r"self_attn\.(in|out)_proj" + excluded_patterns = r"(self|src)_attn\.(in|out)_proj" p = re.compile(excluded_patterns) d = {} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index f3c6df3ff..847c80ab0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -965,7 +965,6 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - logging.info(model) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 9600a8c3c..57753599a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -929,7 +929,6 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - logging.info(model) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -986,11 +985,9 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - ''' logging.warning( f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) - ''' return False # In pruned RNN-T, we require that T >= S @@ -1001,9 +998,8 @@ def run(rank, world_size, args): # for subsampling T = ((c.num_frames - 1) // 2 - 1) // 2 tokens = sp.encode(c.supervisions[0].text, out_type=str) - + if T < len(tokens): - ''' logging.warning( f"Exclude cut with ID {c.id} from training. " f"Number of frames (before subsampling): {c.num_frames}. " @@ -1012,9 +1008,8 @@ def run(rank, world_size, args): f"Tokens: {tokens}. " f"Number of tokens: {len(tokens)}" ) - ''' return False - + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 97a83b974..bf072d865 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -244,10 +244,36 @@ class CodebookIndexExtractor: ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) - cuts_vq = cuts_vq.sort_like(cuts_ori) - for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): - assert cut_vq.id == cut_ori.id - cut_ori.codebook_indexes = cut_vq.codebook_indexes + assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!" + + if set(cuts_vq.ids) == set(cuts_ori.ids): + # IDs match exactly + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id) + cut_ori.codebook_indexes = cut_vq.codebook_indexes + else: + # in case of ID mismatch, remap them + # get the mapping between audio and cut ID + logging + ori_id_map = {} + for id in cuts_ori.ids: + # some text normalization + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + ori_id_map[clean_id] = id + + for id in cuts_vq.ids: + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + assert clean_id in ori_id_map, clean_id + cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[ + id + ].codebook_indexes CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) logging.info(f"Processed {subset}.")