diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/mask.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/mask.py new file mode 100644 index 000000000..6983f548f --- /dev/null +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/mask.py @@ -0,0 +1,159 @@ +from typing import Optional, Tuple +import torch +import numpy as np +# copied from: https://github.com/pytorch/fairseq/blob/c8d6fb198cd58d433cadf178b22afdba40401f13/fairseq/data/data_utils.py#L393 +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = np.random.choice( + mask_idc, len(mask_idc) - num_holes, replace=False + ) + + mask[i, mask_idc] = True + + return mask + + +if __name__ == "__main__": + shape = [2, 100] + mask_prob = 0.8 + mask_length = 2 + mask = compute_mask_indices(shape=shape, padding_mask=None, mask_prob=mask_prob, mask_length=mask_length, no_overlap=True) + mask = torch.tensor(mask) + + indices = torch.randint(0, 256, shape) + # Only the masked region is "PREDICTED", + # So unmasked region is assigned as ignore id using ~mask + masked_indices = indices.masked_fill(~mask, -100) + print(mask) + + print(masked_indices) + m = indices == masked_indices + # this ratio is smaller than mask_prob, + # detailed in function compute_mask_indices + print(m.sum() / m.numel()) + diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py index d976c5533..3fc2a5a9b 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py @@ -54,8 +54,10 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut, MonoCut from lhotse.dataset.sampling.base import CutSampler +from icefall.utils import make_pad_mask from lhotse.utils import fix_random_seed from lhotse.dataset.collation import collate_custom_field +from mask import compute_mask_indices from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -272,6 +274,18 @@ def get_parser(): help="number of code books", ) + parser.add_argument( + "--mask-codebook-indices", + type=str2bool, + default="False", + ) + + parser.add_argument( + "--mask-prob", + type=float, + default="0.8", + ) + return parser @@ -553,6 +567,7 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + info = MetricsTracker() if is_training: cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. @@ -562,6 +577,27 @@ def compute_loss( codebook_indices, codebook_indices_lens = collate_custom_field( cuts_pre_mixed, "codebook_indices", pad_value=-100 ) + if params.mask_codebook_indices: + # codebook_loss.shape == (N, T, C) + # Only (N, T) is needed to compute mask region + shape = codebook_indices.shape[:2] + mask_length=10 + + # length of current encoder output is: + # lengths = ((feature_lens - 1) // 2 - 1) // 2 + # output rate of hubert is 2 * times of that + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (feature_lens - 1) // 2 - 1 + # True means padded frames + padding_mask = make_pad_mask(lengths) + mask = compute_mask_indices(shape, padding_mask=padding_mask,mask_prob=params.mask_prob, mask_length=10, no_overlap=True) + ori_numel = (codebook_indices != -100).sum() + codebook_indices= codebook_indices.masked_fill(~torch.tensor(mask).unsqueeze(2), -100) + masked_numel = (codebook_indices != -100).sum() + info["ori_numel"] = ori_numel + info["masked_numel"] = masked_numel + codebook_indices = codebook_indices.to(device) else: codebook_indices = None @@ -595,7 +631,6 @@ def compute_loss( assert loss.requires_grad == is_training - info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (