mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
address comments from @csukuangfj
This commit is contained in:
parent
75e5c2775d
commit
d1c0ab606d
@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition
|
|||||||
but trained end-to-end with a single loss function. The overall architecture is shown
|
but trained end-to-end with a single loss function. The overall architecture is shown
|
||||||
in the figure below. Note that this architecture is slightly different from the one
|
in the figure below. Note that this architecture is slightly different from the one
|
||||||
in the above papers. A detailed description of the model can be found in the following
|
in the above papers. A detailed description of the model can be found in the following
|
||||||
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
|
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
|
|||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network
|
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
|
||||||
and a Zipfomer-based recognition network. But other combinations are possible as well.
|
and a Zipfomer-based recognition network. But other combinations are possible as well.
|
||||||
|
|
||||||
## Training objective
|
## Training objective
|
||||||
|
@ -233,22 +233,23 @@ def decode_one_batch(
|
|||||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||||
x_masked = [feature * m for m in masks]
|
x_masked = [feature * m for m in masks]
|
||||||
|
|
||||||
# To save the masks, we split them by batch and trim each mask to the length of
|
|
||||||
# the corresponding feature. We save them in a dict, where the key is the
|
|
||||||
# cut ID and the value is the mask.
|
|
||||||
masks_dict = {}
|
masks_dict = {}
|
||||||
for i in range(B):
|
if params.save_masks:
|
||||||
mask = torch.cat(
|
# To save the masks, we split them by batch and trim each mask to the length of
|
||||||
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
|
# the corresponding feature. We save them in a dict, where the key is the
|
||||||
dim=-1,
|
# cut ID and the value is the mask.
|
||||||
)
|
for i in range(B):
|
||||||
mask = mask.cpu().numpy()
|
mask = torch.cat(
|
||||||
masks_dict[batch["cuts"][i].id] = mask
|
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
mask = mask.cpu().numpy()
|
||||||
|
masks_dict[batch["cuts"][i].id] = mask
|
||||||
|
|
||||||
# Recognition
|
# Recognition
|
||||||
# Stack the inputs along the batch axis
|
# Concatenate the inputs along the batch axis
|
||||||
h = torch.cat(x_masked, dim=0)
|
h = torch.cat(x_masked, dim=0)
|
||||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
h_lens = feature_lens.repeat(params.num_channels)
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||||
|
|
||||||
if model.joint_encoder_layer is not None:
|
if model.joint_encoder_layer is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user