mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Add teacher student loss.
This commit is contained in:
parent
3e2dbc9ab5
commit
ca61f189be
@ -15,13 +15,39 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
def compute_teacher_student_loss(
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
teacher_encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output of the student. Its shape is (N, T, C)
|
||||||
|
teacher_encoder_out:
|
||||||
|
Encoder output of the teacher. Its shape is also (N, T, C)
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor containing the number of valid frames in encoder_out before
|
||||||
|
padding.
|
||||||
|
Returns:
|
||||||
|
Return the l1 loss between encoder_out and teacher_encoder_out.
|
||||||
|
"""
|
||||||
|
loss = (encoder_out - teacher_encoder_out).abs().sum(dim=-1)
|
||||||
|
mask = make_pad_mask(encoder_out_lens)
|
||||||
|
loss.masked_fill_(mask, 0)
|
||||||
|
|
||||||
|
return loss.sum() / encoder_out.size(-1)
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -51,9 +77,10 @@ class Transducer(nn.Module):
|
|||||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
It should contain one attribute: `blank_id`.
|
It should contain one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
It has two inputs with shapes: (N, T, encoder_dim) and
|
||||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
(N, U, decoder_dim).
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
Its output shape is (N, T, U, vocab_size). Note that its output
|
||||||
|
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -76,6 +103,7 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
teacher_model: Optional[torch.jit.ScriptModule] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -96,6 +124,8 @@ class Transducer(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
teacher_model:
|
||||||
|
The teacher model.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -111,8 +141,20 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(encoder_out_lens > 0)
|
||||||
|
|
||||||
|
if self.training is True:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_encoder_out, _ = teacher_model.encoder(x, x_lens)
|
||||||
|
|
||||||
|
ts_loss = compute_teacher_student_loss(
|
||||||
|
encoder_out,
|
||||||
|
teacher_encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ts_loss = torch.tensor([0.0])
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
@ -136,7 +178,7 @@ class Transducer(nn.Module):
|
|||||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
)
|
)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = encoder_out_lens
|
||||||
|
|
||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
@ -186,4 +228,4 @@ class Transducer(nn.Module):
|
|||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, ts_loss)
|
||||||
|
25
egs/librispeech/ASR/transducer_lstm/teacher_model.py
Normal file
25
egs/librispeech/ASR/transducer_lstm/teacher_model.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_teacher_model() -> torch.jit.ScriptModule:
|
||||||
|
filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt"
|
||||||
|
model = torch.jit.load(filename)
|
||||||
|
|
||||||
|
return model
|
59
egs/librispeech/ASR/transducer_lstm/test_teacher_model.py
Executable file
59
egs/librispeech/ASR/transducer_lstm/test_teacher_model.py
Executable file
@ -0,0 +1,59 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_lstm/test_teacher_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from teacher_model import get_teacher_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_teacher_model():
|
||||||
|
model = get_teacher_model()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of encoder model parameters: {num_param}")
|
||||||
|
|
||||||
|
N = 3
|
||||||
|
T = 500
|
||||||
|
C = 80
|
||||||
|
|
||||||
|
x = torch.rand(N, T, C)
|
||||||
|
x_lens = torch.tensor([100, 500, 300])
|
||||||
|
|
||||||
|
y, y_lens = model.encoder(x, x_lens)
|
||||||
|
print(y.shape)
|
||||||
|
expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
|
assert torch.all(torch.eq(y_lens, expected_y_lens)), (
|
||||||
|
y_lens,
|
||||||
|
expected_y_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_teacher_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -65,6 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
|
from teacher_model import get_teacher_model
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -229,6 +230,16 @@ def get_parser():
|
|||||||
"with this parameter before adding to the final loss.",
|
"with this parameter before adding to the final loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ts-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
|
"with this parameter before adding to the final loss.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -548,6 +559,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
teacher_model: Optional[torch.jit.ScriptModule] = None,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -564,6 +576,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
|
teacher_model:
|
||||||
|
The teacher model.
|
||||||
"""
|
"""
|
||||||
device = params.device
|
device = params.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -579,16 +593,18 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, ts_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
teacher_model=teacher_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||||
|
loss = loss + params.ts_loss_scale * ts_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -603,6 +619,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
info["ts_loss"] = ts_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -623,6 +640,7 @@ def compute_validation_loss(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
teacher_model=None,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
@ -644,6 +662,7 @@ def compute_validation_loss(
|
|||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
teacher_model: Optional[torch.jit.ScriptModule],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
@ -666,6 +685,8 @@ def train_one_epoch(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training.
|
The model for training.
|
||||||
|
teacher_model:
|
||||||
|
The teacher model.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer we are using.
|
The optimizer we are using.
|
||||||
scheduler:
|
scheduler:
|
||||||
@ -729,6 +750,7 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
teacher_model=teacher_model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -901,10 +923,14 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
teacher_model = get_teacher_model()
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
num_teacher_param = sum([p.numel() for p in teacher_model.parameters()])
|
||||||
|
logging.info(f"Number of teacher model parameters: {num_teacher_param}")
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -919,6 +945,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
teacher_model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
@ -983,6 +1010,7 @@ def run(rank, world_size, args):
|
|||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
|
teacher_model=teacher_model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -1007,6 +1035,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
teacher_model=teacher_model,
|
||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -1075,6 +1104,7 @@ def display_and_save_batch(
|
|||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
teacher_model: torch.jit.ScriptModule,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
@ -1093,6 +1123,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
teacher_model=teacher_model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -1130,5 +1161,9 @@ def main():
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._set_graph_executor_optimize(False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user