mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
change some files to symlinks
This commit is contained in:
parent
93a5c878f1
commit
2d3063becd
@ -1,102 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
|
||||||
network. Different from the above paper, it adds an extra Conv1d
|
|
||||||
right after the embedding layer.
|
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
decoder_dim: int,
|
|
||||||
blank_id: int,
|
|
||||||
context_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
vocab_size:
|
|
||||||
Number of tokens of the modeling unit including blank.
|
|
||||||
decoder_dim:
|
|
||||||
Dimension of the input embedding, and of the decoder output.
|
|
||||||
blank_id:
|
|
||||||
The ID of the blank symbol.
|
|
||||||
context_size:
|
|
||||||
Number of previous words to use to predict the next word.
|
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.embedding = nn.Embedding(
|
|
||||||
num_embeddings=vocab_size,
|
|
||||||
embedding_dim=decoder_dim,
|
|
||||||
padding_idx=blank_id,
|
|
||||||
)
|
|
||||||
self.blank_id = blank_id
|
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
|
||||||
self.context_size = context_size
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
if context_size > 1:
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
in_channels=decoder_dim,
|
|
||||||
out_channels=decoder_dim,
|
|
||||||
kernel_size=context_size,
|
|
||||||
padding=0,
|
|
||||||
groups=decoder_dim // 4, # group size == 4
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
y:
|
|
||||||
A 2-D tensor of shape (N, U).
|
|
||||||
need_pad:
|
|
||||||
True to left pad the input. Should be True during training.
|
|
||||||
False to not pad the input. Should be False during inference.
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
|
||||||
"""
|
|
||||||
y = y.to(torch.int64)
|
|
||||||
# this stuff about clamp() is a temporary fix for a mismatch
|
|
||||||
# at utterance start, we use negative ids in beam_search.py
|
|
||||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
|
||||||
if self.context_size > 1:
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
if need_pad is True:
|
|
||||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
|
||||||
else:
|
|
||||||
# During inference time, there is no need to do extra padding
|
|
||||||
# as we only need one output
|
|
||||||
assert embedding_out.size(-1) == self.context_size
|
|
||||||
embedding_out = self.conv(embedding_out)
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
embedding_out = F.relu(embedding_out)
|
|
||||||
return embedding_out
|
|
||||||
1
egs/libricss/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
||||||
@ -1,43 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderInterface(nn.Module):
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
|
||||||
containing the input features.
|
|
||||||
x_lens:
|
|
||||||
A tensor of shape (batch_size,) containing the number of frames
|
|
||||||
in `x` before padding.
|
|
||||||
Returns:
|
|
||||||
Return a tuple containing two tensors:
|
|
||||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
|
||||||
containing unnormalized probabilities, i.e., the output of a
|
|
||||||
linear layer.
|
|
||||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
|
||||||
the number of frames in `encoder_out` before padding.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Please implement it in a subclass")
|
|
||||||
1
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
|
||||||
@ -1,65 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
encoder_dim: int,
|
|
||||||
decoder_dim: int,
|
|
||||||
joiner_dim: int,
|
|
||||||
vocab_size: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
|
||||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
decoder_out: torch.Tensor,
|
|
||||||
project_input: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
encoder_out:
|
|
||||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
|
||||||
decoder_out:
|
|
||||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
|
||||||
project_input:
|
|
||||||
If true, apply input projections encoder_proj and decoder_proj.
|
|
||||||
If this is false, it is the user's responsibility to do this
|
|
||||||
manually.
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == decoder_out.ndim
|
|
||||||
assert encoder_out.ndim in (2, 4)
|
|
||||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
|
||||||
|
|
||||||
if project_input:
|
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
|
||||||
else:
|
|
||||||
logit = encoder_out + decoder_out
|
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
|
||||||
|
|
||||||
return logit
|
|
||||||
1
egs/libricss/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
||||||
File diff suppressed because it is too large
Load Diff
1
egs/libricss/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
||||||
Loading…
x
Reference in New Issue
Block a user