This commit is contained in:
marcoyang1998 2023-09-08 10:00:00 +08:00
parent 522273f97e
commit 013cafdd6d
2 changed files with 44 additions and 15 deletions

View File

@ -61,14 +61,15 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from dataset import (
from dataset2 import (
triplet_text_sampling,
multi_ref_text_triplet_text_sampling,
triplet_text_sampling_with_context_list,
naive_triplet_text_sampling,
random_shuffle_subset,
joint_triplet_text_sampling,
triplet_style_text_sampling,
)
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder
from joiner import Joiner
@ -211,6 +212,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=0.05,
help="By which probability, dropout the memory when doing cross-attention."
)
parser.add_argument(
"--memory-layer",
type=int,
default=0,
help="Start doing cross-attention from which layer. Zero-indexed"
)
parser.add_argument(
"--query-head-dim",
@ -558,6 +566,11 @@ def get_parser():
default=0.05,
help="The probability of masking prompts",
)
parser.add_argument(
"--freeze-text-encoder",
type=str2bool,
default=True,
)
parser.add_argument(
"--forced-upper-pre-text",
@ -731,12 +744,12 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
if params.text_encoder_type == "BERT":
from transformers import BertModel
# This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encaoder")
logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel
# This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encaoder")
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
else:
raise ValueError()
@ -777,6 +790,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate,
)
return encoder
@ -970,14 +984,19 @@ def _encode_texts_as_bytes_with_tokenizer(
tokenizer,
device: torch.device,
max_len: int=500,
no_limit: bool=False
) -> Tuple[Dict, Tensor]:
"""
Encode texts as bytes and then integer tensors.
Note that the style text will be added to the beginning of texts.
"""
batch_size = len(pre_texts)
max_len = min(max_len, 500)
allowed_lens = [1000 - len(s) for s in style_texts]
if no_limit:
allowed_lens = [5000 - len(s) for s in style_texts]
else:
allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
@ -987,7 +1006,7 @@ def _encode_texts_as_bytes_with_tokenizer(
padding=True,
truncation=True,
return_length=True,
max_length=500,
max_length=max_len,
)
style_lens = encoded_style_texts["length"].to(device)
@ -998,7 +1017,7 @@ def _encode_texts_as_bytes_with_tokenizer(
padding=True,
truncation=True,
return_length=True,
max_length=500,
max_length=max_len,
).to(device)
return encoded_inputs, style_lens
@ -1090,8 +1109,8 @@ def compute_loss(
if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
if random.random() < 0.03:
logging.info(f"Pre_texts: {pre_texts[0]}")
if random.random() < 0.05:
logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}")
@ -1483,9 +1502,15 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.freeze_text_encoder:
freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
else:
freeze_modules = []
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
@ -1506,6 +1531,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module

View File

@ -61,8 +61,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
from dataset2 import triplet_text_sampling_with_context_list
from dataset import naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
from dataset2 import triplet_text_sampling, triplet_text_sampling_with_context_list
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -1093,9 +1093,8 @@ def _encode_text_as_tokens(
style_texts: List[str],
bpe_model: spm.SentencePieceProcessor,
device: torch.device,
max_tokens: int=500,
max_tokens: int=800,
) -> Tuple[Tensor, Tensor]:
max_tokens = min(500, max_tokens)
batch_size = len(pre_texts)
# encoded style texts
@ -1216,6 +1215,7 @@ def compute_loss(
feature = feature.to(device)
supervisions = batch["supervisions"]
cut_ids = [c.id for c in supervisions["cut"]]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
@ -1266,6 +1266,9 @@ def compute_loss(
logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}")
orig_style_texts = batch["supervisions"]["style_text"]
logging.info(f"Orig style texts: {orig_style_texts[0]}")
logging.info(f"Cut ID: {cut_ids[0]}")
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
pre_texts=pre_texts,
@ -1752,7 +1755,7 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
text_sampling_func = triplet_text_sampling_with_context_list
text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(