mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
fix for black
This commit is contained in:
parent
e32bda6a7b
commit
96977c9ddd
@ -17,9 +17,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from scaling import Balancer
|
from scaling import Balancer
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
@ -61,10 +61,15 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
# the balancers are to avoid any drift in the magnitude of the
|
# the balancers are to avoid any drift in the magnitude of the
|
||||||
# embeddings, which would interact badly with parameter averaging.
|
# embeddings, which would interact badly with parameter averaging.
|
||||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
@ -78,14 +83,18 @@ class Decoder(nn.Module):
|
|||||||
out_channels=decoder_dim,
|
out_channels=decoder_dim,
|
||||||
kernel_size=context_size,
|
kernel_size=context_size,
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=decoder_dim//4, # group size == 4
|
groups=decoder_dim // 4, # group size == 4
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer2 = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -108,9 +117,7 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
|
@ -16,9 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import ScaledLinear
|
||||||
ScaledLinear
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -28,7 +26,7 @@ class Joiner(nn.Module):
|
|||||||
decoder_dim: int,
|
decoder_dim: int,
|
||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
context_dim: int=512,
|
context_dim: int = 512,
|
||||||
context_injection: bool = False,
|
context_injection: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -37,7 +35,9 @@ class Joiner(nn.Module):
|
|||||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||||
if context_injection:
|
if context_injection:
|
||||||
self.context_proj = ScaledLinear(context_dim, joiner_dim, initial_scale=0.25)
|
self.context_proj = ScaledLinear(
|
||||||
|
context_dim, joiner_dim, initial_scale=0.25
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.context_proj = None
|
self.context_proj = None
|
||||||
|
|
||||||
@ -68,7 +68,11 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
if context:
|
if context:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + self.context_proj(context)
|
logit = (
|
||||||
|
self.encoder_proj(encoder_out)
|
||||||
|
+ self.decoder_proj(decoder_out)
|
||||||
|
+ self.context_proj(context)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
else:
|
else:
|
||||||
|
@ -111,7 +111,7 @@ def decoding_normalization(text: str) -> str:
|
|||||||
|
|
||||||
# Only keep all alpha-numeric characters, hypen and apostrophe
|
# Only keep all alpha-numeric characters, hypen and apostrophe
|
||||||
text = text.replace("-", " ")
|
text = text.replace("-", " ")
|
||||||
text = re.sub("[^a-zA-Z0-9\s']+", "", text)
|
text = re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@ -130,9 +130,10 @@ def word_normalization(word: str) -> str:
|
|||||||
if word.isnumeric():
|
if word.isnumeric():
|
||||||
word = num_to_words(int(word))
|
word = num_to_words(int(word))
|
||||||
return str(word).upper()
|
return str(word).upper()
|
||||||
if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH
|
# e.g 9TH, 6TH
|
||||||
|
if word[-2:] == "TH" and word[0].isnumeric():
|
||||||
return num_to_ordinal_word(int(word[:-2])).upper()
|
return num_to_ordinal_word(int(word[:-2])).upper()
|
||||||
if word[0] == "\'":
|
if word[0] == "'":
|
||||||
return word[1:]
|
return word[1:]
|
||||||
|
|
||||||
return word
|
return word
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -15,11 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
@ -33,6 +32,7 @@ from scaling import (
|
|||||||
SwooshR,
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
)
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
@ -106,9 +106,7 @@ class ConvNeXt(nn.Module):
|
|||||||
if layerdrop_rate != 0.0:
|
if layerdrop_rate != 0.0:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
mask = (
|
mask = (
|
||||||
torch.rand(
|
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
> layerdrop_rate
|
> layerdrop_rate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -227,9 +225,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# many copies of this extra gradient term.
|
# many copies of this extra gradient term.
|
||||||
self.out_whiten = Whiten(
|
self.out_whiten = Whiten(
|
||||||
num_groups=1,
|
num_groups=1,
|
||||||
whitening_limit=ScheduledFloat(
|
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
|
||||||
),
|
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.02,
|
grad_scale=0.02,
|
||||||
)
|
)
|
||||||
|
@ -23,9 +23,9 @@ To run this file, do:
|
|||||||
python ./pruned_transducer_stateless4/test_model.py
|
python ./pruned_transducer_stateless4/test_model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from train_subformer import get_params, get_transducer_model, get_text_encoder
|
|
||||||
from zipformer import Zipformer2
|
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
|
from train_subformer import get_params, get_text_encoder, get_transducer_model
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
|
||||||
def test_model_1():
|
def test_model_1():
|
||||||
@ -56,7 +56,7 @@ def test_model_M():
|
|||||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||||
params.cnn_module_kernels = "31,31,15,15"
|
params.cnn_module_kernels = "31,31,15,15"
|
||||||
|
|
||||||
params.text_encoder_dim = (192,192,256,384)
|
params.text_encoder_dim = (192, 192, 256, 384)
|
||||||
params.decoder_dim = 512
|
params.decoder_dim = 512
|
||||||
params.joiner_dim = 512
|
params.joiner_dim = 512
|
||||||
model = Zipformer2(
|
model = Zipformer2(
|
||||||
@ -88,7 +88,7 @@ def test_model_M():
|
|||||||
output_downsampling_factor=8,
|
output_downsampling_factor=8,
|
||||||
downsampling_factor=(1, 2, 4, 8),
|
downsampling_factor=(1, 2, 4, 8),
|
||||||
num_encoder_layers=(2, 4, 6, 6),
|
num_encoder_layers=(2, 4, 6, 6),
|
||||||
encoder_dim=(256,256,384,512),
|
encoder_dim=(256, 256, 384, 512),
|
||||||
encoder_unmasked_dim=(196, 196, 256, 256),
|
encoder_unmasked_dim=(196, 196, 256, 256),
|
||||||
query_head_dim=(32, 32, 32, 32),
|
query_head_dim=(32, 32, 32, 32),
|
||||||
pos_head_dim=(4, 4, 4, 4),
|
pos_head_dim=(4, 4, 4, 4),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user