Skip to content

Commit ed89a02

Browse files
authored
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
1 parent cd6de73 commit ed89a02

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

GPT_SoVITS/module/modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import math
2+
import pdb
3+
24
import numpy as np
35
import torch
46
from torch import nn
@@ -716,11 +718,11 @@ def temporal_avg_pool(self, x, mask=None):
716718
if mask is None:
717719
out = torch.mean(x, dim=1)
718720
else:
719-
len_ = (~mask).sum()
721+
len_ = (~mask).sum(dim=1).unsqueeze(1)
720722
x = x.masked_fill(mask.unsqueeze(-1), 0)
721723
dtype=x.dtype
722724
x = x.float()
723-
x=torch.div(x,len_)
725+
x=torch.div(x,len_.unsqueeze(1))
724726
out=x.sum(dim=1).to(dtype)
725727
return out
726728

0 commit comments

Comments
 (0)