Skip to content

Commit d283ec8

Browse files
committed
fix(decoder): channel input from args not hard code.
1 parent ae885d8 commit d283ec8

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

scripts/network/models/basic/decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,14 @@ class ConvGRUDecoder(nn.Module):
142142
def __init__(self, pseudoimage_channels: int = 64, num_iters: int = 4):
143143
super().__init__()
144144

145-
self.offset_encoder = nn.Linear(3, 64)
145+
self.offset_encoder = nn.Linear(3, pseudoimage_channels)
146146

147147
# NOTE: voxel feature is hidden input, point offset is input, check paper's Fig. 3
148-
self.gru = ConvGRU(input_dim=64, hidden_dim=pseudoimage_channels*2)
148+
self.gru = ConvGRU(input_dim=pseudoimage_channels, hidden_dim=pseudoimage_channels*2)
149149

150150
self.decoder = nn.Sequential(
151-
nn.Linear(pseudoimage_channels*3, 32), nn.GELU(),
152-
nn.Linear(32, 3))
151+
nn.Linear(pseudoimage_channels*3, pseudoimage_channels//2), nn.GELU(),
152+
nn.Linear(pseudoimage_channels//2, 3))
153153
self.num_iters = num_iters
154154

155155
def forward_single(self, before_pseudoimage: torch.Tensor,

0 commit comments

Comments
 (0)