mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Initial version of zipformer1 LM that runs, not sure whether it is working
This commit is contained in:
parent
75e9f1a34a
commit
3574e7dbb5
117
egs/libriheavy/LM/zipformer1/chunk_decoder.py
Normal file
117
egs/libriheavy/LM/zipformer1/chunk_decoder.py
Normal file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class ChunkDecoder(nn.Module):
|
||||
"""
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
chunk_size: int,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int = 2):
|
||||
"""
|
||||
A 'decoder' that computes the probability of symbols in a language modeling task.
|
||||
Conceptually it computes the probability of `chunk_size` symbols (e.g. 8 symbols)
|
||||
based on an embedding derived from all symbols preceding this chunk of 8 symbols.
|
||||
Also, within the chunk, we always see all previous symbols (plus the last symbol
|
||||
of the previous chunk).
|
||||
"""
|
||||
super().__init__()
|
||||
self.chunk_size = chunk_size
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.lstm = nn.LSTM(input_size=embed_dim,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers)
|
||||
|
||||
self.label_embed = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embed_dim)
|
||||
|
||||
# project to hidden state and cell state at the beginning of each chunk.
|
||||
# (we don't run the lstm contiuously over the sequence, for both
|
||||
# training speed and stability; instead, we reconstruct the hidden
|
||||
# state for each chunk.)
|
||||
# the factor of 2 is to cover hidden state and cell state.
|
||||
self.init_proj = nn.Linear(embed_dim,
|
||||
2 * hidden_size * num_layers)
|
||||
|
||||
self.out_proj = nn.Linear(hidden_size,
|
||||
vocab_size)
|
||||
|
||||
|
||||
def forward(self,
|
||||
labels: Tensor,
|
||||
encoder_embed: Tensor) -> Tensor:
|
||||
"""
|
||||
Compute log-probs.
|
||||
Args:
|
||||
labels: the labels, a Tensor of integer type of shape (batch_size, seq_len);
|
||||
seq_len is expected to be a multiple of chunk_size.
|
||||
encoder_embed: the embeddings from the encoder, of shape (seq_len//chunk_size, batch_size, embed_dim)
|
||||
|
||||
Returns:
|
||||
returns the log-probs for each symbol, in a Tensor of shape (batch_size, seq_len).
|
||||
"""
|
||||
(batch_size, seq_len) = labels.shape
|
||||
(num_chunks, _batch_size, embed_dim) = encoder_embed.shape
|
||||
chunk_size = self.chunk_size
|
||||
assert batch_size == _batch_size
|
||||
assert num_chunks * chunk_size == seq_len
|
||||
|
||||
labels_shifted = torch.cat((torch.zeros_like(labels[0:1]),
|
||||
labels[:-1]), dim=0)
|
||||
|
||||
labels_embed = self.label_embed(labels_shifted.t()) # (seq_len, batch_size, embed_dim)
|
||||
|
||||
init = self.init_proj(encoder_embed).reshape(num_chunks, batch_size,
|
||||
2, self.num_layers, self.hidden_size)
|
||||
init = init.permute(2, 3, 0, 1, 4).reshape(2, self.num_layers,
|
||||
num_chunks * batch_size,
|
||||
self.hidden_size).contiguous()
|
||||
hidden = init[0]
|
||||
cell = init[1]
|
||||
|
||||
|
||||
labels_embed = labels_embed.reshape(num_chunks, chunk_size, batch_size, embed_dim).transpose(0, 1)
|
||||
labels_embed = labels_embed.contiguous().reshape(chunk_size, num_chunks * batch_size, embed_dim)
|
||||
encoder_embed = encoder_embed.reshape(1, num_chunks * batch_size, embed_dim)
|
||||
|
||||
x = labels_embed + encoder_embed # broadcasts encoder_embed over the chunk_size
|
||||
|
||||
(x, _hidden) = self.lstm(x, hx=(hidden, cell))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
vocab_size = x.shape[-1]
|
||||
# x: (chunk_size, num_chunks * batch_size, vocab_size)
|
||||
x = x.reshape(chunk_size, num_chunks, batch_size, vocab_size)
|
||||
x = x.permute(2, 1, 0, 3).contiguous().reshape(batch_size, num_chunks * chunk_size, vocab_size)
|
||||
|
||||
x = x.log_softmax(dim=-1)
|
||||
|
||||
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||
|
||||
return logprobs
|
1
egs/libriheavy/LM/zipformer1/encoder_interface.py
Symbolic link
1
egs/libriheavy/LM/zipformer1/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless/encoder_interface.py
|
141
egs/libriheavy/LM/zipformer1/lm_datamodule.py
Normal file
141
egs/libriheavy/LM/zipformer1/lm_datamodule.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from icefall.dist import get_world_size, get_rank
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
|
||||
class LmDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self,
|
||||
file_list_fn: Path,
|
||||
bytes_per_segment: int = 200):
|
||||
"""
|
||||
Initialize LmDataset object. Args:
|
||||
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
|
||||
e.g. a line might contain the text "64324 foo/abc.txt".
|
||||
(filenames can not contain spaces).
|
||||
bytes_per_segment: the number of bytes in each segment of data.
|
||||
"""
|
||||
self.files = []
|
||||
self.num_bytes = []
|
||||
self.bytes_per_segment = bytes_per_segment
|
||||
|
||||
num_bytes = []
|
||||
with open(file_list_fn) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip() # remove newline
|
||||
num_bytes = line.split()[0] # a str
|
||||
fn = line[len(num_bytes) + 1:] # this works even if fn has spaces in
|
||||
self.files.append(fn)
|
||||
self.num_bytes.append(int(num_bytes))
|
||||
tot_bytes = sum(self.num_bytes)
|
||||
N = len(self.num_bytes)
|
||||
self.probs = np.array([ x / tot_bytes for x in self.num_bytes ])
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
num_workers = (1 if worker_info is None else worker_info.num_workers)
|
||||
|
||||
tot_workers = num_workers * get_world_size()
|
||||
|
||||
self.num_segments = tot_bytes // (bytes_per_segment * tot_workers)
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
# id includes both worker (within training job) and rank of training job
|
||||
my_id = (0 if worker_info is None else worker_info.id) + 1000 * get_rank()
|
||||
|
||||
seed = random.randint(0, 10000) + my_id
|
||||
logging.info(f"seed={seed}, num_segments={self.num_segments}")
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
for n in range(self.num_segments):
|
||||
# np.random.multinomial / np.random.Generator.multinomial has an interface
|
||||
# where it gives counts of different categories, instead of the chosen category,
|
||||
# so we need to use np.nonzero to get the chosen category (i.e. the file index)
|
||||
# np.nonzero will give an array per dim, so file_idx,
|
||||
# gives the array of nonzero index
|
||||
file_idx, = np.nonzero(rng.multinomial(1, self.probs))
|
||||
file_idx, = file_idx
|
||||
|
||||
fn = self.files[file_idx]
|
||||
num_bytes = self.num_bytes[file_idx]
|
||||
|
||||
# begin_pos, end_pos are the begin,end of a range from which we'll pick
|
||||
# randomly, for where the start of the segment might be.
|
||||
begin_pos = 0
|
||||
end_pos = max(1, num_bytes - self.bytes_per_segment)
|
||||
|
||||
begin, = rng.integers(low=begin_pos, high=end_pos, size=1)
|
||||
|
||||
with open(fn, "rb") as f:
|
||||
f.seek(begin)
|
||||
b = f.read(self.bytes_per_segment) # b is bytes object
|
||||
read_size = len(b)
|
||||
if read_size < self.bytes_per_segment:
|
||||
b = b + b'\0' * (self.bytes_per_segment - read_size)
|
||||
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
|
||||
|
||||
|
||||
|
||||
def LmDataloader(dataset: LmDataset,
|
||||
batch_size: int,
|
||||
num_workers: int):
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=True)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test():
|
||||
l = LmDataset('files.txt')
|
||||
|
||||
d = LmDataloader(l, batch_size=5, num_workers=4)
|
||||
|
||||
for batch in d:
|
||||
logging.info("batch shape: ", batch.shape)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
_test()
|
||||
|
||||
|
||||
|
||||
# cd libriheavy/LM
|
||||
# find /ceph-data3/xiaoyu/librilight_text/output_text_large_cleaned -name text.txt -exec stat --printf='%s ' {} \; -print > files.txt
|
||||
# head -n 2 files.txt > valid.txt
|
||||
# tail -n +3 files.txt > train.txt
|
65
egs/libriheavy/LM/zipformer1/model.py
Normal file
65
egs/libriheavy/LM/zipformer1/model.py
Normal file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from chunk_decoder import ChunkDecoder
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
class Zipformer2LM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: Zipformer2,
|
||||
decoder: ChunkDecoder):
|
||||
super().__init__()
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder # does subsampling
|
||||
self.decoder = decoder
|
||||
|
||||
|
||||
def forward(self,
|
||||
labels: Tensor):
|
||||
"""
|
||||
Compute array of log-probs
|
||||
|
||||
Args:
|
||||
labels: a Tensor containing the labels (in the range 0..num_symbols-1), of shape (batch_size, seq_len).
|
||||
Returns:
|
||||
a Tensor containing the log-probs for each label, of shape (batch_size, seq_len).
|
||||
"""
|
||||
(batch_size, seq_len) = labels.shape
|
||||
|
||||
chunk_size = self.decoder.chunk_size
|
||||
labels_shifted = labels.t() # (time, batch)
|
||||
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]),
|
||||
labels_shifted[:-chunk_size]),
|
||||
dim=0)
|
||||
|
||||
x = self.encoder_embed(labels_shifted)
|
||||
x_lens = torch.full((batch_size,), seq_len,
|
||||
dtype=torch.long, device=labels.device)
|
||||
# x_lens is after subsampling. Actually we don't need it.
|
||||
|
||||
|
||||
(x, x_lens) = self.encoder(x, x_lens)
|
||||
|
||||
logprobs = self.decoder(labels, x)
|
||||
return logprobs
|
1
egs/libriheavy/LM/zipformer1/optim.py
Symbolic link
1
egs/libriheavy/LM/zipformer1/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer2/optim.py
|
1
egs/libriheavy/LM/zipformer1/scaling.py
Symbolic link
1
egs/libriheavy/LM/zipformer1/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer2/scaling.py
|
1178
egs/libriheavy/LM/zipformer1/train.py
Executable file
1178
egs/libriheavy/LM/zipformer1/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/libriheavy/LM/zipformer1/zipformer.py
Symbolic link
1
egs/libriheavy/LM/zipformer1/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer2/zipformer.py
|
@ -111,8 +111,8 @@ class Zipformer2(EncoderInterface):
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
chunk_size: Tuple[int] = (-1,),
|
||||
left_context_frames: Tuple[int] = (-1,),
|
||||
) -> None:
|
||||
super(Zipformer2, self).__init__()
|
||||
|
||||
|
1
egs/librispeech/ASR/zipformer2
Symbolic link
1
egs/librispeech/ASR/zipformer2
Symbolic link
@ -0,0 +1 @@
|
||||
pruned_transducer_stateless7
|
Loading…
x
Reference in New Issue
Block a user