Minor fixes

This commit is contained in:
pkufool 2022-02-17 12:47:17 +08:00
parent 3b6d416c4f
commit a432e356a5
6 changed files with 112 additions and 45 deletions

View File

@ -1,5 +1,53 @@
## Results ## Results
### LibriSpeech BPE training results (Pruned Transducer)
#### Conformer encoder + embedding decoder
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
The decoding command is:
```
epoch=28
avg=15
## greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 100
```
### LibriSpeech BPE training results (Transducer) ### LibriSpeech BPE training results (Transducer)
#### Conformer encoder + embedding decoder #### Conformer encoder + embedding decoder

View File

@ -19,16 +19,16 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 14 \ --epoch 28 \
--avg 7 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 14 \ --epoch 28 \
--avg 7 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=29, default=28,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=13, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",

View File

@ -68,7 +68,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=20, default=28,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -76,7 +76,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",

View File

@ -16,6 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
@ -42,10 +43,8 @@ class Joiner(nn.Module):
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = self.inner_linear(logit) logit = self.inner_linear(torch.tanh(logit))
logit = torch.tanh(logit) output = self.output_linear(F.relu(logit))
output = self.output_linear(logit)
return output return output

View File

@ -33,9 +33,6 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
prune_range: int = 3,
am_scale: float = 0.0,
lm_scale: float = 0.0,
): ):
""" """
Args: Args:
@ -52,21 +49,6 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
@ -75,15 +57,15 @@ class Transducer(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.prune_range = prune_range
self.lm_scale = lm_scale
self.am_scale = am_scale
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -95,8 +77,23 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns: Returns:
Return the transducer loss. Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
""" """
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
@ -114,11 +111,14 @@ class Transducer(nn.Module):
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS # Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
@ -133,31 +133,37 @@ class Transducer(nn.Module):
am=encoder_out, am=encoder_out,
symbols=y_padded, symbols=y_padded,
termination_symbol=blank_id, termination_symbol=blank_id,
lm_only_scale=self.lm_scale, lm_only_scale=lm_scale,
am_only_scale=self.am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction="sum",
return_grad=True, return_grad=True,
) )
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, px_grad=px_grad,
py_grad=py_grad, py_grad=py_grad,
boundary=boundary, boundary=boundary,
s_range=self.prune_range, s_range=prune_range,
) )
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges am=encoder_out, lm=decoder_out, ranges=ranges
) )
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned) logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
joint=logits, logits=logits,
symbols=y_padded, symbols=y_padded,
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
reduction="sum",
) )
return (-torch.sum(simple_loss), -torch.sum(pruned_loss)) return (simple_loss, pruned_loss)

View File

@ -148,7 +148,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
default=3, default=5,
help="The prune range for rnnt loss, it means how many symbols(context)" help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss", "we are using to compute the loss",
) )
@ -156,7 +156,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
default=0.5, default=0.25,
help="The scale to smooth the loss with lm " help="The scale to smooth the loss with lm "
"(output of prediction network) part.", "(output of prediction network) part.",
) )
@ -169,6 +169,16 @@ def get_parser():
"part.", "part.",
) )
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
return parser return parser
@ -289,9 +299,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
prune_range=params.prune_range,
lm_scale=params.lm_scale,
am_scale=params.am_scale,
) )
return model return model
@ -420,8 +427,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) simple_loss, pruned_loss = model(
loss = simple_loss + pruned_loss x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training