Replace [] with () for shapes.

This commit is contained in:
Fangjun Kuang 2021-09-18 17:04:57 +08:00
parent 306c9e1398
commit a9a8448d0f
11 changed files with 62 additions and 58 deletions

View File

@ -98,7 +98,7 @@ class Conformer(Transformer):
""" """
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa

View File

@ -213,12 +213,12 @@ def decode_one_batch(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (

View File

@ -22,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output Convert an input of shape (N, T, idim) to an output
with shape [N, T', odim], where with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on It is based on
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
""" """
Args: Args:
idim: idim:
Input dim. The input shape is [N, T, idim]. Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7 Caution: It requires: T >=7, idim >=7
odim: odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
""" """
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, idim]. Its shape is (N, T, idim).
Returns: Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
""" """
# On entry, x is [N, T, idim] # On entry, x is (N, T, idim)
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x) x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim] # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x return x
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent, This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations. and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output Convert an input of shape (N, T, idim) to an output
with shape [N, T', odim], where with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
""" """
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
Args: Args:
idim: idim:
Input dim. The input shape is [N, T, idim]. Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7 Caution: It requires: T >=7, idim >=7
odim: odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
""" """
super().__init__() super().__init__()
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, idim]. Its shape is (N, T, idim).
Returns: Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
""" """
x = x.unsqueeze(1) x = x.unsqueeze(1)
x = self.layers(x) x = self.layers(x)

View File

@ -310,14 +310,14 @@ def compute_loss(
""" """
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by

View File

@ -83,8 +83,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes] # self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape [N, T//subsampling_factor, d_model]. # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model # (2) embedding: num_classes -> d_model
@ -162,7 +162,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
x: x:
The input tensor. Its shape is [N, T, C]. The input tensor. Its shape is (N, T, C).
supervision: supervision:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -171,17 +171,17 @@ class Transformer(nn.Module):
Returns: Returns:
Return a tuple containing 3 tensors: Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C] - CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape [T, N, C]. It can be used as key and - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder. value for the decoder.
- Encoder output padding mask. It can be used as - Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T]. memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )
@ -195,7 +195,7 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -206,8 +206,8 @@ class Transformer(nn.Module):
padding mask for the decoder. padding mask for the decoder.
Returns: Returns:
Return a tuple with two tensors: Return a tuple with two tensors:
- The encoder output, with shape [T, N, C] - The encoder output, with shape (T, N, C)
- encoder padding mask, with shape [N, T]. - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None. The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder. It is used as memory key padding mask in the decoder.
""" """
@ -225,11 +225,11 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The output tensor from the transformer encoder. The output tensor from the transformer encoder.
Its shape is [T, N, C] Its shape is (T, N, C)
Returns: Returns:
Return a tensor that can be used for CTC decoding. Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C] Its shape is (N, T, C)
""" """
x = self.encoder_output_layer(x) x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -247,7 +247,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -312,7 +312,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -654,13 +654,13 @@ class PositionalEncoding(nn.Module):
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe is (N, T, d_model). If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done. to (N, T, d_model). Otherwise, nothing is done.
Args: Args:
x: x:
It is a tensor of shape [N, T, C]. It is a tensor of shape (N, T, C).
Returns: Returns:
Return None. Return None.
""" """
@ -678,7 +678,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1) # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -687,10 +687,10 @@ class PositionalEncoding(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, C] Its shape is (N, T, C)
Returns: Returns:
Return a tensor of shape [N, T, C] Return a tensor of shape (N, T, C)
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :] x = x * self.xscale + self.pe[:, : x.size(1), :]

View File

@ -190,12 +190,12 @@ def decode_one_batch(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
feature = feature.permute(0, 2, 1) # now feature is [N, C, T] feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
nnet_output = model(feature) nnet_output = model(feature)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]

View File

@ -218,11 +218,11 @@ def main():
features = pad_sequence( features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
features = features.permute(0, 2, 1) # now features is [N, C, T] features = features.permute(0, 2, 1) # now features is (N, C, T)
with torch.no_grad(): with torch.no_grad():
nnet_output = model(features) nnet_output = model(features)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
batch_size = nnet_output.shape[0] batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor( supervision_segments = torch.tensor(

View File

@ -290,14 +290,14 @@ def compute_loss(
""" """
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
feature = feature.permute(0, 2, 1) # now feature is [N, C, T] feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output = model(feature) nnet_output = model(feature)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by

View File

@ -111,10 +111,10 @@ def decode_one_batch(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
nnet_output = model(feature) nnet_output = model(feature)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
batch_size = nnet_output.shape[0] batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor( supervision_segments = torch.tensor(

View File

@ -268,13 +268,13 @@ def compute_loss(
""" """
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output = model(feature) nnet_output = model(feature)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by

View File

@ -78,7 +78,7 @@ def get_lattice(
network output. network output.
Args: Args:
nnet_output: nnet_output:
It is the output of a neural model of shape `[N, T, C]`. It is the output of a neural model of shape `(N, T, C)`.
HLG: HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`. An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments: supervision_segments:
@ -108,10 +108,12 @@ def get_lattice(
subsampling_factor: subsampling_factor:
The subsampling factor of the model. The subsampling factor of the model.
Returns: Returns:
A lattice containing the decoding result. An FsaVec containing the decoding result. It has axes [utt][state][arc].
""" """
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1 nnet_output,
supervision_segments,
allow_truncate=subsampling_factor - 1,
) )
lattice = k2.intersect_dense_pruned( lattice = k2.intersect_dense_pruned(
@ -138,6 +140,8 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
Args: Args:
symbol_ids: symbol_ids:
A list of symbol IDs (excluding 0 and -1) A list of symbol IDs (excluding 0 and -1)
Returns:
Return an Fsa (with 2 axes [state][arc]).
""" """
assert 0 not in symbol_ids assert 0 not in symbol_ids
assert -1 not in symbol_ids assert -1 not in symbol_ids