147147# mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存
148148
149149
150+ v3v4set = {"v3" ,"v4" }
150151def set_default ():
151152 global \
152153 default_batch_size , \
@@ -181,23 +182,23 @@ def set_default():
181182 # minmem = 14
182183 # except RuntimeError as _:
183184 # print("显存不足以开启V3训练")
184- default_batch_size = minmem // 2 if version != "v3" else minmem // 8
185+ default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8
185186 default_batch_size_s1 = minmem // 2
186187 else :
187188 gpu_info = "%s\t %s" % ("0" , "CPU" )
188189 gpu_infos .append ("%s\t %s" % ("0" , "CPU" ))
189190 set_gpu_numbers .add (0 )
190191 default_batch_size = default_batch_size_s1 = int (psutil .virtual_memory ().total / 1024 / 1024 / 1024 / 4 )
191- if version != "v3" :
192+ if version not in v3v4set :
192193 default_sovits_epoch = 8
193194 default_sovits_save_every_epoch = 4
194195 max_sovits_epoch = 25 # 40
195196 max_sovits_save_every_epoch = 25 # 10
196197 else :
197198 default_sovits_epoch = 2
198199 default_sovits_save_every_epoch = 1
199- max_sovits_epoch = 3 # 40
200- max_sovits_save_every_epoch = 3 # 10
200+ max_sovits_epoch = 50 # 40 # 3
201+ max_sovits_save_every_epoch = 10 # 10 # 3
201202
202203 default_batch_size = max (1 , default_batch_size )
203204 default_batch_size_s1 = max (1 , default_batch_size_s1 )
@@ -233,11 +234,13 @@ def fix_gpu_numbers(inputs):
233234 "GPT_SoVITS/pretrained_models/s2G488k.pth" ,
234235 "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" ,
235236 "GPT_SoVITS/pretrained_models/s2Gv3.pth" ,
237+ "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" ,
236238]
237239pretrained_gpt_name = [
238240 "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" ,
239241 "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" ,
240242 "GPT_SoVITS/pretrained_models/s1v3.ckpt" ,
243+ "GPT_SoVITS/pretrained_models/s1v3.ckpt" ,
241244]
242245
243246pretrained_model_list = (
@@ -256,7 +259,7 @@ def fix_gpu_numbers(inputs):
256259 print ("warning: " , i18n ("以下模型不存在:" ) + _ )
257260
258261_ = [[], []]
259- for i in range (3 ):
262+ for i in range (4 ):
260263 if os .path .exists (pretrained_gpt_name [i ]):
261264 _ [0 ].append (pretrained_gpt_name [i ])
262265 else :
@@ -267,8 +270,8 @@ def fix_gpu_numbers(inputs):
267270 _ [- 1 ].append ("" )
268271pretrained_gpt_name , pretrained_sovits_name = _
269272
270- SoVITS_weight_root = ["SoVITS_weights" , "SoVITS_weights_v2" , "SoVITS_weights_v3" ]
271- GPT_weight_root = ["GPT_weights" , "GPT_weights_v2" , "GPT_weights_v3" ]
273+ SoVITS_weight_root = ["SoVITS_weights" , "SoVITS_weights_v2" , "SoVITS_weights_v3" , "SoVITS_weights_v4" ]
274+ GPT_weight_root = ["GPT_weights" , "GPT_weights_v2" , "GPT_weights_v3" , "GPT_weights_v4" ]
272275for root in SoVITS_weight_root + GPT_weight_root :
273276 os .makedirs (root , exist_ok = True )
274277
@@ -1287,7 +1290,6 @@ def close1abc():
12871290 {"__type__" : "update" , "visible" : False },
12881291 )
12891292
1290-
12911293def switch_version (version_ ):
12921294 os .environ ["version" ] = version_
12931295 global version
@@ -1306,15 +1308,15 @@ def switch_version(version_):
13061308 {"__type__" : "update" , "value" : default_batch_size , "maximum" : default_max_batch_size },
13071309 {"__type__" : "update" , "value" : default_sovits_epoch , "maximum" : max_sovits_epoch },
13081310 {"__type__" : "update" , "value" : default_sovits_save_every_epoch , "maximum" : max_sovits_save_every_epoch },
1309- {"__type__" : "update" , "visible" : True if version != "v3" else False },
1311+ {"__type__" : "update" , "visible" : True if version not in v3v4set else False },
13101312 {
13111313 "__type__" : "update" ,
13121314 "value" : False if not if_force_ckpt else True ,
13131315 "interactive" : True if not if_force_ckpt else False ,
13141316 },
13151317 {"__type__" : "update" , "interactive" : True , "value" : False },
1316- {"__type__" : "update" , "visible" : True if version == "v3" else False },
1317- ) # {'__type__': 'update', "interactive": False if version == "v3" else True, "value": False}, \ ####batch infer
1318+ {"__type__" : "update" , "visible" : True if version in v3v4set else False },
1319+ ) # {'__type__': 'update', "interactive": False if version in v3v4set else True, "value": False}, \ ####batch infer
13181320
13191321
13201322if os .path .exists ("GPT_SoVITS/text/G2PWModel" ):
@@ -1489,7 +1491,7 @@ def change_precision_choices(key): # 根据选择的模型修改可选的语言
14891491 with gr .Row ():
14901492 exp_name = gr .Textbox (label = i18n ("*实验/模型名" ), value = "xxx" , interactive = True )
14911493 gpu_info = gr .Textbox (label = i18n ("显卡信息" ), value = gpu_info , visible = True , interactive = False )
1492- version_checkbox = gr .Radio (label = i18n ("版本" ), value = version , choices = ["v1" , "v2" , "v3 " ])
1494+ version_checkbox = gr .Radio (label = i18n ("版本" ), value = version , choices = ["v1" , "v2" , "v4 " ])#, "v3"
14931495 with gr .Row ():
14941496 pretrained_s2G = gr .Textbox (
14951497 label = i18n ("预训练SoVITS-G模型路径" ),
@@ -1716,13 +1718,13 @@ def change_precision_choices(key): # 根据选择的模型修改可选的语言
17161718 step = 0.05 ,
17171719 label = i18n ("文本模块学习率权重" ),
17181720 value = 0.4 ,
1719- visible = True if version != "v3" else False ,
1721+ visible = True if version not in v3v4set else False ,
17201722 ) # v3 not need
17211723 lora_rank = gr .Radio (
17221724 label = i18n ("LoRA秩" ),
17231725 value = "32" ,
17241726 choices = ["16" , "32" , "64" , "128" ],
1725- visible = True if version == "v3" else False ,
1727+ visible = True if version in v3v4set else False ,
17261728 ) # v1v2 not need
17271729 save_every_epoch = gr .Slider (
17281730 minimum = 1 ,
@@ -1749,7 +1751,7 @@ def change_precision_choices(key): # 根据选择的模型修改可选的语言
17491751 if_grad_ckpt = gr .Checkbox (
17501752 label = "v3是否开启梯度检查点节省显存占用" ,
17511753 value = False ,
1752- interactive = True if version == "v3" else False ,
1754+ interactive = True if version in v3v4set else False ,
17531755 show_label = True ,
17541756 visible = False ,
17551757 ) # 只有V3s2可以用
0 commit comments