mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add train.py, model.py
This commit is contained in:
parent
c87f55671a
commit
2fc7535de9
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -47,8 +47,7 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
vocab_size (int): Number of classes.
|
vocab_size (int): Number of classes.
|
||||||
encoder_dim (int):
|
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||||
d_model: (int,int): embedding dimension of 2 encoder stacks
|
|
||||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||||
nhead (int, int): number of heads
|
nhead (int, int): number of heads
|
||||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||||
@ -62,15 +61,15 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
d_model: int,
|
decoder_dim: int,
|
||||||
unmasked_dim: int,
|
unmasked_dim: int,
|
||||||
num_decoder_layers: int,
|
num_decoder_layers: int,
|
||||||
attention_dim: int,
|
attention_dim: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: float,
|
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
ignore_id: int = -1,
|
ignore_id: int = -1,
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
label_smoothing: float = 0.1,
|
label_smoothing: float = 0.1,
|
||||||
@ -84,7 +83,7 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
# layer learn something. Then we start to warm up the other encoders.
|
# layer learn something. Then we start to warm up the other encoders.
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
vocab_size,
|
vocab_size,
|
||||||
d_model,
|
decoder_dim,
|
||||||
unmasked_dim,
|
unmasked_dim,
|
||||||
num_decoder_layers,
|
num_decoder_layers,
|
||||||
attention_dim,
|
attention_dim,
|
||||||
@ -103,7 +102,6 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
def _pre_ys_in_out(self, token_ids: List[List[int]], device: torch.device):
|
def _pre_ys_in_out(self, token_ids: List[List[int]], device: torch.device):
|
||||||
"""Prepare ys_in_pad and ys_out_pad."""
|
"""Prepare ys_in_pad and ys_out_pad."""
|
||||||
ys = k2.RaggedTensor(token_ids).to(device=device)
|
ys = k2.RaggedTensor(token_ids).to(device=device)
|
||||||
|
|
||||||
row_splits = ys.shape.row_splits(1)
|
row_splits = ys.shape.row_splits(1)
|
||||||
ys_lens = row_splits[1:] - row_splits[:-1]
|
ys_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
@ -168,10 +166,9 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
decoder_out.view(-1, num_classes),
|
decoder_out.view(-1, num_classes),
|
||||||
ys_out_pad.view(-1),
|
ys_out_pad.view(-1),
|
||||||
ignore_index=self.ignore_id,
|
ignore_index=self.ignore_id,
|
||||||
reduction="None",
|
reduction="none",
|
||||||
)
|
)
|
||||||
nll = nll.view(batch_size, -1)
|
nll = nll.view(batch_size, -1).sum(1)
|
||||||
nll = nll.sum(1)
|
|
||||||
return nll
|
return nll
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +178,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size: output dim
|
vocab_size: output dim
|
||||||
d_model: equal to encoder_dim
|
d_model: decoder dimension
|
||||||
num_decoder_layers: number of decoder layers
|
num_decoder_layers: number of decoder layers
|
||||||
attention_dim: total dimension of multi head attention
|
attention_dim: total dimension of multi head attention
|
||||||
n_head: number of attention heads
|
n_head: number of attention heads
|
||||||
@ -715,7 +712,7 @@ def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
|||||||
def _test_attention_decoder_model():
|
def _test_attention_decoder_model():
|
||||||
m = AttentionDecoderModel(
|
m = AttentionDecoderModel(
|
||||||
vocab_size=500,
|
vocab_size=500,
|
||||||
d_model=384,
|
decoder_dim=384,
|
||||||
unmasked_dim=256,
|
unmasked_dim=256,
|
||||||
num_decoder_layers=6,
|
num_decoder_layers=6,
|
||||||
attention_dim=192,
|
attention_dim=192,
|
||||||
@ -733,6 +730,9 @@ def _test_attention_decoder_model():
|
|||||||
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
||||||
print(loss)
|
print(loss)
|
||||||
|
|
||||||
|
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
|
||||||
|
print(nll)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_attention_decoder_model()
|
_test_attention_decoder_model()
|
||||||
|
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
|
|
||||||
|
class CTCAttentionModel(nn.Module):
|
||||||
|
"""Hybrid CTC & Attention decoder model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
decoder: nn.Module,
|
||||||
|
encoder_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
It is the Zipformer encoder model. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
decoder:
|
||||||
|
It is the attention decoder.
|
||||||
|
encoder_dim:
|
||||||
|
The embedding dimension of encoder.
|
||||||
|
vocab_size:
|
||||||
|
The vocabulary size.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(encoder_dim, vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
# Attention decoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
token_ids: List[List[int]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
token_ids:
|
||||||
|
A list of token id list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ctc_output, ctc log-probs
|
||||||
|
- att_loss, attention decoder loss
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
|
assert x.size(0) == x_lens.size(0) == len(token_ids)
|
||||||
|
|
||||||
|
# encoder forward
|
||||||
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
|
# compute ctc log-probs
|
||||||
|
ctc_output = self.ctc_output(encoder_out)
|
||||||
|
|
||||||
|
# compute attention decoder loss
|
||||||
|
att_loss = self.decoder.calc_att_loss(encoder_out, x_lens, token_ids)
|
||||||
|
|
||||||
|
return ctc_output, att_loss
|
1268
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
1268
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user