Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 87 additions & 47 deletions plots/heatmap-cohort-retention/implementations/python/plotnine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
""" pyplots.ai
"""anyplot.ai
heatmap-cohort-retention: Cohort Retention Heatmap
Library: plotnine 0.15.3 | Python 3.14.3
Quality: 90/100 | Created: 2026-03-16
Library: plotnine 0.15.7 | Python 3.13.14
"""

# Remove script dir from sys.path first to prevent self-shadowing the 'plotnine' library import
import sys as _sys


_sys.path = [p for p in _sys.path if p != __file__.rsplit("/", 1)[0]]
del _sys

import os

import numpy as np
import pandas as pd
from plotnine import (
Expand All @@ -17,15 +25,23 @@
ggplot,
labs,
scale_color_identity,
scale_fill_gradientn,
scale_fill_gradient,
scale_x_continuous,
scale_y_discrete,
theme,
theme_minimal,
)


# Data
# Theme-adaptive chrome — Imprint palette
THEME = os.getenv("ANYPLOT_THEME", "light")
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
ELEVATED_BG = "#FFFDF6" if THEME == "light" else "#242420"
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0"
INK_MUTED = "#6B6A63" if THEME == "light" else "#A8A79F"

# Data — SaaS monthly cohort retention, Jan–Oct 2024
np.random.seed(42)
cohorts = [
"Jan 2024",
Expand All @@ -49,79 +65,103 @@
if period == 0:
retention = 100.0
else:
base_decay = 100 * np.exp(-0.25 * period)
noise = np.random.uniform(-3, 3)
trend_bonus = i * 1.5
base_decay = 100 * np.exp(-0.28 * period)
noise = np.random.uniform(-2, 2)
# Stronger trend: newer cohorts retain better (product maturity effect)
trend_bonus = i * 2.8
retention = np.clip(base_decay + noise + trend_bonus, 5, 100)
rows.append(
{"cohort": cohort, "period": period, "retention_rate": round(retention, 1), "cohort_size": cohort_sizes[i]}
)

df = pd.DataFrame(rows)

# Create y-axis labels with cohort size
# y-axis: cohort labels with cohort size; reversed so Jan 2024 appears at top
df["cohort_label"] = df.apply(lambda r: f"{r['cohort']} (n={r['cohort_size']:,})", axis=1)

# Preserve ordering
cohort_labels = [f"{c} (n={s:,})" for c, s in zip(cohorts, cohort_sizes, strict=True)]
df["cohort_label"] = pd.Categorical(df["cohort_label"], categories=cohort_labels[::-1], ordered=True)

# Text color: white on dark cells (viridis dark end), dark on light cells
df["text_color"] = df["retention_rate"].apply(lambda v: "#ffffff" if v < 60 else "#1a1a2e")

# Format retention text
# Cell text: dark ink on bright green tiles (high retention), light on dark blue (low)
# Imprint seq: low → #4467A3 (blue), high → #009E73 (green)
df["text_color"] = df["retention_rate"].apply(lambda v: "#1A1A17" if v >= 60 else "#FFFDF6")
df["label"] = df["retention_rate"].apply(lambda v: f"{v:.0f}%")

# Compare earliest vs latest cohort at same period for storytelling
# Storytelling: M4 retention improvement Jan→Jun 2024 (both cohorts have data at period 4)
compare_period = 4
earliest = df[(df["cohort"] == "Jan 2024") & (df["period"] == compare_period)]["retention_rate"].values[0]
latest = df[(df["cohort"] == "Jun 2024") & (df["period"] == compare_period)]["retention_rate"].values[0]
improvement = latest - earliest
jan_m4 = df[(df["cohort"] == "Jan 2024") & (df["period"] == compare_period)]["retention_rate"].values[0]
jun_m4 = df[(df["cohort"] == "Jun 2024") & (df["period"] == compare_period)]["retention_rate"].values[0]
improvement = jun_m4 - jan_m4

# Perceptually uniform sequential palette (viridis-inspired: dark purple → teal → yellow)
colors = ["#440154", "#31688e", "#35b779", "#fde725"]
# Discrete y positions: reversed categories → Oct 2024 at pos 1 (bottom), Jan 2024 at pos 10 (top)
# Jun 2024 is index 5 in original → position (n_cohorts - 5) = 5 from bottom
JAN_Y = n_cohorts # 10
JUN_Y = n_cohorts - cohorts.index("Jun 2024") # 10 - 5 = 5

# Plot
# Build plot — square 2400×2400 canvas
plot = (
ggplot(df, aes(x="period", y="cohort_label", fill="retention_rate"))
+ geom_tile(color="#f8f9fa", size=0.6)
+ geom_text(aes(label="label", color="text_color"), size=13, fontweight="bold")
+ scale_fill_gradientn(colors=colors, limits=(0, 100), name="Retention %")
+ geom_tile(color=PAGE_BG, size=0.5)
+ geom_text(aes(label="label", color="text_color"), size=2.8, fontweight="bold")
+ scale_fill_gradient(low="#4467A3", high="#009E73", limits=(0, 100), name="Retention %")
+ scale_color_identity()
+ scale_x_continuous(breaks=range(n_cohorts), labels=[f"M{i}" for i in range(n_cohorts)])
+ scale_y_discrete(expand=(0.05, 0))
+ scale_x_continuous(breaks=list(range(n_cohorts)), labels=[f"M{i}" for i in range(n_cohorts)])
+ scale_y_discrete(expand=(0, 0))
# Highlight the two cells being compared — borders only (fill="none")
+ annotate(
"rect",
xmin=compare_period - 0.5,
xmax=compare_period + 0.5,
ymin=JAN_Y - 0.5,
ymax=JAN_Y + 0.5,
fill="none",
color=INK,
size=2.0,
)
+ annotate(
"rect",
xmin=compare_period - 0.5,
xmax=compare_period + 0.5,
ymin=JUN_Y - 0.5,
ymax=JUN_Y + 0.5,
fill="none",
color=INK,
size=2.0,
)
+ annotate(
"text",
x=n_cohorts - 2,
y=3,
label=f"Month {compare_period} retention improved\n+{improvement:.0f}pp from Jan→Jun 2024",
size=11,
color="#2d2d2d",
x=7.2,
y=2.3,
label=f"M{compare_period}: +{improvement:.0f}pp\nJan→Jun 2024",
size=3.8,
color=INK_MUTED,
ha="center",
fontweight="bold",
)
+ labs(
x="Months Since Signup",
y="",
title="heatmap-cohort-retention · plotnine · pyplots.ai",
subtitle="Monthly cohort retention — newer cohorts retain significantly better over time",
title="heatmap-cohort-retention · python · plotnine · anyplot.ai",
subtitle="SaaS monthly cohorts — newer signups retain significantly better month over month",
)
+ theme_minimal()
+ theme(
figure_size=(16, 9),
plot_title=element_text(size=26, ha="center", weight="bold", color="#0d1b2a"),
plot_subtitle=element_text(size=18, ha="center", color="#555555", style="italic"),
axis_title_x=element_text(size=20, color="#333333"),
axis_text_x=element_text(size=16, color="#444444"),
axis_text_y=element_text(size=16, color="#444444"),
legend_title=element_text(size=16, weight="bold"),
legend_text=element_text(size=14),
figure_size=(6, 6),
plot_title=element_text(size=12, ha="center", weight="bold", color=INK),
plot_subtitle=element_text(size=9, ha="center", color=INK_SOFT, style="italic"),
axis_title_x=element_text(size=10, color=INK),
axis_text_x=element_text(size=8, color=INK_SOFT),
axis_text_y=element_text(size=8, color=INK_SOFT),
legend_title=element_text(size=8, weight="bold", color=INK),
legend_text=element_text(size=8, color=INK_SOFT),
legend_position=(0.87, 0.58),
panel_grid_major=element_blank(),
panel_grid_minor=element_blank(),
plot_background=element_rect(fill="#fafafa", color="none"),
panel_background=element_rect(fill="#fafafa", color="none"),
panel_border=element_rect(color=INK_SOFT, size=0.8),
axis_ticks=element_blank(),
plot_background=element_rect(fill=PAGE_BG, color="none"),
panel_background=element_rect(fill=PAGE_BG, color="none"),
legend_background=element_rect(fill=ELEVATED_BG, color=INK_SOFT),
)
)

# Save
plot.save("plot.png", dpi=300, width=16, height=9)
# Save — square 2400×2400 at dpi=400 (6 in × 400 dpi = 2400 px)
plot.save(f"plot-{THEME}.png", dpi=400, width=6, height=6, units="in")
Loading