Skip to content

Commit 72b08f5

Browse files
authored
Merge pull request #1274 from HubSpot/chained-filters-optimization
Chained filters optimization
2 parents 0b610c9 + eb83265 commit 72b08f5

8 files changed

Lines changed: 1110 additions & 24 deletions

File tree

src/main/java/com/hubspot/jinjava/JinjavaConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public class JinjavaConfig {
8888
private final ExecutionMode executionMode;
8989
private final LegacyOverrides legacyOverrides;
9090
private final boolean enablePreciseDivideFilter;
91+
private final boolean enableFilterChainOptimization;
9192
private final ObjectMapper objectMapper;
9293

9394
private final Features features;
@@ -151,6 +152,7 @@ private JinjavaConfig(Builder builder) {
151152
legacyOverrides = builder.legacyOverrides;
152153
dateTimeProvider = builder.dateTimeProvider;
153154
enablePreciseDivideFilter = builder.enablePreciseDivideFilter;
155+
enableFilterChainOptimization = builder.enableFilterChainOptimization;
154156
objectMapper = setupObjectMapper(builder.objectMapper);
155157
objectUnwrapper = builder.objectUnwrapper;
156158
processors = builder.processors;
@@ -307,6 +309,10 @@ public boolean getEnablePreciseDivideFilter() {
307309
return enablePreciseDivideFilter;
308310
}
309311

312+
public boolean isEnableFilterChainOptimization() {
313+
return enableFilterChainOptimization;
314+
}
315+
310316
public DateTimeProvider getDateTimeProvider() {
311317
return dateTimeProvider;
312318
}
@@ -349,6 +355,7 @@ public static class Builder {
349355
private ExecutionMode executionMode = DefaultExecutionMode.instance();
350356
private LegacyOverrides legacyOverrides = LegacyOverrides.NONE;
351357
private boolean enablePreciseDivideFilter = false;
358+
private boolean enableFilterChainOptimization = false;
352359
private ObjectMapper objectMapper = null;
353360

354361
private ObjectUnwrapper objectUnwrapper = new JinjavaObjectUnwrapper();
@@ -520,6 +527,13 @@ public Builder withEnablePreciseDivideFilter(boolean enablePreciseDivideFilter)
520527
return this;
521528
}
522529

530+
public Builder withEnableFilterChainOptimization(
531+
boolean enableFilterChainOptimization
532+
) {
533+
this.enableFilterChainOptimization = enableFilterChainOptimization;
534+
return this;
535+
}
536+
523537
public Builder withObjectMapper(ObjectMapper objectMapper) {
524538
this.objectMapper = objectMapper;
525539
return this;
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package com.hubspot.jinjava.el.ext;
2+
3+
import com.hubspot.jinjava.interpret.DisabledException;
4+
import com.hubspot.jinjava.interpret.JinjavaInterpreter;
5+
import com.hubspot.jinjava.interpret.TemplateError;
6+
import com.hubspot.jinjava.interpret.TemplateError.ErrorItem;
7+
import com.hubspot.jinjava.interpret.TemplateError.ErrorReason;
8+
import com.hubspot.jinjava.interpret.TemplateError.ErrorType;
9+
import com.hubspot.jinjava.lib.filter.Filter;
10+
import com.hubspot.jinjava.objects.SafeString;
11+
import de.odysseus.el.tree.Bindings;
12+
import de.odysseus.el.tree.impl.ast.AstNode;
13+
import de.odysseus.el.tree.impl.ast.AstParameters;
14+
import de.odysseus.el.tree.impl.ast.AstRightValue;
15+
import java.util.ArrayList;
16+
import java.util.LinkedHashMap;
17+
import java.util.List;
18+
import java.util.Map;
19+
import java.util.Objects;
20+
import javax.el.ELContext;
21+
import javax.el.ELException;
22+
23+
/**
24+
* AST node for a chain of filters applied to an input expression.
25+
* Instead of creating nested AstMethod calls for each filter in a chain like:
26+
* filter:length.filter(filter:lower.filter(filter:trim.filter(input)))
27+
*
28+
* This node represents the entire chain as a single evaluation unit:
29+
* input|trim|lower|length
30+
*
31+
* This optimization reduces:
32+
* - Filter lookups (done once per filter instead of per AST node traversal)
33+
* - Method invocation overhead
34+
* - Object wrapping/unwrapping between filters
35+
* - Context operations
36+
*/
37+
public class AstFilterChain extends AstRightValue {
38+
39+
protected final AstNode input;
40+
protected final List<FilterSpec> filterSpecs;
41+
42+
public AstFilterChain(AstNode input, List<FilterSpec> filterSpecs) {
43+
this.input = Objects.requireNonNull(input, "Input node cannot be null");
44+
this.filterSpecs = Objects.requireNonNull(filterSpecs, "Filter specs cannot be null");
45+
if (filterSpecs.isEmpty()) {
46+
throw new IllegalArgumentException("Filter chain must have at least one filter");
47+
}
48+
}
49+
50+
public AstNode getInput() {
51+
return input;
52+
}
53+
54+
public List<FilterSpec> getFilterSpecs() {
55+
return filterSpecs;
56+
}
57+
58+
@Override
59+
public Object eval(Bindings bindings, ELContext context) {
60+
JinjavaInterpreter interpreter = getInterpreter(context);
61+
62+
if (interpreter.getContext().isValidationMode()) {
63+
return "";
64+
}
65+
66+
Object value = input.eval(bindings, context);
67+
68+
for (FilterSpec spec : filterSpecs) {
69+
String filterKey = ExtendedParser.FILTER_PREFIX + spec.getName();
70+
interpreter.getContext().addResolvedValue(filterKey);
71+
72+
Filter filter;
73+
try {
74+
filter = interpreter.getContext().getFilter(spec.getName());
75+
} catch (DisabledException e) {
76+
interpreter.addError(
77+
new TemplateError(
78+
ErrorType.FATAL,
79+
ErrorReason.DISABLED,
80+
ErrorItem.FILTER,
81+
e.getMessage(),
82+
spec.getName(),
83+
interpreter.getLineNumber(),
84+
-1,
85+
e
86+
)
87+
);
88+
return null;
89+
}
90+
if (filter == null) {
91+
return null;
92+
}
93+
94+
Object[] args = evaluateFilterArgs(spec, bindings, context);
95+
Map<String, Object> kwargs = extractNamedParams(args);
96+
Object[] positionalArgs = extractPositionalArgs(args);
97+
98+
boolean wasSafeString = value instanceof SafeString;
99+
if (wasSafeString) {
100+
value = value.toString();
101+
}
102+
103+
try {
104+
value = filter.filter(value, interpreter, positionalArgs, kwargs);
105+
} catch (ELException e) {
106+
throw e;
107+
} catch (RuntimeException e) {
108+
throw new ELException(
109+
String.format("Error in filter '%s': %s", spec.getName(), e.getMessage()),
110+
e
111+
);
112+
}
113+
114+
if (wasSafeString && filter.preserveSafeString() && value instanceof String) {
115+
value = new SafeString((String) value);
116+
}
117+
}
118+
119+
return value;
120+
}
121+
122+
protected JinjavaInterpreter getInterpreter(ELContext context) {
123+
return (JinjavaInterpreter) context
124+
.getELResolver()
125+
.getValue(context, null, ExtendedParser.INTERPRETER);
126+
}
127+
128+
protected Object[] evaluateFilterArgs(
129+
FilterSpec spec,
130+
Bindings bindings,
131+
ELContext context
132+
) {
133+
AstParameters params = spec.getParams();
134+
if (params == null || params.getCardinality() == 0) {
135+
return new Object[0];
136+
}
137+
138+
Object[] args = new Object[params.getCardinality()];
139+
for (int i = 0; i < params.getCardinality(); i++) {
140+
args[i] = params.getChild(i).eval(bindings, context);
141+
}
142+
return args;
143+
}
144+
145+
private Map<String, Object> extractNamedParams(Object[] args) {
146+
Map<String, Object> kwargs = new LinkedHashMap<>();
147+
for (Object arg : args) {
148+
if (arg instanceof NamedParameter) {
149+
NamedParameter namedParam = (NamedParameter) arg;
150+
kwargs.put(namedParam.getName(), namedParam.getValue());
151+
}
152+
}
153+
return kwargs;
154+
}
155+
156+
private Object[] extractPositionalArgs(Object[] args) {
157+
List<Object> positional = new ArrayList<>();
158+
for (Object arg : args) {
159+
if (!(arg instanceof NamedParameter)) {
160+
positional.add(arg);
161+
}
162+
}
163+
return positional.toArray();
164+
}
165+
166+
@Override
167+
public void appendStructure(StringBuilder builder, Bindings bindings) {
168+
input.appendStructure(builder, bindings);
169+
for (FilterSpec spec : filterSpecs) {
170+
builder.append('|').append(spec.getName());
171+
AstParameters params = spec.getParams();
172+
if (params != null && params.getCardinality() > 0) {
173+
params.appendStructure(builder, bindings);
174+
}
175+
}
176+
}
177+
178+
@Override
179+
public String toString() {
180+
StringBuilder sb = new StringBuilder();
181+
sb.append(input.toString());
182+
for (FilterSpec spec : filterSpecs) {
183+
sb.append('|').append(spec.toString());
184+
}
185+
return sb.toString();
186+
}
187+
188+
@Override
189+
public int getCardinality() {
190+
return 1 + filterSpecs.size();
191+
}
192+
193+
@Override
194+
public AstNode getChild(int i) {
195+
if (i == 0) {
196+
return input;
197+
}
198+
int filterIndex = i - 1;
199+
if (filterIndex < filterSpecs.size()) {
200+
FilterSpec spec = filterSpecs.get(filterIndex);
201+
return spec.getParams();
202+
}
203+
return null;
204+
}
205+
}

src/main/java/com/hubspot/jinjava/el/ext/ExtendedParser.java

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -531,30 +531,11 @@ protected AstNode value() throws ScanException, ParseException {
531531

532532
private AstNode parseOperators(AstNode left) throws ScanException, ParseException {
533533
if ("|".equals(getToken().getImage()) && lookahead(0).getSymbol() == IDENTIFIER) {
534-
AstNode v = left;
535-
536-
do {
537-
consumeToken(); // '|'
538-
String filterName = consumeToken().getImage();
539-
List<AstNode> filterParams = Lists.newArrayList(v, interpreter());
540-
541-
// optional filter args
542-
if (getToken().getSymbol() == Symbol.LPAREN) {
543-
AstParameters astParameters = params();
544-
for (int i = 0; i < astParameters.getCardinality(); i++) {
545-
filterParams.add(astParameters.getChild(i));
546-
}
547-
}
548-
549-
AstProperty filterProperty = createAstDot(
550-
identifier(FILTER_PREFIX + filterName),
551-
"filter",
552-
true
553-
);
554-
v = createAstMethod(filterProperty, createAstParameters(filterParams)); // function("filter:" + filterName, new AstParameters(filterParams));
555-
} while ("|".equals(getToken().getImage()));
556-
557-
return v;
534+
if (shouldUseFilterChainOptimization()) {
535+
return parseFiltersAsChain(left);
536+
} else {
537+
return parseFiltersAsNestedMethods(left);
538+
}
558539
} else if (
559540
"is".equals(getToken().getImage()) &&
560541
"not".equals(lookahead(0).getImage()) &&
@@ -577,6 +558,68 @@ protected AstParameters createAstParameters(List<AstNode> nodes) {
577558
return new AstParameters(nodes);
578559
}
579560

561+
protected AstFilterChain createAstFilterChain(
562+
AstNode input,
563+
List<FilterSpec> filterSpecs
564+
) {
565+
return new AstFilterChain(input, filterSpecs);
566+
}
567+
568+
private AstNode parseFiltersAsChain(AstNode left) throws ScanException, ParseException {
569+
List<FilterSpec> filterSpecs = new ArrayList<>();
570+
571+
do {
572+
consumeToken(); // '|'
573+
String filterName = consumeToken().getImage();
574+
AstParameters filterParams = null;
575+
576+
// optional filter args
577+
if (getToken().getSymbol() == Symbol.LPAREN) {
578+
filterParams = params();
579+
}
580+
581+
filterSpecs.add(new FilterSpec(filterName, filterParams));
582+
} while ("|".equals(getToken().getImage()));
583+
584+
return createAstFilterChain(left, filterSpecs);
585+
}
586+
587+
protected AstNode parseFiltersAsNestedMethods(AstNode left)
588+
throws ScanException, ParseException {
589+
AstNode v = left;
590+
591+
do {
592+
consumeToken(); // '|'
593+
String filterName = consumeToken().getImage();
594+
List<AstNode> filterParams = Lists.newArrayList(v, interpreter());
595+
596+
// optional filter args
597+
if (getToken().getSymbol() == Symbol.LPAREN) {
598+
AstParameters astParameters = params();
599+
for (int i = 0; i < astParameters.getCardinality(); i++) {
600+
filterParams.add(astParameters.getChild(i));
601+
}
602+
}
603+
604+
AstProperty filterProperty = createAstDot(
605+
identifier(FILTER_PREFIX + filterName),
606+
"filter",
607+
true
608+
);
609+
v = createAstMethod(filterProperty, createAstParameters(filterParams));
610+
} while ("|".equals(getToken().getImage()));
611+
612+
return v;
613+
}
614+
615+
protected boolean shouldUseFilterChainOptimization() {
616+
return JinjavaInterpreter
617+
.getCurrentMaybe()
618+
.map(JinjavaInterpreter::getConfig)
619+
.map(JinjavaConfig::isEnableFilterChainOptimization)
620+
.orElse(false);
621+
}
622+
580623
private boolean isPossibleExpTest(Symbol symbol) {
581624
return VALID_SYMBOLS_FOR_EXP_TEST.contains(symbol);
582625
}

0 commit comments

Comments
 (0)