This commit is contained in:
yfyeung 2024-07-10 00:16:24 -07:00
parent f96bbdef41
commit 7961b6bf23
5 changed files with 35 additions and 2 deletions

12
egs/librispeech/ASR/decode.sh Executable file
View File

@ -0,0 +1,12 @@
export CUDA_VISIBLE_DEVICES=2
for epoch in {30..30}; do
for ((avg=1; avg<=$epoch-1; avg++)); do
./zipformer_lstm/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer_lstm/exp_dropout0.2 \
--max-duration 2000 \
--decoding-method greedy_search
done
done

View File

@ -0,0 +1,8 @@
export CUDA_VISIBLE_DEVICES=$1
./zipformer_lstm/decode.py \
--epoch $2 \
--avg $3 \
--exp-dir ./zipformer_lstm/exp \
--max-duration 2000 \
--decoding-method beam_search

11
egs/librispeech/ASR/sync.sh Executable file
View File

@ -0,0 +1,11 @@
project=icefall-asr-librispeech-zipformer-2023-11-04
run=4V10032G_lstm1_decoderdropout0.2_bpe500
recipe=zipformer_lstm
wandb sync ${recipe}/exp_dropout0.2/tensorboard/ --sync-tensorboard -p $project --id $run
while true
do
wandb sync ${recipe}/exp_dropout0.2/tensorboard/ --sync-tensorboard -p $project --id $run --append
sleep 60
done

View File

@ -36,7 +36,7 @@ def greedy_search(model: nn.Module, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.encoder_embed.device
device = next(model.parameters()).device
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)

View File

@ -76,7 +76,7 @@ class Decoder(nn.Module):
self.vocab_size = vocab_size
# self.embedding_dropout = nn.Dropout(embedding_dropout)
self.embedding_dropout = nn.Dropout(embedding_dropout)
self.rnn = nn.LSTM(
input_size=decoder_dim,
@ -113,6 +113,8 @@ class Decoder(nn.Module):
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.embedding_dropout(embedding_out)
embedding_out = self.balancer(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)