Minor fixes

This commit is contained in:
pkufool 2022-02-17 12:47:17 +08:00
parent 3b6d416c4f
commit a432e356a5
6 changed files with 112 additions and 45 deletions

View File

@ -1,5 +1,53 @@
## Results
### LibriSpeech BPE training results (Pruned Transducer)
#### Conformer encoder + embedding decoder
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
The decoding command is:
```
epoch=28
avg=15
## greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 100
```
### LibriSpeech BPE training results (Transducer)
#### Conformer encoder + embedding decoder

View File

@ -19,16 +19,16 @@
Usage:
(1) greedy search
./pruned_transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
@ -70,14 +70,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=29,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=13,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",

View File

@ -68,7 +68,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=20,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -76,7 +76,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=10,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",

View File

@ -16,6 +16,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
@ -42,10 +43,8 @@ class Joiner(nn.Module):
logit = encoder_out + decoder_out
logit = self.inner_linear(logit)
logit = self.inner_linear(torch.tanh(logit))
logit = torch.tanh(logit)
output = self.output_linear(logit)
output = self.output_linear(F.relu(logit))
return output

View File

@ -33,9 +33,6 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
prune_range: int = 3,
am_scale: float = 0.0,
lm_scale: float = 0.0,
):
"""
Args:
@ -52,21 +49,6 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -75,15 +57,15 @@ class Transducer(nn.Module):
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.prune_range = prune_range
self.lm_scale = lm_scale
self.am_scale = am_scale
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
@ -95,8 +77,23 @@ class Transducer(nn.Module):
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
@ -114,11 +111,14 @@ class Transducer(nn.Module):
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
@ -133,31 +133,37 @@ class Transducer(nn.Module):
am=encoder_out,
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=self.prune_range,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
joint=logits,
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
return (simple_loss, pruned_loss)

View File

@ -148,7 +148,7 @@ def get_parser():
parser.add_argument(
"--prune-range",
type=int,
default=3,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
@ -156,7 +156,7 @@ def get_parser():
parser.add_argument(
"--lm-scale",
type=float,
default=0.5,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
@ -169,6 +169,16 @@ def get_parser():
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
return parser
@ -289,9 +299,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
prune_range=params.prune_range,
lm_scale=params.lm_scale,
am_scale=params.am_scale,
)
return model
@ -420,8 +427,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
loss = simple_loss + pruned_loss
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training