Skip to content

Commit e0f9bf2

Browse files
committed
Lowering precision for old connections
1 parent 01bc1b0 commit e0f9bf2

1 file changed

Lines changed: 44 additions & 2 deletions

File tree

bindsnet/network/topology.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ def reset_state_variables(self) -> None:
141141
Contains resetting logic for the connection.
142142
"""
143143

144+
@abstractmethod
145+
def cast_dtype_if_needed(self, w, w_dtype):
146+
if w.dtype != w_dtype:
147+
warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}")
148+
return w.to(dtype=w_dtype)
149+
else:
150+
return w
151+
144152

145153
class AbstractMulticompartmentConnection(ABC, Module):
146154
# language=rst
@@ -261,6 +269,7 @@ def __init__(
261269
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
262270
reduction: Optional[callable] = None,
263271
weight_decay: float = 0.0,
272+
w_dtype: torch.dtype = torch.float32,
264273
**kwargs,
265274
) -> None:
266275
# language=rst
@@ -275,6 +284,7 @@ def __init__(
275284
:param reduction: Method for reducing parameter updates along the minibatch
276285
dimension.
277286
:param weight_decay: Constant multiple to decay weights by on each iteration.
287+
:param w_dtype: Data type for :code:`w` tensor
278288
279289
Keyword arguments:
280290
@@ -296,9 +306,11 @@ def __init__(
296306
w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
297307
else:
298308
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
309+
w = w.to(dtype=w_dtype)
299310
else:
300311
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
301312
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)
313+
w = self.cast_dtype_if_needed(w, w_dtype)
302314

303315
self.w = Parameter(w, requires_grad=False)
304316

@@ -525,6 +537,7 @@ def __init__(
525537
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
526538
reduction: Optional[callable] = None,
527539
weight_decay: float = 0.0,
540+
w_dtype: torch.dtype = torch.float32,
528541
**kwargs,
529542
) -> None:
530543
# language=rst
@@ -543,6 +556,7 @@ def __init__(
543556
:param reduction: Method for reducing parameter updates along the minibatch
544557
dimension.
545558
:param weight_decay: Constant multiple to decay weights by on each iteration.
559+
:param w_dtype: Data type for :code:`w` tensor
546560
547561
Keyword arguments:
548562
@@ -595,9 +609,11 @@ def __init__(
595609
self.out_channels, self.in_channels, self.kernel_size
596610
)
597611
w += self.wmin
612+
w = w.to(dtype=w_dtype)
598613
else:
599614
if (self.wmin == -inf).any() or (self.wmax == inf).any():
600615
w = torch.clamp(w, self.wmin, self.wmax)
616+
w = self.cast_dtype_if_needed(w, w_dtype)
601617

602618
self.w = Parameter(w, requires_grad=False)
603619
self.b = Parameter(
@@ -667,6 +683,7 @@ def __init__(
667683
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
668684
reduction: Optional[callable] = None,
669685
weight_decay: float = 0.0,
686+
w_dtype: torch.dtype = torch.float32,
670687
**kwargs,
671688
) -> None:
672689
# language=rst
@@ -685,6 +702,7 @@ def __init__(
685702
:param reduction: Method for reducing parameter updates along the minibatch
686703
dimension.
687704
:param weight_decay: Constant multiple to decay weights by on each iteration.
705+
:param w_dtype: Data type for :code:`w` tensor
688706
689707
Keyword arguments:
690708
@@ -750,9 +768,11 @@ def __init__(
750768
self.out_channels, self.in_channels, *self.kernel_size
751769
)
752770
w += self.wmin
771+
w = w.to(dtype=w_dtype)
753772
else:
754773
if (self.wmin == -inf).any() or (self.wmax == inf).any():
755774
w = torch.clamp(w, self.wmin, self.wmax)
775+
w = self.cast_dtype_if_needed(w, w_dtype)
756776

757777
self.w = Parameter(w, requires_grad=False)
758778
self.b = Parameter(
@@ -824,6 +844,7 @@ def __init__(
824844
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
825845
reduction: Optional[callable] = None,
826846
weight_decay: float = 0.0,
847+
w_dtype: torch.dtype = torch.float32,
827848
**kwargs,
828849
) -> None:
829850
# language=rst
@@ -842,6 +863,7 @@ def __init__(
842863
:param reduction: Method for reducing parameter updates along the minibatch
843864
dimension.
844865
:param weight_decay: Constant multiple to decay weights by on each iteration.
866+
:param w_dtype: Data type for :code:`w` tensor
845867
846868
Keyword arguments:
847869
@@ -926,9 +948,11 @@ def __init__(
926948
self.out_channels, self.in_channels, *self.kernel_size
927949
)
928950
w += self.wmin
951+
w = w.to(dtype=w_dtype)
929952
else:
930953
if (self.wmin == -inf).any() or (self.wmax == inf).any():
931954
w = torch.clamp(w, self.wmin, self.wmax)
955+
w = self.cast_dtype_if_needed(w, w_dtype)
932956

933957
self.w = Parameter(w, requires_grad=False)
934958
self.b = Parameter(
@@ -1276,6 +1300,7 @@ def __init__(
12761300
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
12771301
reduction: Optional[callable] = None,
12781302
weight_decay: float = 0.0,
1303+
w_dtype: torch.dtype = torch.float32,
12791304
**kwargs,
12801305
) -> None:
12811306
# language=rst
@@ -1299,6 +1324,7 @@ def __init__(
12991324
:param reduction: Method for reducing parameter updates along the minibatch
13001325
dimension.
13011326
:param weight_decay: Constant multiple to decay weights by on each iteration.
1327+
:param w_dtype: Data type for :code:`w` tensor
13021328
13031329
Keyword arguments:
13041330
@@ -1378,10 +1404,11 @@ def __init__(
13781404
w = torch.clamp(w, self.wmin, self.wmax)
13791405
else:
13801406
w = self.wmin + w * (self.wmax - self.wmin)
1381-
1407+
w = w.to(dtype=w_dtype)
13821408
else:
13831409
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
13841410
w = torch.clamp(w, self.wmin, self.wmax)
1411+
w = self.cast_dtype_if_needed(w, w_dtype)
13851412

13861413
self.w = Parameter(w, requires_grad=False)
13871414

@@ -1456,6 +1483,7 @@ def __init__(
14561483
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
14571484
reduction: Optional[callable] = None,
14581485
weight_decay: float = 0.0,
1486+
w_dtype: torch.dtype = torch.float32,
14591487
**kwargs,
14601488
) -> None:
14611489
"""
@@ -1474,6 +1502,7 @@ def __init__(
14741502
In this case, their shape should be the same size as the connection weights.
14751503
:param reduction: Method for reducing parameter updates along the minibatch dimension.
14761504
:param weight_decay: Constant multiple to decay weights by on each iteration.
1505+
:param w_dtype: Data type for :code:`w` tensor
14771506
Keyword arguments:
14781507
:param LearningRule update_rule: Modifies connection parameters according to some rule.
14791508
:param torch.Tensor w: Strengths of synapses.
@@ -1507,12 +1536,14 @@ def __init__(
15071536
w = torch.rand(
15081537
self.in_channels, self.n_filters * self.conv_size, self.kernel_size
15091538
)
1539+
w = w.to(dtype=w_dtype)
15101540
else:
15111541
assert w.shape == (
15121542
self.in_channels,
15131543
self.out_channels * self.conv_size,
15141544
self.kernel_size,
15151545
), error
1546+
w = self.cast_dtype_if_needed(w, w_dtype)
15161547

15171548
if self.wmin != -np.inf or self.wmax != np.inf:
15181549
w = torch.clamp(w, self.wmin, self.wmax)
@@ -1588,6 +1619,7 @@ def __init__(
15881619
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
15891620
reduction: Optional[callable] = None,
15901621
weight_decay: float = 0.0,
1622+
w_dtype: torch.dtype = torch.float32,
15911623
**kwargs,
15921624
) -> None:
15931625
"""
@@ -1606,6 +1638,7 @@ def __init__(
16061638
In this case, their shape should be the same size as the connection weights.
16071639
:param reduction: Method for reducing parameter updates along the minibatch dimension.
16081640
:param weight_decay: Constant multiple to decay weights by on each iteration.
1641+
:param w_dtype: Data type for :code:`w` tensor
16091642
Keyword arguments:
16101643
:param LearningRule update_rule: Modifies connection parameters according to some rule.
16111644
:param torch.Tensor w: Strengths of synapses.
@@ -1649,12 +1682,14 @@ def __init__(
16491682
w = torch.rand(
16501683
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
16511684
)
1685+
w = w.to(dtype=w_dtype)
16521686
else:
16531687
assert w.shape == (
16541688
self.in_channels,
16551689
self.out_channels * self.conv_prod,
16561690
self.kernel_prod,
16571691
), error
1692+
w = self.cast_dtype_if_needed(w, w_dtype)
16581693

16591694
if self.wmin != -np.inf or self.wmax != np.inf:
16601695
w = torch.clamp(w, self.wmin, self.wmax)
@@ -1731,6 +1766,7 @@ def __init__(
17311766
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
17321767
reduction: Optional[callable] = None,
17331768
weight_decay: float = 0.0,
1769+
w_dtype: torch.dtype = torch.float32,
17341770
**kwargs,
17351771
) -> None:
17361772
"""
@@ -1749,6 +1785,7 @@ def __init__(
17491785
In this case, their shape should be the same size as the connection weights.
17501786
:param reduction: Method for reducing parameter updates along the minibatch dimension.
17511787
:param weight_decay: Constant multiple to decay weights by on each iteration.
1788+
:param w_dtype: Data type for :code:`w` tensor
17521789
Keyword arguments:
17531790
:param LearningRule update_rule: Modifies connection parameters according to some rule.
17541791
:param torch.Tensor w: Strengths of synapses.
@@ -1794,12 +1831,14 @@ def __init__(
17941831
w = torch.rand(
17951832
self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod
17961833
)
1834+
w = w.to(dtype=w_dtype)
17971835
else:
17981836
assert w.shape == (
17991837
self.in_channels,
18001838
self.out_channels * self.conv_prod,
18011839
self.kernel_prod,
18021840
), error
1841+
w = self.cast_dtype_if_needed(w, w_dtype)
18031842

18041843
if self.wmin != -np.inf or self.wmax != np.inf:
18051844
w = torch.clamp(w, self.wmin, self.wmax)
@@ -1875,6 +1914,7 @@ def __init__(
18751914
target: Nodes,
18761915
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
18771916
weight_decay: float = 0.0,
1917+
w_dtype: torch.dtype = torch.float32,
18781918
**kwargs,
18791919
) -> None:
18801920
# language=rst
@@ -1886,6 +1926,7 @@ def __init__(
18861926
accepts a pair of tensors to individualize learning rates of each neuron.
18871927
In this case, their shape should be the same size as the connection weights.
18881928
:param weight_decay: Constant multiple to decay weights by on each iteration.
1929+
:param w_dtype: Data type for :code:`w` tensor
18891930
Keyword arguments:
18901931
:param LearningRule update_rule: Modifies connection parameters according to
18911932
some rule.
@@ -1904,10 +1945,11 @@ def __init__(
19041945
w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
19051946
else:
19061947
w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
1948+
w = w.to(dtype=w_dtype)
19071949
else:
19081950
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
19091951
w = torch.clamp(w, self.wmin, self.wmax)
1910-
1952+
w = self.cast_dtype_if_needed(w, w_dtype)
19111953
self.w = Parameter(w, requires_grad=False)
19121954

19131955
def compute(self, s: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)