This commit is contained in:
marcoyang1998 2023-09-08 10:00:00 +08:00
parent 522273f97e
commit 013cafdd6d
2 changed files with 44 additions and 15 deletions

View File

@ -61,14 +61,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from dataset import ( from dataset2 import (
triplet_text_sampling, triplet_text_sampling,
multi_ref_text_triplet_text_sampling, triplet_text_sampling_with_context_list,
naive_triplet_text_sampling, naive_triplet_text_sampling,
random_shuffle_subset, random_shuffle_subset,
joint_triplet_text_sampling, joint_triplet_text_sampling,
triplet_style_text_sampling, triplet_style_text_sampling,
) )
from dataset import multi_ref_text_triplet_text_sampling
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -212,6 +213,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="By which probability, dropout the memory when doing cross-attention." help="By which probability, dropout the memory when doing cross-attention."
) )
parser.add_argument(
"--memory-layer",
type=int,
default=0,
help="Start doing cross-attention from which layer. Zero-indexed"
)
parser.add_argument( parser.add_argument(
"--query-head-dim", "--query-head-dim",
type=str, type=str,
@ -558,6 +566,11 @@ def get_parser():
default=0.05, default=0.05,
help="The probability of masking prompts", help="The probability of masking prompts",
) )
parser.add_argument(
"--freeze-text-encoder",
type=str2bool,
default=True,
)
parser.add_argument( parser.add_argument(
"--forced-upper-pre-text", "--forced-upper-pre-text",
@ -731,12 +744,12 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
if params.text_encoder_type == "BERT": if params.text_encoder_type == "BERT":
from transformers import BertModel from transformers import BertModel
# This is a BERT-base-cased # This is a BERT-base-cased
logging.info("Loading pre-trained BERT-base-cased as text encaoder") logging.info("Loading pre-trained BERT-base-cased as text encoder")
model = BertModel.from_pretrained("bert-base-cased") model = BertModel.from_pretrained("bert-base-cased")
elif params.text_encoder_type == "DistilBERT": elif params.text_encoder_type == "DistilBERT":
from transformers import DistilBertModel from transformers import DistilBertModel
# This is a DistilBERT-base-cased # This is a DistilBERT-base-cased
logging.info("Loading pre-trained DistilBERT-base-cased as text encaoder") logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
model = DistilBertModel.from_pretrained("distilbert-base-cased") model = DistilBertModel.from_pretrained("distilbert-base-cased")
else: else:
raise ValueError() raise ValueError()
@ -777,6 +790,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=768, # This is fixed as the BERT base model is 768-D memory_dim=768, # This is fixed as the BERT base model is 768-D
memory_layer=params.memory_layer,
memory_dropout_rate=params.memory_dropout_rate, memory_dropout_rate=params.memory_dropout_rate,
) )
return encoder return encoder
@ -970,14 +984,19 @@ def _encode_texts_as_bytes_with_tokenizer(
tokenizer, tokenizer,
device: torch.device, device: torch.device,
max_len: int=500, max_len: int=500,
no_limit: bool=False
) -> Tuple[Dict, Tensor]: ) -> Tuple[Dict, Tensor]:
""" """
Encode texts as bytes and then integer tensors. Encode texts as bytes and then integer tensors.
Note that the style text will be added to the beginning of texts. Note that the style text will be added to the beginning of texts.
""" """
batch_size = len(pre_texts) batch_size = len(pre_texts)
max_len = min(max_len, 500)
allowed_lens = [1000 - len(s) for s in style_texts] if no_limit:
allowed_lens = [5000 - len(s) for s in style_texts]
else:
allowed_lens = [1000 - len(s) for s in style_texts]
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)] truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)] combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
@ -987,7 +1006,7 @@ def _encode_texts_as_bytes_with_tokenizer(
padding=True, padding=True,
truncation=True, truncation=True,
return_length=True, return_length=True,
max_length=500, max_length=max_len,
) )
style_lens = encoded_style_texts["length"].to(device) style_lens = encoded_style_texts["length"].to(device)
@ -998,7 +1017,7 @@ def _encode_texts_as_bytes_with_tokenizer(
padding=True, padding=True,
truncation=True, truncation=True,
return_length=True, return_length=True,
max_length=500, max_length=max_len,
).to(device) ).to(device)
return encoded_inputs, style_lens return encoded_inputs, style_lens
@ -1090,8 +1109,8 @@ def compute_loss(
if not params.use_style_prompt: if not params.use_style_prompt:
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
if random.random() < 0.03: if random.random() < 0.05:
logging.info(f"Pre_texts: {pre_texts[0]}") logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}") logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}") logging.info(f"Style texts: {style_texts[0]}")
@ -1483,9 +1502,15 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.freeze_text_encoder:
freeze_modules = ["text_encoder"]
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
else:
freeze_modules = []
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs( get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
), ),
lr=params.base_lr, # should have no effect lr=params.base_lr, # should have no effect
clipping_scale=2.0, clipping_scale=2.0,
@ -1506,6 +1531,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2 ** 22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module

View File

@ -61,8 +61,8 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring from dataset import naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
from dataset2 import triplet_text_sampling_with_context_list from dataset2 import triplet_text_sampling, triplet_text_sampling_with_context_list
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -1093,9 +1093,8 @@ def _encode_text_as_tokens(
style_texts: List[str], style_texts: List[str],
bpe_model: spm.SentencePieceProcessor, bpe_model: spm.SentencePieceProcessor,
device: torch.device, device: torch.device,
max_tokens: int=500, max_tokens: int=800,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
max_tokens = min(500, max_tokens)
batch_size = len(pre_texts) batch_size = len(pre_texts)
# encoded style texts # encoded style texts
@ -1216,6 +1215,7 @@ def compute_loss(
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
cut_ids = [c.id for c in supervisions["cut"]]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
@ -1266,6 +1266,9 @@ def compute_loss(
logging.info(f"Pre texts: {pre_texts[0]}") logging.info(f"Pre texts: {pre_texts[0]}")
logging.info(f"Ref texts: {texts[0]}") logging.info(f"Ref texts: {texts[0]}")
logging.info(f"Style texts: {style_texts[0]}") logging.info(f"Style texts: {style_texts[0]}")
orig_style_texts = batch["supervisions"]["style_text"]
logging.info(f"Orig style texts: {orig_style_texts[0]}")
logging.info(f"Cut ID: {cut_ids[0]}")
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens( pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
pre_texts=pre_texts, pre_texts=pre_texts,
@ -1752,7 +1755,7 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
text_sampling_func = triplet_text_sampling_with_context_list text_sampling_func = triplet_text_sampling
logging.info(f"Text sampling: {text_sampling_func}") logging.info(f"Text sampling: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders( train_dl = libriheavy.train_dataloaders(