1+ from __future__ import absolute_import , division , print_function , unicode_literals
2+ import sys ,os
3+ import traceback
4+ AP_BWE_main_dir_path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), 'AP_BWE_main' )
5+ sys .path .append (AP_BWE_main_dir_path )
6+ import glob
7+ import argparse
8+ import json
9+ from re import S
10+ import torch
11+ import numpy as np
12+ import torchaudio
13+ import time
14+ import torchaudio .functional as aF
15+ from attrdict import AttrDict
16+ from datasets1 .dataset import amp_pha_stft , amp_pha_istft
17+ from models .model import APNet_BWE_Model
18+ import soundfile as sf
19+ import matplotlib .pyplot as plt
20+ from rich .progress import track
21+
22+ class AP_BWE ():
23+ def __init__ (self ,device ,checkpoint_file = None ):
24+ if checkpoint_file == None :
25+ checkpoint_file = "%s/24kto48k/g_24kto48k.zip" % (AP_BWE_main_dir_path )
26+ if os .path .exists (checkpoint_file )== False :
27+ raise FileNotFoundError
28+ config_file = os .path .join (os .path .split (checkpoint_file )[0 ], 'config.json' )
29+ with open (config_file ) as f :data = f .read ()
30+ json_config = json .loads (data )
31+ h = AttrDict (json_config )
32+ model = APNet_BWE_Model (h ).to (device )
33+ state_dict = torch .load (checkpoint_file ,map_location = "cpu" ,weights_only = False )
34+ model .load_state_dict (state_dict ['generator' ])
35+ model .eval ()
36+ self .device = device
37+ self .model = model
38+ self .h = h
39+
40+ def __call__ (self , audio ,orig_sampling_rate ):
41+ with torch .no_grad ():
42+ # audio, orig_sampling_rate = torchaudio.load(inp_path)
43+ # audio = audio.to(self.device)
44+ audio = aF .resample (audio , orig_freq = orig_sampling_rate , new_freq = self .h .hr_sampling_rate )
45+ amp_nb , pha_nb , com_nb = amp_pha_stft (audio , self .h .n_fft , self .h .hop_size , self .h .win_size )
46+ amp_wb_g , pha_wb_g , com_wb_g = self .model (amp_nb , pha_nb )
47+ audio_hr_g = amp_pha_istft (amp_wb_g , pha_wb_g , self .h .n_fft , self .h .hop_size , self .h .win_size )
48+ # sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16')
49+ return audio_hr_g .squeeze ().cpu ().numpy (),self .h .hr_sampling_rate
0 commit comments