Merge 752e16be1038211bada5d4f15eb4b59d3f6ae9f6 into c401a2646b347bf1fff0c2ce1a4ee13b0f482448

This commit is contained in:
Erwan Zerhouni 2024-01-26 16:24:46 +08:00 committed by GitHub
commit 6c98fbc309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 594 additions and 196 deletions

View File

@ -0,0 +1,177 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from icefall.utils import make_pad_mask
NON_BLANK_THRES = 0.9
class FrameReducer(nn.Module):
"""The encoder output is first used to calculate
the CTC posterior probability; then for each output frame,
if its blank posterior is bigger than some thresholds,
it will be simply discarded from the encoder output.
"""
def __init__(
self,
):
super().__init__()
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The shared encoder output with shape [N, T, C].
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id:
The blank id of ctc_output.
Returns:
out:
The frame reduced encoder output with shape [N, T', C].
out_lens:
A tensor of shape (batch_size,) containing the number of frames in
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
~padding_mask
)
if y_lens is not None or self.training is False:
# Limit the maximum number of reduced frames
if y_lens is not None:
limit_lens = T - y_lens
else:
# In eval mode, ensure audio that is completely silent does not make any errors
limit_lens = T - torch.ones_like(x_lens)
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
_T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
limit_mask = (
torch.full_like(
non_blank_mask,
0,
device=x.device,
).scatter_(1, limit_indexes, 1)
== 1
)
non_blank_mask = non_blank_mask | ~limit_mask
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = int(pad_lens_list.max().item())
out = F.pad(x, (0, 0, 0, max_pad_len))
valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
out = out[total_valid_mask].reshape(N, -1, C)
return out, out_lens
if __name__ == "__main__":
import time
test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)

View File

@ -0,0 +1,113 @@
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from acoustic_model.utils_py.scaling_zipformer import Balancer, ScaledConv1d
class LConv(nn.Module):
"""A convolution module to prevent information loss."""
def __init__(
self,
channels: int,
kernel_size: int = 7,
bias: bool = True,
):
"""
Args:
channels:
Dimension of the input embedding, and of the lconv output.
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.deriv_balancer1 = Balancer(
2 * channels,
channel_dim=1,
min_abs=0.05,
max_abs=10.0,
min_positive=0.05,
max_positive=1.0,
)
self.depthwise_conv = nn.Conv1d(
2 * channels,
2 * channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=2 * channels,
bias=bias,
)
self.deriv_balancer2 = Balancer(
2 * channels,
channel_dim=1,
min_positive=0.05,
max_positive=1.0,
min_abs=0.05,
max_abs=20.0,
)
self.pointwise_conv2 = ScaledConv1d(
2 * channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
initial_scale=0.05,
)
def forward(
self,
x: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: A 3-D tensor of shape (N, T, C).
Returns:
Return a tensor of shape (N, T, C).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(0, 2, 1) # (#batch, channels, time).
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer1(x)
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = self.pointwise_conv2(x) # (batch, channels, time)
return x.permute(0, 2, 1)

View File

@ -16,16 +16,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import warnings
from typing import List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
from icefall.utils import add_sos, encode_supervisions, make_pad_mask
class AsrModel(nn.Module):
def __init__(
@ -34,11 +35,14 @@ class AsrModel(nn.Module):
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
lconv: Optional[nn.Module] = None,
frame_reducer: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
use_bs: bool = True,
):
"""A joint CTC & Transducer ASR model.
@ -77,6 +81,10 @@ class AsrModel(nn.Module):
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert (
(use_ctc and use_bs) or (use_ctc and not use_bs) or not (use_ctc and use_bs)
), "Blank Skip needs CTC"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
@ -111,6 +119,11 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1),
)
self.use_bs = use_bs
if self.use_bs:
self.lconv = lconv
self.frame_reducer = frame_reducer
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -142,217 +155,278 @@ class AsrModel(nn.Module):
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: List[int],
target_lengths: torch.Tensor,
supervisions: dict,
subsampling_factor: int,
ctc_beam_size: int,
reduction: str = "sum",
warmup: float = 1.0,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
supervisions:
Dict into a pair of torch Tensor, and a list of transcription strings or token indexes
reduction:
Specifies the reduction to apply to the output
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
encoder_out_fr = encoder_out
encoder_out_lens_fr = encoder_out_lens
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
if self.use_bs and warmup >= 2.0:
# lconv
encoder_out = self.lconv(
x=encoder_out,
src_key_padding_mask=make_pad_mask(encoder_out_lens),
)
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# frame reduce
encoder_out_fr, encoder_out_lens_fr = self.frame_reducer(
encoder_out,
encoder_out_lens,
ctc_output,
target_lengths,
self.decoder.blank_id,
)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=subsampling_factor,
token_ids=targets,
)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
# TODO: Crash without this line
supervision_segments = supervision_segments.to("cpu")
decoding_graph = k2.ctc_graph(
token_ids, modified=False, device=encoder_out.device
)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=subsampling_factor - 1,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=ctc_beam_size,
reduction=reduction,
use_double_scores=True,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
return ctc_loss, encoder_out_fr, encoder_out_lens_fr
# logits : [B, T, prune_range, vocab_size]
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
delay_penalty: float = 0.0,
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
reduction:
Specifies the reduction to apply to the output
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
# logits : [B, T, prune_range, vocab_size]
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
return simple_loss, pruned_loss
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
supervisions: dict,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
subsampling_factor: int = 4,
ctc_beam_size: int = 10,
delay_penalty: float = 0.0,
reduction: str = "sum",
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
reduction:
Specifies the reduction to apply to the output
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
assert x.size(0) == x_lens.size(0) == y.dim0
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
return simple_loss, pruned_loss, ctc_loss
if self.use_ctc:
# Compute CTC loss
ctc_loss, encoder_out, encoder_out_lens = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=y.tolist(),
target_lengths=y_lens,
supervisions=supervisions,
subsampling_factor=subsampling_factor,
ctc_beam_size=ctc_beam_size,
reduction=reduction,
warmup=warmup,
)
else:
ctc_loss = torch.empty(0, device=encoder_out.device)
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
reduction=reduction,
delay_penalty=delay_penalty,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss

View File

@ -67,7 +67,9 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from frame_reducer import FrameReducer
from joiner import Joiner
from lconv import LConv
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
@ -258,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.",
)
parser.add_argument(
"--use-bs",
type=str2bool,
default=False,
help="If True, use blank-skip.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -529,6 +538,7 @@ def get_params() -> AttributeDict:
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
"ctc_beam_size": 10,
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 2000,
@ -583,6 +593,16 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
return encoder
def get_lconv(params: AttributeDict) -> nn.Module:
lconv = LConv(channels=max(params.encoder_dim))
return lconv
def get_frame_reducer(params: AttributeDict) -> nn.Module:
frame_reducer = FrameReducer()
return frame_reducer
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
@ -620,16 +640,24 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
lconv, frame_reducer = None, None
if params.use_bs:
lconv = get_lconv(params)
frame_reducer = get_frame_reducer(params)
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
lconv=lconv,
frame_reducer=frame_reducer,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
use_bs=params.use_bs,
)
return model
@ -756,6 +784,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
@ -796,9 +825,12 @@ def compute_loss(
x=feature,
x_lens=feature_lens,
y=y,
supervisions=supervisions,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
ctc_beam_size=params.ctc_beam_size,
warmup=warmup,
)
loss = 0.0
@ -859,6 +891,7 @@ def compute_validation_loss(
sp=sp,
batch=batch,
is_training=False,
warmup=(params.batch_idx_train / params.warm_step),
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
@ -953,6 +986,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info