Skip to content

Commit 895fde4

Browse files
authored
dpo改实验性勾选而非必须。勾选后batch size自动减半。
dpo改实验性勾选而非必须。勾选后batch size自动减半。
1 parent 9b5231a commit 895fde4

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

GPT_SoVITS/AR/models/t2s_lightning_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
1212
from AR.modules.optim import ScaledAdam
1313

14-
1514
class Text2SemanticLightningModule(LightningModule):
1615
def __init__(self, config, output_dir, is_train=True):
1716
super().__init__()
@@ -35,7 +34,8 @@ def __init__(self, config, output_dir, is_train=True):
3534
def training_step(self, batch: Dict, batch_idx: int):
3635
opt = self.optimizers()
3736
scheduler = self.lr_schedulers()
38-
loss, acc = self.model.forward(
37+
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
38+
loss, acc = forward(
3939
batch["phoneme_ids"],
4040
batch["phoneme_ids_len"],
4141
batch["semantic_ids"],

0 commit comments

Comments
 (0)