fix for black

This commit is contained in:
marcoyang 2023-10-10 17:44:48 +08:00
parent e32bda6a7b
commit 96977c9ddd
6 changed files with 45 additions and 1911 deletions

View File

@ -17,9 +17,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from scaling import Balancer from scaling import Balancer
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper: """This class modifies the stateless decoder from the following paper:
@ -61,10 +61,15 @@ class Decoder(nn.Module):
) )
# the balancers are to avoid any drift in the magnitude of the # the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging. # embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1, self.balancer = Balancer(
min_positive=0.0, max_positive=1.0, decoder_dim,
min_abs=0.5, max_abs=1.0, channel_dim=-1,
prob=0.05) min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id self.blank_id = blank_id
@ -78,14 +83,18 @@ class Decoder(nn.Module):
out_channels=decoder_dim, out_channels=decoder_dim,
kernel_size=context_size, kernel_size=context_size,
padding=0, padding=0,
groups=decoder_dim//4, # group size == 4 groups=decoder_dim // 4, # group size == 4
bias=False, bias=False,
) )
self.balancer2 = Balancer(decoder_dim, channel_dim=-1, self.balancer2 = Balancer(
min_positive=0.0, max_positive=1.0, decoder_dim,
min_abs=0.5, max_abs=1.0, channel_dim=-1,
prob=0.05) min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
@ -108,9 +117,7 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad( embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output

View File

@ -16,9 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ( from scaling import ScaledLinear
ScaledLinear
)
class Joiner(nn.Module): class Joiner(nn.Module):
@ -28,7 +26,7 @@ class Joiner(nn.Module):
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int, vocab_size: int,
context_dim: int=512, context_dim: int = 512,
context_injection: bool = False, context_injection: bool = False,
): ):
super().__init__() super().__init__()
@ -37,7 +35,9 @@ class Joiner(nn.Module):
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size) self.output_linear = nn.Linear(joiner_dim, vocab_size)
if context_injection: if context_injection:
self.context_proj = ScaledLinear(context_dim, joiner_dim, initial_scale=0.25) self.context_proj = ScaledLinear(
context_dim, joiner_dim, initial_scale=0.25
)
else: else:
self.context_proj = None self.context_proj = None
@ -68,7 +68,11 @@ class Joiner(nn.Module):
if project_input: if project_input:
if context: if context:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + self.context_proj(context) logit = (
self.encoder_proj(encoder_out)
+ self.decoder_proj(decoder_out)
+ self.context_proj(context)
)
else: else:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else: else:

View File

@ -111,7 +111,7 @@ def decoding_normalization(text: str) -> str:
# Only keep all alpha-numeric characters, hypen and apostrophe # Only keep all alpha-numeric characters, hypen and apostrophe
text = text.replace("-", " ") text = text.replace("-", " ")
text = re.sub("[^a-zA-Z0-9\s']+", "", text) text = re.sub(r"[^a-zA-Z0-9\s']+", "", text)
return text return text
@ -130,9 +130,10 @@ def word_normalization(word: str) -> str:
if word.isnumeric(): if word.isnumeric():
word = num_to_words(int(word)) word = num_to_words(int(word))
return str(word).upper() return str(word).upper()
if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH # e.g 9TH, 6TH
if word[-2:] == "TH" and word[0].isnumeric():
return num_to_ordinal_word(int(word[:-2])).upper() return num_to_ordinal_word(int(word[:-2])).upper()
if word[0] == "\'": if word[0] == "'":
return word[1:] return word[1:]
return word return word

File diff suppressed because it is too large Load Diff

View File

@ -15,11 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple
import warnings import warnings
from typing import Tuple
import torch import torch
from torch import Tensor, nn
from scaling import ( from scaling import (
Balancer, Balancer,
BiasNorm, BiasNorm,
@ -33,6 +32,7 @@ from scaling import (
SwooshR, SwooshR,
Whiten, Whiten,
) )
from torch import Tensor, nn
class ConvNeXt(nn.Module): class ConvNeXt(nn.Module):
@ -106,9 +106,7 @@ class ConvNeXt(nn.Module):
if layerdrop_rate != 0.0: if layerdrop_rate != 0.0:
batch_size = x.shape[0] batch_size = x.shape[0]
mask = ( mask = (
torch.rand( torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate > layerdrop_rate
) )
else: else:
@ -227,9 +225,7 @@ class Conv2dSubsampling(nn.Module):
# many copies of this extra gradient term. # many copies of this extra gradient term.
self.out_whiten = Whiten( self.out_whiten = Whiten(
num_groups=1, num_groups=1,
whitening_limit=ScheduledFloat( whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.02, grad_scale=0.02,
) )

View File

@ -23,9 +23,9 @@ To run this file, do:
python ./pruned_transducer_stateless4/test_model.py python ./pruned_transducer_stateless4/test_model.py
""" """
from train_subformer import get_params, get_transducer_model, get_text_encoder
from zipformer import Zipformer2
from scaling import ScheduledFloat from scaling import ScheduledFloat
from train_subformer import get_params, get_text_encoder, get_transducer_model
from zipformer import Zipformer2
def test_model_1(): def test_model_1():
@ -55,8 +55,8 @@ def test_model_M():
params.encoder_unmasked_dims = "256,256,256,256,256" params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2" params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,15,15" params.cnn_module_kernels = "31,31,15,15"
params.text_encoder_dim = (192,192,256,384) params.text_encoder_dim = (192, 192, 256, 384)
params.decoder_dim = 512 params.decoder_dim = 512
params.joiner_dim = 512 params.joiner_dim = 512
model = Zipformer2( model = Zipformer2(
@ -83,12 +83,12 @@ def test_model_M():
) )
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}") print(f"Number of model parameters: {num_param}")
model = Zipformer2( model = Zipformer2(
output_downsampling_factor=8, output_downsampling_factor=8,
downsampling_factor=(1, 2, 4, 8), downsampling_factor=(1, 2, 4, 8),
num_encoder_layers=(2, 4, 6, 6), num_encoder_layers=(2, 4, 6, 6),
encoder_dim=(256,256,384,512), encoder_dim=(256, 256, 384, 512),
encoder_unmasked_dim=(196, 196, 256, 256), encoder_unmasked_dim=(196, 196, 256, 256),
query_head_dim=(32, 32, 32, 32), query_head_dim=(32, 32, 32, 32),
pos_head_dim=(4, 4, 4, 4), pos_head_dim=(4, 4, 4, 4),