Skip to content

Commit 7c3f34a

Browse files
authored
Merge pull request #212 from v3ucn/模型记忆功能
添加模型记忆功能,不用二次选择模型
2 parents 813cf96 + 0bcdf01 commit 7c3f34a

1 file changed

Lines changed: 21 additions & 4 deletions

File tree

GPT_SoVITS/inference_webui.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,25 @@
66
logging.getLogger("asyncio").setLevel(logging.ERROR)
77
import pdb
88

9-
gpt_path = os.environ.get(
10-
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
11-
)
12-
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
9+
if os.path.exists("./gweight.txt"):
10+
with open("./gweight.txt", 'r',encoding="utf-8") as file:
11+
gweight_data = file.read()
12+
gpt_path = os.environ.get(
13+
"gpt_path", gweight_data)
14+
else:
15+
gpt_path = os.environ.get(
16+
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
17+
18+
if os.path.exists("./sweight.txt"):
19+
with open("./sweight.txt", 'r',encoding="utf-8") as file:
20+
sweight_data = file.read()
21+
sovits_path = os.environ.get("sovits_path", sweight_data)
22+
else:
23+
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
24+
# gpt_path = os.environ.get(
25+
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
26+
# )
27+
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
1328
cnhubert_base_path = os.environ.get(
1429
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
1530
)
@@ -124,6 +139,7 @@ def change_sovits_weights(sovits_path):
124139
vq_model = vq_model.to(device)
125140
vq_model.eval()
126141
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
142+
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
127143
change_sovits_weights(sovits_path)
128144

129145
def change_gpt_weights(gpt_path):
@@ -140,6 +156,7 @@ def change_gpt_weights(gpt_path):
140156
t2s_model.eval()
141157
total = sum([param.nelement() for param in t2s_model.parameters()])
142158
print("Number of parameter: %.2fM" % (total / 1e6))
159+
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
143160
change_gpt_weights(gpt_path)
144161

145162
def get_spepc(hps, filename):

0 commit comments

Comments
 (0)