Add text embeddings, but use actual text for now

This commit is contained in:
Daniel Povey 2023-05-01 22:09:27 +08:00
parent fa696e919b
commit 1ab2a4c662
3 changed files with 49 additions and 15 deletions

View File

@ -24,7 +24,7 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
class PromptedTransducer(nn.Module): class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
@ -90,8 +90,8 @@ class PromptedTransducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
text: torch.Tensor, text: torch.Tensor,
style_lens: torch.Tensor,
text_lens: torch.Tensor, text_lens: torch.Tensor,
style_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
@ -111,14 +111,14 @@ class PromptedTransducer(nn.Module):
A 2-D tensor of integer dtype containing prompt text, of shape (N, T). A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content It is exptected to contain the style prompt (first) and then the content
prompt. prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
text_lens: text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes) A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the in `text` before padding, which will include the lengths of the
style plus the content prompt. style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.

View File

@ -51,9 +51,10 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import numpy
import optim import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -68,7 +69,7 @@ from subsampling import Conv2dSubsampling
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import PromptedTransducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
@ -635,7 +636,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
model = Transducer( model = PromptedTransducer(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
text_embed=text_embed, text_embed=text_embed,
@ -768,7 +769,36 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
"""
Encode texts as bytes and then integer tensors.
Args:
texts: the texts to encode, as a list of strings
device: the PyTorch device we want the texts on
Returns:
(text, text_lens, style_lens), where:
text: a torch.Tensor of shape (batch_size, text_len) containing integers
0 <= i < 256
text_lens: a torch.Tensor of shape (batch_size,), giving the length of each byt
sequence
style_lens: a torch.Tensor of shape (batch_size,), giving the length of each
style prompt (style prompts are supposed to come first). Since there is no
style prompt here, this is just all zeros.
"""
texts = [ bytes(s, 'UTF-8') for s in texts ]
N = len(texts)
lengths = [ len(s) for s in texts ]
max_len = max(lengths)
texts = [ s + (b'\0' * (max_len - len(s))) for s in texts ]
text = b''.join(texts) # bytes array containing all of the texts
text = torch.Tensor(numpy.frombuffer(text, dtype=numpy.uint8)).to(device)
text = text.to(dtype=torch.long)
text = text.reshape(N, max_len)
text_lens = torch.tensor(lengths).to(device)
style_lens = torch.zeros(N, dtype=torch.long, device=device)
# print(f"text={text}, text_lens={text_lens}, style_lens={style_lens}")
return text, text_lens, style_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -815,10 +845,15 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
text, text_lens, style_lens = _encode_texts_as_bytes(texts, device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
text=text,
text_lens=text_lens,
style_lens=style_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,

View File

@ -316,7 +316,8 @@ class Zipformer2(EncoderInterface):
# setting memory to zero should be equivalent to not using the # setting memory to zero should be equivalent to not using the
# memory input at all, since the Attention module has no biases. # memory input at all, since the Attention module has no biases.
memory_dropout_rate = 0.05 memory_dropout_rate = 0.05
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate) memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
memory_dropout_rate)
for i, module in enumerate(self.encoders): for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i] ds = self.downsampling_factor[i]
@ -354,11 +355,9 @@ class Zipformer2(EncoderInterface):
# most recent output that has it present. # most recent output that has it present.
x = get_full_dim_output() x = get_full_dim_output()
x = self.downsample_output(x) x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2 d = self.output_downsampling_factor
with warnings.catch_warnings(): lengths = (x_lens + d - 1) // d
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths return x, lengths