@@ -49,36 +49,41 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
4949alpha = 0.5
5050device = "cuda:0"
5151model = cnhubert .get_model ()
52+ # is_half=False
5253if (is_half == True ):
5354 model = model .half ().to (device )
5455else :
5556 model = model .to (device )
57+
58+ nan_fails = []
5659def name2go (wav_name ):
5760 hubert_path = "%s/%s.pt" % (hubert_dir ,wav_name )
5861 if (os .path .exists (hubert_path )):return
5962 wav_path = "%s/%s" % (inp_wav_dir ,wav_name )
6063 tmp_audio = load_audio (wav_path , 32000 )
6164 tmp_max = np .abs (tmp_audio ).max ()
6265 if tmp_max > 2.2 :
63- print ("%s-%s-%s- filtered" % (idx0 , idx1 , tmp_max ))
66+ print ("%s-filtered" % (wav_name , tmp_max ))
6467 return
6568 tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768 )) + ((1 - alpha )* 32768 ) * tmp_audio
6669 tmp_audio = librosa .resample (
6770 tmp_audio32 , orig_sr = 32000 , target_sr = 16000
68- )
71+ )#不是重采样问题
6972 tensor_wav16 = torch .from_numpy (tmp_audio )
7073 if (is_half == True ):
7174 tensor_wav16 = tensor_wav16 .half ().to (device )
7275 else :
7376 tensor_wav16 = tensor_wav16 .to (device )
7477 ssl = model .model (tensor_wav16 .unsqueeze (0 ))["last_hidden_state" ].transpose (1 ,2 ).cpu ()#torch.Size([1, 768, 215])
75- if np .isnan (ssl .detach ().numpy ()).sum ()!= 0 :return
78+ if np .isnan (ssl .detach ().numpy ()).sum ()!= 0 :
79+ nan_fails .append (wav_name )
80+ print ("nan filtered:%s" % wav_name )
81+ return
7682 wavfile .write (
7783 "%s/%s" % (wav32dir ,wav_name ),
7884 32000 ,
7985 tmp_audio32 .astype ("int16" ),
8086 )
81- # torch.save(ssl,hubert_path )
8287 my_save (ssl ,hubert_path )
8388
8489with open (inp_text ,"r" ,encoding = "utf8" )as f :
@@ -92,3 +97,12 @@ def name2go(wav_name):
9297 name2go (wav_name )
9398 except :
9499 print (line ,traceback .format_exc ())
100+
101+ if (len (nan_fails )> 0 and is_half == True ):
102+ is_half = False
103+ model = model .float ()
104+ for wav_name in nan_fails :
105+ try :
106+ name2go (wav_name )
107+ except :
108+ print (wav_name ,traceback .format_exc ())
0 commit comments