mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
fix for black
This commit is contained in:
parent
e32bda6a7b
commit
96977c9ddd
@ -17,9 +17,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
@ -61,10 +61,15 @@ class Decoder(nn.Module):
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
self.balancer = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
@ -78,14 +83,18 @@ class Decoder(nn.Module):
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim//4, # group size == 4
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
|
||||
self.balancer2 = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
@ -108,9 +117,7 @@ class Decoder(nn.Module):
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
|
@ -16,9 +16,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ScaledLinear
|
||||
)
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -28,7 +26,7 @@ class Joiner(nn.Module):
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
context_dim: int=512,
|
||||
context_dim: int = 512,
|
||||
context_injection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -37,7 +35,9 @@ class Joiner(nn.Module):
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
if context_injection:
|
||||
self.context_proj = ScaledLinear(context_dim, joiner_dim, initial_scale=0.25)
|
||||
self.context_proj = ScaledLinear(
|
||||
context_dim, joiner_dim, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
self.context_proj = None
|
||||
|
||||
@ -68,7 +68,11 @@ class Joiner(nn.Module):
|
||||
|
||||
if project_input:
|
||||
if context:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + self.context_proj(context)
|
||||
logit = (
|
||||
self.encoder_proj(encoder_out)
|
||||
+ self.decoder_proj(decoder_out)
|
||||
+ self.context_proj(context)
|
||||
)
|
||||
else:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
|
@ -111,7 +111,7 @@ def decoding_normalization(text: str) -> str:
|
||||
|
||||
# Only keep all alpha-numeric characters, hypen and apostrophe
|
||||
text = text.replace("-", " ")
|
||||
text = re.sub("[^a-zA-Z0-9\s']+", "", text)
|
||||
text = re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||
return text
|
||||
|
||||
|
||||
@ -130,9 +130,10 @@ def word_normalization(word: str) -> str:
|
||||
if word.isnumeric():
|
||||
word = num_to_words(int(word))
|
||||
return str(word).upper()
|
||||
if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH
|
||||
# e.g 9TH, 6TH
|
||||
if word[-2:] == "TH" and word[0].isnumeric():
|
||||
return num_to_ordinal_word(int(word[:-2])).upper()
|
||||
if word[0] == "\'":
|
||||
if word[0] == "'":
|
||||
return word[1:]
|
||||
|
||||
return word
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,11 +15,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
@ -33,6 +32,7 @@ from scaling import (
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
@ -106,9 +106,7 @@ class ConvNeXt(nn.Module):
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand(
|
||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
||||
)
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
@ -227,9 +225,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat(
|
||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
||||
),
|
||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
|
@ -23,9 +23,9 @@ To run this file, do:
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train_subformer import get_params, get_transducer_model, get_text_encoder
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
from train_subformer import get_params, get_text_encoder, get_transducer_model
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
def test_model_1():
|
||||
@ -55,8 +55,8 @@ def test_model_M():
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,15,15"
|
||||
|
||||
params.text_encoder_dim = (192,192,256,384)
|
||||
|
||||
params.text_encoder_dim = (192, 192, 256, 384)
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = Zipformer2(
|
||||
@ -83,12 +83,12 @@ def test_model_M():
|
||||
)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
model = Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1, 2, 4, 8),
|
||||
num_encoder_layers=(2, 4, 6, 6),
|
||||
encoder_dim=(256,256,384,512),
|
||||
encoder_dim=(256, 256, 384, 512),
|
||||
encoder_unmasked_dim=(196, 196, 256, 256),
|
||||
query_head_dim=(32, 32, 32, 32),
|
||||
pos_head_dim=(4, 4, 4, 4),
|
||||
|
Loading…
x
Reference in New Issue
Block a user