Skip to content

Commit 93c47cd

Browse files
authored
fix nan issue(causing sovits zerodivision)
fix nan issue(which will cause sovits zerodivision)
1 parent 948e7fc commit 93c47cd

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,36 +49,41 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
4949
alpha=0.5
5050
device="cuda:0"
5151
model=cnhubert.get_model()
52+
# is_half=False
5253
if(is_half==True):
5354
model=model.half().to(device)
5455
else:
5556
model = model.to(device)
57+
58+
nan_fails=[]
5659
def name2go(wav_name):
5760
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
5861
if(os.path.exists(hubert_path)):return
5962
wav_path="%s/%s"%(inp_wav_dir,wav_name)
6063
tmp_audio = load_audio(wav_path, 32000)
6164
tmp_max = np.abs(tmp_audio).max()
6265
if tmp_max > 2.2:
63-
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
66+
print("%s-filtered" % (wav_name, tmp_max))
6467
return
6568
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
6669
tmp_audio = librosa.resample(
6770
tmp_audio32, orig_sr=32000, target_sr=16000
68-
)
71+
)#不是重采样问题
6972
tensor_wav16 = torch.from_numpy(tmp_audio)
7073
if (is_half == True):
7174
tensor_wav16=tensor_wav16.half().to(device)
7275
else:
7376
tensor_wav16 = tensor_wav16.to(device)
7477
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
75-
if np.isnan(ssl.detach().numpy()).sum()!= 0:return
78+
if np.isnan(ssl.detach().numpy()).sum()!= 0:
79+
nan_fails.append(wav_name)
80+
print("nan filtered:%s"%wav_name)
81+
return
7682
wavfile.write(
7783
"%s/%s"%(wav32dir,wav_name),
7884
32000,
7985
tmp_audio32.astype("int16"),
8086
)
81-
# torch.save(ssl,hubert_path )
8287
my_save(ssl,hubert_path )
8388

8489
with open(inp_text,"r",encoding="utf8")as f:
@@ -92,3 +97,12 @@ def name2go(wav_name):
9297
name2go(wav_name)
9398
except:
9499
print(line,traceback.format_exc())
100+
101+
if(len(nan_fails)>0 and is_half==True):
102+
is_half=False
103+
model=model.float()
104+
for wav_name in nan_fails:
105+
try:
106+
name2go(wav_name)
107+
except:
108+
print(wav_name,traceback.format_exc())

0 commit comments

Comments
 (0)