Skip to content

Commit 7d70852

Browse files
authored
fix precision auto detection
fix precision auto detection
1 parent dbf7702 commit 7d70852

1 file changed

Lines changed: 5 additions & 9 deletions

File tree

config.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def change_choices():
144144

145145
api_port = 9880
146146

147-
147+
#Thanks to the contribution of @Karasukaigan and @XXXXRT666
148148
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
149149
cpu = torch.device("cpu")
150150
cuda = torch.device(f"cuda:{idx}")
@@ -157,14 +157,10 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
157157
mem_gb = mem_bytes / (1024**3) + 0.4
158158
major, minor = capability
159159
sm_version = major + minor / 10.0
160-
is_16_series = bool(re.search(r"16\d{2}", name))
161-
if mem_gb < 4:
162-
return cpu, torch.float32, 0.0, 0.0
163-
if (sm_version >= 7.0 and sm_version != 7.5) or (5.3 <= sm_version <= 6.0):
164-
if is_16_series and sm_version == 7.5:
165-
return cuda, torch.float32, sm_version, mem_gb # 16系卡除外
166-
else:
167-
return cuda, torch.float16, sm_version, mem_gb
160+
is_16_series = bool(re.search(r"16\d{2}", name))and sm_version == 7.5
161+
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0
162+
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb
163+
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb
168164
return cpu, torch.float32, 0.0, 0.0
169165

170166

0 commit comments

Comments
 (0)