mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Minor fixes
This commit is contained in:
parent
3b6d416c4f
commit
a432e356a5
@ -1,5 +1,53 @@
|
||||
## Results
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Transducer)
|
||||
|
||||
#### Conformer encoder + embedding decoder
|
||||
|
||||
Conformer encoder + non-current decoder. The decoder
|
||||
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
|
||||
layer (to transform tensor dim).
|
||||
|
||||
The WERs are
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
--prune-range 5 \
|
||||
--lr-factor 5 \
|
||||
--lm-scale 0.25 \
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=28
|
||||
avg=15
|
||||
|
||||
## greedy search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--max-duration 100
|
||||
```
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Transducer)
|
||||
|
||||
#### Conformer encoder + embedding decoder
|
||||
|
||||
@ -19,16 +19,16 @@
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
@ -70,14 +70,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=29,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=13,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
|
||||
@ -68,7 +68,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
@ -76,7 +76,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -42,10 +43,8 @@ class Joiner(nn.Module):
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.inner_linear(logit)
|
||||
logit = self.inner_linear(torch.tanh(logit))
|
||||
|
||||
logit = torch.tanh(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
output = self.output_linear(F.relu(logit))
|
||||
|
||||
return output
|
||||
|
||||
@ -33,9 +33,6 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
prune_range: int = 3,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -52,21 +49,6 @@ class Transducer(nn.Module):
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
@ -75,15 +57,15 @@ class Transducer(nn.Module):
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.prune_range = prune_range
|
||||
self.lm_scale = lm_scale
|
||||
self.am_scale = am_scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -95,8 +77,23 @@ class Transducer(nn.Module):
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
@ -114,11 +111,14 @@ class Transducer(nn.Module):
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, C]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
@ -133,31 +133,37 @@ class Transducer(nn.Module):
|
||||
am=encoder_out,
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=self.lm_scale,
|
||||
am_only_scale=self.am_scale,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=self.prune_range,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, C]
|
||||
# lm_pruned : [B, T, prune_range, C]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, C]
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
joint=logits,
|
||||
logits=logits,
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
@ -148,7 +148,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=3,
|
||||
default=5,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
@ -156,7 +156,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=0.25,
|
||||
help="The scale to smooth the loss with lm "
|
||||
"(output of prediction network) part.",
|
||||
)
|
||||
@ -169,6 +169,16 @@ def get_parser():
|
||||
"part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simple-loss-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="To get pruning ranges, we will calculate a simple version"
|
||||
"loss(joiner is just addition), this simple loss also uses for"
|
||||
"training (as a regularization item). We will scale the simple loss"
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -289,9 +299,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
prune_range=params.prune_range,
|
||||
lm_scale=params.lm_scale,
|
||||
am_scale=params.am_scale,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -420,8 +427,15 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
loss = simple_loss + pruned_loss
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user