-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_abc.py
More file actions
101 lines (82 loc) · 3.56 KB
/
_abc.py
File metadata and controls
101 lines (82 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Proof of concept code for MEP 30: Automatic subplot management."""
import itertools as it
from abc import ABCMeta, abstractmethod
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
class GridStrategy(metaclass=ABCMeta):
"""
Static class used to compute grid arrangements given the number of subplots
you want to show. By default, it goes for a symmetrical arrangement that is
nearly square (nearly equal in both dimensions).
"""
def __init__(self, alignment="center"):
self.alignment = alignment
def get_grid(self, n, figure=None):
"""
Return a list of axes designed according to the strategy.
Grid arrangements are tuples with the same length as the number of rows,
and each element specifies the number of colums in the row.
Ex (2, 3, 2) leads to the shape
x x
x x x
x x
where each x would be a subplot.
If `figure` is None, creates a new figure.
"""
grid_arrangement = self.get_grid_arrangement(n)
return self.get_gridspec(grid_arrangement, figure)
@classmethod
@abstractmethod
def get_grid_arrangement(cls, n): # pragma: nocover
pass
def get_gridspec(self, grid_arrangement, figure=None):
nrows = len(grid_arrangement)
ncols = max(grid_arrangement)
# If it has justified alignment, will not be the same as the other alignments
if self.alignment == "justified":
return self._justified(nrows, grid_arrangement, figure)
else:
return self._ragged(nrows, ncols, grid_arrangement, figure)
def _justified(self, nrows, grid_arrangement, figure=None):
ax_specs = []
num_small_cols = np.lcm.reduce(grid_arrangement)
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, num_small_cols, figure=figure)
for r, row_cols in enumerate(grid_arrangement):
skip = num_small_cols // row_cols
for col in range(row_cols):
s = col * skip
e = s + skip
ax_specs.append(gs[r, s:e])
return ax_specs
def _ragged(self, nrows, ncols, grid_arrangement, figure=None):
if len(set(grid_arrangement)) > 1:
col_width = 2
else:
col_width = 1
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, ncols * col_width, figure=figure)
ax_specs = []
for r, row_cols in enumerate(grid_arrangement):
# This is the number of missing columns in this row. If some rows
# are a different width than others, the column width is 2 so every
# column skipped at the beginning is also a missing slot at the end.
if self.alignment == "left":
# This is left-justified (or possibly full justification)
# so no need to skip anything
skip = 0
elif self.alignment == "right":
# Skip two slots for every missing plot - right justified.
skip = (ncols - row_cols) * 2
else:
# Defaults to centered, as that is the default value for the class.
# Skip one for each missing column - centered
skip = ncols - row_cols
for col in range(row_cols):
s = skip + col * col_width
e = s + col_width
ax_specs.append(gs[r, s:e])
return ax_specs