@@ -144,7 +144,7 @@ def change_choices():
144144
145145api_port = 9880
146146
147-
147+ #Thanks to the contribution of @Karasukaigan and @XXXXRT666
148148def get_device_dtype_sm (idx : int ) -> tuple [torch .device , torch .dtype , float , float ]:
149149 cpu = torch .device ("cpu" )
150150 cuda = torch .device (f"cuda:{ idx } " )
@@ -157,14 +157,10 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
157157 mem_gb = mem_bytes / (1024 ** 3 ) + 0.4
158158 major , minor = capability
159159 sm_version = major + minor / 10.0
160- is_16_series = bool (re .search (r"16\d{2}" , name ))
161- if mem_gb < 4 :
162- return cpu , torch .float32 , 0.0 , 0.0
163- if (sm_version >= 7.0 and sm_version != 7.5 ) or (5.3 <= sm_version <= 6.0 ):
164- if is_16_series and sm_version == 7.5 :
165- return cuda , torch .float32 , sm_version , mem_gb # 16系卡除外
166- else :
167- return cuda , torch .float16 , sm_version , mem_gb
160+ is_16_series = bool (re .search (r"16\d{2}" , name ))and sm_version == 7.5
161+ if mem_gb < 4 or sm_version < 5.3 :return cpu , torch .float32 , 0.0 , 0.0
162+ if sm_version == 6.1 or is_16_series == True :return cuda , torch .float32 , sm_version , mem_gb
163+ if sm_version > 6.1 :return cuda , torch .float16 , sm_version , mem_gb
168164 return cpu , torch .float32 , 0.0 , 0.0
169165
170166
0 commit comments