masked codebook indices

This commit is contained in:
Guo Liyong 2022-05-07 17:31:16 +08:00
parent 391bd82e92
commit 0396944f9d
2 changed files with 195 additions and 1 deletions

View File

@ -0,0 +1,159 @@
from typing import Optional, Tuple
import torch
import numpy as np
# copied from: https://github.com/pytorch/fairseq/blob/c8d6fb198cd58d433cadf178b22afdba40401f13/fairseq/data/data_utils.py#L393
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len and require_same_masks:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if mask_dropout > 0:
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
mask_idc = np.random.choice(
mask_idc, len(mask_idc) - num_holes, replace=False
)
mask[i, mask_idc] = True
return mask
if __name__ == "__main__":
shape = [2, 100]
mask_prob = 0.8
mask_length = 2
mask = compute_mask_indices(shape=shape, padding_mask=None, mask_prob=mask_prob, mask_length=mask_length, no_overlap=True)
mask = torch.tensor(mask)
indices = torch.randint(0, 256, shape)
# Only the masked region is "PREDICTED",
# So unmasked region is assigned as ignore id using ~mask
masked_indices = indices.masked_fill(~mask, -100)
print(mask)
print(masked_indices)
m = indices == masked_indices
# this ratio is smaller than mask_prob,
# detailed in function compute_mask_indices
print(m.sum() / m.numel())

View File

@ -54,8 +54,10 @@ from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.sampling.base import CutSampler
from icefall.utils import make_pad_mask
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from mask import compute_mask_indices
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -272,6 +274,18 @@ def get_parser():
help="number of code books",
)
parser.add_argument(
"--mask-codebook-indices",
type=str2bool,
default="False",
)
parser.add_argument(
"--mask-prob",
type=float,
default="0.8",
)
return parser
@ -553,6 +567,7 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
info = MetricsTracker()
if is_training:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
@ -562,6 +577,27 @@ def compute_loss(
codebook_indices, codebook_indices_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indices", pad_value=-100
)
if params.mask_codebook_indices:
# codebook_loss.shape == (N, T, C)
# Only (N, T) is needed to compute mask region
shape = codebook_indices.shape[:2]
mask_length=10
# length of current encoder output is:
# lengths = ((feature_lens - 1) // 2 - 1) // 2
# output rate of hubert is 2 * times of that
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (feature_lens - 1) // 2 - 1
# True means padded frames
padding_mask = make_pad_mask(lengths)
mask = compute_mask_indices(shape, padding_mask=padding_mask,mask_prob=params.mask_prob, mask_length=10, no_overlap=True)
ori_numel = (codebook_indices != -100).sum()
codebook_indices= codebook_indices.masked_fill(~torch.tensor(mask).unsqueeze(2), -100)
masked_numel = (codebook_indices != -100).sum()
info["ori_numel"] = ori_numel
info["masked_numel"] = masked_numel
codebook_indices = codebook_indices.to(device)
else:
codebook_indices = None
@ -595,7 +631,6 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (