mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add train.py, model.py
This commit is contained in:
parent
c87f55671a
commit
2fc7535de9
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -47,8 +47,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
vocab_size (int): Number of classes.
|
||||
encoder_dim (int):
|
||||
d_model: (int,int): embedding dimension of 2 encoder stacks
|
||||
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
nhead (int, int): number of heads
|
||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||
@ -62,15 +61,15 @@ class AttentionDecoderModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int,
|
||||
decoder_dim: int,
|
||||
unmasked_dim: int,
|
||||
num_decoder_layers: int,
|
||||
attention_dim: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int,
|
||||
dropout: float,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
dropout: float = 0.1,
|
||||
ignore_id: int = -1,
|
||||
warmup_batches: float = 4000.0,
|
||||
label_smoothing: float = 0.1,
|
||||
@ -84,7 +83,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
self.decoder = TransformerDecoder(
|
||||
vocab_size,
|
||||
d_model,
|
||||
decoder_dim,
|
||||
unmasked_dim,
|
||||
num_decoder_layers,
|
||||
attention_dim,
|
||||
@ -103,7 +102,6 @@ class AttentionDecoderModel(nn.Module):
|
||||
def _pre_ys_in_out(self, token_ids: List[List[int]], device: torch.device):
|
||||
"""Prepare ys_in_pad and ys_out_pad."""
|
||||
ys = k2.RaggedTensor(token_ids).to(device=device)
|
||||
|
||||
row_splits = ys.shape.row_splits(1)
|
||||
ys_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
@ -168,10 +166,9 @@ class AttentionDecoderModel(nn.Module):
|
||||
decoder_out.view(-1, num_classes),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=self.ignore_id,
|
||||
reduction="None",
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(batch_size, -1)
|
||||
nll = nll.sum(1)
|
||||
nll = nll.view(batch_size, -1).sum(1)
|
||||
return nll
|
||||
|
||||
|
||||
@ -181,7 +178,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
d_model: equal to encoder_dim
|
||||
d_model: decoder dimension
|
||||
num_decoder_layers: number of decoder layers
|
||||
attention_dim: total dimension of multi head attention
|
||||
n_head: number of attention heads
|
||||
@ -715,7 +712,7 @@ def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||
def _test_attention_decoder_model():
|
||||
m = AttentionDecoderModel(
|
||||
vocab_size=500,
|
||||
d_model=384,
|
||||
decoder_dim=384,
|
||||
unmasked_dim=256,
|
||||
num_decoder_layers=6,
|
||||
attention_dim=192,
|
||||
@ -733,6 +730,9 @@ def _test_attention_decoder_model():
|
||||
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
||||
print(loss)
|
||||
|
||||
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
|
||||
print(nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_attention_decoder_model()
|
||||
|
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
|
||||
class CTCAttentionModel(nn.Module):
|
||||
"""Hybrid CTC & Attention decoder model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
encoder_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the Zipformer encoder model. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the attention decoder.
|
||||
encoder_dim:
|
||||
The embedding dimension of encoder.
|
||||
vocab_size:
|
||||
The vocabulary size.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
# Attention decoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
token_ids:
|
||||
A list of token id list.
|
||||
|
||||
Returns:
|
||||
- ctc_output, ctc log-probs
|
||||
- att_loss, attention decoder loss
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert x.size(0) == x_lens.size(0) == len(token_ids)
|
||||
|
||||
# encoder forward
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# compute attention decoder loss
|
||||
att_loss = self.decoder.calc_att_loss(encoder_out, x_lens, token_ids)
|
||||
|
||||
return ctc_output, att_loss
|
1268
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
1268
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user