mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
address comments from @csukuangfj
This commit is contained in:
parent
75e5c2775d
commit
d1c0ab606d
@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition
|
||||
but trained end-to-end with a single loss function. The overall architecture is shown
|
||||
in the figure below. Note that this architecture is slightly different from the one
|
||||
in the above papers. A detailed description of the model can be found in the following
|
||||
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
|
||||
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
|
||||
|
||||
<p align="center">
|
||||
|
||||
@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
|
||||
|
||||
</p>
|
||||
|
||||
In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network
|
||||
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
|
||||
and a Zipfomer-based recognition network. But other combinations are possible as well.
|
||||
|
||||
## Training objective
|
||||
|
@ -233,10 +233,11 @@ def decode_one_batch(
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
masks_dict = {}
|
||||
if params.save_masks:
|
||||
# To save the masks, we split them by batch and trim each mask to the length of
|
||||
# the corresponding feature. We save them in a dict, where the key is the
|
||||
# cut ID and the value is the mask.
|
||||
masks_dict = {}
|
||||
for i in range(B):
|
||||
mask = torch.cat(
|
||||
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
|
||||
@ -246,9 +247,9 @@ def decode_one_batch(
|
||||
masks_dict[batch["cuts"][i].id] = mask
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
# Concatenate the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
h_lens = feature_lens.repeat(params.num_channels)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user