diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 327cba2d3..753e5c473 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1630,6 +1630,30 @@ class EmformerEncoder(nn.Module): ) return output, output_lengths, output_states + def init_states(self, device: torch.device = torch.device("cpu")): + """Create initial states.""" + attn_caches = [ + [ + torch.zeros(self.memory_size, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + ] + for _ in range(self.num_encoder_layers) + ] + conv_caches = [ + torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device) + for _ in range(self.num_encoder_layers) + ] + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = ( + attn_caches, + conv_caches, + ) + return states + class Emformer(EncoderInterface): def __init__( @@ -1802,6 +1826,10 @@ class Emformer(EncoderInterface): return output, output_lengths, output_states + def init_states(self, device: torch.device = torch.device("cpu")): + """Create initial states.""" + return self.encoder.init_states(device) + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 31ad3f50a..69ee7ee9a 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -43,15 +43,12 @@ class Stream(object): device: The device to run this stream. """ - self.device = device self.LOG_EPS = LOG_EPS # Containing attention caches and convolution caches self.states: Optional[ Tuple[List[List[torch.Tensor]], List[torch.Tensor]] ] = None - # Initailize zero states. - self.init_states(params) # It uses different attributes for different decoding methods. self.context_size = params.context_size @@ -107,34 +104,11 @@ class Stream(object): def set_ground_truth(self, ground_truth: str) -> None: self.ground_truth = ground_truth - def init_states(self, params: AttributeDict) -> None: - attn_caches = [ - [ - torch.zeros( - params.memory_size, params.encoder_dim, device=self.device - ), - torch.zeros( - params.left_context_length // params.subsampling_factor, - params.encoder_dim, - device=self.device, - ), - torch.zeros( - params.left_context_length // params.subsampling_factor, - params.encoder_dim, - device=self.device, - ), - ] - for _ in range(params.num_encoder_layers) - ] - conv_caches = [ - torch.zeros( - params.encoder_dim, - params.cnn_module_kernel - 1, - device=self.device, - ) - for _ in range(params.num_encoder_layers) - ] - self.states = (attn_caches, conv_caches) + def set_states( + self, states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] + ) -> None: + """Set states.""" + self.states = states def get_feature_chunk(self) -> torch.Tensor: """Get a chunk of feature frames. diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 4fac405b0..0a6bbfa8b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -683,6 +683,8 @@ def decode_dataset( LOG_EPS=LOG_EPSILON, ) + stream.set_states(model.encoder.init_states(device)) + audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 287fb94df..402ec4293 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -19,10 +19,10 @@ """ Usage: (1) greedy search -./conv_emformer_transducer_stateless/decode.py \ +./conv_emformer_transducer_stateless2/decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --max-duration 300 \ --num-encoder-layers 12 \ --chunk-length 32 \ @@ -34,10 +34,10 @@ Usage: --use-averaged-model True (2) modified beam search -./conv_emformer_transducer_stateless/decode.py \ +./conv_emformer_transducer_stateless2/decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --max-duration 300 \ --num-encoder-layers 12 \ --chunk-length 32 \ @@ -50,10 +50,10 @@ Usage: --beam-size 4 (3) fast beam search -./conv_emformer_transducer_stateless/decode.py \ +./conv_emformer_transducer_stateless2/decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --max-duration 300 \ --num-encoder-layers 12 \ --chunk-length 32 \ diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 45ca03dd2..e3a598b0e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1529,6 +1529,30 @@ class EmformerEncoder(nn.Module): ) return output, output_lengths, output_states + def init_states(self, device: torch.device = torch.device("cpu")): + """Create initial states.""" + attn_caches = [ + [ + torch.zeros(self.memory_size, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + ] + for _ in range(self.num_encoder_layers) + ] + conv_caches = [ + torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device) + for _ in range(self.num_encoder_layers) + ] + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = ( + attn_caches, + conv_caches, + ) + return states + class Emformer(EncoderInterface): def __init__( @@ -1701,6 +1725,10 @@ class Emformer(EncoderInterface): return output, output_lengths, output_states + def init_states(self, device: torch.device = torch.device("cpu")): + """Create initial states.""" + return self.encoder.init_states(device) + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 4fac405b0..0f687898f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -19,10 +19,10 @@ """ Usage: (1) greedy search -./conv_emformer_transducer_stateless/streaming_decode.py \ +./conv_emformer_transducer_stateless2/streaming_decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --num-decode-streams 2000 \ --num-encoder-layers 12 \ --chunk-length 32 \ @@ -34,10 +34,10 @@ Usage: --use-averaged-model True (2) modified beam search -./conv_emformer_transducer_stateless/streaming_decode.py \ +./conv_emformer_transducer_stateless2/streaming_decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --num-decode-streams 2000 \ --num-encoder-layers 12 \ --chunk-length 32 \ @@ -50,10 +50,10 @@ Usage: --beam-size 4 (3) fast beam search -./conv_emformer_transducer_stateless/streaming_decode.py \ +./conv_emformer_transducer_stateless2/streaming_decode.py \ --epoch 30 \ --avg 10 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --num-decode-streams 2000 \ --num-encoder-layers 12 \ --chunk-length 32 \ @@ -683,6 +683,8 @@ def decode_dataset( LOG_EPS=LOG_EPSILON, ) + stream.set_states(model.encoder.init_states(device)) + audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 106f3e511..716ecc8b1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -22,11 +22,11 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./conv_emformer_transducer_stateless/train.py \ +./conv_emformer_transducer_stateless2/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ --master-port 12321 \ @@ -38,12 +38,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --memory-size 32 # For mix precision training: -./conv_emformer_transducer_stateless/train.py \ +./conv_emformer_transducer_stateless2/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir conv_emformer_transducer_stateless/exp \ + --exp-dir conv_emformer_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ --master-port 12321 \