|
147 | 147 | spikes = {} |
148 | 148 | for layer in set(network.layers): |
149 | 149 | spikes[layer] = Monitor( |
150 | | - network.layers[layer], state_vars=["s"], time=int(time / dt), device=device |
| 150 | + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True |
151 | 151 | ) |
152 | 152 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) |
153 | 153 |
|
|
165 | 165 | perf_ax = None |
166 | 166 | voltage_axes, voltage_ims = None, None |
167 | 167 |
|
168 | | -spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device) |
| 168 | +spike_record = [torch.zeros((batch_size, int(time / dt), n_neurons), device=device).to_sparse() for _ in range(update_interval // batch_size)] |
| 169 | +spike_record_idx = 0 |
169 | 170 |
|
170 | 171 | # Train the network. |
171 | 172 | print("\nBegin training...") |
|
197 | 198 | # Convert the array of labels into a tensor |
198 | 199 | label_tensor = torch.tensor(labels, device=device) |
199 | 200 |
|
| 201 | + spike_record_tensor = torch.cat(spike_record, dim=0) |
200 | 202 | # Get network predictions. |
201 | 203 | all_activity_pred = all_activity( |
202 | | - spikes=spike_record, assignments=assignments, n_labels=n_classes |
| 204 | + spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes |
203 | 205 | ) |
204 | 206 | proportion_pred = proportion_weighting( |
205 | | - spikes=spike_record, |
| 207 | + spikes=spike_record_tensor, |
206 | 208 | assignments=assignments, |
207 | 209 | proportions=proportions, |
208 | 210 | n_labels=n_classes, |
|
240 | 242 |
|
241 | 243 | # Assign labels to excitatory layer neurons. |
242 | 244 | assignments, proportions, rates = assign_labels( |
243 | | - spikes=spike_record, |
| 245 | + spikes=spike_record_tensor, |
244 | 246 | labels=label_tensor, |
245 | 247 | n_labels=n_classes, |
246 | 248 | rates=rates, |
|
261 | 263 |
|
262 | 264 | # Add to spikes recording. |
263 | 265 | s = spikes["Ae"].get("s").permute((1, 0, 2)) |
264 | | - spike_record[ |
265 | | - (step * batch_size) |
266 | | - % update_interval : (step * batch_size % update_interval) |
267 | | - + s.size(0) |
268 | | - ] = s |
| 266 | + spike_record[spike_record_idx] = s |
| 267 | + spike_record_idx += 1 |
| 268 | + if spike_record_idx == len(spike_record): |
| 269 | + spike_record_idx = 0 |
269 | 270 |
|
270 | 271 | # Get voltage recording. |
271 | 272 | exc_voltages = exc_voltage_monitor.get("v") |
|
0 commit comments