mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add text embeddings, but use actual text for now
This commit is contained in:
parent
fa696e919b
commit
1ab2a4c662
@ -24,7 +24,7 @@ from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
@ -90,8 +90,8 @@ class PromptedTransducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
@ -111,14 +111,14 @@ class PromptedTransducer(nn.Module):
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
|
||||
@ -51,9 +51,10 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -68,7 +69,7 @@ from subsampling import Conv2dSubsampling
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from model import PromptedTransducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
@ -635,7 +636,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
model = Transducer(
|
||||
model = PromptedTransducer(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
text_embed=text_embed,
|
||||
@ -768,7 +769,36 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Encode texts as bytes and then integer tensors.
|
||||
Args:
|
||||
texts: the texts to encode, as a list of strings
|
||||
device: the PyTorch device we want the texts on
|
||||
Returns:
|
||||
(text, text_lens, style_lens), where:
|
||||
text: a torch.Tensor of shape (batch_size, text_len) containing integers
|
||||
0 <= i < 256
|
||||
text_lens: a torch.Tensor of shape (batch_size,), giving the length of each byt
|
||||
sequence
|
||||
style_lens: a torch.Tensor of shape (batch_size,), giving the length of each
|
||||
style prompt (style prompts are supposed to come first). Since there is no
|
||||
style prompt here, this is just all zeros.
|
||||
"""
|
||||
texts = [ bytes(s, 'UTF-8') for s in texts ]
|
||||
N = len(texts)
|
||||
lengths = [ len(s) for s in texts ]
|
||||
max_len = max(lengths)
|
||||
texts = [ s + (b'\0' * (max_len - len(s))) for s in texts ]
|
||||
text = b''.join(texts) # bytes array containing all of the texts
|
||||
|
||||
text = torch.Tensor(numpy.frombuffer(text, dtype=numpy.uint8)).to(device)
|
||||
text = text.to(dtype=torch.long)
|
||||
text = text.reshape(N, max_len)
|
||||
text_lens = torch.tensor(lengths).to(device)
|
||||
style_lens = torch.zeros(N, dtype=torch.long, device=device)
|
||||
# print(f"text={text}, text_lens={text_lens}, style_lens={style_lens}")
|
||||
return text, text_lens, style_lens
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
@ -815,10 +845,15 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
text, text_lens, style_lens = _encode_texts_as_bytes(texts, device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
text=text,
|
||||
text_lens=text_lens,
|
||||
style_lens=style_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
|
||||
@ -316,7 +316,8 @@ class Zipformer2(EncoderInterface):
|
||||
# setting memory to zero should be equivalent to not using the
|
||||
# memory input at all, since the Attention module has no biases.
|
||||
memory_dropout_rate = 0.05
|
||||
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
|
||||
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
||||
memory_dropout_rate)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
@ -354,11 +355,9 @@ class Zipformer2(EncoderInterface):
|
||||
# most recent output that has it present.
|
||||
x = get_full_dim_output()
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
lengths = (x_lens + 1) // 2
|
||||
|
||||
d = self.output_downsampling_factor
|
||||
lengths = (x_lens + d - 1) // d
|
||||
|
||||
return x, lengths
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user