mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 07:34:21 +00:00
Merge 752e16be1038211bada5d4f15eb4b59d3f6ae9f6 into c401a2646b347bf1fff0c2ce1a4ee13b0f482448
This commit is contained in:
commit
6c98fbc309
177
egs/librispeech/ASR/zipformer/frame_reducer.py
Normal file
177
egs/librispeech/ASR/zipformer/frame_reducer.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
NON_BLANK_THRES = 0.9
|
||||||
|
|
||||||
|
|
||||||
|
class FrameReducer(nn.Module):
|
||||||
|
"""The encoder output is first used to calculate
|
||||||
|
the CTC posterior probability; then for each output frame,
|
||||||
|
if its blank posterior is bigger than some thresholds,
|
||||||
|
it will be simply discarded from the encoder output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
ctc_output: torch.Tensor,
|
||||||
|
y_lens: Optional[torch.Tensor] = None,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The shared encoder output with shape [N, T, C].
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
ctc_output:
|
||||||
|
The CTC output with shape [N, T, vocab_size].
|
||||||
|
y_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`y` before padding.
|
||||||
|
blank_id:
|
||||||
|
The blank id of ctc_output.
|
||||||
|
Returns:
|
||||||
|
out:
|
||||||
|
The frame reduced encoder output with shape [N, T', C].
|
||||||
|
out_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`out` before padding.
|
||||||
|
"""
|
||||||
|
N, T, C = x.size()
|
||||||
|
|
||||||
|
padding_mask = make_pad_mask(x_lens)
|
||||||
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
|
||||||
|
~padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if y_lens is not None or self.training is False:
|
||||||
|
# Limit the maximum number of reduced frames
|
||||||
|
if y_lens is not None:
|
||||||
|
limit_lens = T - y_lens
|
||||||
|
else:
|
||||||
|
# In eval mode, ensure audio that is completely silent does not make any errors
|
||||||
|
limit_lens = T - torch.ones_like(x_lens)
|
||||||
|
max_limit_len = limit_lens.max().int()
|
||||||
|
fake_limit_indexes = torch.topk(
|
||||||
|
ctc_output[:, :, blank_id], max_limit_len
|
||||||
|
).indices
|
||||||
|
_T = (
|
||||||
|
torch.arange(max_limit_len)
|
||||||
|
.expand_as(
|
||||||
|
fake_limit_indexes,
|
||||||
|
)
|
||||||
|
.to(device=x.device)
|
||||||
|
)
|
||||||
|
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
|
||||||
|
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
|
||||||
|
limit_mask = (
|
||||||
|
torch.full_like(
|
||||||
|
non_blank_mask,
|
||||||
|
0,
|
||||||
|
device=x.device,
|
||||||
|
).scatter_(1, limit_indexes, 1)
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
|
||||||
|
out_lens = non_blank_mask.sum(dim=1)
|
||||||
|
max_len = out_lens.max()
|
||||||
|
pad_lens_list = (
|
||||||
|
torch.full_like(
|
||||||
|
out_lens,
|
||||||
|
max_len.item(),
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
- out_lens
|
||||||
|
)
|
||||||
|
max_pad_len = int(pad_lens_list.max().item())
|
||||||
|
|
||||||
|
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||||
|
|
||||||
|
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
||||||
|
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
||||||
|
|
||||||
|
out = out[total_valid_mask].reshape(N, -1, C)
|
||||||
|
|
||||||
|
return out, out_lens
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
|
test_times = 10000
|
||||||
|
device = "cuda:0"
|
||||||
|
frame_reducer = FrameReducer()
|
||||||
|
|
||||||
|
# non zero case
|
||||||
|
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
||||||
|
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||||
|
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||||
|
ctc_output = torch.log(
|
||||||
|
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_time = 0
|
||||||
|
for i in range(test_times):
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
|
delta_time = time.time()
|
||||||
|
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
|
delta_time = time.time() - delta_time
|
||||||
|
avg_time += delta_time
|
||||||
|
print(x_fr.shape)
|
||||||
|
print(x_lens_fr)
|
||||||
|
print(avg_time / test_times)
|
||||||
|
|
||||||
|
# all zero case
|
||||||
|
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
||||||
|
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||||
|
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||||
|
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
avg_time = 0
|
||||||
|
for i in range(test_times):
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
|
delta_time = time.time()
|
||||||
|
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
|
delta_time = time.time() - delta_time
|
||||||
|
avg_time += delta_time
|
||||||
|
print(x_fr.shape)
|
||||||
|
print(x_lens_fr)
|
||||||
|
print(avg_time / test_times)
|
113
egs/librispeech/ASR/zipformer/lconv.py
Normal file
113
egs/librispeech/ASR/zipformer/lconv.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from acoustic_model.utils_py.scaling_zipformer import Balancer, ScaledConv1d
|
||||||
|
|
||||||
|
|
||||||
|
class LConv(nn.Module):
|
||||||
|
"""A convolution module to prevent information loss."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 7,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
channels:
|
||||||
|
Dimension of the input embedding, and of the lconv output.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deriv_balancer1 = Balancer(
|
||||||
|
2 * channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_abs=0.05,
|
||||||
|
max_abs=10.0,
|
||||||
|
min_positive=0.05,
|
||||||
|
max_positive=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
2 * channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=2 * channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deriv_balancer2 = Balancer(
|
||||||
|
2 * channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_positive=0.05,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.05,
|
||||||
|
max_abs=20.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pointwise_conv2 = ScaledConv1d(
|
||||||
|
2 * channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
initial_scale=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: A 3-D tensor of shape (N, T, C).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(0, 2, 1) # (#batch, channels, time).
|
||||||
|
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
|
||||||
|
x = self.deriv_balancer1(x)
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
|
x = self.deriv_balancer2(x)
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x) # (batch, channels, time)
|
||||||
|
|
||||||
|
return x.permute(0, 2, 1)
|
@ -16,16 +16,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
from icefall.utils import add_sos, encode_supervisions, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -34,11 +35,14 @@ class AsrModel(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
|
lconv: Optional[nn.Module] = None,
|
||||||
|
frame_reducer: Optional[nn.Module] = None,
|
||||||
encoder_dim: int = 384,
|
encoder_dim: int = 384,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
use_transducer: bool = True,
|
use_transducer: bool = True,
|
||||||
use_ctc: bool = False,
|
use_ctc: bool = False,
|
||||||
|
use_bs: bool = True,
|
||||||
):
|
):
|
||||||
"""A joint CTC & Transducer ASR model.
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
@ -77,6 +81,10 @@ class AsrModel(nn.Module):
|
|||||||
use_transducer or use_ctc
|
use_transducer or use_ctc
|
||||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
(use_ctc and use_bs) or (use_ctc and not use_bs) or not (use_ctc and use_bs)
|
||||||
|
), "Blank Skip needs CTC"
|
||||||
|
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
@ -111,6 +119,11 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_bs = use_bs
|
||||||
|
if self.use_bs:
|
||||||
|
self.lconv = lconv
|
||||||
|
self.frame_reducer = frame_reducer
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -146,8 +159,13 @@ class AsrModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
targets: torch.Tensor,
|
targets: List[int],
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
|
supervisions: dict,
|
||||||
|
subsampling_factor: int,
|
||||||
|
ctc_beam_size: int,
|
||||||
|
reduction: str = "sum",
|
||||||
|
warmup: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute CTC loss.
|
"""Compute CTC loss.
|
||||||
Args:
|
Args:
|
||||||
@ -158,18 +176,60 @@ class AsrModel(nn.Module):
|
|||||||
targets:
|
targets:
|
||||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
to be un-padded and concatenated within 1 dimension.
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
supervisions:
|
||||||
|
Dict into a pair of torch Tensor, and a list of transcription strings or token indexes
|
||||||
|
reduction:
|
||||||
|
Specifies the reduction to apply to the output
|
||||||
"""
|
"""
|
||||||
# Compute CTC log-prob
|
# Compute CTC log-prob
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
encoder_out_fr = encoder_out
|
||||||
|
encoder_out_lens_fr = encoder_out_lens
|
||||||
|
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
if self.use_bs and warmup >= 2.0:
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
# lconv
|
||||||
targets=targets,
|
encoder_out = self.lconv(
|
||||||
input_lengths=encoder_out_lens,
|
x=encoder_out,
|
||||||
target_lengths=target_lengths,
|
src_key_padding_mask=make_pad_mask(encoder_out_lens),
|
||||||
reduction="sum",
|
|
||||||
)
|
)
|
||||||
return ctc_loss
|
|
||||||
|
# frame reduce
|
||||||
|
encoder_out_fr, encoder_out_lens_fr = self.frame_reducer(
|
||||||
|
encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
ctc_output,
|
||||||
|
target_lengths,
|
||||||
|
self.decoder.blank_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
supervision_segments, token_ids = encode_supervisions(
|
||||||
|
supervisions,
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
token_ids=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Crash without this line
|
||||||
|
supervision_segments = supervision_segments.to("cpu")
|
||||||
|
decoding_graph = k2.ctc_graph(
|
||||||
|
token_ids, modified=False, device=encoder_out.device
|
||||||
|
)
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
ctc_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=ctc_beam_size,
|
||||||
|
reduction=reduction,
|
||||||
|
use_double_scores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ctc_loss, encoder_out_fr, encoder_out_lens_fr
|
||||||
|
|
||||||
def forward_transducer(
|
def forward_transducer(
|
||||||
self,
|
self,
|
||||||
@ -180,6 +240,8 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
reduction: str = "sum",
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute Transducer loss.
|
"""Compute Transducer loss.
|
||||||
Args:
|
Args:
|
||||||
@ -199,6 +261,8 @@ class AsrModel(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
reduction:
|
||||||
|
Specifies the reduction to apply to the output
|
||||||
"""
|
"""
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
@ -226,11 +290,6 @@ class AsrModel(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -240,7 +299,8 @@ class AsrModel(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
delay_penalty=delay_penalty,
|
||||||
|
reduction=reduction,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -273,7 +333,8 @@ class AsrModel(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
delay_penalty=delay_penalty,
|
||||||
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return simple_loss, pruned_loss
|
return simple_loss, pruned_loss
|
||||||
@ -283,9 +344,15 @@ class AsrModel(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
supervisions: dict,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
ctc_beam_size: int = 10,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
reduction: str = "sum",
|
||||||
|
warmup: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -306,10 +373,11 @@ class AsrModel(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
reduction:
|
||||||
|
Specifies the reduction to apply to the output
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer losses and CTC loss,
|
Return the transducer losses and CTC loss,
|
||||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
the form:
|
the form:
|
||||||
@ -320,7 +388,7 @@ class AsrModel(nn.Module):
|
|||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
assert y.num_axes == 2, y.num_axes
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
@ -328,6 +396,22 @@ class AsrModel(nn.Module):
|
|||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
if self.use_ctc:
|
||||||
|
# Compute CTC loss
|
||||||
|
ctc_loss, encoder_out, encoder_out_lens = self.forward_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=y.tolist(),
|
||||||
|
target_lengths=y_lens,
|
||||||
|
supervisions=supervisions,
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
ctc_beam_size=ctc_beam_size,
|
||||||
|
reduction=reduction,
|
||||||
|
warmup=warmup,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ctc_loss = torch.empty(0, device=encoder_out.device)
|
||||||
|
|
||||||
if self.use_transducer:
|
if self.use_transducer:
|
||||||
# Compute transducer loss
|
# Compute transducer loss
|
||||||
simple_loss, pruned_loss = self.forward_transducer(
|
simple_loss, pruned_loss = self.forward_transducer(
|
||||||
@ -338,21 +422,11 @@ class AsrModel(nn.Module):
|
|||||||
prune_range=prune_range,
|
prune_range=prune_range,
|
||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
lm_scale=lm_scale,
|
lm_scale=lm_scale,
|
||||||
|
reduction=reduction,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
simple_loss = torch.empty(0)
|
simple_loss = torch.empty(0)
|
||||||
pruned_loss = torch.empty(0)
|
pruned_loss = torch.empty(0)
|
||||||
|
|
||||||
if self.use_ctc:
|
|
||||||
# Compute CTC loss
|
|
||||||
targets = y.values
|
|
||||||
ctc_loss = self.forward_ctc(
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
targets=targets,
|
|
||||||
target_lengths=y_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ctc_loss = torch.empty(0)
|
|
||||||
|
|
||||||
return simple_loss, pruned_loss, ctc_loss
|
return simple_loss, pruned_loss, ctc_loss
|
||||||
|
@ -67,7 +67,9 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from frame_reducer import FrameReducer
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
from lconv import LConv
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -258,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use CTC head.",
|
help="If True, use CTC head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-bs",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use blank-skip.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -529,6 +538,7 @@ def get_params() -> AttributeDict:
|
|||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
|
"ctc_beam_size": 10,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
@ -583,6 +593,16 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_lconv(params: AttributeDict) -> nn.Module:
|
||||||
|
lconv = LConv(channels=max(params.encoder_dim))
|
||||||
|
return lconv
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame_reducer(params: AttributeDict) -> nn.Module:
|
||||||
|
frame_reducer = FrameReducer()
|
||||||
|
return frame_reducer
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -620,16 +640,24 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
|
lconv, frame_reducer = None, None
|
||||||
|
if params.use_bs:
|
||||||
|
lconv = get_lconv(params)
|
||||||
|
frame_reducer = get_frame_reducer(params)
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
lconv=lconv,
|
||||||
|
frame_reducer=frame_reducer,
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
|
use_bs=params.use_bs,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -756,6 +784,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
warmup: float,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -796,9 +825,12 @@ def compute_loss(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
|
supervisions=supervisions,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
ctc_beam_size=params.ctc_beam_size,
|
||||||
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -859,6 +891,7 @@ def compute_validation_loss(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
|
warmup=(params.batch_idx_train / params.warm_step),
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -953,6 +986,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
warmup=(params.batch_idx_train / params.warm_step),
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
Loading…
x
Reference in New Issue
Block a user