mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add text embeddings, but use actual text for now
This commit is contained in:
parent
fa696e919b
commit
1ab2a4c662
@ -24,7 +24,7 @@ from encoder_interface import EncoderInterface
|
|||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
class PromptedTransducer(nn.Module):
|
class PromptedTransducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
@ -90,8 +90,8 @@ class PromptedTransducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
style_lens: torch.Tensor,
|
|
||||||
text_lens: torch.Tensor,
|
text_lens: torch.Tensor,
|
||||||
|
style_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
@ -111,14 +111,14 @@ class PromptedTransducer(nn.Module):
|
|||||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||||
It is exptected to contain the style prompt (first) and then the content
|
It is exptected to contain the style prompt (first) and then the content
|
||||||
prompt.
|
prompt.
|
||||||
style_lens:
|
|
||||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
|
||||||
within each row of `text` that correspond to the style prompt (these
|
|
||||||
are expected to come first).
|
|
||||||
text_lens:
|
text_lens:
|
||||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||||
in `text` before padding, which will include the lengths of the
|
in `text` before padding, which will include the lengths of the
|
||||||
style plus the content prompt.
|
style plus the content prompt.
|
||||||
|
style_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||||
|
within each row of `text` that correspond to the style prompt (these
|
||||||
|
are expected to come first).
|
||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
|||||||
@ -51,9 +51,10 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -68,7 +69,7 @@ from subsampling import Conv2dSubsampling
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import PromptedTransducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -635,7 +636,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = PromptedTransducer(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
text_embed=text_embed,
|
text_embed=text_embed,
|
||||||
@ -768,7 +769,36 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Encode texts as bytes and then integer tensors.
|
||||||
|
Args:
|
||||||
|
texts: the texts to encode, as a list of strings
|
||||||
|
device: the PyTorch device we want the texts on
|
||||||
|
Returns:
|
||||||
|
(text, text_lens, style_lens), where:
|
||||||
|
text: a torch.Tensor of shape (batch_size, text_len) containing integers
|
||||||
|
0 <= i < 256
|
||||||
|
text_lens: a torch.Tensor of shape (batch_size,), giving the length of each byt
|
||||||
|
sequence
|
||||||
|
style_lens: a torch.Tensor of shape (batch_size,), giving the length of each
|
||||||
|
style prompt (style prompts are supposed to come first). Since there is no
|
||||||
|
style prompt here, this is just all zeros.
|
||||||
|
"""
|
||||||
|
texts = [ bytes(s, 'UTF-8') for s in texts ]
|
||||||
|
N = len(texts)
|
||||||
|
lengths = [ len(s) for s in texts ]
|
||||||
|
max_len = max(lengths)
|
||||||
|
texts = [ s + (b'\0' * (max_len - len(s))) for s in texts ]
|
||||||
|
text = b''.join(texts) # bytes array containing all of the texts
|
||||||
|
|
||||||
|
text = torch.Tensor(numpy.frombuffer(text, dtype=numpy.uint8)).to(device)
|
||||||
|
text = text.to(dtype=torch.long)
|
||||||
|
text = text.reshape(N, max_len)
|
||||||
|
text_lens = torch.tensor(lengths).to(device)
|
||||||
|
style_lens = torch.zeros(N, dtype=torch.long, device=device)
|
||||||
|
# print(f"text={text}, text_lens={text_lens}, style_lens={style_lens}")
|
||||||
|
return text, text_lens, style_lens
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -815,10 +845,15 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
text, text_lens, style_lens = _encode_texts_as_bytes(texts, device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
text=text,
|
||||||
|
text_lens=text_lens,
|
||||||
|
style_lens=style_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
|
|||||||
@ -316,7 +316,8 @@ class Zipformer2(EncoderInterface):
|
|||||||
# setting memory to zero should be equivalent to not using the
|
# setting memory to zero should be equivalent to not using the
|
||||||
# memory input at all, since the Attention module has no biases.
|
# memory input at all, since the Attention module has no biases.
|
||||||
memory_dropout_rate = 0.05
|
memory_dropout_rate = 0.05
|
||||||
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
|
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
||||||
|
memory_dropout_rate)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
@ -354,11 +355,9 @@ class Zipformer2(EncoderInterface):
|
|||||||
# most recent output that has it present.
|
# most recent output that has it present.
|
||||||
x = get_full_dim_output()
|
x = get_full_dim_output()
|
||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
|
||||||
assert self.output_downsampling_factor == 2
|
d = self.output_downsampling_factor
|
||||||
with warnings.catch_warnings():
|
lengths = (x_lens + d - 1) // d
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
lengths = (x_lens + 1) // 2
|
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user