fix for black

This commit is contained in:
marcoyang 2023-10-10 17:44:48 +08:00
parent e32bda6a7b
commit 96977c9ddd
6 changed files with 45 additions and 1911 deletions

View File

@ -17,9 +17,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
@ -61,10 +61,15 @@ class Decoder(nn.Module):
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
@ -78,14 +83,18 @@ class Decoder(nn.Module):
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim//4, # group size == 4
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
@ -108,9 +117,7 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output

View File

@ -16,9 +16,7 @@
import torch
import torch.nn as nn
from scaling import (
ScaledLinear
)
from scaling import ScaledLinear
class Joiner(nn.Module):
@ -28,7 +26,7 @@ class Joiner(nn.Module):
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
context_dim: int=512,
context_dim: int = 512,
context_injection: bool = False,
):
super().__init__()
@ -37,7 +35,9 @@ class Joiner(nn.Module):
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
if context_injection:
self.context_proj = ScaledLinear(context_dim, joiner_dim, initial_scale=0.25)
self.context_proj = ScaledLinear(
context_dim, joiner_dim, initial_scale=0.25
)
else:
self.context_proj = None
@ -68,7 +68,11 @@ class Joiner(nn.Module):
if project_input:
if context:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + self.context_proj(context)
logit = (
self.encoder_proj(encoder_out)
+ self.decoder_proj(decoder_out)
+ self.context_proj(context)
)
else:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:

View File

@ -111,7 +111,7 @@ def decoding_normalization(text: str) -> str:
# Only keep all alpha-numeric characters, hypen and apostrophe
text = text.replace("-", " ")
text = re.sub("[^a-zA-Z0-9\s']+", "", text)
text = re.sub(r"[^a-zA-Z0-9\s']+", "", text)
return text
@ -130,9 +130,10 @@ def word_normalization(word: str) -> str:
if word.isnumeric():
word = num_to_words(int(word))
return str(word).upper()
if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH
# e.g 9TH, 6TH
if word[-2:] == "TH" and word[0].isnumeric():
return num_to_ordinal_word(int(word[:-2])).upper()
if word[0] == "\'":
if word[0] == "'":
return word[1:]
return word

File diff suppressed because it is too large Load Diff

View File

@ -15,11 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
from typing import Tuple
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
@ -33,6 +32,7 @@ from scaling import (
SwooshR,
Whiten,
)
from torch import Tensor, nn
class ConvNeXt(nn.Module):
@ -106,9 +106,7 @@ class ConvNeXt(nn.Module):
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
@ -227,9 +225,7 @@ class Conv2dSubsampling(nn.Module):
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)

View File

@ -23,9 +23,9 @@ To run this file, do:
python ./pruned_transducer_stateless4/test_model.py
"""
from train_subformer import get_params, get_transducer_model, get_text_encoder
from zipformer import Zipformer2
from scaling import ScheduledFloat
from train_subformer import get_params, get_text_encoder, get_transducer_model
from zipformer import Zipformer2
def test_model_1():
@ -55,8 +55,8 @@ def test_model_M():
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,15,15"
params.text_encoder_dim = (192,192,256,384)
params.text_encoder_dim = (192, 192, 256, 384)
params.decoder_dim = 512
params.joiner_dim = 512
model = Zipformer2(
@ -83,12 +83,12 @@ def test_model_M():
)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model = Zipformer2(
output_downsampling_factor=8,
downsampling_factor=(1, 2, 4, 8),
num_encoder_layers=(2, 4, 6, 6),
encoder_dim=(256,256,384,512),
encoder_dim=(256, 256, 384, 512),
encoder_unmasked_dim=(196, 196, 256, 256),
query_head_dim=(32, 32, 32, 32),
pos_head_dim=(4, 4, 4, 4),