refactor init states for stream
This commit is contained in:
parent
5b19011edb
commit
42e3e883fd
@ -1630,6 +1630,30 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
return output, output_lengths, output_states
|
||||
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(self.memory_size, self.d_model, device=device),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
]
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
|
||||
attn_caches,
|
||||
conv_caches,
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
class Emformer(EncoderInterface):
|
||||
def __init__(
|
||||
@ -1802,6 +1826,10 @@ class Emformer(EncoderInterface):
|
||||
|
||||
return output, output_lengths, output_states
|
||||
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
return self.encoder.init_states(device)
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
@ -43,15 +43,12 @@ class Stream(object):
|
||||
device:
|
||||
The device to run this stream.
|
||||
"""
|
||||
self.device = device
|
||||
self.LOG_EPS = LOG_EPS
|
||||
|
||||
# Containing attention caches and convolution caches
|
||||
self.states: Optional[
|
||||
Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
] = None
|
||||
# Initailize zero states.
|
||||
self.init_states(params)
|
||||
|
||||
# It uses different attributes for different decoding methods.
|
||||
self.context_size = params.context_size
|
||||
@ -107,34 +104,11 @@ class Stream(object):
|
||||
def set_ground_truth(self, ground_truth: str) -> None:
|
||||
self.ground_truth = ground_truth
|
||||
|
||||
def init_states(self, params: AttributeDict) -> None:
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(
|
||||
params.memory_size, params.encoder_dim, device=self.device
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
params.left_context_length // params.subsampling_factor,
|
||||
params.encoder_dim,
|
||||
device=self.device,
|
||||
),
|
||||
]
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(
|
||||
params.encoder_dim,
|
||||
params.cnn_module_kernel - 1,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(params.num_encoder_layers)
|
||||
]
|
||||
self.states = (attn_caches, conv_caches)
|
||||
def set_states(
|
||||
self, states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
|
||||
) -> None:
|
||||
"""Set states."""
|
||||
self.states = states
|
||||
|
||||
def get_feature_chunk(self) -> torch.Tensor:
|
||||
"""Get a chunk of feature frames.
|
||||
|
||||
@ -683,6 +683,8 @@ def decode_dataset(
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
)
|
||||
|
||||
stream.set_states(model.encoder.init_states(device))
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
|
||||
@ -19,10 +19,10 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./conv_emformer_transducer_stateless/decode.py \
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -34,10 +34,10 @@ Usage:
|
||||
--use-averaged-model True
|
||||
|
||||
(2) modified beam search
|
||||
./conv_emformer_transducer_stateless/decode.py \
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -50,10 +50,10 @@ Usage:
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search
|
||||
./conv_emformer_transducer_stateless/decode.py \
|
||||
./conv_emformer_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--max-duration 300 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
|
||||
@ -1529,6 +1529,30 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
return output, output_lengths, output_states
|
||||
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
attn_caches = [
|
||||
[
|
||||
torch.zeros(self.memory_size, self.d_model, device=device),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
self.left_context_length, self.d_model, device=device
|
||||
),
|
||||
]
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
conv_caches = [
|
||||
torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
|
||||
for _ in range(self.num_encoder_layers)
|
||||
]
|
||||
states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
|
||||
attn_caches,
|
||||
conv_caches,
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
class Emformer(EncoderInterface):
|
||||
def __init__(
|
||||
@ -1701,6 +1725,10 @@ class Emformer(EncoderInterface):
|
||||
|
||||
return output, output_lengths, output_states
|
||||
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
return self.encoder.init_states(device)
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
@ -19,10 +19,10 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./conv_emformer_transducer_stateless/streaming_decode.py \
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -34,10 +34,10 @@ Usage:
|
||||
--use-averaged-model True
|
||||
|
||||
(2) modified beam search
|
||||
./conv_emformer_transducer_stateless/streaming_decode.py \
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -50,10 +50,10 @@ Usage:
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search
|
||||
./conv_emformer_transducer_stateless/streaming_decode.py \
|
||||
./conv_emformer_transducer_stateless2/streaming_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--num-decode-streams 2000 \
|
||||
--num-encoder-layers 12 \
|
||||
--chunk-length 32 \
|
||||
@ -683,6 +683,8 @@ def decode_dataset(
|
||||
LOG_EPS=LOG_EPSILON,
|
||||
)
|
||||
|
||||
stream.set_states(model.encoder.init_states(device))
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
|
||||
@ -22,11 +22,11 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./conv_emformer_transducer_stateless/train.py \
|
||||
./conv_emformer_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
--master-port 12321 \
|
||||
@ -38,12 +38,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--memory-size 32
|
||||
|
||||
# For mix precision training:
|
||||
./conv_emformer_transducer_stateless/train.py \
|
||||
./conv_emformer_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir conv_emformer_transducer_stateless/exp \
|
||||
--exp-dir conv_emformer_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
--master-port 12321 \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user