mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Minor fixes.
This commit is contained in:
parent
b7676ca1f2
commit
3e2dbc9ab5
@ -452,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
model_avg: nn.Module = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import warnings
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -79,9 +78,8 @@ class LstmEncoder(EncoderInterface):
|
|||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
|
||||||
assert x.size(1) == lengths.max().item(), (
|
assert x.size(1) == lengths.max().item(), (
|
||||||
x.size(1),
|
x.size(1),
|
||||||
lengths.max(),
|
lengths.max(),
|
||||||
|
@ -20,11 +20,9 @@
|
|||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/librispeech/ASR
|
cd icefall/egs/librispeech/ASR
|
||||||
python ./transducer_lstm/test_model.py
|
python ./transducer_lstm/test_encoder.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from train import get_encoder_model, get_params
|
from train import get_encoder_model, get_params
|
||||||
|
|
||||||
@ -47,9 +45,7 @@ def test_encoder_model():
|
|||||||
|
|
||||||
y, y_lens = encoder(x, x_lens)
|
y, y_lens = encoder(x, x_lens)
|
||||||
print(y.shape)
|
print(y.shape)
|
||||||
with warnings.catch_warnings():
|
expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
expected_y_lens = ((x_lens - 1) // 2 - 1) // 2
|
|
||||||
|
|
||||||
assert torch.all(torch.eq(y_lens, expected_y_lens)), (
|
assert torch.all(torch.eq(y_lens, expected_y_lens)), (
|
||||||
y_lens,
|
y_lens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user