Skip to content

Commit 92961c3

Browse files
authored
支持24k音频超分48k采样率
支持24k音频超分48k采样率
1 parent 25a6829 commit 92961c3

1 file changed

Lines changed: 49 additions & 0 deletions

File tree

tools/audio_sr.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
import sys,os
3+
import traceback
4+
AP_BWE_main_dir_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'AP_BWE_main')
5+
sys.path.append(AP_BWE_main_dir_path)
6+
import glob
7+
import argparse
8+
import json
9+
from re import S
10+
import torch
11+
import numpy as np
12+
import torchaudio
13+
import time
14+
import torchaudio.functional as aF
15+
from attrdict import AttrDict
16+
from datasets1.dataset import amp_pha_stft, amp_pha_istft
17+
from models.model import APNet_BWE_Model
18+
import soundfile as sf
19+
import matplotlib.pyplot as plt
20+
from rich.progress import track
21+
22+
class AP_BWE():
23+
def __init__(self,device,checkpoint_file=None):
24+
if checkpoint_file==None:
25+
checkpoint_file="%s/24kto48k/g_24kto48k.zip"%(AP_BWE_main_dir_path)
26+
if os.path.exists(checkpoint_file)==False:
27+
raise FileNotFoundError
28+
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
29+
with open(config_file) as f:data = f.read()
30+
json_config = json.loads(data)
31+
h = AttrDict(json_config)
32+
model = APNet_BWE_Model(h).to(device)
33+
state_dict = torch.load(checkpoint_file,map_location="cpu",weights_only=False)
34+
model.load_state_dict(state_dict['generator'])
35+
model.eval()
36+
self.device=device
37+
self.model=model
38+
self.h=h
39+
40+
def __call__(self, audio,orig_sampling_rate):
41+
with torch.no_grad():
42+
# audio, orig_sampling_rate = torchaudio.load(inp_path)
43+
# audio = audio.to(self.device)
44+
audio = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.h.hr_sampling_rate)
45+
amp_nb, pha_nb, com_nb = amp_pha_stft(audio, self.h.n_fft, self.h.hop_size, self.h.win_size)
46+
amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb)
47+
audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size)
48+
# sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16')
49+
return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate

0 commit comments

Comments
 (0)