@@ -312,7 +312,11 @@ def merge_short_text_in_array(texts, threshold):
312312 result [len (result ) - 1 ] += text
313313 return result
314314
315- def get_tts_wav (ref_wav_path , prompt_text , prompt_language , text , text_language , how_to_cut = i18n ("不切" ), top_k = 20 , top_p = 0.6 , temperature = 0.6 , ref_free = False ,speed = 1 ):
315+ ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
316+ # cache_tokens={}#暂未实现清理机制
317+ cache = None
318+ def get_tts_wav (ref_wav_path , prompt_text , prompt_language , text , text_language , how_to_cut = i18n ("不切" ), top_k = 20 , top_p = 0.6 , temperature = 0.6 , ref_free = False ,speed = 1 ,if_freeze = False ):
319+ global cache
316320 if prompt_text is None or len (prompt_text ) == 0 :
317321 ref_free = True
318322 t0 = ttime ()
@@ -395,38 +399,30 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
395399 all_phoneme_len = torch .tensor ([all_phoneme_ids .shape [- 1 ]]).to (device )
396400
397401 t2 = ttime ()
398- with torch .no_grad ():
399- # pred_semantic = t2s_model.model.infer(
400- pred_semantic , idx = t2s_model .model .infer_panel (
401- all_phoneme_ids ,
402- all_phoneme_len ,
403- None if ref_free else prompt ,
404- bert ,
405- # prompt_phone_len=ph_offset,
406- top_k = top_k ,
407- top_p = top_p ,
408- temperature = temperature ,
409- early_stop_num = hz * max_sec ,
410- )
402+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
403+ if (type (cache )!= type (None )and if_freeze == True ):pred_semantic = cache
404+ else :
405+ with torch .no_grad ():
406+ pred_semantic , idx = t2s_model .model .infer_panel (
407+ all_phoneme_ids ,
408+ all_phoneme_len ,
409+ None if ref_free else prompt ,
410+ bert ,
411+ # prompt_phone_len=ph_offset,
412+ top_k = top_k ,
413+ top_p = top_p ,
414+ temperature = temperature ,
415+ early_stop_num = hz * max_sec ,
416+ )
417+ pred_semantic = pred_semantic [:, - idx :].unsqueeze (0 )
418+ cache = pred_semantic
411419 t3 = ttime ()
412- # print(pred_semantic.shape,idx)
413- pred_semantic = pred_semantic [:, - idx :].unsqueeze (
414- 0
415- ) # .unsqueeze(0)#mq要多unsqueeze一次
416420 refer = get_spepc (hps , ref_wav_path ) # .to(device)
417421 if is_half == True :
418422 refer = refer .half ().to (device )
419423 else :
420424 refer = refer .to (device )
421- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
422- audio = (
423- vq_model .decode (
424- pred_semantic , torch .LongTensor (phones2 ).to (device ).unsqueeze (0 ), refer ,speed = speed
425- )
426- .detach ()
427- .cpu ()
428- .numpy ()[0 , 0 ]
429- ) ###试试重建不带上prompt部分
425+ audio = (vq_model .decode (pred_semantic , torch .LongTensor (phones2 ).to (device ).unsqueeze (0 ), refer ,speed = speed ).detach ().cpu ().numpy ()[0 , 0 ])
430426 max_audio = np .abs (audio ).max ()#简单防止16bit爆音
431427 if max_audio > 1 :audio /= max_audio
432428 audio_opt .append (audio )
@@ -611,29 +607,36 @@ def get_weights_names():
611607 )
612608 gr .Markdown (value = i18n ("*请填写需要合成的目标文本和语种模式" ))
613609 with gr .Row ():
614- text = gr .Textbox (label = i18n ("需要合成的文本" ), value = "" )
615- text_language = gr .Dropdown (
616- label = i18n ("需要合成的语种" ), choices = [i18n ("中文" ), i18n ("英文" ), i18n ("日文" ), i18n ("中英混合" ), i18n ("日英混合" ), i18n ("多语种混合" )], value = i18n ("中文" )
617- )
618- how_to_cut = gr .Radio (
619- label = i18n ("怎么切" ),
620- choices = [i18n ("不切" ), i18n ("凑四句一切" ), i18n ("凑50字一切" ), i18n ("按中文句号。切" ), i18n ("按英文句号.切" ), i18n ("按标点符号切" ), ],
621- value = i18n ("凑四句一切" ),
622- interactive = True ,
623- )
624- with gr .Row ():
625- gr .Markdown (value = i18n ("gpt采样参数(无参考文本时不要太低):" ))
610+ with gr .Column ():
611+ text = gr .Textbox (label = i18n ("需要合成的文本" ), value = "" )
612+ text_language = gr .Dropdown (
613+ label = i18n ("需要合成的语种" ), choices = [i18n ("中文" ), i18n ("英文" ), i18n ("日文" ), i18n ("中英混合" ), i18n ("日英混合" ), i18n ("多语种混合" )], value = i18n ("中文" )
614+ )
615+ how_to_cut = gr .Radio (
616+ label = i18n ("怎么切" ),
617+ choices = [i18n ("不切" ), i18n ("凑四句一切" ), i18n ("凑50字一切" ), i18n ("按中文句号。切" ), i18n ("按英文句号.切" ), i18n ("按标点符号切" ), ],
618+ value = i18n ("凑四句一切" ),
619+ interactive = True ,
620+ )
621+ with gr .Column ():
622+ gr .Markdown (value = i18n ("gpt采样参数(无参考文本时不要太低。不懂就用默认):" ))
626623 top_k = gr .Slider (minimum = 1 ,maximum = 100 ,step = 1 ,label = i18n ("top_k" ),value = 10 ,interactive = True )
627624 top_p = gr .Slider (minimum = 0 ,maximum = 1 ,step = 0.05 ,label = i18n ("top_p" ),value = 1 ,interactive = True )
628625 temperature = gr .Slider (minimum = 0 ,maximum = 1 ,step = 0.05 ,label = i18n ("temperature" ),value = 1 ,interactive = True )
629- with gr .Row ():
630- speed = gr .Slider (minimum = 0.5 ,maximum = 2 ,step = 0.05 ,label = i18n ("speed" ),value = 1 ,interactive = True )
626+ with gr .Column ():
627+ gr .Markdown (value = i18n ("语速调整,高为更快" ))
628+ if_freeze = gr .Checkbox (label = i18n ("是否直接对上次合成结果调整语速。防止随机性。" ), value = False , interactive = True , show_label = True )
629+ speed = gr .Slider (minimum = 0.6 ,maximum = 1.65 ,step = 0.05 ,label = i18n ("语速" ),value = 1 ,interactive = True )
630+ # with gr.Column():
631+ # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
632+ # phoneme=gr.Textbox(label=i18n("音素框"), value="")
633+ # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
631634 inference_button = gr .Button (i18n ("合成语音" ), variant = "primary" )
632635 output = gr .Audio (label = i18n ("输出的语音" ))
633636
634637 inference_button .click (
635638 get_tts_wav ,
636- [inp_ref , prompt_text , prompt_language , text , text_language , how_to_cut , top_k , top_p , temperature , ref_text_free ,speed ],
639+ [inp_ref , prompt_text , prompt_language , text , text_language , how_to_cut , top_k , top_p , temperature , ref_text_free ,speed , if_freeze ],
637640 [output ],
638641 )
639642
0 commit comments