From e04ec2fcceea81813d7f1bb27fbc544056b04d9e Mon Sep 17 00:00:00 2001 From: charSLee013 Date: Fri, 28 Jun 2024 18:01:20 +0800 Subject: [PATCH] refactor: remove unnecessary transpositions (#488) --- ChatTTS/model/dvae.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 491477e82..c3cc91ee2 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -154,9 +154,8 @@ def __init__( ) self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) - def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: - # B, T, C - x = input.transpose_(1, 2) + def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: + # B, C, T y = self.conv_in(x) del x for f in self.decoder_block: @@ -164,7 +163,7 @@ def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: x = self.conv_out(y) del y - return x.transpose_(1, 2) + return x class DVAE(nn.Module): @@ -214,8 +213,8 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: dec_out = self.out_conv( self.decoder( - input=vq_feats.transpose_(1, 2), - ).transpose_(1, 2), + x=vq_feats, + ), ) return torch.mul(dec_out, self.coef, out=dec_out)