diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 491477e82..c3cc91ee2 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -154,9 +154,8 @@ def __init__( ) self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) - def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: - # B, T, C - x = input.transpose_(1, 2) + def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: + # B, C, T y = self.conv_in(x) del x for f in self.decoder_block: @@ -164,7 +163,7 @@ def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: x = self.conv_out(y) del y - return x.transpose_(1, 2) + return x class DVAE(nn.Module): @@ -214,8 +213,8 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: dec_out = self.out_conv( self.decoder( - input=vq_feats.transpose_(1, 2), - ).transpose_(1, 2), + x=vq_feats, + ), ) return torch.mul(dec_out, self.coef, out=dec_out)