mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
remove/rename files
This commit is contained in:
parent
2f4eb18466
commit
ae2c7c73f6
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -31,10 +31,6 @@ from typing import Optional, Tuple, Dict
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
Note that this is a PromptedTransducer, meaning that the transducer is able to decode
|
||||
with prompts.
|
||||
It has a text encoder of BERT type model.
|
||||
This transducer also has a special context fuser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -105,13 +101,8 @@ class PromptedTransducer(nn.Module):
|
||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||
self.context_fuser = context_fuser
|
||||
|
||||
assert text_encoder_type in (
|
||||
"BERT",
|
||||
"DistilBERT",
|
||||
"BERT-UNCASED",
|
||||
"BERT-LARGE-UNCASED",
|
||||
), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED", "BERT-LARGE-UNCASED") else self.text_encoder.config.dim
|
||||
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
|
||||
|
||||
if text_encoder_adapter:
|
||||
self.text_encoder_adapter = nn.Sequential(
|
||||
@ -120,12 +111,15 @@ class PromptedTransducer(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.text_encoder_adapter = None
|
||||
|
||||
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
@ -189,7 +183,10 @@ class PromptedTransducer(nn.Module):
|
||||
# freeze the BERT text encoder
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(encoded_inputs)
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
encoded_inputs,
|
||||
style_lens=style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
@ -279,12 +276,7 @@ class PromptedTransducer(nn.Module):
|
||||
else:
|
||||
context = None
|
||||
|
||||
logits = self.joiner(
|
||||
am_pruned,
|
||||
lm_pruned,
|
||||
context=context,
|
||||
project_input=False,
|
||||
)
|
||||
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
@ -305,8 +297,8 @@ class PromptedTransducer(nn.Module):
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
@ -318,13 +310,14 @@ class PromptedTransducer(nn.Module):
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
@ -335,18 +328,21 @@ class PromptedTransducer(nn.Module):
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text_lens = encoded_inputs["attention_mask"].sum(1)
|
||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||
|
||||
# Freeze the pre-trained text encoder
|
||||
with torch.no_grad():
|
||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||
memory = memory.permute(1,0,2)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
# Text encoder adapter
|
||||
if self.text_encoder_adapter is not None:
|
||||
memory = self.text_encoder_adapter(memory)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
|
@ -1,382 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
use_BERT: bool = True,
|
||||
text_encoder_type: str = "BERT",
|
||||
text_encoder_adapter: bool = False,
|
||||
context_fuser: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
text_encoder:
|
||||
This is a encoder that processes text information (e.g content prompt
|
||||
and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
text_encoder_type:
|
||||
The type of the text_encoder. Supported are (BERT, DistilBERT)
|
||||
context_fuser
|
||||
A optional module that fuses the embeddings of text encoder. The fused embedding
|
||||
will be added to the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||
self.context_fuser = context_fuser
|
||||
|
||||
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
|
||||
|
||||
if text_encoder_adapter:
|
||||
self.text_encoder_adapter = nn.Sequential(
|
||||
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
|
||||
nn.Tanh(),
|
||||
)
|
||||
else:
|
||||
self.text_encoder_adapter = None
|
||||
|
||||
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_pre_text: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# freeze the BERT text encoder
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
encoded_inputs,
|
||||
style_lens=style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
if self.context_fuser is not None and memory is not None:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
||||
context = self.joiner.context_proj(context)
|
||||
else:
|
||||
context = None
|
||||
|
||||
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
||||
< style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
encoded_inputs: Dict,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
Args:
|
||||
encoded_inputs: The encoded inputs generated by a tokenizer (Dict)
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||
|
||||
# Freeze the pre-trained text encoder
|
||||
with torch.no_grad():
|
||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||
memory = memory.permute(1,0,2)
|
||||
|
||||
# Text encoder adapter
|
||||
if self.text_encoder_adapter is not None:
|
||||
memory = self.text_encoder_adapter(memory)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor],
|
||||
memory_key_padding_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
memory (Tensor): Embeddings from the text encoder
|
||||
memory_key_padding_mask (Tensor): _description_
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: _description_
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
@ -61,13 +61,22 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from dataset import triplet_text_sampling, naive_triplet_text_sampling, random_shuffle_subset, joint_triplet_text_sampling, get_substring
|
||||
from dataset2 import (
|
||||
triplet_text_sampling,
|
||||
triplet_text_sampling_with_context_list,
|
||||
naive_triplet_text_sampling,
|
||||
random_shuffle_subset,
|
||||
joint_triplet_text_sampling,
|
||||
triplet_style_text_sampling,
|
||||
)
|
||||
from dataset import multi_ref_text_triplet_text_sampling
|
||||
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model_with_BERT import PromptedTransducer
|
||||
from model_with_BERT_with_style import PromptedTransducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
|
||||
from subsampling import Conv2dSubsampling
|
||||
@ -107,11 +116,6 @@ style_transforms = [
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
rare_words_file = "data/context_biasing/small_rare_words_5.txt"
|
||||
with open(rare_words_file, "r") as f:
|
||||
rare_words = f.read()
|
||||
rare_words_list = rare_words.split("\n")
|
||||
|
||||
def random_sampling(texts: List[str]) -> str:
|
||||
return random.choice(texts)
|
||||
|
||||
@ -126,18 +130,6 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||
}
|
||||
return out
|
||||
|
||||
def joint_random_sampling_mixed_recog(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
||||
i = random.randint(0, 1)
|
||||
trans = style_transforms[i]
|
||||
out = {
|
||||
"text": trans(texts[0]),
|
||||
"pre_text": trans(pre_texts[0]),
|
||||
"style_text": "",
|
||||
"transform_ids": i,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||
out = {
|
||||
"text": texts[0],
|
||||
@ -153,159 +145,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
||||
"text": upper_only_alpha(texts[0]),
|
||||
"pre_text": upper_only_alpha(pre_texts[0]),
|
||||
"style_text": "",
|
||||
"transform_ids": 1,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_upper_only_alpha_with_multiple_ref_texts(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Choose between the first and the last one in texts (gt and decoding results)
|
||||
# But return the upper_only_alpha version
|
||||
i = random.sample([0,2], 1)[0]
|
||||
out = {
|
||||
"text": upper_only_alpha(texts[i]), # either the first or the last
|
||||
"pre_text": upper_only_alpha(pre_texts[0]),
|
||||
"style_text": "",
|
||||
"transform_ids": i,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_upper_only_alpha_with_multiple_pre_texts(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Choose between the first and the last one in texts (gt and decoding results)
|
||||
# But return the upper_only_alpha version
|
||||
v = random.random()
|
||||
if v < 0.5: # The normal case
|
||||
out = {
|
||||
"text": upper_only_alpha(texts[0]),
|
||||
"pre_text": upper_only_alpha(pre_texts[0]),
|
||||
"style_text": "",
|
||||
"transform_ids": 0,
|
||||
}
|
||||
else: # Use the decoded output as pre_text
|
||||
out = {
|
||||
"text": upper_only_alpha(texts[0]),
|
||||
"pre_text": upper_only_alpha(texts[2]),
|
||||
"style_text": "",
|
||||
"transform_ids": 1,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_upper_only_alpha_with_random_ref_text(texts: List[str], pre_texts: List[str]) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
text = upper_only_alpha(texts[0])
|
||||
if random.random() < 0.1:
|
||||
if random.random() < 0.5:
|
||||
pre_text = get_substring(text, min_len=15, max_len=80)
|
||||
else:
|
||||
pre_text = text.split()
|
||||
random.shuffle(pre_text) # shuffle the words
|
||||
i = random.randint(5, 20) # random sample the number of words to be included
|
||||
pre_text = " ".join(pre_text[:i])
|
||||
else:
|
||||
pre_text = upper_only_alpha(pre_texts[0])
|
||||
out = {
|
||||
"text": text,
|
||||
"pre_text": pre_text,
|
||||
"style_text": "",
|
||||
"transform_ids": 1,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_upper_only_alpha_with_random_ref_text_v2(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
text = upper_only_alpha(texts[0])
|
||||
if random.random() < 0.5 and len(text.split()) > 8:
|
||||
v = random.random()
|
||||
if v < 0.4: # Use phrases from ref_text as content prompt
|
||||
splitted = text.split()
|
||||
num_phrases = numpy.random.randint(3) + 1 # 1 to 3 context phrases
|
||||
start_pos = numpy.random.choice(len(splitted) - 3, num_phrases, replace=False)
|
||||
phrases = [" ".join(splitted[start_pos[i]: start_pos[i]+random.randint(0,4) + 1]) for i in range(num_phrases)]
|
||||
num_distractors = random.randint(0,60)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
phrases += distractors
|
||||
random.shuffle(phrases)
|
||||
pre_text = " ".join(phrases)
|
||||
elif v < 0.8: # Use random discrete words
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 8))
|
||||
splitted = list(numpy.random.choice(splitted, i, p=sampling_weights))
|
||||
|
||||
num_distractors = random.randint(0,60)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the word list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = get_substring(text, min_len=40, max_len=120)
|
||||
else:
|
||||
pre_text = pre_texts[0]
|
||||
|
||||
out = {
|
||||
"text": text,
|
||||
"pre_text": upper_only_alpha(pre_text),
|
||||
"style_text": "",
|
||||
"transform_ids": 1,
|
||||
}
|
||||
return out
|
||||
|
||||
def get_upper_only_alpha_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
text = upper_only_alpha(texts[0])
|
||||
if context_list != "":
|
||||
if random.random() < 0.5:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(0, 50)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
else:
|
||||
pre_text = upper_only_alpha(pre_texts[0])
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
random.shuffle(splitted)
|
||||
i = random.randint(5, 20)
|
||||
splitted = splitted[:i]
|
||||
pre_text = " ".join(splitted)
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=80)
|
||||
else:
|
||||
pre_text = upper_only_alpha(pre_texts[0])
|
||||
|
||||
out = {
|
||||
"text": text,
|
||||
"pre_text": pre_text,
|
||||
"style_text": "",
|
||||
"transform_ids": 1,
|
||||
"transform_ids": 0,
|
||||
}
|
||||
return out
|
||||
|
||||
@ -373,13 +213,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="By which probability, dropout the memory when doing cross-attention."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--memory-dim",
|
||||
type=int,
|
||||
default=768,
|
||||
help="The embedding dimension of the text encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--memory-layer",
|
||||
type=int,
|
||||
@ -430,14 +263,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
@ -455,6 +280,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
to this dimension before adding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal",
|
||||
@ -484,7 +317,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--text-encoder-type",
|
||||
type=str,
|
||||
default="BERT",
|
||||
choices=["BERT","DistilBERT","BERT-UNCASED", "BERT-LARGE-UNCASED"],
|
||||
choices=["BERT","DistilBERT"],
|
||||
help="Type of the text encoder",
|
||||
)
|
||||
|
||||
@ -733,6 +566,18 @@ def get_parser():
|
||||
default=0.05,
|
||||
help="The probability of masking prompts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-text-encoder",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--forced-upper-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Forced format of pre-text",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -825,30 +670,89 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
)
|
||||
return encoder_embed
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int=256,
|
||||
embedding_dim: int=256,
|
||||
kernel_size: int=3,
|
||||
layer1_channels: int = 256,
|
||||
layer2_channels: int = 256,
|
||||
bias: bool=True,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
|
||||
embedding_dim=embedding_dim, #
|
||||
)
|
||||
|
||||
assert embedding_dim == layer1_channels # for depth wise convolution
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
embedding_dim,
|
||||
layer1_channels, # depthwise convolution
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=layer1_channels,
|
||||
bias=True,
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(
|
||||
layer1_channels,
|
||||
layer2_channels,
|
||||
kernel_size=1, # pointwise convolution
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
),
|
||||
Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.out_norm = BiasNorm(layer2_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
def forward(self, text: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of the text embedding
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): Text in UTF-8 bytes (T,N)
|
||||
Returns:
|
||||
The embeddings of text (T,N,C)
|
||||
"""
|
||||
text = self.embed(text) # (T,N,C)
|
||||
|
||||
#src = text
|
||||
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
|
||||
text = self.conv(text)
|
||||
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
|
||||
#src = src + text
|
||||
|
||||
text = self.out_norm(text)
|
||||
text = self.dropout(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||
# Return a text encoder
|
||||
if params.text_encoder_type == "BERT": # # This is a BERT-base-cased
|
||||
from transformers import BertModel
|
||||
assert params.memory_dim == 768
|
||||
if params.text_encoder_type == "BERT":
|
||||
from transformers import BertModel
|
||||
# This is a BERT-base-cased
|
||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
elif params.text_encoder_type == "BERT-UNCASED": # This is a BERT-base-uncased
|
||||
from transformers import BertModel
|
||||
assert params.memory_dim == 768
|
||||
logging.info("Loading pre-trained BERT-base-uncased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-base-uncased")
|
||||
elif params.text_encoder_type == "BERT-LARGE-UNCASED": # This is a BERT-large-uncased
|
||||
from transformers import BertModel
|
||||
assert params.memory_dim == 1024
|
||||
logging.info("Loading pre-trained BERT-large-uncased as text encoder")
|
||||
model = BertModel.from_pretrained("bert-large-uncased")
|
||||
elif params.text_encoder_type == "DistilBERT": # This is a DistilBERT-base-cased
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertModel
|
||||
assert params.memory_dim == 768
|
||||
# This is a DistilBERT-base-cased
|
||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||
else:
|
||||
raise ValueError(f"Unknown text encoder type: {params.text_encoder_type}")
|
||||
raise ValueError()
|
||||
|
||||
return model
|
||||
|
||||
@ -858,14 +762,6 @@ def get_tokenizer(params: AttributeDict):
|
||||
from transformers import BertTokenizer
|
||||
# This is a BERT-base-cased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||
elif params.text_encoder_type == "BERT-UNCASED":
|
||||
from transformers import BertTokenizer
|
||||
# This is a BERT-base-uncased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
elif params.text_encoder_type == "BERT-LARGE-UNCASED":
|
||||
from transformers import BertTokenizer
|
||||
# This is a BERT-base-uncased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
||||
elif params.text_encoder_type == "DistilBERT":
|
||||
from transformers import DistilBertTokenizer
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||
@ -893,7 +789,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
memory_dim=params.memory_dim, # This is fixed as the BERT base model is 768-D
|
||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||
memory_layer=params.memory_layer,
|
||||
memory_dropout_rate=params.memory_dropout_rate,
|
||||
)
|
||||
@ -934,7 +830,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
if params.context_injection:
|
||||
from context_fuser import ContextFuser, SelfAttContextFuser
|
||||
context_fuser = SelfAttContextFuser(
|
||||
embed_dim=params.memory_dim,
|
||||
embed_dim=768,
|
||||
nhead=4,
|
||||
context_dropout_rate=params.context_dropout_rate,
|
||||
)
|
||||
@ -1082,6 +978,50 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts: List[str],
|
||||
style_texts: List[str],
|
||||
tokenizer,
|
||||
device: torch.device,
|
||||
max_len: int=500,
|
||||
no_limit: bool=False
|
||||
) -> Tuple[Dict, Tensor]:
|
||||
"""
|
||||
Encode texts as bytes and then integer tensors.
|
||||
Note that the style text will be added to the beginning of texts.
|
||||
"""
|
||||
batch_size = len(pre_texts)
|
||||
max_len = min(max_len, 500)
|
||||
|
||||
if no_limit:
|
||||
allowed_lens = [5000 - len(s) for s in style_texts]
|
||||
else:
|
||||
allowed_lens = [1000 - len(s) for s in style_texts]
|
||||
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
|
||||
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
|
||||
|
||||
encoded_style_texts = tokenizer(
|
||||
style_texts,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=max_len,
|
||||
)
|
||||
style_lens = encoded_style_texts["length"].to(device)
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs = tokenizer(
|
||||
combined_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=max_len,
|
||||
).to(device)
|
||||
|
||||
return encoded_inputs, style_lens
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -1117,6 +1057,7 @@ def compute_loss(
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
batch_size = feature.size(0)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
@ -1137,13 +1078,16 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
if params.forced_upper_pre_text:
|
||||
pre_texts = [upper_only_alpha(p) for p in pre_texts]
|
||||
|
||||
# only shuffle the pre_text and style texts if during training, and use style prompt
|
||||
if is_training:
|
||||
# randomly shuffle&mask the pre_text
|
||||
pre_texts = random_shuffle_subset(
|
||||
pre_texts,
|
||||
p=params.pre_text_shuffle_prob,
|
||||
p_mask=params.prompt_mask_prob
|
||||
p_mask=params.prompt_mask_prob,
|
||||
)
|
||||
|
||||
if params.use_style_prompt:
|
||||
@ -1156,28 +1100,26 @@ def compute_loss(
|
||||
p_mask=params.prompt_mask_prob
|
||||
)
|
||||
|
||||
assert len(transform_ids) == len(style_texts)
|
||||
assert len(transform_ids) == len(style_texts)
|
||||
|
||||
for i in range(len(style_texts)):
|
||||
t = transform_ids[i] # get the transform id
|
||||
style_texts[i] = style_transforms[t](style_texts[i])
|
||||
for i in range(len(style_texts)):
|
||||
t = transform_ids[i] # get the transform id
|
||||
style_texts[i] = style_transforms[t](style_texts[i])
|
||||
|
||||
if not params.use_style_prompt:
|
||||
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
|
||||
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"Pre_texts: {pre_texts[0]}")
|
||||
if random.random() < 0.05:
|
||||
logging.info(f"Pre texts: {pre_texts[0]}")
|
||||
logging.info(f"Ref texts: {texts[0]}")
|
||||
logging.info(f"Style texts: {style_texts[0]}")
|
||||
|
||||
# Use tokenizer to prepare input for text encoder
|
||||
encoded_inputs = tokenizer(
|
||||
pre_texts,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=min(500, max(supervisions["num_frames"])//4,),
|
||||
).to(device)
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if random.random() < 0.02:
|
||||
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
|
||||
@ -1187,6 +1129,7 @@ def compute_loss(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
@ -1559,9 +1502,15 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.freeze_text_encoder:
|
||||
freeze_modules = ["text_encoder"]
|
||||
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
|
||||
else:
|
||||
freeze_modules = []
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True
|
||||
model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules
|
||||
),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
@ -1582,6 +1531,7 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
args.max_duration = 100
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
@ -1637,7 +1587,7 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
text_sampling_func = get_upper_only_alpha_with_multiple_pre_texts
|
||||
text_sampling_func = triplet_text_sampling
|
||||
logging.info(f"Text sampling: {text_sampling_func}")
|
||||
|
||||
train_dl = libriheavy.train_dataloaders(
|
||||
@ -1650,7 +1600,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts = libriheavy.dev_cuts()
|
||||
valid_dl = libriheavy.valid_dataloaders(
|
||||
valid_cuts,
|
||||
text_sampling_func=text_sampling_func
|
||||
text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user