Fix building for cuda 12.6

This commit is contained in:
Fangjun Kuang 2025-01-31 15:44:12 +08:00
parent 3f79fbbd6d
commit c8b4ad639c

View File

@ -1,6 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
"""
See also
https://github.com/pytorch/test-infra/blob/main/.github/workflows/test_build_wheels_linux_with_cuda.yml
https://github.com/pytorch/test-infra/blob/main/.github/workflows/test_build_wheels_linux_without_cuda.yml
https://github.com/pytorch/test-infra/actions/workflows/test_build_wheels_linux_with_cuda.yml
"""
import argparse import argparse
import json import json
@ -318,6 +326,22 @@ def generate_build_matrix(enable_cuda, for_windows, for_macos, test_only_latest_
if c in ["10.1", "11.0"]: if c in ["10.1", "11.0"]:
# no docker image for cuda 10.1 and 11.0 # no docker image for cuda 10.1 and 11.0
continue continue
if version_ge(torch, "2.7.0") or (
version_ge(torch, "2.6.0") and c == "12.6"
):
# case 1: torch >= 2.7
# case 2: torch == 2.6.0 && cuda == 12.6
ans.append(
{
"torch": torch,
"python-version": p,
"cuda": c,
"image": f"pytorch/manylinux2_28-builder:cuda{c}",
}
)
continue
ans.append( ans.append(
{ {
"torch": torch, "torch": torch,