Add text embeddings, but use actual text for now

This commit is contained in:
Daniel Povey 2023-05-01 22:09:27 +08:00
parent fa696e919b
commit 1ab2a4c662
3 changed files with 49 additions and 15 deletions

View File

@ -24,7 +24,7 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
@ -90,8 +90,8 @@ class PromptedTransducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
text: torch.Tensor,
style_lens: torch.Tensor,
text_lens: torch.Tensor,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
@ -111,14 +111,14 @@ class PromptedTransducer(nn.Module):
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.

View File

@ -51,9 +51,10 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import numpy
import optim
import sentencepiece as spm
import torch
@ -68,7 +69,7 @@ from subsampling import Conv2dSubsampling
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from model import PromptedTransducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch import nn
@ -635,7 +636,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
model = PromptedTransducer(
encoder_embed=encoder_embed,
encoder=encoder,
text_embed=text_embed,
@ -768,7 +769,36 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
"""
Encode texts as bytes and then integer tensors.
Args:
texts: the texts to encode, as a list of strings
device: the PyTorch device we want the texts on
Returns:
(text, text_lens, style_lens), where:
text: a torch.Tensor of shape (batch_size, text_len) containing integers
0 <= i < 256
text_lens: a torch.Tensor of shape (batch_size,), giving the length of each byt
sequence
style_lens: a torch.Tensor of shape (batch_size,), giving the length of each
style prompt (style prompts are supposed to come first). Since there is no
style prompt here, this is just all zeros.
"""
texts = [ bytes(s, 'UTF-8') for s in texts ]
N = len(texts)
lengths = [ len(s) for s in texts ]
max_len = max(lengths)
texts = [ s + (b'\0' * (max_len - len(s))) for s in texts ]
text = b''.join(texts) # bytes array containing all of the texts
text = torch.Tensor(numpy.frombuffer(text, dtype=numpy.uint8)).to(device)
text = text.to(dtype=torch.long)
text = text.reshape(N, max_len)
text_lens = torch.tensor(lengths).to(device)
style_lens = torch.zeros(N, dtype=torch.long, device=device)
# print(f"text={text}, text_lens={text_lens}, style_lens={style_lens}")
return text, text_lens, style_lens
def compute_loss(
params: AttributeDict,
@ -815,10 +845,15 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
text, text_lens, style_lens = _encode_texts_as_bytes(texts, device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
text=text,
text_lens=text_lens,
style_lens=style_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,

View File

@ -316,7 +316,8 @@ class Zipformer2(EncoderInterface):
# setting memory to zero should be equivalent to not using the
# memory input at all, since the Attention module has no biases.
memory_dropout_rate = 0.05
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
memory_dropout_rate)
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
@ -354,11 +355,9 @@ class Zipformer2(EncoderInterface):
# most recent output that has it present.
x = get_full_dim_output()
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
d = self.output_downsampling_factor
lengths = (x_lens + d - 1) // d
return x, lengths