Skip to content

Commit 1b1fa93

Browse files
committed
Fix: Respect user-specified color_continuous_scale when template has autocolorscale=True
When a user explicitly provides color_continuous_scale, it should always be respected, even if the template has coloraxis_autocolorscale=True. Previously, the template's autocolorscale setting would override the user's explicit colorscale. The fix tracks whether color_continuous_scale was explicitly provided by the user (before apply_default_cascade fills it from template/defaults), and only sets autocolorscale=False when the user explicitly provided a colorscale. This preserves automatic diverging palette selection when colorscale comes from template/defaults. Changes: - plotly/express/_core.py: Track user_provided_colorscale and conditionally set autocolorscale=False only when user explicitly provides colorscale - plotly/express/_imshow.py: Same fix for imshow() function
1 parent f083977 commit 1b1fa93

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

plotly/express/_core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,12 @@ def get_groups_and_orders(args, grouper):
24862486
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
24872487
trace_patch = trace_patch or {}
24882488
layout_patch = layout_patch or {}
2489+
# Track if color_continuous_scale was explicitly provided by user
2490+
# (before apply_default_cascade fills it from template/defaults)
2491+
user_provided_colorscale = (
2492+
"color_continuous_scale" in args
2493+
and args["color_continuous_scale"] is not None
2494+
)
24892495
apply_default_cascade(args)
24902496

24912497
args = build_dataframe(args, constructor)
@@ -2704,7 +2710,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
27042710
range_color = args["range_color"] or [None, None]
27052711

27062712
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
2707-
layout_patch["coloraxis1"] = dict(
2713+
coloraxis_dict = dict(
27082714
colorscale=colorscale_validator.validate_coerce(
27092715
args["color_continuous_scale"]
27102716
),
@@ -2715,6 +2721,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
27152721
title_text=get_decorated_label(args, args[colorvar], colorvar)
27162722
),
27172723
)
2724+
# Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template
2725+
# that sets autocolorscale=True would override the user provided colorscale.
2726+
if user_provided_colorscale:
2727+
coloraxis_dict["autocolorscale"] = False
2728+
layout_patch["coloraxis1"] = coloraxis_dict
27182729
for v in ["height", "width"]:
27192730
if args[v]:
27202731
layout_patch[v] = args[v]

plotly/express/_imshow.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def imshow(
233233
axes labels and ticks.
234234
"""
235235
args = locals()
236+
# Track if color_continuous_scale was explicitly provided by user
237+
# (before apply_default_cascade fills it from template/defaults)
238+
user_provided_colorscale = (
239+
"color_continuous_scale" in args
240+
and args["color_continuous_scale"] is not None
241+
)
236242
apply_default_cascade(args)
237243
labels = labels.copy()
238244
nslices_facet = 1
@@ -419,14 +425,19 @@ def imshow(
419425
layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
420426
layout["yaxis"]["constrain"] = "domain"
421427
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
422-
layout["coloraxis1"] = dict(
428+
coloraxis_dict = dict(
423429
colorscale=colorscale_validator.validate_coerce(
424430
args["color_continuous_scale"]
425431
),
426432
cmid=color_continuous_midpoint,
427433
cmin=zmin,
428434
cmax=zmax,
429435
)
436+
# Set autocolorscale=False if user explicitly provided colorscale. Otherwise a template
437+
# that sets autocolorscale=True would override the user provided colorscale.
438+
if user_provided_colorscale:
439+
coloraxis_dict["autocolorscale"] = False
440+
layout["coloraxis1"] = coloraxis_dict
430441
if labels["color"]:
431442
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
432443

0 commit comments

Comments
 (0)