mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Merge 752e16be1038211bada5d4f15eb4b59d3f6ae9f6 into c401a2646b347bf1fff0c2ce1a4ee13b0f482448
This commit is contained in:
commit
6c98fbc309
177
egs/librispeech/ASR/zipformer/frame_reducer.py
Normal file
177
egs/librispeech/ASR/zipformer/frame_reducer.py
Normal file
@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
NON_BLANK_THRES = 0.9
|
||||
|
||||
|
||||
class FrameReducer(nn.Module):
|
||||
"""The encoder output is first used to calculate
|
||||
the CTC posterior probability; then for each output frame,
|
||||
if its blank posterior is bigger than some thresholds,
|
||||
it will be simply discarded from the encoder output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
ctc_output: torch.Tensor,
|
||||
y_lens: Optional[torch.Tensor] = None,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The shared encoder output with shape [N, T, C].
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
ctc_output:
|
||||
The CTC output with shape [N, T, vocab_size].
|
||||
y_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`y` before padding.
|
||||
blank_id:
|
||||
The blank id of ctc_output.
|
||||
Returns:
|
||||
out:
|
||||
The frame reduced encoder output with shape [N, T', C].
|
||||
out_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`out` before padding.
|
||||
"""
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
|
||||
~padding_mask
|
||||
)
|
||||
|
||||
if y_lens is not None or self.training is False:
|
||||
# Limit the maximum number of reduced frames
|
||||
if y_lens is not None:
|
||||
limit_lens = T - y_lens
|
||||
else:
|
||||
# In eval mode, ensure audio that is completely silent does not make any errors
|
||||
limit_lens = T - torch.ones_like(x_lens)
|
||||
max_limit_len = limit_lens.max().int()
|
||||
fake_limit_indexes = torch.topk(
|
||||
ctc_output[:, :, blank_id], max_limit_len
|
||||
).indices
|
||||
_T = (
|
||||
torch.arange(max_limit_len)
|
||||
.expand_as(
|
||||
fake_limit_indexes,
|
||||
)
|
||||
.to(device=x.device)
|
||||
)
|
||||
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
|
||||
limit_mask = (
|
||||
torch.full_like(
|
||||
non_blank_mask,
|
||||
0,
|
||||
device=x.device,
|
||||
).scatter_(1, limit_indexes, 1)
|
||||
== 1
|
||||
)
|
||||
|
||||
non_blank_mask = non_blank_mask | ~limit_mask
|
||||
|
||||
out_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
pad_lens_list = (
|
||||
torch.full_like(
|
||||
out_lens,
|
||||
max_len.item(),
|
||||
device=x.device,
|
||||
)
|
||||
- out_lens
|
||||
)
|
||||
max_pad_len = int(pad_lens_list.max().item())
|
||||
|
||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||
|
||||
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
||||
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
||||
|
||||
out = out[total_valid_mask].reshape(N, -1, C)
|
||||
|
||||
return out, out_lens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
test_times = 10000
|
||||
device = "cuda:0"
|
||||
frame_reducer = FrameReducer()
|
||||
|
||||
# non zero case
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.log(
|
||||
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
||||
)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
print(x_lens_fr)
|
||||
print(avg_time / test_times)
|
||||
|
||||
# all zero case
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
print(x_lens_fr)
|
||||
print(avg_time / test_times)
|
113
egs/librispeech/ASR/zipformer/lconv.py
Normal file
113
egs/librispeech/ASR/zipformer/lconv.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from acoustic_model.utils_py.scaling_zipformer import Balancer, ScaledConv1d
|
||||
|
||||
|
||||
class LConv(nn.Module):
|
||||
"""A convolution module to prevent information loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 7,
|
||||
bias: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
channels:
|
||||
Dimension of the input embedding, and of the lconv output.
|
||||
"""
|
||||
super().__init__()
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer1 = Balancer(
|
||||
2 * channels,
|
||||
channel_dim=1,
|
||||
min_abs=0.05,
|
||||
max_abs=10.0,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0,
|
||||
)
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
2 * channels,
|
||||
2 * channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=2 * channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = Balancer(
|
||||
2 * channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0,
|
||||
min_abs=0.05,
|
||||
max_abs=20.0,
|
||||
)
|
||||
|
||||
self.pointwise_conv2 = ScaledConv1d(
|
||||
2 * channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: A 3-D tensor of shape (N, T, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, C).
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(0, 2, 1) # (#batch, channels, time).
|
||||
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
|
||||
x = self.deriv_balancer1(x)
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channels, time)
|
||||
|
||||
return x.permute(0, 2, 1)
|
@ -16,16 +16,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, encode_supervisions, make_pad_mask
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
@ -34,11 +35,14 @@ class AsrModel(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
lconv: Optional[nn.Module] = None,
|
||||
frame_reducer: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
use_bs: bool = True,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
@ -77,6 +81,10 @@ class AsrModel(nn.Module):
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
assert (
|
||||
(use_ctc and use_bs) or (use_ctc and not use_bs) or not (use_ctc and use_bs)
|
||||
), "Blank Skip needs CTC"
|
||||
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
@ -111,6 +119,11 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.use_bs = use_bs
|
||||
if self.use_bs:
|
||||
self.lconv = lconv
|
||||
self.frame_reducer = frame_reducer
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -142,217 +155,278 @@ class AsrModel(nn.Module):
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: List[int],
|
||||
target_lengths: torch.Tensor,
|
||||
supervisions: dict,
|
||||
subsampling_factor: int,
|
||||
ctc_beam_size: int,
|
||||
reduction: str = "sum",
|
||||
warmup: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
supervisions:
|
||||
Dict into a pair of torch Tensor, and a list of transcription strings or token indexes
|
||||
reduction:
|
||||
Specifies the reduction to apply to the output
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
encoder_out_fr = encoder_out
|
||||
encoder_out_lens_fr = encoder_out_lens
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets,
|
||||
input_lengths=encoder_out_lens,
|
||||
target_lengths=target_lengths,
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
if self.use_bs and warmup >= 2.0:
|
||||
# lconv
|
||||
encoder_out = self.lconv(
|
||||
x=encoder_out,
|
||||
src_key_padding_mask=make_pad_mask(encoder_out_lens),
|
||||
)
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
# frame reduce
|
||||
encoder_out_fr, encoder_out_lens_fr = self.frame_reducer(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
ctc_output,
|
||||
target_lengths,
|
||||
self.decoder.blank_id,
|
||||
)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
supervision_segments, token_ids = encode_supervisions(
|
||||
supervisions,
|
||||
subsampling_factor=subsampling_factor,
|
||||
token_ids=targets,
|
||||
)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
# TODO: Crash without this line
|
||||
supervision_segments = supervision_segments.to("cpu")
|
||||
decoding_graph = k2.ctc_graph(
|
||||
token_ids, modified=False, device=encoder_out.device
|
||||
)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=subsampling_factor - 1,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=ctc_beam_size,
|
||||
reduction=reduction,
|
||||
use_double_scores=True,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
return ctc_loss, encoder_out_fr, encoder_out_lens_fr
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
delay_penalty: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
Specifies the reduction to apply to the output
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
supervisions: dict,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
subsampling_factor: int = 4,
|
||||
ctc_beam_size: int = 10,
|
||||
delay_penalty: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
Specifies the reduction to apply to the output
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
ctc_loss, encoder_out, encoder_out_lens = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=y.tolist(),
|
||||
target_lengths=y_lens,
|
||||
supervisions=supervisions,
|
||||
subsampling_factor=subsampling_factor,
|
||||
ctc_beam_size=ctc_beam_size,
|
||||
reduction=reduction,
|
||||
warmup=warmup,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0, device=encoder_out.device)
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
reduction=reduction,
|
||||
delay_penalty=delay_penalty,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss
|
||||
|
@ -67,7 +67,9 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from frame_reducer import FrameReducer
|
||||
from joiner import Joiner
|
||||
from lconv import LConv
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -258,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="If True, use CTC head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-bs",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use blank-skip.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -529,6 +538,7 @@ def get_params() -> AttributeDict:
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"ctc_beam_size": 10,
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"warm_step": 2000,
|
||||
@ -583,6 +593,16 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
return encoder
|
||||
|
||||
|
||||
def get_lconv(params: AttributeDict) -> nn.Module:
|
||||
lconv = LConv(channels=max(params.encoder_dim))
|
||||
return lconv
|
||||
|
||||
|
||||
def get_frame_reducer(params: AttributeDict) -> nn.Module:
|
||||
frame_reducer = FrameReducer()
|
||||
return frame_reducer
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
@ -620,16 +640,24 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
lconv, frame_reducer = None, None
|
||||
if params.use_bs:
|
||||
lconv = get_lconv(params)
|
||||
frame_reducer = get_frame_reducer(params)
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
lconv=lconv,
|
||||
frame_reducer=frame_reducer,
|
||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
use_ctc=params.use_ctc,
|
||||
use_bs=params.use_bs,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -756,6 +784,7 @@ def compute_loss(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
warmup: float,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute loss given the model and its inputs.
|
||||
@ -796,9 +825,12 @@ def compute_loss(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
supervisions=supervisions,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
ctc_beam_size=params.ctc_beam_size,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -859,6 +891,7 @@ def compute_validation_loss(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
warmup=(params.batch_idx_train / params.warm_step),
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
@ -953,6 +986,7 @@ def train_one_epoch(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
Loading…
x
Reference in New Issue
Block a user