refactor init states for stream

This commit is contained in:
yaozengwei 2022-06-21 22:07:36 +08:00
parent 5b19011edb
commit 42e3e883fd
7 changed files with 81 additions and 47 deletions

View File

@ -1630,6 +1630,30 @@ class EmformerEncoder(nn.Module):
)
return output, output_lengths, output_states
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
attn_caches = [
[
torch.zeros(self.memory_size, self.d_model, device=device),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
]
for _ in range(self.num_encoder_layers)
]
conv_caches = [
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
for _ in range(self.num_encoder_layers)
]
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
attn_caches,
conv_caches,
)
return states
class Emformer(EncoderInterface):
def __init__(
@ -1802,6 +1826,10 @@ class Emformer(EncoderInterface):
return output, output_lengths, output_states
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
return self.encoder.init_states(device)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).

View File

@ -43,15 +43,12 @@ class Stream(object):
device:
The device to run this stream.
"""
self.device = device
self.LOG_EPS = LOG_EPS
# Containing attention caches and convolution caches
self.states: Optional[
Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
] = None
# Initailize zero states.
self.init_states(params)
# It uses different attributes for different decoding methods.
self.context_size = params.context_size
@ -107,34 +104,11 @@ class Stream(object):
def set_ground_truth(self, ground_truth: str) -> None:
self.ground_truth = ground_truth
def init_states(self, params: AttributeDict) -> None:
attn_caches = [
[
torch.zeros(
params.memory_size, params.encoder_dim, device=self.device
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
torch.zeros(
params.left_context_length // params.subsampling_factor,
params.encoder_dim,
device=self.device,
),
]
for _ in range(params.num_encoder_layers)
]
conv_caches = [
torch.zeros(
params.encoder_dim,
params.cnn_module_kernel - 1,
device=self.device,
)
for _ in range(params.num_encoder_layers)
]
self.states = (attn_caches, conv_caches)
def set_states(
self, states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
) -> None:
"""Set states."""
self.states = states
def get_feature_chunk(self) -> torch.Tensor:
"""Get a chunk of feature frames.

View File

@ -683,6 +683,8 @@ def decode_dataset(
LOG_EPS=LOG_EPSILON,
)
stream.set_states(model.encoder.init_states(device))
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2

View File

@ -19,10 +19,10 @@
"""
Usage:
(1) greedy search
./conv_emformer_transducer_stateless/decode.py \
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
@ -34,10 +34,10 @@ Usage:
--use-averaged-model True
(2) modified beam search
./conv_emformer_transducer_stateless/decode.py \
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \
@ -50,10 +50,10 @@ Usage:
--beam-size 4
(3) fast beam search
./conv_emformer_transducer_stateless/decode.py \
./conv_emformer_transducer_stateless2/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--max-duration 300 \
--num-encoder-layers 12 \
--chunk-length 32 \

View File

@ -1529,6 +1529,30 @@ class EmformerEncoder(nn.Module):
)
return output, output_lengths, output_states
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
attn_caches = [
[
torch.zeros(self.memory_size, self.d_model, device=device),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
torch.zeros(
self.left_context_length, self.d_model, device=device
),
]
for _ in range(self.num_encoder_layers)
]
conv_caches = [
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
for _ in range(self.num_encoder_layers)
]
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
attn_caches,
conv_caches,
)
return states
class Emformer(EncoderInterface):
def __init__(
@ -1701,6 +1725,10 @@ class Emformer(EncoderInterface):
return output, output_lengths, output_states
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
return self.encoder.init_states(device)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).

View File

@ -19,10 +19,10 @@
"""
Usage:
(1) greedy search
./conv_emformer_transducer_stateless/streaming_decode.py \
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
@ -34,10 +34,10 @@ Usage:
--use-averaged-model True
(2) modified beam search
./conv_emformer_transducer_stateless/streaming_decode.py \
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
@ -50,10 +50,10 @@ Usage:
--beam-size 4
(3) fast beam search
./conv_emformer_transducer_stateless/streaming_decode.py \
./conv_emformer_transducer_stateless2/streaming_decode.py \
--epoch 30 \
--avg 10 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--chunk-length 32 \
@ -683,6 +683,8 @@ def decode_dataset(
LOG_EPS=LOG_EPSILON,
)
stream.set_states(model.encoder.init_states(device))
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2

View File

@ -22,11 +22,11 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./conv_emformer_transducer_stateless/train.py \
./conv_emformer_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 300 \
--master-port 12321 \
@ -38,12 +38,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--memory-size 32
# For mix precision training:
./conv_emformer_transducer_stateless/train.py \
./conv_emformer_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir conv_emformer_transducer_stateless/exp \
--exp-dir conv_emformer_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 300 \
--master-port 12321 \