@@ -141,6 +141,14 @@ def reset_state_variables(self) -> None:
141141 Contains resetting logic for the connection.
142142 """
143143
144+ @abstractmethod
145+ def cast_dtype_if_needed (self , w , w_dtype ):
146+ if w .dtype != w_dtype :
147+ warnings .warn (f"Provided w has data type { w .dtype } but parameter w_dtype is { w_dtype } " )
148+ return w .to (dtype = w_dtype )
149+ else :
150+ return w
151+
144152
145153class AbstractMulticompartmentConnection (ABC , Module ):
146154 # language=rst
@@ -261,6 +269,7 @@ def __init__(
261269 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
262270 reduction : Optional [callable ] = None ,
263271 weight_decay : float = 0.0 ,
272+ w_dtype : torch .dtype = torch .float32 ,
264273 ** kwargs ,
265274 ) -> None :
266275 # language=rst
@@ -275,6 +284,7 @@ def __init__(
275284 :param reduction: Method for reducing parameter updates along the minibatch
276285 dimension.
277286 :param weight_decay: Constant multiple to decay weights by on each iteration.
287+ :param w_dtype: Data type for :code:`w` tensor
278288
279289 Keyword arguments:
280290
@@ -296,9 +306,11 @@ def __init__(
296306 w = torch .clamp (torch .rand (source .n , target .n ), self .wmin , self .wmax )
297307 else :
298308 w = self .wmin + torch .rand (source .n , target .n ) * (self .wmax - self .wmin )
309+ w = w .to (dtype = w_dtype )
299310 else :
300311 if (self .wmin != - np .inf ).any () or (self .wmax != np .inf ).any ():
301312 w = torch .clamp (torch .as_tensor (w ), self .wmin , self .wmax )
313+ w = self .cast_dtype_if_needed (w , w_dtype )
302314
303315 self .w = Parameter (w , requires_grad = False )
304316
@@ -525,6 +537,7 @@ def __init__(
525537 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
526538 reduction : Optional [callable ] = None ,
527539 weight_decay : float = 0.0 ,
540+ w_dtype : torch .dtype = torch .float32 ,
528541 ** kwargs ,
529542 ) -> None :
530543 # language=rst
@@ -543,6 +556,7 @@ def __init__(
543556 :param reduction: Method for reducing parameter updates along the minibatch
544557 dimension.
545558 :param weight_decay: Constant multiple to decay weights by on each iteration.
559+ :param w_dtype: Data type for :code:`w` tensor
546560
547561 Keyword arguments:
548562
@@ -595,9 +609,11 @@ def __init__(
595609 self .out_channels , self .in_channels , self .kernel_size
596610 )
597611 w += self .wmin
612+ w = w .to (dtype = w_dtype )
598613 else :
599614 if (self .wmin == - inf ).any () or (self .wmax == inf ).any ():
600615 w = torch .clamp (w , self .wmin , self .wmax )
616+ w = self .cast_dtype_if_needed (w , w_dtype )
601617
602618 self .w = Parameter (w , requires_grad = False )
603619 self .b = Parameter (
@@ -667,6 +683,7 @@ def __init__(
667683 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
668684 reduction : Optional [callable ] = None ,
669685 weight_decay : float = 0.0 ,
686+ w_dtype : torch .dtype = torch .float32 ,
670687 ** kwargs ,
671688 ) -> None :
672689 # language=rst
@@ -685,6 +702,7 @@ def __init__(
685702 :param reduction: Method for reducing parameter updates along the minibatch
686703 dimension.
687704 :param weight_decay: Constant multiple to decay weights by on each iteration.
705+ :param w_dtype: Data type for :code:`w` tensor
688706
689707 Keyword arguments:
690708
@@ -750,9 +768,11 @@ def __init__(
750768 self .out_channels , self .in_channels , * self .kernel_size
751769 )
752770 w += self .wmin
771+ w = w .to (dtype = w_dtype )
753772 else :
754773 if (self .wmin == - inf ).any () or (self .wmax == inf ).any ():
755774 w = torch .clamp (w , self .wmin , self .wmax )
775+ w = self .cast_dtype_if_needed (w , w_dtype )
756776
757777 self .w = Parameter (w , requires_grad = False )
758778 self .b = Parameter (
@@ -824,6 +844,7 @@ def __init__(
824844 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
825845 reduction : Optional [callable ] = None ,
826846 weight_decay : float = 0.0 ,
847+ w_dtype : torch .dtype = torch .float32 ,
827848 ** kwargs ,
828849 ) -> None :
829850 # language=rst
@@ -842,6 +863,7 @@ def __init__(
842863 :param reduction: Method for reducing parameter updates along the minibatch
843864 dimension.
844865 :param weight_decay: Constant multiple to decay weights by on each iteration.
866+ :param w_dtype: Data type for :code:`w` tensor
845867
846868 Keyword arguments:
847869
@@ -926,9 +948,11 @@ def __init__(
926948 self .out_channels , self .in_channels , * self .kernel_size
927949 )
928950 w += self .wmin
951+ w = w .to (dtype = w_dtype )
929952 else :
930953 if (self .wmin == - inf ).any () or (self .wmax == inf ).any ():
931954 w = torch .clamp (w , self .wmin , self .wmax )
955+ w = self .cast_dtype_if_needed (w , w_dtype )
932956
933957 self .w = Parameter (w , requires_grad = False )
934958 self .b = Parameter (
@@ -1276,6 +1300,7 @@ def __init__(
12761300 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
12771301 reduction : Optional [callable ] = None ,
12781302 weight_decay : float = 0.0 ,
1303+ w_dtype : torch .dtype = torch .float32 ,
12791304 ** kwargs ,
12801305 ) -> None :
12811306 # language=rst
@@ -1299,6 +1324,7 @@ def __init__(
12991324 :param reduction: Method for reducing parameter updates along the minibatch
13001325 dimension.
13011326 :param weight_decay: Constant multiple to decay weights by on each iteration.
1327+ :param w_dtype: Data type for :code:`w` tensor
13021328
13031329 Keyword arguments:
13041330
@@ -1378,10 +1404,11 @@ def __init__(
13781404 w = torch .clamp (w , self .wmin , self .wmax )
13791405 else :
13801406 w = self .wmin + w * (self .wmax - self .wmin )
1381-
1407+ w = w . to ( dtype = w_dtype )
13821408 else :
13831409 if (self .wmin != - np .inf ).any () or (self .wmax != np .inf ).any ():
13841410 w = torch .clamp (w , self .wmin , self .wmax )
1411+ w = self .cast_dtype_if_needed (w , w_dtype )
13851412
13861413 self .w = Parameter (w , requires_grad = False )
13871414
@@ -1456,6 +1483,7 @@ def __init__(
14561483 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
14571484 reduction : Optional [callable ] = None ,
14581485 weight_decay : float = 0.0 ,
1486+ w_dtype : torch .dtype = torch .float32 ,
14591487 ** kwargs ,
14601488 ) -> None :
14611489 """
@@ -1474,6 +1502,7 @@ def __init__(
14741502 In this case, their shape should be the same size as the connection weights.
14751503 :param reduction: Method for reducing parameter updates along the minibatch dimension.
14761504 :param weight_decay: Constant multiple to decay weights by on each iteration.
1505+ :param w_dtype: Data type for :code:`w` tensor
14771506 Keyword arguments:
14781507 :param LearningRule update_rule: Modifies connection parameters according to some rule.
14791508 :param torch.Tensor w: Strengths of synapses.
@@ -1507,12 +1536,14 @@ def __init__(
15071536 w = torch .rand (
15081537 self .in_channels , self .n_filters * self .conv_size , self .kernel_size
15091538 )
1539+ w = w .to (dtype = w_dtype )
15101540 else :
15111541 assert w .shape == (
15121542 self .in_channels ,
15131543 self .out_channels * self .conv_size ,
15141544 self .kernel_size ,
15151545 ), error
1546+ w = self .cast_dtype_if_needed (w , w_dtype )
15161547
15171548 if self .wmin != - np .inf or self .wmax != np .inf :
15181549 w = torch .clamp (w , self .wmin , self .wmax )
@@ -1588,6 +1619,7 @@ def __init__(
15881619 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
15891620 reduction : Optional [callable ] = None ,
15901621 weight_decay : float = 0.0 ,
1622+ w_dtype : torch .dtype = torch .float32 ,
15911623 ** kwargs ,
15921624 ) -> None :
15931625 """
@@ -1606,6 +1638,7 @@ def __init__(
16061638 In this case, their shape should be the same size as the connection weights.
16071639 :param reduction: Method for reducing parameter updates along the minibatch dimension.
16081640 :param weight_decay: Constant multiple to decay weights by on each iteration.
1641+ :param w_dtype: Data type for :code:`w` tensor
16091642 Keyword arguments:
16101643 :param LearningRule update_rule: Modifies connection parameters according to some rule.
16111644 :param torch.Tensor w: Strengths of synapses.
@@ -1649,12 +1682,14 @@ def __init__(
16491682 w = torch .rand (
16501683 self .in_channels , self .n_filters * self .conv_prod , self .kernel_prod
16511684 )
1685+ w = w .to (dtype = w_dtype )
16521686 else :
16531687 assert w .shape == (
16541688 self .in_channels ,
16551689 self .out_channels * self .conv_prod ,
16561690 self .kernel_prod ,
16571691 ), error
1692+ w = self .cast_dtype_if_needed (w , w_dtype )
16581693
16591694 if self .wmin != - np .inf or self .wmax != np .inf :
16601695 w = torch .clamp (w , self .wmin , self .wmax )
@@ -1731,6 +1766,7 @@ def __init__(
17311766 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
17321767 reduction : Optional [callable ] = None ,
17331768 weight_decay : float = 0.0 ,
1769+ w_dtype : torch .dtype = torch .float32 ,
17341770 ** kwargs ,
17351771 ) -> None :
17361772 """
@@ -1749,6 +1785,7 @@ def __init__(
17491785 In this case, their shape should be the same size as the connection weights.
17501786 :param reduction: Method for reducing parameter updates along the minibatch dimension.
17511787 :param weight_decay: Constant multiple to decay weights by on each iteration.
1788+ :param w_dtype: Data type for :code:`w` tensor
17521789 Keyword arguments:
17531790 :param LearningRule update_rule: Modifies connection parameters according to some rule.
17541791 :param torch.Tensor w: Strengths of synapses.
@@ -1794,12 +1831,14 @@ def __init__(
17941831 w = torch .rand (
17951832 self .in_channels , self .n_filters * self .conv_prod , self .kernel_prod
17961833 )
1834+ w = w .to (dtype = w_dtype )
17971835 else :
17981836 assert w .shape == (
17991837 self .in_channels ,
18001838 self .out_channels * self .conv_prod ,
18011839 self .kernel_prod ,
18021840 ), error
1841+ w = self .cast_dtype_if_needed (w , w_dtype )
18031842
18041843 if self .wmin != - np .inf or self .wmax != np .inf :
18051844 w = torch .clamp (w , self .wmin , self .wmax )
@@ -1875,6 +1914,7 @@ def __init__(
18751914 target : Nodes ,
18761915 nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
18771916 weight_decay : float = 0.0 ,
1917+ w_dtype : torch .dtype = torch .float32 ,
18781918 ** kwargs ,
18791919 ) -> None :
18801920 # language=rst
@@ -1886,6 +1926,7 @@ def __init__(
18861926 accepts a pair of tensors to individualize learning rates of each neuron.
18871927 In this case, their shape should be the same size as the connection weights.
18881928 :param weight_decay: Constant multiple to decay weights by on each iteration.
1929+ :param w_dtype: Data type for :code:`w` tensor
18891930 Keyword arguments:
18901931 :param LearningRule update_rule: Modifies connection parameters according to
18911932 some rule.
@@ -1904,10 +1945,11 @@ def __init__(
19041945 w = torch .clamp ((torch .randn (1 )[0 ] + 1 ) / 10 , self .wmin , self .wmax )
19051946 else :
19061947 w = self .wmin + ((torch .randn (1 )[0 ] + 1 ) / 10 ) * (self .wmax - self .wmin )
1948+ w = w .to (dtype = w_dtype )
19071949 else :
19081950 if (self .wmin == - np .inf ).any () or (self .wmax == np .inf ).any ():
19091951 w = torch .clamp (w , self .wmin , self .wmax )
1910-
1952+ w = self . cast_dtype_if_needed ( w , w_dtype )
19111953 self .w = Parameter (w , requires_grad = False )
19121954
19131955 def compute (self , s : torch .Tensor ) -> torch .Tensor :
0 commit comments