mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Minor fixes
This commit is contained in:
parent
3b6d416c4f
commit
a432e356a5
@ -1,5 +1,53 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### LibriSpeech BPE training results (Pruned Transducer)
|
||||||
|
|
||||||
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
|
Conformer encoder + non-current decoder. The decoder
|
||||||
|
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
|
||||||
|
layer (to transform tensor dim).
|
||||||
|
|
||||||
|
The WERs are
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
|
| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
|
||||||
|
|
||||||
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--prune-range 5 \
|
||||||
|
--lr-factor 5 \
|
||||||
|
--lm-scale 0.25 \
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensorboard training log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```
|
||||||
|
epoch=28
|
||||||
|
avg=15
|
||||||
|
|
||||||
|
## greedy search
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 100
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Transducer)
|
### LibriSpeech BPE training results (Transducer)
|
||||||
|
|
||||||
#### Conformer encoder + embedding decoder
|
#### Conformer encoder + embedding decoder
|
||||||
|
|||||||
@ -19,16 +19,16 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless/decode.py \
|
||||||
--epoch 14 \
|
--epoch 28 \
|
||||||
--avg 7 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless/decode.py \
|
||||||
--epoch 14 \
|
--epoch 28 \
|
||||||
--avg 7 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
@ -70,14 +70,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=29,
|
default=28,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=13,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
|
|||||||
@ -68,7 +68,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=28,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
@ -76,7 +76,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -42,10 +43,8 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
logit = self.inner_linear(logit)
|
logit = self.inner_linear(torch.tanh(logit))
|
||||||
|
|
||||||
logit = torch.tanh(logit)
|
output = self.output_linear(F.relu(logit))
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -33,9 +33,6 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
prune_range: int = 3,
|
|
||||||
am_scale: float = 0.0,
|
|
||||||
lm_scale: float = 0.0,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -52,21 +49,6 @@ class Transducer(nn.Module):
|
|||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
prune_range:
|
|
||||||
The prune range for rnnt loss, it means how many symbols(context)
|
|
||||||
we are considering for each frame to compute the loss.
|
|
||||||
am_scale:
|
|
||||||
The scale to smooth the loss with am (output of encoder network)
|
|
||||||
part
|
|
||||||
lm_scale:
|
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
|
||||||
part
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
|
||||||
the form:
|
|
||||||
lm_scale * lm_probs + am_scale * am_probs +
|
|
||||||
(1-lm_scale-am_scale) * combined_probs
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -75,15 +57,15 @@ class Transducer(nn.Module):
|
|||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
self.prune_range = prune_range
|
|
||||||
self.lm_scale = lm_scale
|
|
||||||
self.am_scale = am_scale
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
prune_range: int = 5,
|
||||||
|
am_scale: float = 0.0,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -95,8 +77,23 @@ class Transducer(nn.Module):
|
|||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
|
the form:
|
||||||
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
"""
|
"""
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
@ -114,11 +111,14 @@ class Transducer(nn.Module):
|
|||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
# decoder_out: [B, S + 1, C]
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
|
# y_padded : [B, S]
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
@ -133,31 +133,37 @@ class Transducer(nn.Module):
|
|||||||
am=encoder_out,
|
am=encoder_out,
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
lm_only_scale=self.lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=self.am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad,
|
px_grad=px_grad,
|
||||||
py_grad=py_grad,
|
py_grad=py_grad,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
s_range=self.prune_range,
|
s_range=prune_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# am_pruned : [B, T, prune_range, C]
|
||||||
|
# lm_pruned : [B, T, prune_range, C]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# logits : [B, T, prune_range, C]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
joint=logits,
|
logits=logits,
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
return (simple_loss, pruned_loss)
|
||||||
|
|||||||
@ -148,7 +148,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=5,
|
||||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"we are using to compute the loss",
|
"we are using to compute the loss",
|
||||||
)
|
)
|
||||||
@ -156,7 +156,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.25,
|
||||||
help="The scale to smooth the loss with lm "
|
help="The scale to smooth the loss with lm "
|
||||||
"(output of prediction network) part.",
|
"(output of prediction network) part.",
|
||||||
)
|
)
|
||||||
@ -169,6 +169,16 @@ def get_parser():
|
|||||||
"part.",
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simple-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
|
"with this parameter before adding to the final loss.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -289,9 +299,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
prune_range=params.prune_range,
|
|
||||||
lm_scale=params.lm_scale,
|
|
||||||
am_scale=params.am_scale,
|
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -420,8 +427,15 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
simple_loss, pruned_loss = model(
|
||||||
loss = simple_loss + pruned_loss
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
y=y,
|
||||||
|
prune_range=params.prune_range,
|
||||||
|
am_scale=params.am_scale,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user