mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
updates
This commit is contained in:
parent
522273f97e
commit
013cafdd6d
@ -61,14 +61,15 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from dataset import (
|
||||
from dataset2 import (
|
||||
triplet_text_sampling,
|
||||
multi_ref_text_triplet_text_sampling,
|
||||
triplet_text_sampling_with_context_list,
|
||||
naive_triplet_text_sampling,
|
||||
random_shuffle_subset,
|
||||
joint_triplet_text_sampling,
|
||||
triplet_style_text_sampling,
|
||||
)
|
||||
from dataset import multi_ref_text_triplet_text_sampling
|
||||
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -211,6 +212,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default=0.05,
|
||||
help="By which probability, dropout the memory when doing cross-attention."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--memory-layer",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start doing cross-attention from which layer. Zero-indexed"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
@ -558,6 +566,11 @@ def get_parser():
|
||||
default=0.05,
|
||||
help="The probability of masking prompts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-text-encoder",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--forced-upper-pre-text",
|
||||
@ -731,12 +744,12 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||
if params.text_encoder_type == "BERT":
|
||||
from transformers import BertModel
|
||||
# This is a BERT-base-cased
|
||||
logging.info("Loading pre-trained BERT-base-cased as text encaoder")
|
||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertModel
|
||||
# This is a DistilBERT-base-cased
|
||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encaoder")
|
||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||
else:
|
||||
raise ValueError()
|
||||
@ -777,6 +790,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||
memory_layer=params.memory_layer,
|
||||
memory_dropout_rate=params.memory_dropout_rate,
|
||||
)
|
||||
return encoder
|
||||
@ -970,14 +984,19 @@ def _encode_texts_as_bytes_with_tokenizer(
|
||||
tokenizer,
|
||||
device: torch.device,
|
||||
max_len: int=500,
|
||||
no_limit: bool=False
|
||||
) -> Tuple[Dict, Tensor]:
|
||||
"""
|
||||
Encode texts as bytes and then integer tensors.
|
||||
Note that the style text will be added to the beginning of texts.
|
||||
"""
|
||||
batch_size = len(pre_texts)
|
||||
max_len = min(max_len, 500)
|
||||
|
||||
allowed_lens = [1000 - len(s) for s in style_texts]
|
||||
if no_limit:
|
||||
allowed_lens = [5000 - len(s) for s in style_texts]
|
||||
else:
|
||||
allowed_lens = [1000 - len(s) for s in style_texts]
|
||||
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
|
||||
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
|
||||
|
||||
@ -987,7 +1006,7 @@ def _encode_texts_as_bytes_with_tokenizer(
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=500,
|
||||
max_length=max_len,
|
||||
)
|
||||
style_lens = encoded_style_texts["length"].to(device)
|
||||
|
||||
@ -998,7 +1017,7 @@ def _encode_texts_as_bytes_with_tokenizer(
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=500,
|
||||
max_length=max_len,
|
||||
).to(device)
|
||||
|
||||
return encoded_inputs, style_lens
|
||||
@ -1090,8 +1109,8 @@ def compute_loss(
|
||||
if not params.use_style_prompt:
|
||||
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
|
||||
|
||||
if random.random() < 0.03:
|
||||
logging.info(f"Pre_texts: {pre_texts[0]}")
|
||||
if random.random() < 0.05:
|
||||
logging.info(f"Pre texts: {pre_texts[0]}")
|
||||
logging.info(f"Ref texts: {texts[0]}")
|
||||
logging.info(f"Style texts: {style_texts[0]}")
|
||||
|
||||
@ -1483,9 +1502,15 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.freeze_text_encoder:
|
||||
freeze_modules = ["text_encoder"]
|
||||
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
|
||||
else:
|
||||
freeze_modules = []
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True
|
||||
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
|
||||
),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
@ -1506,6 +1531,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
args.max_duration = 100
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
|
@ -61,8 +61,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
|
||||
from dataset2 import triplet_text_sampling_with_context_list
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
|
||||
from dataset2 import triplet_text_sampling, triplet_text_sampling_with_context_list
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -1093,9 +1093,8 @@ def _encode_text_as_tokens(
|
||||
style_texts: List[str],
|
||||
bpe_model: spm.SentencePieceProcessor,
|
||||
device: torch.device,
|
||||
max_tokens: int=500,
|
||||
max_tokens: int=800,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
max_tokens = min(500, max_tokens)
|
||||
batch_size = len(pre_texts)
|
||||
|
||||
# encoded style texts
|
||||
@ -1216,6 +1215,7 @@ def compute_loss(
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
cut_ids = [c.id for c in supervisions["cut"]]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
@ -1266,6 +1266,9 @@ def compute_loss(
|
||||
logging.info(f"Pre texts: {pre_texts[0]}")
|
||||
logging.info(f"Ref texts: {texts[0]}")
|
||||
logging.info(f"Style texts: {style_texts[0]}")
|
||||
orig_style_texts = batch["supervisions"]["style_text"]
|
||||
logging.info(f"Orig style texts: {orig_style_texts[0]}")
|
||||
logging.info(f"Cut ID: {cut_ids[0]}")
|
||||
|
||||
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
|
||||
pre_texts=pre_texts,
|
||||
@ -1752,7 +1755,7 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
text_sampling_func = triplet_text_sampling_with_context_list
|
||||
text_sampling_func = triplet_text_sampling
|
||||
logging.info(f"Text sampling: {text_sampling_func}")
|
||||
|
||||
train_dl = libriheavy.train_dataloaders(
|
||||
|
Loading…
x
Reference in New Issue
Block a user