remove/rename files

This commit is contained in:
marcoyang1998 2023-09-15 10:54:58 +08:00
parent 2f4eb18466
commit ae2c7c73f6
9 changed files with 202 additions and 8050 deletions

File diff suppressed because it is too large Load Diff

View File

@ -31,10 +31,6 @@ from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
Note that this is a PromptedTransducer, meaning that the transducer is able to decode
with prompts.
It has a text encoder of BERT type model.
This transducer also has a special context fuser.
"""
def __init__(
@ -105,13 +101,8 @@ class PromptedTransducer(nn.Module):
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser
assert text_encoder_type in (
"BERT",
"DistilBERT",
"BERT-UNCASED",
"BERT-LARGE-UNCASED",
), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED", "BERT-LARGE-UNCASED") else self.text_encoder.config.dim
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential(
@ -120,12 +111,15 @@ class PromptedTransducer(nn.Module):
)
else:
self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
encoded_inputs: Dict,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
@ -189,7 +183,10 @@ class PromptedTransducer(nn.Module):
# freeze the BERT text encoder
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(encoded_inputs)
memory, memory_key_padding_mask = self.encode_text(
encoded_inputs,
style_lens=style_lens
)
else:
memory = None
memory_key_padding_mask = None
@ -279,12 +276,7 @@ class PromptedTransducer(nn.Module):
else:
context = None
logits = self.joiner(
am_pruned,
lm_pruned,
context=context,
project_input=False,
)
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
@ -305,8 +297,8 @@ class PromptedTransducer(nn.Module):
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
@ -318,13 +310,14 @@ class PromptedTransducer(nn.Module):
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
return memory + extra_term
def encode_text(
self,
encoded_inputs: Dict,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
@ -335,18 +328,21 @@ class PromptedTransducer(nn.Module):
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask
"""
text_lens = encoded_inputs["attention_mask"].sum(1)
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
# Freeze the pre-trained text encoder
with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2)
memory_key_padding_mask = make_pad_mask(text_lens)
# Text encoder adapter
if self.text_encoder_adapter is not None:
memory = self.text_encoder_adapter(memory)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask
def encode_audio(

View File

@ -1,382 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
text_encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
use_BERT: bool = True,
text_encoder_type: str = "BERT",
text_encoder_adapter: bool = False,
context_fuser: nn.Module = None,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
text_encoder:
This is a encoder that processes text information (e.g content prompt
and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
text_encoder_type:
The type of the text_encoder. Supported are (BERT, DistilBERT)
context_fuser
A optional module that fuses the embeddings of text encoder. The fused embedding
will be added to the joiner.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.text_encoder = text_encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential(
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
nn.Tanh(),
)
else:
self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
encoded_inputs: Dict,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_pre_text: bool = True,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
text:
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# freeze the BERT text encoder
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
encoded_inputs,
style_lens=style_lens
)
else:
memory = None
memory_key_padding_mask = None
encoder_out, x_lens = self.encoder(
x,
x_lens,
src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
if self.context_fuser is not None and memory is not None:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
context = self.joiner.context_proj(context)
else:
context = None
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
< style_lens
)
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
return memory + extra_term
def encode_text(
self,
encoded_inputs: Dict,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
Args:
encoded_inputs: The encoded inputs generated by a tokenizer (Dict)
Returns:
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask
"""
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
# Freeze the pre-trained text encoder
with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2)
# Text encoder adapter
if self.text_encoder_adapter is not None:
memory = self.text_encoder_adapter(memory)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask
def encode_audio(
self,
feature: Tensor,
feature_lens: Tensor,
memory: Optional[Tensor],
memory_key_padding_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
"""Encode the input audio features
Args:
feature (Tensor): Input audio (N,T,C)
feature_lens (Tensor): Length of input audio (N,)
memory (Tensor): Embeddings from the text encoder
memory_key_padding_mask (Tensor): _description_
Returns:
Tuple[Tensor, Tensor]: _description_
"""
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x=x,
x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
Transducer = PromptedTransducer # for decoding

View File

@ -61,13 +61,22 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
from dataset2 import (
triplet_text_sampling,
triplet_text_sampling_with_context_list,
naive_triplet_text_sampling,
random_shuffle_subset,
joint_triplet_text_sampling,
triplet_style_text_sampling,
)
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model_with_BERT import PromptedTransducer
from model_with_BERT_with_style import PromptedTransducer
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
from subsampling import Conv2dSubsampling
@ -107,11 +116,6 @@ style_transforms = [
lower_all_char,
]
rare_words_file = "data/context_biasing/small_rare_words_5.txt"
with open(rare_words_file, "r") as f:
rare_words = f.read()
rare_words_list = rare_words.split("\n")
def random_sampling(texts: List[str]) -> str:
return random.choice(texts)
@ -126,18 +130,6 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
}
return out
def joint_random_sampling_mixed_recog(texts: List[str], pre_texts: List[str]) -> str:
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
i = random.randint(0, 1)
trans = style_transforms[i]
out = {
"text": trans(texts[0]),
"pre_text": trans(pre_texts[0]),
"style_text": "",
"transform_ids": i,
}
return out
def get_first(texts: List[str], pre_texts: List[str]) -> str:
out = {
"text": texts[0],
@ -153,159 +145,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_multiple_ref_texts(texts: List[str], pre_texts: List[str]) -> str:
# Choose between the first and the last one in texts (gt and decoding results)
# But return the upper_only_alpha version
i = random.sample([0,2], 1)[0]
out = {
"text": upper_only_alpha(texts[i]), # either the first or the last
"pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "",
"transform_ids": i,
}
return out
def get_upper_only_alpha_with_multiple_pre_texts(texts: List[str], pre_texts: List[str]) -> str:
# Choose between the first and the last one in texts (gt and decoding results)
# But return the upper_only_alpha version
v = random.random()
if v < 0.5: # The normal case
out = {
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "",
"transform_ids": 0,
}
else: # Use the decoded output as pre_text
out = {
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(texts[2]),
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_random_ref_text(texts: List[str], pre_texts: List[str]) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if random.random() < 0.1:
if random.random() < 0.5:
pre_text = get_substring(text, min_len=15, max_len=80)
else:
pre_text = text.split()
random.shuffle(pre_text) # shuffle the words
i = random.randint(5, 20) # random sample the number of words to be included
pre_text = " ".join(pre_text[:i])
else:
pre_text = upper_only_alpha(pre_texts[0])
out = {
"text": text,
"pre_text": pre_text,
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_random_ref_text_v2(
texts: List[str],
pre_texts: List[str],
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if random.random() < 0.5 and len(text.split()) > 8:
v = random.random()
if v < 0.4: # Use phrases from ref_text as content prompt
splitted = text.split()
num_phrases = numpy.random.randint(3) + 1 # 1 to 3 context phrases
start_pos = numpy.random.choice(len(splitted) - 3, num_phrases, replace=False)
phrases = [" ".join(splitted[start_pos[i]: start_pos[i]+random.randint(0,4) + 1]) for i in range(num_phrases)]
num_distractors = random.randint(0,60)
distractors = random.sample(rare_words_list, num_distractors)
phrases += distractors
random.shuffle(phrases)
pre_text = " ".join(phrases)
elif v < 0.8: # Use random discrete words
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 8))
splitted = list(numpy.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,60)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the word list
pre_text = " ".join(splitted)
else:
pre_text = get_substring(text, min_len=40, max_len=120)
else:
pre_text = pre_texts[0]
out = {
"text": text,
"pre_text": upper_only_alpha(pre_text),
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_context_list(
texts: List[str],
pre_texts: List[str],
context_list: str,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if context_list != "":
if random.random() < 0.5:
# correct + distractors
# sample distractors
num_distractors = random.randint(0, 50)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
else:
pre_text = upper_only_alpha(pre_texts[0])
else:
v = random.random()
if v < 0.1:
splitted = text.split()
random.shuffle(splitted)
i = random.randint(5, 20)
splitted = splitted[:i]
pre_text = " ".join(splitted)
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=80)
else:
pre_text = upper_only_alpha(pre_texts[0])
out = {
"text": text,
"pre_text": pre_text,
"style_text": "",
"transform_ids": 1,
"transform_ids": 0,
}
return out
@ -373,13 +213,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="By which probability, dropout the memory when doing cross-attention."
)
parser.add_argument(
"--memory-dim",
type=int,
default=768,
help="The embedding dimension of the text encoder"
)
parser.add_argument(
"--memory-layer",
type=int,
@ -430,14 +263,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--decoder-dim",
@ -455,6 +280,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
to this dimension before adding.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--causal",
@ -484,7 +317,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--text-encoder-type",
type=str,
default="BERT",
choices=["BERT","DistilBERT","BERT-UNCASED", "BERT-LARGE-UNCASED"],
choices=["BERT","DistilBERT"],
help="Type of the text encoder",
)
@ -733,6 +566,18 @@ def get_parser():
default=0.05,
help="The probability of masking prompts",
)
parser.add_argument(
"--freeze-text-encoder",
type=str2bool,
default=True,
)
parser.add_argument(
"--forced-upper-pre-text",
type=str2bool,
default=False,
help="Forced format of pre-text",
)
add_model_arguments(parser)
@ -825,30 +670,89 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
)
return encoder_embed
class TextEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int=256,
embedding_dim: int=256,
kernel_size: int=3,
layer1_channels: int = 256,
layer2_channels: int = 256,
bias: bool=True,
dropout: float = 0.1
):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
embedding_dim=embedding_dim, #
)
assert embedding_dim == layer1_channels # for depth wise convolution
self.conv = nn.Sequential(
nn.Conv1d(
embedding_dim,
layer1_channels, # depthwise convolution
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=layer1_channels,
bias=True,
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
nn.ReLU(),
nn.Conv1d(
layer1_channels,
layer2_channels,
kernel_size=1, # pointwise convolution
stride=1,
padding=0,
bias=True,
),
Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
nn.ReLU(),
)
self.out_norm = BiasNorm(layer2_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, text: torch.Tensor) -> torch.Tensor:
"""Forward function of the text embedding
Args:
text (torch.Tensor): Text in UTF-8 bytes (T,N)
Returns:
The embeddings of text (T,N,C)
"""
text = self.embed(text) # (T,N,C)
#src = text
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
text = self.conv(text)
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
#src = src + text
text = self.out_norm(text)
text = self.dropout(text)
return text
def get_text_encoder(params: AttributeDict) -> nn.Module:
# Return a text encoder
if params.text_encoder_type == "BERT": # # This is a BERT-base-cased
from transformers import BertModel
assert params.memory_dim == 768
if params.text_encoder_type == "BERT":
from transformers import BertModel
# This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "BERT-UNCASED": # This is a BERT-base-uncased
from transformers import BertModel
assert params.memory_dim == 768
logging.info("Loading pre-trained BERT-base-uncased as text encoder")
model = BertModel.from_pretrained("bert-base-uncased")
elif params.text_encoder_type == "BERT-LARGE-UNCASED": # This is a BERT-large-uncased
from transformers import BertModel
assert params.memory_dim == 1024
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
model = BertModel.from_pretrained("bert-large-uncased")
elif params.text_encoder_type == "DistilBERT": # This is a DistilBERT-base-cased
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel
assert params.memory_dim == 768
# This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
else:
raise ValueError(f"Unknown text encoder type: {params.text_encoder_type}")
raise ValueError()
return model
@ -858,14 +762,6 @@ def get_tokenizer(params: AttributeDict):
from transformers import BertTokenizer
# This is a BERT-base-cased
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
elif params.text_encoder_type == "BERT-UNCASED":
from transformers import BertTokenizer
# This is a BERT-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif params.text_encoder_type == "BERT-LARGE-UNCASED":
from transformers import BertTokenizer
# This is a BERT-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
@ -893,7 +789,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=params.memory_dim, # This is fixed as the BERT base model is 768-D
memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate,
)
@ -934,7 +830,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
if params.context_injection:
from context_fuser import ContextFuser, SelfAttContextFuser
context_fuser = SelfAttContextFuser(
embed_dim=params.memory_dim,
embed_dim=768,
nhead=4,
context_dropout_rate=params.context_dropout_rate,
)
@ -1082,6 +978,50 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes_with_tokenizer(
pre_texts: List[str],
style_texts: List[str],
tokenizer,
device: torch.device,
max_len: int=500,
no_limit: bool=False
) -> Tuple[Dict, Tensor]:
"""
Encode texts as bytes and then integer tensors.
Note that the style text will be added to the beginning of texts.
"""
batch_size = len(pre_texts)
max_len = min(max_len, 500)
if no_limit:
allowed_lens = [5000 - len(s) for s in style_texts]
else:
allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
encoded_style_texts = tokenizer(
style_texts,
return_tensors='pt',
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
)
style_lens = encoded_style_texts["length"].to(device)
# Use tokenizer to prepare input for text encoder
encoded_inputs = tokenizer(
combined_text,
return_tensors='pt',
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
).to(device)
return encoded_inputs, style_lens
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -1117,6 +1057,7 @@ def compute_loss(
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
batch_size = feature.size(0)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
@ -1137,13 +1078,16 @@ def compute_loss(
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
y = k2.RaggedTensor(y).to(device)
if params.forced_upper_pre_text:
pre_texts = [upper_only_alpha(p) for p in pre_texts]
# only shuffle the pre_text and style texts if during training, and use style prompt
if is_training:
# randomly shuffle&mask the pre_text
pre_texts = random_shuffle_subset(
pre_texts,
p=params.pre_text_shuffle_prob,
p_mask=params.prompt_mask_prob
p_mask=params.prompt_mask_prob,
)
if params.use_style_prompt:
@ -1156,28 +1100,26 @@ def compute_loss(
p_mask=params.prompt_mask_prob
)
assert len(transform_ids) == len(style_texts)
assert len(transform_ids) == len(style_texts)
for i in range(len(style_texts)):
t = transform_ids[i] # get the transform id
style_texts[i] = style_transforms[t](style_texts[i])
for i in range(len(style_texts)):
t = transform_ids[i] # get the transform id
style_texts[i] = style_transforms[t](style_texts[i])
if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
if random.random() < 0.01:
logging.info(f"Pre_texts: {pre_texts[0]}")
if random.random() < 0.05:
logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}")
# Use tokenizer to prepare input for text encoder
encoded_inputs = tokenizer(
pre_texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=min(500, max(supervisions["num_frames"])//4,),
).to(device)
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts,
style_texts=style_texts,
tokenizer=tokenizer,
device=device,
)
if random.random() < 0.02:
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
@ -1187,6 +1129,7 @@ def compute_loss(
x=feature,
x_lens=feature_lens,
encoded_inputs=encoded_inputs,
style_lens=style_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
@ -1559,9 +1502,15 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.freeze_text_encoder:
freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
else:
freeze_modules = []
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
@ -1582,6 +1531,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
@ -1637,7 +1587,7 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
text_sampling_func = get_upper_only_alpha_with_multiple_pre_texts
text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(
@ -1650,7 +1600,7 @@ def run(rank, world_size, args):
valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders(
valid_cuts,
text_sampling_func=text_sampling_func
text_sampling_func=naive_triplet_text_sampling
)
# if not params.print_diagnostics:

File diff suppressed because it is too large Load Diff