remove/rename files

This commit is contained in:
marcoyang1998 2023-09-15 10:54:58 +08:00
parent 2f4eb18466
commit ae2c7c73f6
9 changed files with 202 additions and 8050 deletions

File diff suppressed because it is too large Load Diff

View File

@ -31,10 +31,6 @@ from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module): class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks" "Sequence Transduction with Recurrent Neural Networks"
Note that this is a PromptedTransducer, meaning that the transducer is able to decode
with prompts.
It has a text encoder of BERT type model.
This transducer also has a special context fuser.
""" """
def __init__( def __init__(
@ -105,13 +101,8 @@ class PromptedTransducer(nn.Module):
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser self.context_fuser = context_fuser
assert text_encoder_type in ( assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
"BERT", self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
"DistilBERT",
"BERT-UNCASED",
"BERT-LARGE-UNCASED",
), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED", "BERT-LARGE-UNCASED") else self.text_encoder.config.dim
if text_encoder_adapter: if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential( self.text_encoder_adapter = nn.Sequential(
@ -121,11 +112,14 @@ class PromptedTransducer(nn.Module):
else: else:
self.text_encoder_adapter = None self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
encoded_inputs: Dict, encoded_inputs: Dict,
style_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
@ -189,7 +183,10 @@ class PromptedTransducer(nn.Module):
# freeze the BERT text encoder # freeze the BERT text encoder
if use_pre_text: if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(encoded_inputs) memory, memory_key_padding_mask = self.encode_text(
encoded_inputs,
style_lens=style_lens
)
else: else:
memory = None memory = None
memory_key_padding_mask = None memory_key_padding_mask = None
@ -279,12 +276,7 @@ class PromptedTransducer(nn.Module):
else: else:
context = None context = None
logits = self.joiner( logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
am_pruned,
lm_pruned,
context=context,
project_input=False,
)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
@ -305,8 +297,8 @@ class PromptedTransducer(nn.Module):
scale of the embedding vector can adjust to compensate. scale of the embedding vector can adjust to compensate.
Args: Args:
memory: (memory_len, batch_size, embed_dim) memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt. style_lens: (batch_size,), a vector of lengths of the style prompt.
""" """
(memory_len, batch_size, embed_dim) = memory.shape (memory_len, batch_size, embed_dim) = memory.shape
@ -318,13 +310,14 @@ class PromptedTransducer(nn.Module):
indicator = indicator.to(memory.dtype) indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory) extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
return memory + extra_term return memory + extra_term
def encode_text( def encode_text(
self, self,
encoded_inputs: Dict, encoded_inputs: Dict,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text """Get the embeddings of text
@ -335,18 +328,21 @@ class PromptedTransducer(nn.Module):
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask text_encoder and the attention mask
""" """
text_lens = encoded_inputs["attention_mask"].sum(1) text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
# Freeze the pre-trained text encoder # Freeze the pre-trained text encoder
with torch.no_grad(): with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2) memory = memory.permute(1,0,2)
memory_key_padding_mask = make_pad_mask(text_lens)
# Text encoder adapter # Text encoder adapter
if self.text_encoder_adapter is not None: if self.text_encoder_adapter is not None:
memory = self.text_encoder_adapter(memory) memory = self.text_encoder_adapter(memory)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask return memory, memory_key_padding_mask
def encode_audio( def encode_audio(

View File

@ -1,382 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
text_encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
use_BERT: bool = True,
text_encoder_type: str = "BERT",
text_encoder_adapter: bool = False,
context_fuser: nn.Module = None,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
text_encoder:
This is a encoder that processes text information (e.g content prompt
and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
text_encoder_type:
The type of the text_encoder. Supported are (BERT, DistilBERT)
context_fuser
A optional module that fuses the embeddings of text encoder. The fused embedding
will be added to the joiner.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.text_encoder = text_encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential(
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
nn.Tanh(),
)
else:
self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
encoded_inputs: Dict,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_pre_text: bool = True,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
text:
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# freeze the BERT text encoder
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
encoded_inputs,
style_lens=style_lens
)
else:
memory = None
memory_key_padding_mask = None
encoder_out, x_lens = self.encoder(
x,
x_lens,
src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
if self.context_fuser is not None and memory is not None:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
context = self.joiner.context_proj(context)
else:
context = None
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
< style_lens
)
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
return memory + extra_term
def encode_text(
self,
encoded_inputs: Dict,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
Args:
encoded_inputs: The encoded inputs generated by a tokenizer (Dict)
Returns:
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask
"""
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
# Freeze the pre-trained text encoder
with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2)
# Text encoder adapter
if self.text_encoder_adapter is not None:
memory = self.text_encoder_adapter(memory)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask
def encode_audio(
self,
feature: Tensor,
feature_lens: Tensor,
memory: Optional[Tensor],
memory_key_padding_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
"""Encode the input audio features
Args:
feature (Tensor): Input audio (N,T,C)
feature_lens (Tensor): Length of input audio (N,)
memory (Tensor): Embeddings from the text encoder
memory_key_padding_mask (Tensor): _description_
Returns:
Tuple[Tensor, Tensor]: _description_
"""
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x=x,
x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
Transducer = PromptedTransducer # for decoding

View File

@ -61,13 +61,22 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring from dataset2 import (
triplet_text_sampling,
triplet_text_sampling_with_context_list,
naive_triplet_text_sampling,
random_shuffle_subset,
joint_triplet_text_sampling,
triplet_style_text_sampling,
)
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model_with_BERT import PromptedTransducer from model_with_BERT_with_style import PromptedTransducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
@ -107,11 +116,6 @@ style_transforms = [
lower_all_char, lower_all_char,
] ]
rare_words_file = "data/context_biasing/small_rare_words_5.txt"
with open(rare_words_file, "r") as f:
rare_words = f.read()
rare_words_list = rare_words.split("\n")
def random_sampling(texts: List[str]) -> str: def random_sampling(texts: List[str]) -> str:
return random.choice(texts) return random.choice(texts)
@ -126,18 +130,6 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
} }
return out return out
def joint_random_sampling_mixed_recog(texts: List[str], pre_texts: List[str]) -> str:
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
i = random.randint(0, 1)
trans = style_transforms[i]
out = {
"text": trans(texts[0]),
"pre_text": trans(pre_texts[0]),
"style_text": "",
"transform_ids": i,
}
return out
def get_first(texts: List[str], pre_texts: List[str]) -> str: def get_first(texts: List[str], pre_texts: List[str]) -> str:
out = { out = {
"text": texts[0], "text": texts[0],
@ -153,159 +145,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
"text": upper_only_alpha(texts[0]), "text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]), "pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "", "style_text": "",
"transform_ids": 1, "transform_ids": 0,
}
return out
def get_upper_only_alpha_with_multiple_ref_texts(texts: List[str], pre_texts: List[str]) -> str:
# Choose between the first and the last one in texts (gt and decoding results)
# But return the upper_only_alpha version
i = random.sample([0,2], 1)[0]
out = {
"text": upper_only_alpha(texts[i]), # either the first or the last
"pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "",
"transform_ids": i,
}
return out
def get_upper_only_alpha_with_multiple_pre_texts(texts: List[str], pre_texts: List[str]) -> str:
# Choose between the first and the last one in texts (gt and decoding results)
# But return the upper_only_alpha version
v = random.random()
if v < 0.5: # The normal case
out = {
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]),
"style_text": "",
"transform_ids": 0,
}
else: # Use the decoded output as pre_text
out = {
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(texts[2]),
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_random_ref_text(texts: List[str], pre_texts: List[str]) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if random.random() < 0.1:
if random.random() < 0.5:
pre_text = get_substring(text, min_len=15, max_len=80)
else:
pre_text = text.split()
random.shuffle(pre_text) # shuffle the words
i = random.randint(5, 20) # random sample the number of words to be included
pre_text = " ".join(pre_text[:i])
else:
pre_text = upper_only_alpha(pre_texts[0])
out = {
"text": text,
"pre_text": pre_text,
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_random_ref_text_v2(
texts: List[str],
pre_texts: List[str],
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if random.random() < 0.5 and len(text.split()) > 8:
v = random.random()
if v < 0.4: # Use phrases from ref_text as content prompt
splitted = text.split()
num_phrases = numpy.random.randint(3) + 1 # 1 to 3 context phrases
start_pos = numpy.random.choice(len(splitted) - 3, num_phrases, replace=False)
phrases = [" ".join(splitted[start_pos[i]: start_pos[i]+random.randint(0,4) + 1]) for i in range(num_phrases)]
num_distractors = random.randint(0,60)
distractors = random.sample(rare_words_list, num_distractors)
phrases += distractors
random.shuffle(phrases)
pre_text = " ".join(phrases)
elif v < 0.8: # Use random discrete words
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 8))
splitted = list(numpy.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,60)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the word list
pre_text = " ".join(splitted)
else:
pre_text = get_substring(text, min_len=40, max_len=120)
else:
pre_text = pre_texts[0]
out = {
"text": text,
"pre_text": upper_only_alpha(pre_text),
"style_text": "",
"transform_ids": 1,
}
return out
def get_upper_only_alpha_with_context_list(
texts: List[str],
pre_texts: List[str],
context_list: str,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
text = upper_only_alpha(texts[0])
if context_list != "":
if random.random() < 0.5:
# correct + distractors
# sample distractors
num_distractors = random.randint(0, 50)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
else:
pre_text = upper_only_alpha(pre_texts[0])
else:
v = random.random()
if v < 0.1:
splitted = text.split()
random.shuffle(splitted)
i = random.randint(5, 20)
splitted = splitted[:i]
pre_text = " ".join(splitted)
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=80)
else:
pre_text = upper_only_alpha(pre_texts[0])
out = {
"text": text,
"pre_text": pre_text,
"style_text": "",
"transform_ids": 1,
} }
return out return out
@ -373,13 +213,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="By which probability, dropout the memory when doing cross-attention." help="By which probability, dropout the memory when doing cross-attention."
) )
parser.add_argument(
"--memory-dim",
type=int,
default=768,
help="The embedding dimension of the text encoder"
)
parser.add_argument( parser.add_argument(
"--memory-layer", "--memory-layer",
type=int, type=int,
@ -431,14 +264,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"a single int or comma-separated list.", "a single int or comma-separated list.",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim",
type=int, type=int,
@ -456,6 +281,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""", """,
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument( parser.add_argument(
"--causal", "--causal",
type=str2bool, type=str2bool,
@ -484,7 +317,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--text-encoder-type", "--text-encoder-type",
type=str, type=str,
default="BERT", default="BERT",
choices=["BERT","DistilBERT","BERT-UNCASED", "BERT-LARGE-UNCASED"], choices=["BERT","DistilBERT"],
help="Type of the text encoder", help="Type of the text encoder",
) )
@ -733,6 +566,18 @@ def get_parser():
default=0.05, default=0.05,
help="The probability of masking prompts", help="The probability of masking prompts",
) )
parser.add_argument(
"--freeze-text-encoder",
type=str2bool,
default=True,
)
parser.add_argument(
"--forced-upper-pre-text",
type=str2bool,
default=False,
help="Forced format of pre-text",
)
add_model_arguments(parser) add_model_arguments(parser)
@ -825,30 +670,89 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
) )
return encoder_embed return encoder_embed
class TextEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int=256,
embedding_dim: int=256,
kernel_size: int=3,
layer1_channels: int = 256,
layer2_channels: int = 256,
bias: bool=True,
dropout: float = 0.1
):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
embedding_dim=embedding_dim, #
)
assert embedding_dim == layer1_channels # for depth wise convolution
self.conv = nn.Sequential(
nn.Conv1d(
embedding_dim,
layer1_channels, # depthwise convolution
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=layer1_channels,
bias=True,
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
nn.ReLU(),
nn.Conv1d(
layer1_channels,
layer2_channels,
kernel_size=1, # pointwise convolution
stride=1,
padding=0,
bias=True,
),
Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
nn.ReLU(),
)
self.out_norm = BiasNorm(layer2_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, text: torch.Tensor) -> torch.Tensor:
"""Forward function of the text embedding
Args:
text (torch.Tensor): Text in UTF-8 bytes (T,N)
Returns:
The embeddings of text (T,N,C)
"""
text = self.embed(text) # (T,N,C)
#src = text
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
text = self.conv(text)
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
#src = src + text
text = self.out_norm(text)
text = self.dropout(text)
return text
def get_text_encoder(params: AttributeDict) -> nn.Module: def get_text_encoder(params: AttributeDict) -> nn.Module:
# Return a text encoder # Return a text encoder
if params.text_encoder_type == "BERT": # # This is a BERT-base-cased if params.text_encoder_type == "BERT":
from transformers import BertModel from transformers import BertModel
assert params.memory_dim == 768 # This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encoder") logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased") model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "BERT-UNCASED": # This is a BERT-base-uncased elif params.text_encoder_type == "DistilBERT":
from transformers import BertModel
assert params.memory_dim == 768
logging.info("Loading pre-trained BERT-base-uncased as text encoder")
model = BertModel.from_pretrained("bert-base-uncased")
elif params.text_encoder_type == "BERT-LARGE-UNCASED": # This is a BERT-large-uncased
from transformers import BertModel
assert params.memory_dim == 1024
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
model = BertModel.from_pretrained("bert-large-uncased")
elif params.text_encoder_type == "DistilBERT": # This is a DistilBERT-base-cased
from transformers import DistilBertModel from transformers import DistilBertModel
assert params.memory_dim == 768 # This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased") model = DistilBertModel.from_pretrained("distilbert-base-cased")
else: else:
raise ValueError(f"Unknown text encoder type: {params.text_encoder_type}") raise ValueError()
return model return model
@ -858,14 +762,6 @@ def get_tokenizer(params: AttributeDict):
from transformers import BertTokenizer from transformers import BertTokenizer
# This is a BERT-base-cased # This is a BERT-base-cased
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
elif params.text_encoder_type == "BERT-UNCASED":
from transformers import BertTokenizer
# This is a BERT-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif params.text_encoder_type == "BERT-LARGE-UNCASED":
from transformers import BertTokenizer
# This is a BERT-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
elif params.text_encoder_type == "DistilBERT": elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertTokenizer from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
@ -893,7 +789,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal, causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=params.memory_dim, # This is fixed as the BERT base model is 768-D memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer, memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate, memory_dropout_rate=params.memory_dropout_rate,
) )
@ -934,7 +830,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
if params.context_injection: if params.context_injection:
from context_fuser import ContextFuser, SelfAttContextFuser from context_fuser import ContextFuser, SelfAttContextFuser
context_fuser = SelfAttContextFuser( context_fuser = SelfAttContextFuser(
embed_dim=params.memory_dim, embed_dim=768,
nhead=4, nhead=4,
context_dropout_rate=params.context_dropout_rate, context_dropout_rate=params.context_dropout_rate,
) )
@ -1082,6 +978,50 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes_with_tokenizer(
pre_texts: List[str],
style_texts: List[str],
tokenizer,
device: torch.device,
max_len: int=500,
no_limit: bool=False
) -> Tuple[Dict, Tensor]:
"""
Encode texts as bytes and then integer tensors.
Note that the style text will be added to the beginning of texts.
"""
batch_size = len(pre_texts)
max_len = min(max_len, 500)
if no_limit:
allowed_lens = [5000 - len(s) for s in style_texts]
else:
allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
encoded_style_texts = tokenizer(
style_texts,
return_tensors='pt',
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
)
style_lens = encoded_style_texts["length"].to(device)
# Use tokenizer to prepare input for text encoder
encoded_inputs = tokenizer(
combined_text,
return_tensors='pt',
padding=True,
truncation=True,
return_length=True,
max_length=max_len,
).to(device)
return encoded_inputs, style_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -1117,6 +1057,7 @@ def compute_loss(
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
batch_size = feature.size(0)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
@ -1137,13 +1078,16 @@ def compute_loss(
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
if params.forced_upper_pre_text:
pre_texts = [upper_only_alpha(p) for p in pre_texts]
# only shuffle the pre_text and style texts if during training, and use style prompt # only shuffle the pre_text and style texts if during training, and use style prompt
if is_training: if is_training:
# randomly shuffle&mask the pre_text # randomly shuffle&mask the pre_text
pre_texts = random_shuffle_subset( pre_texts = random_shuffle_subset(
pre_texts, pre_texts,
p=params.pre_text_shuffle_prob, p=params.pre_text_shuffle_prob,
p_mask=params.prompt_mask_prob p_mask=params.prompt_mask_prob,
) )
if params.use_style_prompt: if params.use_style_prompt:
@ -1156,28 +1100,26 @@ def compute_loss(
p_mask=params.prompt_mask_prob p_mask=params.prompt_mask_prob
) )
assert len(transform_ids) == len(style_texts) assert len(transform_ids) == len(style_texts)
for i in range(len(style_texts)): for i in range(len(style_texts)):
t = transform_ids[i] # get the transform id t = transform_ids[i] # get the transform id
style_texts[i] = style_transforms[t](style_texts[i]) style_texts[i] = style_transforms[t](style_texts[i])
if not params.use_style_prompt: if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
if random.random() < 0.01: if random.random() < 0.05:
logging.info(f"Pre_texts: {pre_texts[0]}") logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}") logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}") logging.info(f"Style texts: {style_texts[0]}")
# Use tokenizer to prepare input for text encoder encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
encoded_inputs = tokenizer( pre_texts=pre_texts,
pre_texts, style_texts=style_texts,
return_tensors='pt', tokenizer=tokenizer,
padding=True, device=device,
truncation=True, )
max_length=min(500, max(supervisions["num_frames"])//4,),
).to(device)
if random.random() < 0.02: if random.random() < 0.02:
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
@ -1187,6 +1129,7 @@ def compute_loss(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
encoded_inputs=encoded_inputs, encoded_inputs=encoded_inputs,
style_lens=style_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
@ -1559,9 +1502,15 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.freeze_text_encoder:
freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
else:
freeze_modules = []
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs( get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
), ),
lr=params.base_lr, # should have no effect lr=params.base_lr, # should have no effect
clipping_scale=2.0, clipping_scale=2.0,
@ -1582,6 +1531,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2 ** 22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
@ -1637,7 +1587,7 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
text_sampling_func = get_upper_only_alpha_with_multiple_pre_texts text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}") logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders( train_dl = libriheavy.train_dataloaders(
@ -1650,7 +1600,7 @@ def run(rank, world_size, args):
valid_cuts = libriheavy.dev_cuts() valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders( valid_dl = libriheavy.valid_dataloaders(
valid_cuts, valid_cuts,
text_sampling_func=text_sampling_func text_sampling_func=naive_triplet_text_sampling
) )
# if not params.print_diagnostics: # if not params.print_diagnostics:

File diff suppressed because it is too large Load Diff