Skip to content

Commit 118d4f2

Browse files
committed
Added new feature tester
1 parent ed52fc5 commit 118d4f2

1 file changed

Lines changed: 362 additions & 0 deletions

File tree

examples/mnist/reservoir_delays.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import argparse
6+
import matplotlib.pyplot as plt
7+
from bindsnet import network
8+
9+
from torchvision import transforms
10+
from tqdm import tqdm
11+
12+
from bindsnet.analysis.plotting import (
13+
plot_input,
14+
plot_spikes,
15+
plot_voltages,
16+
plot_weights,
17+
)
18+
from bindsnet.datasets import MNIST
19+
from bindsnet.encoding import PoissonEncoder
20+
from bindsnet.network import Network
21+
from bindsnet.network.nodes import Input
22+
23+
# Build a simple two-layer, input-output network.
24+
from bindsnet.network.monitors import Monitor
25+
from bindsnet.network.nodes import LIFNodes
26+
from bindsnet.network.topology import MulticompartmentConnection
27+
28+
from bindsnet.network.topology_features import Delay, Mask, Probability, Weight
29+
from bindsnet.learning.MCC_learning import PostPre, MSTDP, NoOp
30+
31+
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument("--seed", type=int, default=0)
34+
parser.add_argument("--n_neurons", type=int, default=500)
35+
parser.add_argument("--n_epochs", type=int, default=500)
36+
parser.add_argument("--examples", type=int, default=500)
37+
parser.add_argument("--n_workers", type=int, default=-1)
38+
parser.add_argument("--time", type=int, default=250)
39+
parser.add_argument("--dt", type=int, default=1.0)
40+
parser.add_argument("--intensity", type=float, default=64)
41+
parser.add_argument("--progress_interval", type=int, default=10)
42+
parser.add_argument("--update_interval", type=int, default=250)
43+
parser.add_argument("--plot", dest="plot", action="store_true")
44+
parser.add_argument("--gpu", dest="gpu", action="store_true")
45+
parser.set_defaults(plot=False, gpu=True, train=True)
46+
47+
args = parser.parse_args()
48+
49+
seed = args.seed
50+
n_neurons = args.n_neurons
51+
n_epochs = args.n_epochs
52+
examples = args.examples
53+
n_workers = args.n_workers
54+
time = args.time
55+
dt = args.dt
56+
intensity = args.intensity
57+
progress_interval = args.progress_interval
58+
update_interval = args.update_interval
59+
train = args.train
60+
plot = args.plot
61+
gpu = args.gpu
62+
63+
np.random.seed(seed)
64+
torch.cuda.manual_seed_all(seed)
65+
torch.manual_seed(seed)
66+
67+
# Sets up Gpu use
68+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69+
if gpu and torch.cuda.is_available():
70+
torch.cuda.manual_seed_all(seed)
71+
else:
72+
torch.manual_seed(seed)
73+
device = "cpu"
74+
if gpu:
75+
gpu = False
76+
torch.set_num_threads(os.cpu_count() - 1)
77+
print("Running on Device = ", device)
78+
79+
80+
### Base model ###
81+
model = Network()
82+
model.to(device)
83+
84+
85+
### Layers ###
86+
input_l = Input(n=784, shape=(1, 28, 28), traces=True)
87+
output_l = LIFNodes(
88+
n=n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float), traces=True
89+
)
90+
91+
model.add_layer(input_l, name="X")
92+
model.add_layer(output_l, name="Y")
93+
94+
95+
## Connections ###
96+
# Initialize features
97+
weight_feature = Weight(name="my_weights", value=torch.rand(input_l.n, output_l.n))
98+
delay_feature = Delay(name="my_delay", value=torch.rand(input_l.n, output_l.n))
99+
100+
# Construct pipeline
101+
pl_in = [weight_feature, delay_feature]
102+
103+
# Add pipeline to a new connection
104+
input_con = MulticompartmentConnection(
105+
source=input_l, target=output_l, device=device, pipeline=pl_in
106+
)
107+
108+
109+
# Initialize features
110+
weight_feature = Weight(
111+
name="my_weights2",
112+
value=torch.randn(output_l.n, output_l.n),
113+
norm=1,
114+
nu=[0.001, 0.002],
115+
learning_rule=PostPre,
116+
)
117+
118+
# Construct pipeline
119+
pl_rec = [weight_feature]
120+
121+
# Add pipeline to a new connection
122+
recurrent_con = MulticompartmentConnection(
123+
source=output_l, target=output_l, device=device, pipeline=pl_rec
124+
)
125+
126+
model.add_connection(input_con, source="X", target="Y")
127+
model.add_connection(recurrent_con, source="Y", target="Y")
128+
129+
# Directs network to GPU
130+
if gpu:
131+
model.to("cuda")
132+
133+
### MNIST ###
134+
dataset = MNIST(
135+
PoissonEncoder(time=time, dt=dt),
136+
None,
137+
root=os.path.join("../../test", "..", "data", "MNIST"),
138+
download=True,
139+
transform=transforms.Compose(
140+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
141+
),
142+
)
143+
144+
145+
### Monitor setup ###
146+
inpt_axes = None
147+
inpt_ims = None
148+
spike_axes = None
149+
spike_ims = None
150+
weights_im = None
151+
weights_im2 = None
152+
voltage_ims = None
153+
voltage_axes = None
154+
spikes = {}
155+
voltages = {}
156+
for l in model.layers:
157+
spikes[l] = Monitor(model.layers[l], ["s"], time=time, device=device)
158+
model.add_monitor(spikes[l], name="%s_spikes" % l)
159+
160+
if type(model.layers[l]) != Input:
161+
voltages[l] = Monitor(model.layers[l], ["v"], time=time, device=device)
162+
model.add_monitor(voltages[l], name="%s_voltages" % l)
163+
164+
165+
### Running model on MNIST ###
166+
167+
# Create a dataloader to iterate and batch data
168+
dataloader = torch.utils.data.DataLoader(
169+
dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
170+
)
171+
172+
n_iters = examples
173+
174+
# Connection tunning
175+
pbar = tqdm(enumerate(dataloader))
176+
model.train(True)
177+
for i, dataPoint in pbar:
178+
if i > n_iters:
179+
break
180+
181+
# Extract & resize the MNIST samples image data for training
182+
# int(time / dt) -> length of spike train
183+
# 28 x 28 -> size of sample
184+
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
185+
label = dataPoint["label"]
186+
pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters))
187+
188+
# Run network on sample image
189+
model.run(inputs={"X": datum}, time=time, input_time_dim=1, reward=1.0)
190+
191+
# Plot spiking activity using monitors
192+
if plot:
193+
# inpt_axes, inpt_ims = plot_input(
194+
# dataPoint["image"].view(28, 28),
195+
# datum.view(int(time / dt), 784).sum(0).view(28, 28),
196+
# label=label,
197+
# axes=inpt_axes,
198+
# ims=inpt_ims,
199+
# )
200+
spike_ims, spike_axes = plot_spikes(
201+
{layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
202+
axes=spike_axes,
203+
ims=spike_ims,
204+
)
205+
voltage_ims, voltage_axes = plot_voltages(
206+
{layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
207+
ims=voltage_ims,
208+
axes=voltage_axes,
209+
)
210+
211+
plt.pause(1e-8)
212+
model.reset_state_variables()
213+
214+
# Run the model on the data for training the detactor.
215+
training_pairs = []
216+
pbar = tqdm(enumerate(dataloader))
217+
model.train(False)
218+
for i, dataPoint in pbar:
219+
if i > n_iters:
220+
break
221+
222+
# Extract & resize the MNIST samples image data for training
223+
# int(time / dt) -> length of spike train
224+
# 28 x 28 -> size of sample
225+
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
226+
label = dataPoint["label"]
227+
pbar.set_description_str("Data extraction progress: (%d / %d)" % (i, n_iters))
228+
229+
# Run network on sample image
230+
model.run(inputs={"X": datum}, time=time, input_time_dim=1, reward=1.0)
231+
training_pairs.append([spikes["Y"].get("s").sum(0), label])
232+
233+
# Plot spiking activity using monitors
234+
if plot:
235+
# inpt_axes, inpt_ims = plot_input(
236+
# dataPoint["image"].view(28, 28),
237+
# datum.view(int(time / dt), 784).sum(0).view(28, 28),
238+
# label=label,
239+
# axes=inpt_axes,
240+
# ims=inpt_ims,
241+
# )
242+
spike_ims, spike_axes = plot_spikes(
243+
{layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
244+
axes=spike_axes,
245+
ims=spike_ims,
246+
)
247+
voltage_ims, voltage_axes = plot_voltages(
248+
{layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
249+
ims=voltage_ims,
250+
axes=voltage_axes,
251+
)
252+
253+
plt.pause(1e-8)
254+
model.reset_state_variables()
255+
256+
257+
# TODO: Delete this portion for fully delay/prob-conn dependent learning
258+
### Classification ###
259+
260+
261+
# Define logistic regression model using PyTorch.
262+
# These neurons will take the reservoirs output as its input, and be trained to classify the images.
263+
class NN(nn.Module):
264+
def __init__(self, input_size, num_classes):
265+
super(NN, self).__init__()
266+
# h = int(input_size/2)
267+
self.linear_1 = nn.Linear(input_size, num_classes)
268+
# self.linear_1 = nn.Linear(input_size, h)
269+
# self.linear_2 = nn.Linear(h, num_classes)
270+
271+
def forward(self, x):
272+
out = torch.sigmoid(self.linear_1(x.float().view(-1)))
273+
# out = torch.sigmoid(self.linear_2(out))
274+
return out
275+
276+
277+
# Create and train logistic regression model on reservoir outputs.
278+
learning_model = NN(n_neurons, 10).to(device)
279+
criterion = torch.nn.MSELoss(reduction="sum")
280+
optimizer = torch.optim.SGD(learning_model.parameters(), lr=1e-4, momentum=0.9)
281+
282+
# Training the Model
283+
print("\n Training the read out")
284+
pbar = tqdm(enumerate(range(n_epochs)))
285+
for epoch, _ in pbar:
286+
avg_loss = 0
287+
288+
# Extract spike outputs from reservoir for a training sample
289+
# i -> Loop index
290+
# s -> Reservoir output spikes
291+
# l -> Image label
292+
for i, (s, l) in enumerate(training_pairs):
293+
# Reset gradients to 0
294+
optimizer.zero_grad()
295+
296+
# Run spikes through logistic regression model
297+
outputs = learning_model(s)
298+
299+
# Calculate MSE
300+
label = torch.zeros(1, 1, 10).float().to(device)
301+
label[0, 0, l] = 1.0
302+
loss = criterion(outputs.view(1, 1, -1), label)
303+
avg_loss += loss.data
304+
305+
# Optimize parameters
306+
loss.backward()
307+
optimizer.step()
308+
309+
pbar.set_description_str(
310+
"Epoch: %d/%d, Loss: %.4f"
311+
% (epoch + 1, n_epochs, avg_loss / len(training_pairs))
312+
)
313+
314+
# Run same simulation on reservoir with testing data instead of training data
315+
# (see training section for intuition)
316+
n_iters = examples
317+
test_pairs = []
318+
pbar = tqdm(enumerate(dataloader))
319+
for i, dataPoint in pbar:
320+
if i > n_iters:
321+
break
322+
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
323+
label = dataPoint["label"]
324+
pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters))
325+
326+
model.run(inputs={"X": datum}, time=time, input_time_dim=1)
327+
test_pairs.append([spikes["Y"].get("s").sum(0), label])
328+
329+
if plot:
330+
# inpt_axes, inpt_ims = plot_input(
331+
# dataPoint["image"].view(28, 28),
332+
# datum.view(time, 784).sum(0).view(28, 28),
333+
# label=label,
334+
# axes=inpt_axes,
335+
# ims=inpt_ims,
336+
# )
337+
spike_ims, spike_axes = plot_spikes(
338+
{layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
339+
axes=spike_axes,
340+
ims=spike_ims,
341+
)
342+
voltage_ims, voltage_axes = plot_voltages(
343+
{layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
344+
ims=voltage_ims,
345+
axes=voltage_axes,
346+
)
347+
348+
plt.pause(1e-8)
349+
model.reset_state_variables()
350+
351+
# Test learning model with previously trained logistic regression classifier
352+
correct, total = 0, 0
353+
for s, label in test_pairs:
354+
outputs = learning_model(s)
355+
_, predicted = torch.max(outputs.data.unsqueeze(0), 1)
356+
total += 1
357+
correct += int(predicted == label.long().to(device))
358+
359+
print(
360+
"\n Accuracy of the model on %d test images: %.2f %%"
361+
% (n_iters, 100 * correct / total)
362+
)

0 commit comments

Comments
 (0)