diff --git a/README.rst b/README.rst index fc6c48ca..978ac203 100644 --- a/README.rst +++ b/README.rst @@ -63,6 +63,7 @@ Current features * Thread-safety. * **Per-call max age:** Specify a maximum age for cached values per call. +* **Cache analytics and observability:** Track cache performance metrics including hit rates, latencies, and more. Cachier is **NOT**: @@ -325,6 +326,102 @@ Cache `None` Values By default, ``cachier`` does not cache ``None`` values. You can override this behaviour by passing ``allow_none=True`` to the function call. +Cache Analytics and Observability +================================== + +Cachier provides built-in metrics collection to monitor cache performance in production environments. This feature is particularly useful for understanding cache effectiveness, identifying optimization opportunities, and debugging performance issues. + +Enabling Metrics +---------------- + +Enable metrics by setting ``enable_metrics=True`` when decorating a function: + +.. code-block:: python + + from cachier import cachier + + @cachier(backend='memory', enable_metrics=True) + def expensive_operation(x): + return x ** 2 + + # Access metrics + stats = expensive_operation.metrics.get_stats() + print(f"Hit rate: {stats.hit_rate}%") + print(f"Avg latency: {stats.avg_latency_ms}ms") + +Tracked Metrics +--------------- + +The metrics system tracks: + +* **Cache hits and misses**: Number of cache hits/misses and hit rate percentage +* **Operation latencies**: Average time for cache operations +* **Stale cache hits**: Number of times stale cache entries were accessed +* **Recalculations**: Count of cache recalculations triggered +* **Wait timeouts**: Timeouts during concurrent calculation waits +* **Size limit rejections**: Entries rejected due to ``entry_size_limit`` +* **Cache size (memory backend only)**: Number of entries and total size in bytes for the in-memory cache core + +Sampling Rate +------------- + +For high-traffic functions, you can reduce overhead by sampling a fraction of operations: + +.. code-block:: python + + @cachier(enable_metrics=True, metrics_sampling_rate=0.1) # Sample 10% of calls + def high_traffic_function(x): + return x * 2 + +Exporting to Prometheus +------------------------ + +Export metrics to Prometheus for monitoring and alerting: + +.. code-block:: python + + from cachier import cachier + from cachier.exporters import PrometheusExporter + + @cachier(backend='redis', enable_metrics=True) + def my_operation(x): + return x ** 2 + + # Set up Prometheus exporter + # use_prometheus_client controls whether metrics are exposed via the prometheus_client + # registry (True) or via Cachier's own HTTP handler (False). In both modes, metrics for + # registered functions are collected live at scrape time. + exporter = PrometheusExporter(port=9090, use_prometheus_client=True) + exporter.register_function(my_operation) + exporter.start() + + # Metrics available at http://localhost:9090/metrics + +The exporter provides metrics in Prometheus text format, compatible with standard Prometheus scraping, in both ``use_prometheus_client=True`` and ``use_prometheus_client=False`` modes. When ``use_prometheus_client=True``, Cachier registers a custom collector with ``prometheus_client`` that pulls live statistics from registered functions at scrape time, so scraped values reflect the current state of the cache. When ``use_prometheus_client=False``, Cachier serves the same metrics directly without requiring the ``prometheus_client`` dependency. + +Programmatic Access +------------------- + +Access metrics programmatically for custom monitoring: + +.. code-block:: python + + stats = my_function.metrics.get_stats() + + if stats.hit_rate < 70.0: + print(f"Warning: Cache hit rate is {stats.hit_rate}%") + print(f"Consider increasing cache size or adjusting stale_after") + +Reset Metrics +------------- + +Clear collected metrics: + +.. code-block:: python + + my_function.metrics.reset() + + Cachier Cores ============= diff --git a/examples/metrics_example.py b/examples/metrics_example.py new file mode 100644 index 00000000..64cfe0ec --- /dev/null +++ b/examples/metrics_example.py @@ -0,0 +1,231 @@ +"""Demonstration of cachier's metrics and observability features.""" + +import time +from datetime import timedelta + +from cachier import cachier + + +def demo_basic_metrics_tracking(): + """Demonstrate basic metrics tracking.""" + print("=" * 60) + print("Example 1: Basic Metrics Tracking") + print("=" * 60) + + @cachier(backend="memory", enable_metrics=True) + def expensive_operation(x): + """Simulate an expensive computation.""" + time.sleep(0.1) # Simulate work + return x**2 + + expensive_operation.clear_cache() + + # First call - cache miss + print("\nFirst call (cache miss):") + result1 = expensive_operation(5) + print(f" Result: {result1}") + + stats = expensive_operation.metrics.get_stats() + print(f" Hits: {stats.hits}, Misses: {stats.misses}") + print(f" Hit rate: {stats.hit_rate:.1f}%") + print(f" Avg latency: {stats.avg_latency_ms:.2f}ms") + + # Second call - cache hit + print("\nSecond call (cache hit):") + result2 = expensive_operation(5) + print(f" Result: {result2}") + + stats = expensive_operation.metrics.get_stats() + print(f" Hits: {stats.hits}, Misses: {stats.misses}") + print(f" Hit rate: {stats.hit_rate:.1f}%") + print(f" Avg latency: {stats.avg_latency_ms:.2f}ms") + + # Third call with different argument - cache miss + print("\nThird call with different argument (cache miss):") + result3 = expensive_operation(10) + print(f" Result: {result3}") + + stats = expensive_operation.metrics.get_stats() + print(f" Hits: {stats.hits}, Misses: {stats.misses}") + print(f" Hit rate: {stats.hit_rate:.1f}%") + print(f" Avg latency: {stats.avg_latency_ms:.2f}ms") + print(f" Total calls: {stats.total_calls}") + + +def demo_stale_cache_tracking(): + """Demonstrate stale cache tracking.""" + print("\n" + "=" * 60) + print("Example 2: Stale Cache Tracking") + print("=" * 60) + + @cachier( + backend="memory", + enable_metrics=True, + stale_after=timedelta(seconds=1), + next_time=False, + ) + def time_sensitive_operation(x): + """Operation with stale_after configured.""" + return x * 2 + + time_sensitive_operation.clear_cache() + + # Initial call + print("\nInitial call:") + result = time_sensitive_operation(5) + print(f" Result: {result}") + + # Call while fresh + print("\nCall while fresh (within 1 second):") + result = time_sensitive_operation(5) + print(f" Result: {result}") + + # Wait for cache to become stale + print("\nWaiting for cache to become stale...") + time.sleep(1.5) + + # Call after stale + print("Call after cache is stale:") + result = time_sensitive_operation(5) + print(f" Result: {result}") + + stats = time_sensitive_operation.metrics.get_stats() + print("\nMetrics after stale access:") + print(f" Hits: {stats.hits}") + print(f" Stale hits: {stats.stale_hits}") + print(f" Recalculations: {stats.recalculations}") + + +def demo_metrics_sampling(): + """Demonstrate metrics sampling to reduce overhead.""" + print("\n" + "=" * 60) + print("Example 3: Metrics Sampling (50% sampling rate)") + print("=" * 60) + + @cachier( + backend="memory", + enable_metrics=True, + metrics_sampling_rate=0.5, # Only sample 50% of calls + ) + def sampled_operation(x): + """Operation with reduced metrics sampling.""" + return x + 1 + + sampled_operation.clear_cache() + + # Make many calls + print("\nMaking 100 calls with 10 unique arguments...") + for i in range(100): + sampled_operation(i % 10) + + stats = sampled_operation.metrics.get_stats() + print("\nMetrics (with 50% sampling):") + print(f" Total calls recorded: {stats.total_calls}") + print(f" Hits: {stats.hits}") + print(f" Misses: {stats.misses}") + print(f" Hit rate: {stats.hit_rate:.1f}%") + print(" Note: Total calls < 100 due to sampling; hit rate is approximately representative of overall behavior.") + + +def demo_comprehensive_metrics(): + """Demonstrate a comprehensive metrics snapshot.""" + print("\n" + "=" * 60) + print("Example 4: Comprehensive Metrics Snapshot") + print("=" * 60) + + @cachier(backend="memory", enable_metrics=True, entry_size_limit="1KB") + def comprehensive_operation(x): + """Operation to demonstrate all metrics.""" + if x > 1000: + # Return large data to trigger size limit rejection + return "x" * 2000 + return x * 2 + + comprehensive_operation.clear_cache() + + # Generate various metric events + comprehensive_operation(5) # Miss + recalculation + comprehensive_operation(5) # Hit + comprehensive_operation(10) # Miss + recalculation + comprehensive_operation(2000) # Size limit rejection + + stats = comprehensive_operation.metrics.get_stats() + print( + f"\nComplete metrics snapshot:\n" + f" Hits: {stats.hits}\n" + f" Misses: {stats.misses}\n" + f" Hit rate: {stats.hit_rate:.1f}%\n" + f" Total calls: {stats.total_calls}\n" + f" Avg latency: {stats.avg_latency_ms:.2f}ms\n" + f" Stale hits: {stats.stale_hits}\n" + f" Recalculations: {stats.recalculations}\n" + f" Wait timeouts: {stats.wait_timeouts}\n" + f" Size limit rejections: {stats.size_limit_rejections}\n" + f" Entry count: {stats.entry_count}\n" + f" Total size (bytes): {stats.total_size_bytes}" + ) + + +def demo_programmatic_monitoring(): + """Demonstrate programmatic cache health monitoring.""" + print("\n" + "=" * 60) + print("Example 5: Programmatic Monitoring") + print("=" * 60) + + @cachier(backend="memory", enable_metrics=True) + def monitored_operation(x): + """Operation being monitored.""" + return x**3 + + monitored_operation.clear_cache() + + def check_cache_health(func, threshold=80.0): + """Check if cache hit rate meets threshold.""" + stats = func.metrics.get_stats() + if stats.total_calls == 0: + return True, "No calls yet" + + if stats.hit_rate >= threshold: + return True, f"Hit rate {stats.hit_rate:.1f}% meets threshold" + else: + return ( + False, + f"Hit rate {stats.hit_rate:.1f}% below threshold {threshold}%", + ) + + # Simulate some usage + print("\nSimulating cache usage...") + for i in range(20): + monitored_operation(i % 5) + + # Check health + is_healthy, message = check_cache_health(monitored_operation, threshold=70.0) + print("\nCache health check:") + print(f" Status: {'OK HEALTHY' if is_healthy else 'UNHEALTHY'}") + print(f" {message}") + + stats = monitored_operation.metrics.get_stats() + print(f" Details: {stats.hits} hits, {stats.misses} misses") + + +def main(): + """Run all metrics demonstration examples.""" + demo_basic_metrics_tracking() + demo_stale_cache_tracking() + demo_metrics_sampling() + demo_comprehensive_metrics() + demo_programmatic_monitoring() + + print("\n" + "=" * 60) + print("Examples complete!") + print("=" * 60) + print("\nKey takeaways:") + print(" - Metrics are opt-in via enable_metrics=True") + print(" - Access metrics via function.metrics.get_stats()") + print(" - Sampling reduces overhead for high-traffic functions") + print(" - Metrics are thread-safe and backend-agnostic") + print(" - Use for production monitoring and optimization") + + +if __name__ == "__main__": + main() diff --git a/examples/prometheus_exporter_example.py b/examples/prometheus_exporter_example.py new file mode 100644 index 00000000..8a4ddad1 --- /dev/null +++ b/examples/prometheus_exporter_example.py @@ -0,0 +1,126 @@ +"""Prometheus Exporter Example for Cachier. + +This example demonstrates using the PrometheusExporter to export cache metrics +to Prometheus for monitoring and alerting. + +Usage with Prometheus +--------------------- + +To use this exporter with Prometheus: + +1. Start the exporter HTTP server: + >>> exporter.start() + +2. Configure Prometheus to scrape the metrics endpoint. + Add this to your prometheus.yml: + + scrape_configs: + - job_name: 'cachier' + static_configs: + - targets: ['localhost:9090'] + +3. Access metrics at http://localhost:9090/metrics + +4. Create dashboards in Grafana or set up alerts based on: + - cachier_cache_hit_rate (target: > 80%) + - cachier_cache_misses_total (alert on spikes) + - cachier_avg_latency_ms (monitor performance) + +Available Metrics +----------------- +- cachier_cache_hits_total: Total number of cache hits +- cachier_cache_misses_total: Total number of cache misses +- cachier_cache_hit_rate: Cache hit rate percentage +- cachier_avg_latency_ms: Average cache operation latency +- cachier_stale_hits_total: Total stale cache hits +- cachier_recalculations_total: Total cache recalculations +- cachier_entry_count: Current number of cache entries +- cachier_cache_size_bytes: Total cache size in bytes +- cachier_size_limit_rejections_total: Entries rejected due to size limit + +""" + +import time + +from cachier import cachier +from cachier.exporters import PrometheusExporter + + +def demo_basic_metrics(): + """Demonstrate basic metrics collection.""" + print("\n=== Basic Metrics Collection ===") + + @cachier(backend="memory", enable_metrics=True) + def compute(x): + time.sleep(0.1) # Simulate work + return x * 2 + + compute.clear_cache() + + # Generate some traffic + for i in range(5): + result = compute(i) + print(f" compute({i}) = {result}") + + # Access hits create cache hits + for i in range(3): + compute(i) + + stats = compute.metrics.get_stats() + print("\nMetrics:") + print(f" Hits: {stats.hits}") + print(f" Misses: {stats.misses}") + print(f" Hit Rate: {stats.hit_rate:.1f}%") + print(f" Avg Latency: {stats.avg_latency_ms:.2f}ms") + + compute.clear_cache() + + +def demo_prometheus_export(): + """Demonstrate exporting metrics to Prometheus.""" + print("\n=== Prometheus Export ===") + + @cachier(backend="memory", enable_metrics=True) + def calculate(x, y): + return x + y + + calculate.clear_cache() + + # Create exporter + exporter = PrometheusExporter(port=9090, use_prometheus_client=False) + exporter.register_function(calculate) + + # Generate some metrics + calculate(1, 2) + calculate(1, 2) # hit + calculate(3, 4) # miss + + # Show text format metrics + metrics_text = exporter._generate_text_metrics() + print("\nGenerated Prometheus metrics:") + print(metrics_text[:500] + "...") + + print("\nNote: In production, call exporter.start() to serve metrics") + print(" Metrics would be available at http://localhost:9090/metrics") + + calculate.clear_cache() + + +def main(): + """Run all demonstrations.""" + print("Cachier Prometheus Exporter Demo") + print("=" * 60) + + # Print usage instructions from module docstring + if __doc__: + print(__doc__) + + demo_basic_metrics() + demo_prometheus_export() + + print("\n" + "=" * 60) + print("✓ All demonstrations completed!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 13759f81..d9dc25c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ exclude = [ "build", "dist", ] + # Enable Pyflakes `E` and `F` codes by default. lint.select = [ "D", # see: https://pypi.org/project/pydocstyle @@ -128,6 +129,8 @@ lint.extend-select = [ ] lint.ignore = [ "C901", + "D203", + "D213", "E203", "S301", ] @@ -156,15 +159,12 @@ lint.per-file-ignores."tests/**" = [ lint.unfixable = [ "F401", ] - # --- flake8 --- -#[tool.ruff.pydocstyle] -## Use Google-style docstrings. -#convention = "google" #[tool.ruff.pycodestyle] #ignore-overlong-task-comments = true # Unlike Flake8, default to a complexity level of 10. lint.mccabe.max-complexity = 10 +lint.pydocstyle.convention = "numpy" [tool.docformatter] recursive = true diff --git a/src/cachier/__init__.py b/src/cachier/__init__.py index 922ab021..755dd3eb 100644 --- a/src/cachier/__init__.py +++ b/src/cachier/__init__.py @@ -8,6 +8,7 @@ set_global_params, ) from .core import cachier +from .metrics import CacheMetrics, MetricSnapshot from .util import parse_bytes __all__ = [ @@ -19,5 +20,7 @@ "parse_bytes", "enable_caching", "disable_caching", + "CacheMetrics", + "MetricSnapshot", "__version__", ] diff --git a/src/cachier/core.py b/src/cachier/core.py index 9031e1fa..a5e5c308 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -28,6 +28,7 @@ from .cores.redis import _RedisCore from .cores.s3 import _S3Core from .cores.sql import _SQLCore +from .metrics import CacheMetrics, MetricsContext from .util import parse_bytes MAX_WORKERS_ENVAR_NAME = "CACHIER_MAX_WORKERS" @@ -47,6 +48,26 @@ def __await__(self): return self._value +async def _background_recalc_async( + core: _BaseCore, + key: Any, + func: Callable[..., Any], + args: Any, + kwds: Any, +) -> None: + """Run async recomputation in background and clear processing flag. + + This helper ensures that the cache entry's "being calculated" state is + cleared only after the background recomputation and cache update + (performed by ``_function_thread_async``) have completed. + + """ + try: + await _function_thread_async(core, key, func, args, kwds) + finally: + await core.amark_entry_not_calculated(key) + + def _max_workers(): return int(os.environ.get(MAX_WORKERS_ENVAR_NAME, DEFAULT_MAX_WORKERS)) @@ -121,10 +142,7 @@ def _convert_args_kwargs(func, _is_method: bool, args: tuple, kwds: dict) -> dic param = sig.parameters[param_name] if param.kind == inspect.Parameter.VAR_POSITIONAL: var_positional_name = param_name - elif param.kind in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): + elif param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): regular_params.append(param_name) # Map positional arguments to regular parameters @@ -138,7 +156,7 @@ def _convert_args_kwargs(func, _is_method: bool, args: tuple, kwds: dict) -> dic # Map as many args as possible to regular parameters num_regular = len(params_to_use) - args_as_kw = dict(zip(params_to_use, args_to_map[:num_regular], strict=False)) + args_as_kw = {params_to_use[index]: arg for index, arg in enumerate(args_to_map[:num_regular])} # Handle variadic positional arguments # Store them with indexed keys like __varargs_0__, __varargs_1__, etc. @@ -200,6 +218,8 @@ def cachier( cleanup_stale: Optional[bool] = None, cleanup_interval: Optional[timedelta] = None, entry_size_limit: Optional[Union[int, str]] = None, + enable_metrics: bool = False, + metrics_sampling_rate: float = 1.0, ): """Wrap as a persistent, stale-free memoization decorator. @@ -287,6 +307,12 @@ def cachier( Maximum serialized size of a cached value. Values exceeding the limit are returned but not cached. Human readable strings like ``"10MB"`` are allowed. + enable_metrics: bool, optional + Enable metrics collection for this cached function. When enabled, + cache hits, misses, latencies, and other performance metrics are tracked. Defaults to False. + metrics_sampling_rate: float, optional + Sampling rate for metrics collection (0.0 to 1.0). Lower values + reduce overhead at the cost of accuracy. Only used when enable_metrics is True. Defaults to 1.0 (100% sampling). """ # Check for deprecated parameters @@ -298,6 +324,12 @@ def cachier( backend = _update_with_defaults(backend, "backend") mongetter = _update_with_defaults(mongetter, "mongetter") size_limit_bytes = parse_bytes(_update_with_defaults(entry_size_limit, "entry_size_limit")) + + # Create metrics object if enabled + cache_metrics = None + if enable_metrics: + cache_metrics = CacheMetrics(sampling_rate=metrics_sampling_rate) + # Override the backend parameter if a mongetter is provided. if callable(mongetter): backend = "mongo" @@ -310,6 +342,7 @@ def cachier( separate_files=separate_files, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) elif backend == "mongo": core = _MongoCore( @@ -317,12 +350,14 @@ def cachier( mongetter=mongetter, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) elif backend == "memory": core = _MemoryCore( hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) elif backend == "sql": core = _SQLCore( @@ -330,6 +365,7 @@ def cachier( sql_engine=sql_engine, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) elif backend == "redis": core = _RedisCore( @@ -337,6 +373,7 @@ def cachier( redis_client=redis_client, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) elif backend == "s3": core = _S3Core( @@ -350,6 +387,7 @@ def cachier( s3_endpoint_url=s3_endpoint_url, s3_config=s3_config, entry_size_limit=size_limit_bytes, + metrics=cache_metrics, ) else: raise ValueError("specified an invalid core: %s" % backend) @@ -451,57 +489,86 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds): if ignore_cache or not _global_params.caching_enabled: return func(args[0], **kwargs) if core.func_is_method else func(**kwargs) - key, entry = core.get_entry((), kwargs) - if overwrite_cache: - return _calc_entry(core, key, func, args, kwds, _print) - if entry is None or (not entry._completed and not entry._processing): - _print("No entry found. No current calc. Calling like a boss.") - return _calc_entry(core, key, func, args, kwds, _print) - _print("Entry found.") - if _allow_none or entry.value is not None: - _print("Cached result found.") - now = datetime.now() - max_allowed_age = _stale_after - nonneg_max_age = True - if max_age is not None: - if max_age < ZERO_TIMEDELTA: - _print("max_age is negative. Cached result considered stale.") - nonneg_max_age = False - else: - assert max_age is not None # noqa: S101 - max_allowed_age = min(_stale_after, max_age) - # note: if max_age < 0, we always consider a value stale - if nonneg_max_age and (now - entry.time <= max_allowed_age): - _print("And it is fresh!") - return entry.value - _print("But it is stale... :(") - if entry._processing: + + with MetricsContext(cache_metrics) as _mctx: + key, entry = core.get_entry((), kwargs) + if overwrite_cache: + _mctx.record_miss() + _mctx.record_recalculation() + return _calc_entry(core, key, func, args, kwds, _print) + if entry is None or (not entry._completed and not entry._processing): + _print("No entry found. No current calc. Calling like a boss.") + _mctx.record_miss() + _mctx.record_recalculation() + return _calc_entry(core, key, func, args, kwds, _print) + _print("Entry found.") + if _allow_none or entry.value is not None: + _print("Cached result found.") + now = datetime.now() + max_allowed_age = _stale_after + nonneg_max_age = True + if max_age is not None: + if max_age < ZERO_TIMEDELTA: + _print("max_age is negative. Cached result considered stale.") + nonneg_max_age = False + else: + assert max_age is not None # noqa: S101 + max_allowed_age = min(_stale_after, max_age) + # note: if max_age < 0, we always consider a value stale + if nonneg_max_age and (now - entry.time <= max_allowed_age): + _print("And it is fresh!") + _mctx.record_hit() + return entry.value + _print("But it is stale... :(") + _mctx.record_stale_hit() + _mctx.record_miss() + if entry._processing: + if _next_time: + _print("Returning stale.") + return entry.value # return stale val + _print("Already calc. Waiting on change.") + try: + return core.wait_on_entry_calc(key) + except RecalculationNeeded: + _mctx.record_wait_timeout() + _mctx.record_recalculation() + return _calc_entry(core, key, func, args, kwds, _print) if _next_time: - _print("Returning stale.") - return entry.value # return stale val - _print("Already calc. Waiting on change.") + _print("Async calc and return stale") + _mctx.record_recalculation() + core.mark_entry_being_calculated(key) + + def _wrapped_function_thread( + core_arg: _BaseCore, + key_arg: Any, + func_arg: Callable[..., Any], + args_arg: tuple[Any, ...], + kwds_arg: dict[str, Any], + ) -> None: + """Run background recalculation and clear processing flag when done.""" + try: + _function_thread(core_arg, key_arg, func_arg, args_arg, kwds_arg) + finally: + core_arg.mark_entry_not_calculated(key_arg) + + _get_executor().submit(_wrapped_function_thread, core, key, func, args, kwds) + return entry.value + _print("Calling decorated function and waiting") + _mctx.record_recalculation() + return _calc_entry(core, key, func, args, kwds, _print) + if entry._processing: + _print("No value but being calculated. Waiting.") try: return core.wait_on_entry_calc(key) except RecalculationNeeded: + _mctx.record_wait_timeout() + _mctx.record_miss() + _mctx.record_recalculation() return _calc_entry(core, key, func, args, kwds, _print) - if _next_time: - _print("Async calc and return stale") - core.mark_entry_being_calculated(key) - try: - _get_executor().submit(_function_thread, core, key, func, args, kwds) - finally: - core.mark_entry_not_calculated(key) - return entry.value - _print("Calling decorated function and waiting") + _print("No entry found. No current calc. Calling like a boss.") + _mctx.record_miss() + _mctx.record_recalculation() return _calc_entry(core, key, func, args, kwds, _print) - if entry._processing: - _print("No value but being calculated. Waiting.") - try: - return core.wait_on_entry_calc(key) - except RecalculationNeeded: - return _calc_entry(core, key, func, args, kwds, _print) - _print("No entry found. No current calc. Calling like a boss.") - return _calc_entry(core, key, func, args, kwds, _print) async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds): # NOTE: For async functions, wait_for_calc_timeout is not honored. @@ -538,54 +605,65 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds): if ignore_cache or not _global_params.caching_enabled: return await func(args[0], **kwargs) if core.func_is_method else await func(**kwargs) - key, entry = await core.aget_entry((), kwargs) - if overwrite_cache: - result = await _calc_entry_async(core, key, func, args, kwds, _print) - return result - if entry is None or (not entry._completed and not entry._processing): + + with MetricsContext(cache_metrics) as _mctx: + key, entry = await core.aget_entry((), kwargs) + if overwrite_cache: + _mctx.record_miss() + _mctx.record_recalculation() + return await _calc_entry_async(core, key, func, args, kwds, _print) + if entry is None or (not entry._completed and not entry._processing): + _print("No entry found. No current calc. Calling like a boss.") + _mctx.record_miss() + _mctx.record_recalculation() + return await _calc_entry_async(core, key, func, args, kwds, _print) + _print("Entry found.") + if _allow_none or entry.value is not None: + _print("Cached result found.") + now = datetime.now() + max_allowed_age = _stale_after + nonneg_max_age = True + if max_age is not None: + if max_age < ZERO_TIMEDELTA: + _print("max_age is negative. Cached result considered stale.") + nonneg_max_age = False + else: + assert max_age is not None # noqa: S101 + max_allowed_age = min(_stale_after, max_age) + # note: if max_age < 0, we always consider a value stale + if nonneg_max_age and (now - entry.time <= max_allowed_age): + _print("And it is fresh!") + _mctx.record_hit() + return entry.value + _print("But it is stale... :(") + _mctx.record_stale_hit() + _mctx.record_miss() + if _next_time: + _print("Async calc and return stale") + _mctx.record_recalculation() + # Mark entry as being calculated; background task will + # update cache and clear the flag when done. + await core.amark_entry_being_calculated(key) + # Use asyncio.create_task for background execution, + # ensuring that the processing flag is only cleared + # after recomputation completes. + asyncio.create_task(_background_recalc_async(core, key, func, args, kwds)) + return entry.value + _print("Calling decorated function and waiting") + _mctx.record_recalculation() + return await _calc_entry_async(core, key, func, args, kwds, _print) + if entry._processing: + msg = "No value but being calculated. Recalculating" + _print(f"{msg} (async - no wait).") + # For async, don't wait - just recalculate + # This avoids blocking the event loop + _mctx.record_miss() + _mctx.record_recalculation() + return await _calc_entry_async(core, key, func, args, kwds, _print) _print("No entry found. No current calc. Calling like a boss.") - result = await _calc_entry_async(core, key, func, args, kwds, _print) - return result - _print("Entry found.") - if _allow_none or entry.value is not None: - _print("Cached result found.") - now = datetime.now() - max_allowed_age = _stale_after - nonneg_max_age = True - if max_age is not None: - if max_age < ZERO_TIMEDELTA: - _print("max_age is negative. Cached result considered stale.") - nonneg_max_age = False - else: - assert max_age is not None # noqa: S101 - max_allowed_age = min(_stale_after, max_age) - # note: if max_age < 0, we always consider a value stale - if nonneg_max_age and (now - entry.time <= max_allowed_age): - _print("And it is fresh!") - return entry.value - _print("But it is stale... :(") - if _next_time: - _print("Async calc and return stale") - # Mark entry as being calculated then immediately unmark - # This matches sync behavior and ensures entry exists - # Background task will update cache when complete - await core.amark_entry_being_calculated(key) - # Use asyncio.create_task for background execution - asyncio.create_task(_function_thread_async(core, key, func, args, kwds)) - await core.amark_entry_not_calculated(key) - return entry.value - _print("Calling decorated function and waiting") - result = await _calc_entry_async(core, key, func, args, kwds, _print) - return result - if entry._processing: - msg = "No value but being calculated. Recalculating" - _print(f"{msg} (async - no wait).") - # For async, don't wait - just recalculate - # This avoids blocking the event loop - result = await _calc_entry_async(core, key, func, args, kwds, _print) - return result - _print("No entry found. No current calc. Calling like a boss.") - return await _calc_entry_async(core, key, func, args, kwds, _print) + _mctx.record_miss() + _mctx.record_recalculation() + return await _calc_entry_async(core, key, func, args, kwds, _print) # MAINTAINER NOTE: The main function wrapper is now a standard function # that passes *args and **kwargs to _call. This ensures that user @@ -631,13 +709,17 @@ def _cache_dpath(): """Return the path to the cache dir, if exists; None if not.""" return getattr(core, "cache_dir", None) - def _precache_value(*args, value_to_cache, **kwds): # noqa: D417 + def _precache_value(*args, value_to_cache, **kwds): """Add an initial value to the cache. - Arguments: - --------- + Parameters + ---------- + *args : Any + Positional arguments used to build the cache key. value_to_cache : any - entry to be written into the cache + Entry to be written into the cache. + **kwds : Any + Keyword arguments used to build the cache key. """ # merge args expanded as kwargs and the original kwds @@ -650,6 +732,7 @@ def _precache_value(*args, value_to_cache, **kwds): # noqa: D417 func_wrapper.aclear_being_calculated = _aclear_being_calculated func_wrapper.cache_dpath = _cache_dpath func_wrapper.precache_value = _precache_value + func_wrapper.metrics = cache_metrics # Expose metrics object return func_wrapper return _cachier_decorator diff --git a/src/cachier/cores/base.py b/src/cachier/cores/base.py index 547382da..632ccd28 100644 --- a/src/cachier/cores/base.py +++ b/src/cachier/cores/base.py @@ -13,12 +13,15 @@ import sys import threading from datetime import timedelta -from typing import Any, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple from pympler import asizeof # type: ignore -from .._types import HashFunc -from ..config import CacheEntry, _update_with_defaults +from cachier._types import HashFunc +from cachier.config import CacheEntry, _update_with_defaults + +if TYPE_CHECKING: + from cachier.metrics import CacheMetrics class RecalculationNeeded(Exception): @@ -43,11 +46,13 @@ def __init__( hash_func: Optional[HashFunc], wait_for_calc_timeout: Optional[int], entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): self.hash_func = _update_with_defaults(hash_func, "hash_func") self.wait_for_calc_timeout = wait_for_calc_timeout self.lock = threading.RLock() self.entry_size_limit = entry_size_limit + self.metrics = metrics def set_func(self, func): """Set the function this core will use. @@ -124,17 +129,99 @@ def _should_store(self, value: Any) -> bool: if self.entry_size_limit is None: return True try: - return self._estimate_size(value) <= self.entry_size_limit + should_store = self._estimate_size(value) <= self.entry_size_limit except Exception: return True + if not should_store and self.metrics is not None: + self.metrics.record_size_limit_rejection() + return should_store + + def _update_size_metrics(self) -> None: + """Update cache size metrics if metrics are enabled. + + Subclasses should call this after cache modifications. + + """ + if self.metrics is None: + return + from contextlib import suppress + + # Get cache size - subclasses should override if they can provide this + # Suppress errors if subclass doesn't implement size tracking + with suppress(AttributeError, NotImplementedError): + entry_count = self._get_entry_count() + total_size = self._get_total_size() + self.metrics.update_size_metrics(entry_count, total_size) + + def _get_entry_count(self) -> int: + """Get the number of entries in the cache. + + Subclasses should override this to provide accurate counts. + + Returns + ------- + int + Number of entries in cache + + """ + return 0 + + def _get_total_size(self) -> int: + """Get the total size of the cache in bytes. + + Subclasses should override this to provide accurate sizes. + + Returns + ------- + int + Total size in bytes + + """ + return 0 @abc.abstractmethod - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: """Map the given result to the given key in this core's cache.""" + async def _aset_entry(self, key: str, func_res: Any) -> bool: + """Async variant of :meth:`_set_entry`; defaults to the sync version.""" + return self._set_entry(key, func_res) + + def set_entry(self, key: str, func_res: Any) -> bool: + """Store an entry in the cache. + + Parameters + ---------- + key : str + Cache key for the entry. + func_res : Any + Value to store in the cache. + + Returns + ------- + bool + True if the entry was stored successfully, False otherwise. + + """ + return self._set_entry(key, func_res) + async def aset_entry(self, key: str, func_res: Any) -> bool: - """Async-compatible variant of :meth:`set_entry`.""" - return self.set_entry(key, func_res) + """Async variant of :meth:`set_entry`. + + Parameters + ---------- + key : str + Cache key for the entry. + func_res : Any + Value to store in the cache. + + Returns + ------- + bool + True if the entry was stored successfully, False otherwise. + + """ + return await self._aset_entry(key, func_res) @abc.abstractmethod def mark_entry_being_calculated(self, key: str) -> None: diff --git a/src/cachier/cores/memory.py b/src/cachier/cores/memory.py index 6d546c2f..36fddf41 100644 --- a/src/cachier/cores/memory.py +++ b/src/cachier/cores/memory.py @@ -2,12 +2,15 @@ import threading from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from .._types import HashFunc from ..config import CacheEntry from .base import _BaseCore, _get_func_str +if TYPE_CHECKING: + from ..metrics import CacheMetrics + class _MemoryCore(_BaseCore): """The memory core class for cachier.""" @@ -17,8 +20,9 @@ def __init__( hash_func: Optional[HashFunc], wait_for_calc_timeout: Optional[int], entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): - super().__init__(hash_func, wait_for_calc_timeout, entry_size_limit) + super().__init__(hash_func, wait_for_calc_timeout, entry_size_limit, metrics) self.cache: Dict[str, CacheEntry] = {} def _hash_func_key(self, key: str) -> str: @@ -36,7 +40,7 @@ async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: """Get an entry by key.""" return self.get_entry_by_key(key) - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False hash_key = self._hash_func_key(key) @@ -56,11 +60,13 @@ def set_entry(self, key: str, func_res: Any) -> bool: _condition=cond, _completed=True, ) + # Update size metrics after modifying cache + self._update_size_metrics() return True - async def aset_entry(self, key: str, func_res: Any) -> bool: + async def _aset_entry(self, key: str, func_res: Any) -> bool: """Set an entry.""" - return self.set_entry(key, func_res) + return self._set_entry(key, func_res) def mark_entry_being_calculated(self, key: str) -> None: with self.lock: @@ -117,6 +123,8 @@ def wait_on_entry_calc(self, key: str) -> Any: def clear_cache(self) -> None: with self.lock: self.cache.clear() + # Update size metrics after clearing + self._update_size_metrics() def clear_being_calculated(self) -> None: with self.lock: @@ -131,3 +139,19 @@ def delete_stale_entries(self, stale_after: timedelta) -> None: keys_to_delete = [k for k, v in self.cache.items() if now - v.time > stale_after] for key in keys_to_delete: del self.cache[key] + # Update size metrics after deletion + if keys_to_delete: + self._update_size_metrics() + + def _get_entry_count(self) -> int: + """Get the number of entries in the memory cache.""" + with self.lock: + return len(self.cache) + + def _get_total_size(self) -> int: + """Get the total size of cached values in bytes.""" + with self.lock: + total = 0 + for entry in self.cache.values(): + total += self._estimate_size(entry.value) + return total diff --git a/src/cachier/cores/mongo.py b/src/cachier/cores/mongo.py index 5f12e7db..2ace3aaa 100644 --- a/src/cachier/cores/mongo.py +++ b/src/cachier/cores/mongo.py @@ -13,7 +13,7 @@ import warnings # to warn if pymongo is missing from contextlib import suppress from datetime import datetime, timedelta -from typing import Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional, Tuple from .._types import HashFunc, Mongetter from ..config import CacheEntry @@ -25,6 +25,9 @@ from .base import RecalculationNeeded, _BaseCore, _get_func_str +if TYPE_CHECKING: + from ..metrics import CacheMetrics + MONGO_SLEEP_DURATION_IN_SEC = 1 @@ -41,6 +44,7 @@ def __init__( mongetter: Optional[Mongetter], wait_for_calc_timeout: Optional[int], entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): if "pymongo" not in sys.modules: warnings.warn( @@ -53,6 +57,7 @@ def __init__( hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=entry_size_limit, + metrics=metrics, ) if mongetter is None: raise MissingMongetter("must specify ``mongetter`` when using the mongo core") @@ -143,7 +148,7 @@ async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: ) return key, entry - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False mongo_collection = self._ensure_collection() @@ -165,7 +170,7 @@ def set_entry(self, key: str, func_res: Any) -> bool: ) return True - async def aset_entry(self, key: str, func_res: Any) -> bool: + async def _aset_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False mongo_collection = await self._ensure_collection_async() diff --git a/src/cachier/cores/pickle.py b/src/cachier/cores/pickle.py index 52f304c2..fda3b308 100644 --- a/src/cachier/cores/pickle.py +++ b/src/cachier/cores/pickle.py @@ -14,7 +14,7 @@ import time from contextlib import suppress from datetime import datetime, timedelta -from typing import IO, Any, Dict, Optional, Tuple, Union, cast +from typing import IO, TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast import portalocker # to lock on pickle cache IO from watchdog.events import PatternMatchingEventHandler @@ -26,6 +26,9 @@ # Alternative: https://github.com/WoLpH/portalocker from .base import _BaseCore +if TYPE_CHECKING: + from ..metrics import CacheMetrics + class _PickleCore(_BaseCore): """The pickle core class for cachier.""" @@ -87,8 +90,9 @@ def __init__( separate_files: Optional[bool], wait_for_calc_timeout: Optional[int], entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): - super().__init__(hash_func, wait_for_calc_timeout, entry_size_limit) + super().__init__(hash_func, wait_for_calc_timeout, entry_size_limit, metrics) self._cache_dict: Dict[str, CacheEntry] = {} self.reload = _update_with_defaults(pickle_reload, "pickle_reload") self.cache_dir = os.path.expanduser(_update_with_defaults(cache_dir, "cache_dir")) @@ -244,7 +248,7 @@ async def aget_entry(self, args: tuple[Any, ...], kwds: dict[str, Any]) -> Tuple async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: return self.get_entry_by_key(key) - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False key_data = CacheEntry( @@ -264,8 +268,8 @@ def set_entry(self, key: str, func_res: Any) -> bool: self._save_cache(cache) return True - async def aset_entry(self, key: str, func_res: Any) -> bool: - return self.set_entry(key, func_res) + async def _aset_entry(self, key: str, func_res: Any) -> bool: + return self._set_entry(key, func_res) def mark_entry_being_calculated_separate_files(self, key: str) -> None: self._save_cache( diff --git a/src/cachier/cores/redis.py b/src/cachier/cores/redis.py index fdaac014..df6eea00 100644 --- a/src/cachier/cores/redis.py +++ b/src/cachier/cores/redis.py @@ -4,7 +4,7 @@ import time import warnings from datetime import datetime, timedelta -from typing import Any, Awaitable, Callable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Union try: import redis @@ -17,6 +17,9 @@ from ..config import CacheEntry from .base import RecalculationNeeded, _BaseCore, _get_func_str +if TYPE_CHECKING: + from ..metrics import CacheMetrics + REDIS_SLEEP_DURATION_IN_SEC = 1 @@ -39,6 +42,7 @@ def __init__( wait_for_calc_timeout: Optional[int] = None, key_prefix: str = "cachier", entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): if not REDIS_AVAILABLE: warnings.warn( @@ -48,7 +52,10 @@ def __init__( ) super().__init__( - hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=entry_size_limit + hash_func=hash_func, + wait_for_calc_timeout=wait_for_calc_timeout, + entry_size_limit=entry_size_limit, + metrics=metrics, ) if redis_client is None: raise MissingRedisClient("must specify ``redis_client`` when using the redis core") @@ -214,7 +221,7 @@ async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: warnings.warn(f"Redis get_entry_by_key failed: {e}", stacklevel=2) return key, None - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: """Map the given result to the given key in Redis.""" if not self._should_store(func_res): return False @@ -242,7 +249,7 @@ def set_entry(self, key: str, func_res: Any) -> bool: warnings.warn(f"Redis set_entry failed: {e}", stacklevel=2) return False - async def aset_entry(self, key: str, func_res: Any) -> bool: + async def _aset_entry(self, key: str, func_res: Any) -> bool: """Map the given result to the given key in Redis using async operations.""" if not self._should_store(func_res): return False diff --git a/src/cachier/cores/s3.py b/src/cachier/cores/s3.py index 239612ec..3dda917e 100644 --- a/src/cachier/cores/s3.py +++ b/src/cachier/cores/s3.py @@ -6,7 +6,7 @@ import time import warnings from datetime import datetime, timedelta -from typing import Any, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple try: import boto3 # type: ignore[import-untyped] @@ -20,6 +20,9 @@ from ..config import CacheEntry from .base import RecalculationNeeded, _BaseCore, _get_func_str +if TYPE_CHECKING: + from ..metrics import CacheMetrics + S3_SLEEP_DURATION_IN_SEC = 1 @@ -62,6 +65,8 @@ class _S3Core(_BaseCore): Optional ``botocore.config.Config`` object passed when creating the client. entry_size_limit : int, optional Maximum allowed size in bytes of a cached value. + metrics : CacheMetrics, optional + Metrics collector for tracking cache performance. """ @@ -77,6 +82,7 @@ def __init__( s3_endpoint_url: Optional[str] = None, s3_config: Optional[Any] = None, entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): if not BOTO3_AVAILABLE: _safe_warn( @@ -88,6 +94,7 @@ def __init__( hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=entry_size_limit, + metrics=metrics, ) if not s3_bucket: @@ -199,7 +206,7 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: _safe_warn(f"S3 get_entry_by_key failed: {exc}") return key, None - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: """Store a function result in S3 under the given key. Parameters @@ -400,14 +407,14 @@ async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: """ return await asyncio.to_thread(self.get_entry_by_key, key) - async def aset_entry(self, key: str, func_res: Any) -> bool: + async def _aset_entry(self, key: str, func_res: Any) -> bool: """Async-compatible variant of :meth:`set_entry`. This method delegates to the sync implementation via ``asyncio.to_thread`` because boto3 is sync-only. """ - return await asyncio.to_thread(self.set_entry, key, func_res) + return await asyncio.to_thread(self._set_entry, key, func_res) async def amark_entry_being_calculated(self, key: str) -> None: """Async-compatible variant of :meth:`mark_entry_being_calculated`. diff --git a/src/cachier/cores/sql.py b/src/cachier/cores/sql.py index 1b581b0c..28e5d721 100644 --- a/src/cachier/cores/sql.py +++ b/src/cachier/cores/sql.py @@ -4,7 +4,7 @@ import threading from contextlib import suppress from datetime import datetime, timedelta -from typing import Any, Callable, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast try: from sqlalchemy import ( @@ -30,10 +30,14 @@ except ImportError: SQLALCHEMY_AVAILABLE = False -from .._types import HashFunc -from ..config import CacheEntry +from cachier._types import HashFunc +from cachier.config import CacheEntry + from .base import RecalculationNeeded, _BaseCore, _get_func_str +if TYPE_CHECKING: + from ..metrics import CacheMetrics + if SQLALCHEMY_AVAILABLE: Base = declarative_base() @@ -69,6 +73,7 @@ def __init__( sql_engine: Optional[Union[str, "Engine", "AsyncEngine", Callable[[], "Engine"], Callable[[], "AsyncEngine"]]], wait_for_calc_timeout: Optional[int] = None, entry_size_limit: Optional[int] = None, + metrics: Optional["CacheMetrics"] = None, ): if not SQLALCHEMY_AVAILABLE: raise ImportError("SQLAlchemy is required for the SQL core. Install with `pip install SQLAlchemy`.") @@ -76,6 +81,7 @@ def __init__( hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout, entry_size_limit=entry_size_limit, + metrics=metrics, ) self._lock = threading.RLock() self._func_str = None @@ -200,7 +206,7 @@ async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: ) return key, entry - def set_entry(self, key: str, func_res: Any) -> bool: + def _set_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False session_factory = self._get_sync_session() @@ -258,7 +264,7 @@ def set_entry(self, key: str, func_res: Any) -> bool: session.commit() return True - async def aset_entry(self, key: str, func_res: Any) -> bool: + async def _aset_entry(self, key: str, func_res: Any) -> bool: if not self._should_store(func_res): return False session_factory = await self._get_async_session() diff --git a/src/cachier/exporters/__init__.py b/src/cachier/exporters/__init__.py new file mode 100644 index 00000000..80e15f25 --- /dev/null +++ b/src/cachier/exporters/__init__.py @@ -0,0 +1,6 @@ +"""Metrics exporters for cachier.""" + +from .base import MetricsExporter +from .prometheus import PrometheusExporter + +__all__ = ["MetricsExporter", "PrometheusExporter"] diff --git a/src/cachier/exporters/base.py b/src/cachier/exporters/base.py new file mode 100644 index 00000000..c0c461db --- /dev/null +++ b/src/cachier/exporters/base.py @@ -0,0 +1,59 @@ +"""Base interface for metrics exporters.""" + +# This file is part of Cachier. +# https://github.com/python-cachier/cachier + +# Licensed under the MIT license: +# http://www.opensource.org/licenses/MIT-license + +import abc +from typing import Any, Callable + + +class MetricsExporter(metaclass=abc.ABCMeta): + """Abstract base class for metrics exporters. + + Exporters collect metrics from cached functions and export them to monitoring systems like Prometheus, StatsD, + CloudWatch, etc. + + """ + + @abc.abstractmethod + def register_function(self, func: Callable) -> None: + """Register a cached function for metrics export. + + Parameters + ---------- + func : Callable + A function decorated with @cachier that has metrics enabled + + Raises + ------ + ValueError + If the function doesn't have metrics enabled + + """ + + def export_metrics(self, func_name: str, metrics: Any) -> None: # noqa: B027 + """Export metrics for a specific function. + + Default implementation is a no-op. Subclasses may override to push + metrics to a specific backend, but this is not required -- pull-based + exporters (e.g. Prometheus custom collectors) typically do not need it. + + Parameters + ---------- + func_name : str + Name of the function + metrics : MetricSnapshot + Metrics snapshot to export + + """ + + @abc.abstractmethod + def start(self) -> None: + """Start the exporter (e.g., start HTTP server for Prometheus).""" + + @abc.abstractmethod + def stop(self) -> None: + """Stop the exporter and clean up resources.""" diff --git a/src/cachier/exporters/prometheus.py b/src/cachier/exporters/prometheus.py new file mode 100644 index 00000000..205d2ddf --- /dev/null +++ b/src/cachier/exporters/prometheus.py @@ -0,0 +1,372 @@ +"""Prometheus exporter for cachier metrics.""" + +# This file is part of Cachier. +# https://github.com/python-cachier/cachier + +# Licensed under the MIT license: +# http://www.opensource.org/licenses/MIT-license + +import threading +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, cast + +from .base import MetricsExporter + +if TYPE_CHECKING: + from ..metrics import CacheMetrics, MetricSnapshot + + +class _MetricsEnabledCallable(Protocol): + """Callable wrapper that exposes cachier metrics.""" + + __module__: str + __name__: str + metrics: Optional["CacheMetrics"] + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Invoke the wrapped callable.""" + + +def _get_func_metrics(func: Callable[..., Any]) -> Optional["CacheMetrics"]: + """Return the metrics object for a registered function, if available.""" + metrics_func = cast(_MetricsEnabledCallable, func) + return metrics_func.metrics + + +try: + import prometheus_client # type: ignore[import-not-found] + from prometheus_client import CollectorRegistry # type: ignore[import-not-found] + from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily # type: ignore[import-not-found] + + PROMETHEUS_CLIENT_AVAILABLE = True +except (ImportError, AttributeError): # pragma: no cover + PROMETHEUS_CLIENT_AVAILABLE = False + prometheus_client = None # type: ignore[assignment] + CollectorRegistry = None # type: ignore[assignment] + CounterMetricFamily = None # type: ignore[assignment] + GaugeMetricFamily = None # type: ignore[assignment] + + +class CachierCollector: + """Custom Prometheus collector that pulls metrics from registered functions.""" + + def __init__(self, exporter: "PrometheusExporter") -> None: + self.exporter = exporter + + def describe(self) -> list: + """Return an empty list; metrics are described at collect time.""" + return [] + + def collect(self) -> Any: + """Collect metrics from all registered functions.""" + # Snapshot all metrics in one lock acquisition for consistency + with self.exporter._lock: + snapshots: Dict[str, "MetricSnapshot"] = {} + for func_name, func in self.exporter._registered_functions.items(): + m = _get_func_metrics(func) + if m is not None: + snapshots[func_name] = m.get_stats() + + # Build metric families outside the lock using the snapshots + hits = CounterMetricFamily("cachier_cache_hits_total", "Total cache hits", labels=["function"]) + misses = CounterMetricFamily("cachier_cache_misses_total", "Total cache misses", labels=["function"]) + hit_rate = GaugeMetricFamily("cachier_cache_hit_rate", "Cache hit rate percentage", labels=["function"]) + stale_hits = CounterMetricFamily("cachier_stale_hits_total", "Total stale cache hits", labels=["function"]) + recalculations = CounterMetricFamily( + "cachier_recalculations_total", "Total cache recalculations", labels=["function"] + ) + wait_timeouts = CounterMetricFamily("cachier_wait_timeouts_total", "Total wait timeouts", labels=["function"]) + entry_count = GaugeMetricFamily("cachier_entry_count", "Current number of cache entries", labels=["function"]) + cache_size = GaugeMetricFamily("cachier_cache_size_bytes", "Total cache size in bytes", labels=["function"]) + + for func_name, stats in snapshots.items(): + hits.add_metric([func_name], stats.hits) + misses.add_metric([func_name], stats.misses) + hit_rate.add_metric([func_name], stats.hit_rate) + stale_hits.add_metric([func_name], stats.stale_hits) + recalculations.add_metric([func_name], stats.recalculations) + wait_timeouts.add_metric([func_name], stats.wait_timeouts) + entry_count.add_metric([func_name], stats.entry_count) + cache_size.add_metric([func_name], stats.total_size_bytes) + + # Yield metrics one by one as required by Prometheus collector protocol + yield hits + yield misses + yield hit_rate + yield stale_hits + yield recalculations + yield wait_timeouts + yield entry_count + yield cache_size + + +class PrometheusExporter(MetricsExporter): + """Export cachier metrics in Prometheus format. + + This exporter provides a simple HTTP server that exposes metrics in + Prometheus text format. It can be used with prometheus_client or + as a standalone exporter. + + Parameters + ---------- + port : int, optional + Port for the HTTP server, by default 9090 + use_prometheus_client : bool, optional + Whether to use prometheus_client library if available, by default True + + Examples + -------- + >>> from cachier import cachier + >>> from cachier.exporters import PrometheusExporter + >>> + >>> @cachier(backend='memory', enable_metrics=True) + ... def my_func(x): + ... return x * 2 + >>> + >>> exporter = PrometheusExporter(port=9090) + >>> exporter.register_function(my_func) + >>> exporter.start() + + """ + + def __init__(self, port: int = 9090, use_prometheus_client: bool = True, host: str = "127.0.0.1"): + """Initialize Prometheus exporter. + + Parameters + ---------- + port : int + HTTP server port + use_prometheus_client : bool + Whether to use prometheus_client library + host : str + Host address to bind to (default: 127.0.0.1 for localhost only) + + """ + self.port = port + self.host = host + self.use_prometheus_client = use_prometheus_client + self._registered_functions: Dict[str, _MetricsEnabledCallable] = {} + self._lock = threading.Lock() + self._server: Optional[Any] = None + self._server_thread: Optional[threading.Thread] = None + + # Try to import prometheus_client if requested + self._prom_client = None + # Per-instance registry to avoid double-registration on the global + # REGISTRY when multiple PrometheusExporter instances are created. + self._registry: Optional[Any] = None + if use_prometheus_client and PROMETHEUS_CLIENT_AVAILABLE: + self._prom_client = prometheus_client + self._init_prometheus_metrics() + self._setup_collector() + + def _setup_collector(self) -> None: + """Set up a custom collector to pull metrics from registered functions.""" + if not self._prom_client: # pragma: no cover + return + + # Use a per-instance registry so multiple exporters don't conflict + self._registry = CollectorRegistry() + self._registry.register(CachierCollector(self)) + + def _init_prometheus_metrics(self) -> None: + """Initialize Prometheus metrics using prometheus_client. + + Note: With custom collector, we don't need to pre-define metrics. + The collector will generate them dynamically at scrape time. + + """ + # Metrics are now handled by the custom collector in _setup_collector() + pass + + def register_function(self, func: Callable[..., Any]) -> None: + """Register a cached function for metrics export. + + Parameters + ---------- + func : Callable + A function decorated with @cachier that has metrics enabled + + Raises + ------ + ValueError + If the function doesn't have metrics enabled + + """ + metrics = _get_func_metrics(func) + if metrics is None: + raise ValueError( + f"Function {func.__name__} does not have metrics enabled. Use @cachier(enable_metrics=True)" + ) + + with self._lock: + func_name = f"{func.__module__}.{func.__name__}" + self._registered_functions[func_name] = cast(_MetricsEnabledCallable, func) + + def export_metrics(self, func_name: str, metrics: Any) -> None: + """Export metrics for a specific function to Prometheus. + + With custom collector mode, metrics are automatically pulled at scrape time. + This method is kept for backward compatibility but is a no-op when using + prometheus_client with custom collector. + + Parameters + ---------- + func_name : str + Name of the function + metrics : MetricSnapshot + Metrics snapshot to export + + """ + # With custom collector, metrics are pulled automatically at scrape time + # No need to manually push metrics + pass + + def _generate_text_metrics(self) -> str: + """Generate Prometheus text format metrics. + + Returns + ------- + str + Metrics in Prometheus text format + + """ + # Snapshot all metrics in one lock acquisition for consistency + with self._lock: + snapshots: Dict[str, "MetricSnapshot"] = {} + for func_name, func in self._registered_functions.items(): + m = _get_func_metrics(func) + if m is not None: + snapshots[func_name] = m.get_stats() + + # (name, help, type, getter, fmt) + metric_defs = [ + ("cachier_cache_hits_total", "Total cache hits", "counter", lambda s: s.hits, "{}"), + ("cachier_cache_misses_total", "Total cache misses", "counter", lambda s: s.misses, "{}"), + ("cachier_cache_hit_rate", "Cache hit rate percentage", "gauge", lambda s: s.hit_rate, "{:.2f}"), + ( + "cachier_avg_latency_ms", + "Average cache operation latency in milliseconds", + "gauge", + lambda s: s.avg_latency_ms, + "{:.4f}", + ), + ("cachier_stale_hits_total", "Total stale cache hits", "counter", lambda s: s.stale_hits, "{}"), + ("cachier_recalculations_total", "Total cache recalculations", "counter", lambda s: s.recalculations, "{}"), + ("cachier_wait_timeouts_total", "Total wait timeouts", "counter", lambda s: s.wait_timeouts, "{}"), + ("cachier_entry_count", "Current cache entries", "gauge", lambda s: s.entry_count, "{}"), + ("cachier_cache_size_bytes", "Total cache size in bytes", "gauge", lambda s: s.total_size_bytes, "{}"), + ( + "cachier_size_limit_rejections_total", + "Entries rejected due to size limit", + "counter", + lambda s: s.size_limit_rejections, + "{}", + ), + ] + + lines: list[str] = [] + for name, help_text, metric_type, getter, fmt in metric_defs: + lines.append(f"# HELP {name} {help_text}") + lines.append(f"# TYPE {name} {metric_type}") + for func_name, stats in snapshots.items(): + value = fmt.format(getter(stats)) + lines.append(f'{name}{{function="{func_name}"}} {value}') + lines.append("") + + return "\n".join(lines) + + def start(self) -> None: + """Start the Prometheus exporter. + + If prometheus_client is available, starts the HTTP server using the per-instance registry. Otherwise, provides a + simple HTTP server for text format metrics. + + """ + if self._prom_client and self._registry is not None: + # Use a simple HTTP server that serves from our per-instance registry + # instead of prometheus_client's start_http_server which uses the + # global REGISTRY. + self._start_prometheus_server() + else: + # Provide simple HTTP server for text format + self._start_simple_server() + + def _start_prometheus_server(self) -> None: + """Start an HTTP server that serves metrics from the per-instance registry.""" + from http.server import BaseHTTPRequestHandler, HTTPServer + + from prometheus_client import exposition + + if self._registry is None: # pragma: no cover + raise RuntimeError("registry must be initialized before starting server") + registry = self._registry + + class MetricsHandler(BaseHTTPRequestHandler): + """HTTP handler that serves Prometheus metrics from a specific registry.""" + + def do_GET(self) -> None: + """Handle GET requests for /metrics endpoint.""" + if self.path == "/metrics": + output = exposition.generate_latest(registry) + self.send_response(200) + self.send_header("Content-Type", exposition.CONTENT_TYPE_LATEST) + self.end_headers() + self.wfile.write(output) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, fmt: str, *args: Any) -> None: + """Suppress log messages.""" + + server = HTTPServer((self.host, self.port), MetricsHandler) + self._server = server + + def run_server() -> None: + server.serve_forever() + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + def _start_simple_server(self) -> None: + """Start a simple HTTP server for Prometheus text format.""" + from http.server import BaseHTTPRequestHandler, HTTPServer + + exporter = self + + class MetricsHandler(BaseHTTPRequestHandler): + """HTTP handler that serves Prometheus text-format metrics.""" + + def do_GET(self) -> None: + """Handle GET requests for /metrics endpoint.""" + if self.path == "/metrics": + self.send_response(200) + self.send_header("Content-Type", "text/plain") + self.end_headers() + metrics_text = exporter._generate_text_metrics() + self.wfile.write(metrics_text.encode()) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, fmt: str, *args: Any) -> None: + """Suppress log messages.""" + + server = HTTPServer((self.host, self.port), MetricsHandler) + self._server = server + + def run_server() -> None: + server.serve_forever() + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + def stop(self) -> None: + """Stop the Prometheus exporter and clean up resources.""" + if self._server: + self._server.shutdown() + self._server.server_close() + self._server = None + if self._server_thread: + self._server_thread.join(timeout=5) + self._server_thread = None diff --git a/src/cachier/metrics.py b/src/cachier/metrics.py new file mode 100644 index 00000000..c6b42eac --- /dev/null +++ b/src/cachier/metrics.py @@ -0,0 +1,384 @@ +"""Cache metrics and observability framework for cachier.""" + +# This file is part of Cachier. +# https://github.com/python-cachier/cachier + +# Licensed under the MIT license: +# http://www.opensource.org/licenses/MIT-license + +import random +import threading +import time +from collections import deque +from dataclasses import dataclass +from datetime import timedelta +from typing import Deque, Optional + + +@dataclass +class MetricSnapshot: + """Snapshot of cache metrics at a point in time. + + Attributes + ---------- + hits : int + Number of cache hits + misses : int + Number of cache misses. Note: stale cache hits are also counted + as misses, so stale_hits and misses may overlap. + hit_rate : float + Cache hit rate as percentage (0-100). Note: stale cache hits are + counted as misses when computing the hit rate, so the rate may + appear lower than expected when stale entries are served. + total_calls : int + Total number of cache accesses + avg_latency_ms : float + Average operation latency in milliseconds + stale_hits : int + Number of times stale cache entries were accessed + recalculations : int + Number of cache recalculations performed + wait_timeouts : int + Number of wait timeouts that occurred + entry_count : int + Current number of entries in cache + total_size_bytes : int + Total size of cache in bytes. Only populated for the memory + backend; all other backends report 0. + size_limit_rejections : int + Number of entries rejected due to size limit + + """ + + hits: int = 0 + misses: int = 0 + hit_rate: float = 0.0 + total_calls: int = 0 + avg_latency_ms: float = 0.0 + stale_hits: int = 0 + recalculations: int = 0 + wait_timeouts: int = 0 + entry_count: int = 0 + total_size_bytes: int = 0 + size_limit_rejections: int = 0 + + +@dataclass +class _TimestampedMetric: + """Internal metric with monotonic timestamp for time-windowed aggregation. + + Uses time.perf_counter() for monotonic timestamps that are immune to + system clock adjustments. + + Parameters + ---------- + timestamp : float + Monotonic timestamp when the metric was recorded (from time.perf_counter()) + value : float + The metric value + + """ + + timestamp: float + value: float + + +class CacheMetrics: + """Thread-safe metrics collector for cache operations. + + This class collects and aggregates cache performance metrics including + hit/miss rates, latencies, and size information. Metrics are collected + in a thread-safe manner and can be aggregated over time windows. + + Parameters + ---------- + sampling_rate : float, optional + Sampling rate for metrics collection (0.0-1.0), by default 1.0 + Lower values reduce overhead at the cost of accuracy + window_sizes : list of timedelta, optional + Time windows to track for aggregated metrics, by default [1 minute, 1 hour, 1 day] + + Examples + -------- + >>> metrics = CacheMetrics(sampling_rate=0.1) + >>> metrics.record_hit() + >>> metrics.record_miss() + >>> stats = metrics.get_stats() + >>> print(f"Hit rate: {stats.hit_rate}%") + + """ + + def __init__(self, sampling_rate: float = 1.0, window_sizes: Optional[list[timedelta]] = None): + """Initialize cache metrics collector. + + Parameters + ---------- + sampling_rate : float + Sampling rate between 0.0 and 1.0 + window_sizes : list of timedelta, optional + Time windows for aggregated metrics + + """ + if not 0.0 <= sampling_rate <= 1.0: + raise ValueError("sampling_rate must be between 0.0 and 1.0") + + self._lock = threading.RLock() + self._sampling_rate = sampling_rate + + # Core counters + self._hits = 0 + self._misses = 0 + self._stale_hits = 0 + self._recalculations = 0 + self._wait_timeouts = 0 + self._size_limit_rejections = 0 + + # Latency tracking - time-windowed + if window_sizes is None: + window_sizes = [ + timedelta(minutes=1), + timedelta(hours=1), + timedelta(days=1), + ] + self._window_sizes = window_sizes + self._max_window = max(window_sizes) if window_sizes else timedelta(0) + + # Use deque with fixed size based on expected frequency + # Assuming ~1000 ops/sec max, keep 1 day of data = 86.4M points + # Limit to 100K points for memory efficiency + max_latency_points = 100000 + self._latencies: Deque[_TimestampedMetric] = deque(maxlen=max_latency_points) + + # Size tracking + self._entry_count = 0 + self._total_size_bytes = 0 + + self._random = random.Random() # noqa: S311 + + def _should_sample(self) -> bool: + """Determine if this metric should be sampled. + + Returns + ------- + bool + True if metric should be recorded + + """ + if self._sampling_rate >= 1.0: + return True + return self._random.random() < self._sampling_rate + + def _record_counter(self, attr: str) -> None: + """Increment a named counter if sampling allows it. + + Parameters + ---------- + attr : str + Name of the instance attribute to increment (e.g. ``"_hits"``) + + """ + if self._should_sample(): + with self._lock: + self.__dict__[attr] += 1 + + def record_hit(self) -> None: + """Record a cache hit.""" + self._record_counter("_hits") + + def record_miss(self) -> None: + """Record a cache miss.""" + self._record_counter("_misses") + + def record_stale_hit(self) -> None: + """Record a stale cache hit.""" + self._record_counter("_stale_hits") + + def record_recalculation(self) -> None: + """Record a cache recalculation.""" + self._record_counter("_recalculations") + + def record_wait_timeout(self) -> None: + """Record a wait timeout.""" + self._record_counter("_wait_timeouts") + + def record_size_limit_rejection(self) -> None: + """Record an entry rejection due to size limit.""" + self._record_counter("_size_limit_rejections") + + def record_latency(self, latency_seconds: float) -> None: + """Record an operation latency. + + Parameters + ---------- + latency_seconds : float + Operation latency in seconds + + """ + if not self._should_sample(): + return + with self._lock: + # Use monotonic timestamp for immune-to-clock-adjustment windowing + timestamp = time.perf_counter() + self._latencies.append(_TimestampedMetric(timestamp=timestamp, value=latency_seconds)) + + def update_size_metrics(self, entry_count: int, total_size_bytes: int) -> None: + """Update cache size metrics. + + Parameters + ---------- + entry_count : int + Current number of entries in cache + total_size_bytes : int + Total size of cache in bytes + + """ + with self._lock: + self._entry_count = entry_count + self._total_size_bytes = total_size_bytes + + def _calculate_avg_latency(self, window: Optional[timedelta] = None) -> float: + """Calculate average latency within a time window. + + Parameters + ---------- + window : timedelta, optional + Time window to consider. If None, uses all data. + + Returns + ------- + float + Average latency in milliseconds + + """ + # Use monotonic clock for cutoff calculation + now = time.perf_counter() + cutoff = 0.0 if not window else now - window.total_seconds() + + latencies = [metric.value for metric in self._latencies if metric.timestamp >= cutoff] + + if not latencies: + return 0.0 + + return (sum(latencies) / len(latencies)) * 1000 # Convert to ms + + def get_stats(self, window: Optional[timedelta] = None) -> MetricSnapshot: + """Get current cache statistics. + + Parameters + ---------- + window : timedelta, optional + Time window for windowed metrics (latency). + If None, returns all-time statistics. + + Returns + ------- + MetricSnapshot + Snapshot of current cache metrics + + """ + with self._lock: + total_calls = self._hits + self._misses + hit_rate = (self._hits / total_calls * 100) if total_calls > 0 else 0.0 + avg_latency = self._calculate_avg_latency(window) + + return MetricSnapshot( + hits=self._hits, + misses=self._misses, + hit_rate=hit_rate, + total_calls=total_calls, + avg_latency_ms=avg_latency, + stale_hits=self._stale_hits, + recalculations=self._recalculations, + wait_timeouts=self._wait_timeouts, + entry_count=self._entry_count, + total_size_bytes=self._total_size_bytes, + size_limit_rejections=self._size_limit_rejections, + ) + + def reset(self) -> None: + """Reset all metrics to zero. + + Thread-safe method to clear all collected metrics. + + """ + with self._lock: + self._hits = 0 + self._misses = 0 + self._stale_hits = 0 + self._recalculations = 0 + self._wait_timeouts = 0 + self._size_limit_rejections = 0 + self._latencies.clear() + self._entry_count = 0 + self._total_size_bytes = 0 + + +class MetricsContext: + """Null-object context manager for cache operation instrumentation. + + Wraps an optional ``CacheMetrics`` instance so call-path code can invoke + ``record_*`` methods unconditionally without ``if metrics:`` guards. + Starts the latency timer on ``__enter__`` and records it automatically on + ``__exit__``, covering every return path including exceptions. + + Parameters + ---------- + metrics : CacheMetrics, optional + Metrics object to record to. When ``None`` all operations are no-ops. + + Examples + -------- + >>> metrics = CacheMetrics() + >>> with MetricsContext(metrics) as m: + ... m.record_miss() + ... # Do cache operation + ... m.record_recalculation() + + """ + + __slots__ = ("_m", "_start") + + def __init__(self, metrics: Optional[CacheMetrics]) -> None: + self._m = metrics + self._start: float = 0.0 + + def __enter__(self) -> "MetricsContext": + """Start timing the operation.""" + if self._m is not None: + self._start = time.perf_counter() + return self + + def __exit__(self, *_: object) -> None: + """Record the operation latency.""" + if self._m is not None: + self._m.record_latency(time.perf_counter() - self._start) + + def record_hit(self) -> None: + """Record a cache hit.""" + if self._m: + self._m.record_hit() + + def record_miss(self) -> None: + """Record a cache miss.""" + if self._m: + self._m.record_miss() + + def record_stale_hit(self) -> None: + """Record a stale cache hit.""" + if self._m: + self._m.record_stale_hit() + + def record_recalculation(self) -> None: + """Record a cache recalculation.""" + if self._m: + self._m.record_recalculation() + + def record_wait_timeout(self) -> None: + """Record a wait timeout.""" + if self._m: + self._m.record_wait_timeout() + + def record_size_limit_rejection(self) -> None: + """Record an entry rejection due to size limit.""" + if self._m: + self._m.record_size_limit_rejection() diff --git a/tests/sql_tests/test_async_sql_core.py b/tests/sql_tests/test_async_sql_core.py index 410f86c3..555d5708 100644 --- a/tests/sql_tests/test_async_sql_core.py +++ b/tests/sql_tests/test_async_sql_core.py @@ -253,7 +253,7 @@ def scalar_one_or_none(self): return DummyResult() - monkeypatch.setitem(_SQLCore.aset_entry.__globals__, "insert", fake_insert) + monkeypatch.setitem(_SQLCore._aset_entry.__globals__, "insert", fake_insert) monkeypatch.setattr(AsyncSession, "execute", fake_execute) core = _SQLCore(hash_func=None, sql_engine=async_sql_engine) diff --git a/tests/sql_tests/test_sql_core.py b/tests/sql_tests/test_sql_core.py index cd27e0eb..4601f7b2 100644 --- a/tests/sql_tests/test_sql_core.py +++ b/tests/sql_tests/test_sql_core.py @@ -62,6 +62,8 @@ def f(x, y): @pytest.mark.sql def test_sql_core_keywords(): + """Keyword arguments produce a cache hit the same as positional arguments.""" + @cachier(backend="sql", sql_engine=SQL_CONN_STR) def f(x, y): return random() + x + y @@ -70,14 +72,7 @@ def f(x, y): v1 = f(1, y=2) v2 = f(1, y=2) assert v1 == v2 - v3 = f(1, y=2, cachier__skip_cache=True) - assert v3 != v1 - v4 = f(1, y=2) - assert v4 == v1 - v5 = f(1, y=2, cachier__overwrite_cache=True) - assert v5 != v1 - v6 = f(1, y=2) - assert v6 == v5 + f.clear_cache() @pytest.mark.sql @@ -100,24 +95,6 @@ def f(x, y): assert v3 != v1 -@pytest.mark.sql -def test_sql_overwrite_and_skip_cache(): - @cachier(backend="sql", sql_engine=SQL_CONN_STR) - def f(x): - return random() + x - - f.clear_cache() - v1 = f(1) - v2 = f(1) - assert v1 == v2 - v3 = f(1, cachier__skip_cache=True) - assert v3 != v1 - v4 = f(1, cachier__overwrite_cache=True) - assert v4 != v1 - v5 = f(1) - assert v5 == v4 - - @pytest.mark.sql def test_sql_concurrency(): @cachier(backend="sql", sql_engine=SQL_CONN_STR) @@ -424,7 +401,7 @@ def scalar_one_or_none(self): return DummyResult() - monkeypatch.setitem(_SQLCore.set_entry.__globals__, "insert", fake_insert) + monkeypatch.setitem(_SQLCore._set_entry.__globals__, "insert", fake_insert) monkeypatch.setattr(Session, "execute", fake_execute) core = _SQLCore(hash_func=None, sql_engine=SQL_CONN_STR) diff --git a/tests/test_base_core.py b/tests/test_base_core.py index 19fd9f5f..589eac66 100644 --- a/tests/test_base_core.py +++ b/tests/test_base_core.py @@ -5,7 +5,7 @@ import pytest -from cachier.cores.base import _BaseCore +from cachier.cores.base import RecalculationNeeded, _BaseCore class ConcreteCachingCore(_BaseCore): @@ -26,7 +26,7 @@ def get_entry_by_key(self, key, reload=False): """Retrieve an entry by its key.""" return key, None - def set_entry(self, key, func_res): + def _set_entry(self, key, func_res): """Store an entry in the cache.""" self.last_set = (key, func_res) return True @@ -143,3 +143,26 @@ async def test_base_core_aset_entry_fallback(): assert result is True assert core.last_set == (key, 99) + + +def test_base_core_size_hooks_default_to_zero(): + """Base metric hooks should return zero when a backend does not override them.""" + core = ConcreteCachingCore(hash_func=None, wait_for_calc_timeout=None) + + assert core._get_entry_count() == 0 + assert core._get_total_size() == 0 + + +def test_check_calc_timeout_raises_recalculation_needed(): + """check_calc_timeout should raise when elapsed time reaches the configured timeout.""" + core = ConcreteCachingCore(hash_func=None, wait_for_calc_timeout=2) + + with pytest.raises(RecalculationNeeded): + core.check_calc_timeout(2) + + +def test_check_calc_timeout_does_not_raise_before_timeout(): + """check_calc_timeout should not raise before the configured timeout.""" + core = ConcreteCachingCore(hash_func=None, wait_for_calc_timeout=2) + + core.check_calc_timeout(1) diff --git a/tests/test_exporters.py b/tests/test_exporters.py new file mode 100644 index 00000000..67d651d4 --- /dev/null +++ b/tests/test_exporters.py @@ -0,0 +1,585 @@ +"""Tests for metrics exporters.""" + +import re + +import pytest + +from cachier import cachier +from cachier.exporters import MetricsExporter, PrometheusExporter + + +@pytest.mark.memory +def test_prometheus_exporter_registration(): + """Test registering a function with PrometheusExporter.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + exporter = PrometheusExporter(port=0) + + # Should succeed with metrics-enabled function + exporter.register_function(test_func) + assert test_func in exporter._registered_functions.values() + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_exporter_requires_metrics(): + """Test that PrometheusExporter requires metrics to be enabled.""" + + @cachier(backend="memory") # metrics disabled by default + def test_func(x): + return x * 2 + + exporter = PrometheusExporter(port=0) + + # Should raise error for function without metrics + with pytest.raises(ValueError, match="does not have metrics enabled"): + exporter.register_function(test_func) + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_exporter_text_format(): + """Test that PrometheusExporter generates valid text format.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + exporter = PrometheusExporter(port=9093, use_prometheus_client=False) + exporter.register_function(test_func) + + # Generate some metrics + test_func(5) + test_func(5) + + # Generate text format + metrics_text = exporter._generate_text_metrics() + + # Check for Prometheus format elements + assert "cachier_cache_hits_total" in metrics_text + assert "cachier_cache_misses_total" in metrics_text + assert "cachier_cache_hit_rate" in metrics_text + assert "# HELP" in metrics_text + assert "# TYPE" in metrics_text + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_exporter_multiple_functions(): + """Test PrometheusExporter with multiple functions.""" + + @cachier(backend="memory", enable_metrics=True) + def func1(x): + return x * 2 + + @cachier(backend="memory", enable_metrics=True) + def func2(x): + return x * 3 + + func1.clear_cache() + func2.clear_cache() + + exporter = PrometheusExporter(port=9094, use_prometheus_client=False) + exporter.register_function(func1) + exporter.register_function(func2) + + # Generate some metrics + func1(5) + func2(10) + + metrics_text = exporter._generate_text_metrics() + + # Both functions should be in the output + assert "func1" in metrics_text + assert "func2" in metrics_text + + func1.clear_cache() + func2.clear_cache() + + +def test_metrics_exporter_interface(): + """Test PrometheusExporter implements MetricsExporter interface.""" + exporter = PrometheusExporter(port=9095) + + # Check that it has the required methods + assert hasattr(exporter, "register_function") + assert hasattr(exporter, "export_metrics") + assert hasattr(exporter, "start") + assert hasattr(exporter, "stop") + + # Check that it's an instance of the base class + assert isinstance(exporter, MetricsExporter) + + +@pytest.mark.memory +def test_prometheus_exporter_double_instantiation(): + """Test that two PrometheusExporter instances both work independently.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + + exporter1 = PrometheusExporter(port=9097, use_prometheus_client=False) + exporter1.register_function(test_func) + + exporter2 = PrometheusExporter(port=9098, use_prometheus_client=False) + exporter2.register_function(test_func) + + # Both should generate valid metrics + text1 = exporter1._generate_text_metrics() + text2 = exporter2._generate_text_metrics() + + assert "cachier_cache_hits_total" in text1 + assert "cachier_cache_hits_total" in text2 + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_text_metrics_consistency(): + """Test that hits + misses == total_calls in generated text at one point in time.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + exporter = PrometheusExporter(port=9099, use_prometheus_client=False) + exporter.register_function(test_func) + + test_func(5) # miss + test_func(5) # hit + test_func(10) # miss + + # Get stats and text at same time + stats = test_func.metrics.get_stats() + metrics_text = exporter._generate_text_metrics() + + # Verify consistency: parse hits and misses from text + func_name = f"{test_func.__module__}.{test_func.__name__}" + hits_match = re.search( + rf'cachier_cache_hits_total\{{function="{re.escape(func_name)}"\}} (\d+)', + metrics_text, + ) + misses_match = re.search( + rf'cachier_cache_misses_total\{{function="{re.escape(func_name)}"\}} (\d+)', + metrics_text, + ) + + assert hits_match + assert misses_match + hits = int(hits_match.group(1)) + misses = int(misses_match.group(1)) + assert hits + misses == stats.total_calls + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_export_metrics_noop(): + """Test that export_metrics is a no-op (backward-compat method).""" + exporter = PrometheusExporter(port=0, use_prometheus_client=False) + # Should not raise + exporter.export_metrics("some_func", None) + + +@pytest.mark.memory +def test_prometheus_text_metrics_skips_none_metrics(): + """Test that _generate_text_metrics skips functions whose metrics attr is None.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + + exporter = PrometheusExporter(port=0, use_prometheus_client=False) + exporter.register_function(test_func) + + # Inject a fake entry whose metrics resolve to None + class _NoMetrics: + __module__ = "test" + __name__ = "no_metrics" + metrics = None + + def __call__(self, *a, **kw): + pass + + exporter._registered_functions["test.no_metrics"] = _NoMetrics() + + # Should not raise; the None-metrics entry is silently skipped + text = exporter._generate_text_metrics() + assert "cachier_cache_hits_total" in text + assert "no_metrics" not in text + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_start_stop_simple_server(): + """Test starting and stopping the simple HTTP server.""" + exporter = PrometheusExporter(port=0, use_prometheus_client=False) + exporter.start() + assert exporter._server is not None + exporter.stop() + assert exporter._server is None + + +@pytest.mark.memory +def test_prometheus_start_stop_prometheus_server(): + """Test starting and stopping the prometheus_client-backed HTTP server.""" + prometheus_client = pytest.importorskip("prometheus_client") # noqa: F841 + exporter = PrometheusExporter(port=0, use_prometheus_client=True) + assert exporter._registry is not None + exporter.start() + assert exporter._server is not None + exporter.stop() + assert exporter._server is None + + +@pytest.mark.memory +def test_prometheus_collector_collect(): + """Test that the CachierCollector.collect() yields metrics correctly.""" + pytest.importorskip("prometheus_client") + from prometheus_client import generate_latest + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + test_func(5) + + exporter = PrometheusExporter(port=0, use_prometheus_client=True) + exporter.register_function(test_func) + + assert exporter._registry is not None + output = generate_latest(exporter._registry).decode() + assert "cachier_cache_hits_total" in output + assert "cachier_cache_misses_total" in output + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_client_not_available(monkeypatch): + """Test PrometheusExporter falls back gracefully when prometheus_client is patched out.""" + monkeypatch.setattr("cachier.exporters.prometheus.PROMETHEUS_CLIENT_AVAILABLE", False) + monkeypatch.setattr("cachier.exporters.prometheus.prometheus_client", None) + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + + exporter = PrometheusExporter(port=19093, use_prometheus_client=True) + assert exporter._prom_client is None + exporter.register_function(test_func) + text = exporter._generate_text_metrics() + assert "cachier_cache_hits_total" in text + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_prom_client_available_paths(): + """Cover prometheus_client-available code paths via module-level patching. + + Exercises: __init__ branch (L157-160), _setup_collector (L168-169), + _init_prometheus_metrics (L179), CachierCollector.describe (L57), and + CachierCollector.collect() None-metrics skip (L66 False branch). + + """ + from unittest.mock import MagicMock, patch + + from cachier.exporters.prometheus import CachierCollector + + mock_registry = MagicMock() + + with ( + patch("cachier.exporters.prometheus.PROMETHEUS_CLIENT_AVAILABLE", True), + patch("cachier.exporters.prometheus.CollectorRegistry", lambda: mock_registry), + patch("cachier.exporters.prometheus.prometheus_client", MagicMock()), + ): + exporter = PrometheusExporter(port=0, use_prometheus_client=True) + assert exporter._prom_client is not None + assert exporter._registry is mock_registry + + # L57: CachierCollector.describe() -> [] + collector = CachierCollector(exporter) + assert collector.describe() == [] + + # L66 False branch: register a function whose metrics is None + class _NoMetrics: + __module__ = "test" + __name__ = "no_metrics" + metrics = None + + def __call__(self, *a, **kw): + pass + + exporter._registered_functions["test.no_metrics"] = _NoMetrics() + + with ( + patch("cachier.exporters.prometheus.CounterMetricFamily", lambda *a, **kw: MagicMock()), + patch("cachier.exporters.prometheus.GaugeMetricFamily", lambda *a, **kw: MagicMock()), + ): + results = list(collector.collect()) + # Yields 8 families even though snapshots is empty (no non-None metrics) + assert len(results) == 8 + + +def test_prometheus_module_import_with_prom_client(): + """Cover the try-block import lines (L37-40) via module reload with a mocked prometheus_client.""" + import importlib + import sys + from unittest.mock import MagicMock + + import cachier.exporters.prometheus as prom_mod + + mock_prom = MagicMock() + mock_prom_core = MagicMock() + + saved_prom = sys.modules.get("prometheus_client") + saved_core = sys.modules.get("prometheus_client.core") + + sys.modules["prometheus_client"] = mock_prom + sys.modules["prometheus_client.core"] = mock_prom_core + try: + importlib.reload(prom_mod) + assert prom_mod.PROMETHEUS_CLIENT_AVAILABLE is True + assert prom_mod.CollectorRegistry is mock_prom.CollectorRegistry + finally: + if saved_prom is None: + sys.modules.pop("prometheus_client", None) + else: + sys.modules["prometheus_client"] = saved_prom + if saved_core is None: + sys.modules.pop("prometheus_client.core", None) + else: + sys.modules["prometheus_client.core"] = saved_core + importlib.reload(prom_mod) # restore original state + + +@pytest.mark.memory +def test_prometheus_stop_when_not_started(): + """Test that stop() is a no-op when the server was never started.""" + exporter = PrometheusExporter(port=19094, use_prometheus_client=False) + exporter.stop() # Should not raise + + +@pytest.mark.memory +def test_prometheus_simple_server_404(): + """Test that simple HTTP server returns 404 for non-metrics paths.""" + import http.client + + exporter = PrometheusExporter(port=19095, use_prometheus_client=False) + exporter.start() + try: + conn = http.client.HTTPConnection("127.0.0.1", 19095) + conn.request("GET", "/notfound") + response = conn.getresponse() + assert response.status == 404 + conn.close() + finally: + exporter.stop() + + +@pytest.mark.memory +def test_prometheus_prometheus_server_404(): + """Test that prometheus_client-backed server returns 404 for non-metrics paths.""" + import http.client + + pytest.importorskip("prometheus_client") + + exporter = PrometheusExporter(port=19096, use_prometheus_client=True) + exporter.start() + try: + conn = http.client.HTTPConnection("127.0.0.1", 19096) + conn.request("GET", "/notfound") + response = conn.getresponse() + assert response.status == 404 + conn.close() + finally: + exporter.stop() + + +@pytest.mark.memory +def test_prometheus_collector_collect_empty(): + """Test CachierCollector.collect() when no functions have metrics.""" + pytest.importorskip("prometheus_client") + from prometheus_client import generate_latest + + exporter = PrometheusExporter(port=19097, use_prometheus_client=True) + assert exporter._registry is not None + # No functions registered — collect() should run without error and yield metric families + output = generate_latest(exporter._registry).decode() + # Output may be empty or contain only headers; no crash is the key assertion + assert isinstance(output, str) + + +@pytest.mark.memory +def test_prometheus_simple_server_metrics_endpoint(): + """Test that simple HTTP server returns metrics on /metrics.""" + import urllib.request + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + + exporter = PrometheusExporter(port=19098, use_prometheus_client=False) + exporter.register_function(test_func) + exporter.start() + try: + response = urllib.request.urlopen("http://127.0.0.1:19098/metrics", timeout=5) + body = response.read().decode() + assert "cachier_cache_hits_total" in body + finally: + exporter.stop() + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_prometheus_server_metrics_endpoint(): + """Test that prometheus_client-backed server returns metrics on /metrics.""" + import urllib.request + + pytest.importorskip("prometheus_client") + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + + exporter = PrometheusExporter(port=19099, use_prometheus_client=True) + exporter.register_function(test_func) + exporter.start() + try: + response = urllib.request.urlopen("http://127.0.0.1:19099/metrics") + body = response.read().decode() + assert "cachier_cache_hits_total" in body + finally: + exporter.stop() + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_collector_collect_mocked(): + """Test CachierCollector.collect() loop using mocked metric family types. + + Covers lines 81-99 without requiring prometheus_client to be installed. + + """ + from unittest.mock import MagicMock, patch + + from cachier.exporters.prometheus import CachierCollector + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + test_func(5) + test_func(5) + + exporter = PrometheusExporter(port=0, use_prometheus_client=False) + exporter.register_function(test_func) + + with ( + patch("cachier.exporters.prometheus.CounterMetricFamily", lambda *a, **kw: MagicMock()), + patch("cachier.exporters.prometheus.GaugeMetricFamily", lambda *a, **kw: MagicMock()), + ): + collector = CachierCollector(exporter) + results = list(collector.collect()) + # 5 counter families + 3 gauge families + assert len(results) == 8 + + test_func.clear_cache() + + +@pytest.mark.memory +def test_prometheus_start_prometheus_server_mocked(): + """Test _start_prometheus_server and its MetricsHandler without prometheus_client. + + Covers lines 285-329 (start() prom branch, MetricsHandler.do_GET, log_message). + + """ + import sys + import urllib.request + from http.client import HTTPConnection + from unittest.mock import MagicMock, patch + + mock_exposition = MagicMock() + mock_exposition.generate_latest.return_value = b"# mocked metrics" + mock_exposition.CONTENT_TYPE_LATEST = "text/plain" + + prom_mock = MagicMock() + prom_mock.exposition = mock_exposition + + exporter = PrometheusExporter(port=0, use_prometheus_client=False) + # Manually inject prometheus state to trigger _start_prometheus_server path + exporter._prom_client = prom_mock + exporter._registry = MagicMock() + + with patch.dict(sys.modules, {"prometheus_client": prom_mock, "prometheus_client.exposition": mock_exposition}): + exporter.start() + actual_port = exporter._server.server_address[1] + assert exporter._server is not None + try: + response = urllib.request.urlopen(f"http://127.0.0.1:{actual_port}/metrics", timeout=5) + assert b"# mocked metrics" in response.read() + + conn = HTTPConnection("127.0.0.1", actual_port) + conn.request("GET", "/notfound") + resp = conn.getresponse() + assert resp.status == 404 + conn.close() + finally: + exporter.stop() + assert exporter._server is None + + +@pytest.mark.memory +def test_prometheus_collector_collect_skips_none_metrics(): + """Test CachierCollector.collect() skips functions where metrics is None.""" + pytest.importorskip("prometheus_client") + from prometheus_client import generate_latest + + exporter = PrometheusExporter(port=19200, use_prometheus_client=True) + + class _NoMetrics: + __module__ = "test" + __name__ = "no_metrics" + metrics = None + + def __call__(self, *a, **kw): + pass + + exporter._registered_functions["test.no_metrics"] = _NoMetrics() + + assert exporter._registry is not None + output = generate_latest(exporter._registry).decode() + assert isinstance(output, str) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 00000000..0ceb4226 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,597 @@ +"""Tests for cache metrics and observability framework.""" + +import asyncio +import time +from datetime import timedelta +from threading import Thread + +import pytest + +from cachier import cachier +from cachier.metrics import CacheMetrics, MetricsContext, MetricSnapshot + + +@pytest.mark.memory +def test_metrics_enabled(): + """Test that metrics can be enabled for a cached function.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + # Check metrics object is attached + assert hasattr(test_func, "metrics") + assert isinstance(test_func.metrics, CacheMetrics) + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_disabled_by_default(): + """Test that metrics are disabled by default.""" + + @cachier(backend="memory") + def test_func(x): + return x * 2 + + # Metrics should be None when disabled + assert test_func.metrics is None + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_hit_miss_tracking(): + """Test that cache hits and misses are correctly tracked.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # First call should be a miss + result1 = test_func(5) + assert result1 == 10 + + stats = test_func.metrics.get_stats() + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.total_calls == 1 + assert stats.hit_rate == 0.0 + + # Second call should be a hit + result2 = test_func(5) + assert result2 == 10 + + stats = test_func.metrics.get_stats() + assert stats.hits == 1 + assert stats.misses == 1 + assert stats.total_calls == 2 + assert stats.hit_rate == 50.0 + + # Third call with different arg should be a miss + result3 = test_func(10) + assert result3 == 20 + + stats = test_func.metrics.get_stats() + assert stats.hits == 1 + assert stats.misses == 2 + assert stats.total_calls == 3 + assert stats.hit_rate == pytest.approx(33.33, rel=0.1) + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_stale_hit_tracking(): + """Test that stale cache hits are tracked.""" + + @cachier( + backend="memory", + enable_metrics=True, + stale_after=timedelta(milliseconds=100), + next_time=False, + ) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # First call + result1 = test_func(5) + assert result1 == 10 + + # Second call while fresh + result2 = test_func(5) + assert result2 == 10 + + # Wait for cache to become stale + time.sleep(0.15) + + # Third call when stale - should trigger recalculation + result3 = test_func(5) + assert result3 == 10 + + stats = test_func.metrics.get_stats() + assert stats.stale_hits >= 1 + assert stats.recalculations >= 2 # Initial + stale recalculation + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_latency_tracking(): + """Test that operation latencies are tracked.""" + + @cachier(backend="memory", enable_metrics=True) + def slow_func(x): + time.sleep(0.05) # 50ms + return x * 2 + + slow_func.clear_cache() + + # First call (miss with computation) + slow_func(5) + + stats = slow_func.metrics.get_stats() + # Should have some latency recorded + assert stats.avg_latency_ms > 0 + + # Second call (hit, should be faster) + slow_func(5) + + stats = slow_func.metrics.get_stats() + # Average should still be positive + assert stats.avg_latency_ms > 0 + + slow_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_recalculation_tracking(): + """Test that recalculations are tracked.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # First call + test_func(5) + stats = test_func.metrics.get_stats() + assert stats.recalculations == 1 + + # Cached call + test_func(5) + stats = test_func.metrics.get_stats() + assert stats.recalculations == 1 # No change + + # Force recalculation + test_func(5, cachier__overwrite_cache=True) + stats = test_func.metrics.get_stats() + assert stats.recalculations == 2 + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_sampling_rate(): + """Test that sampling rate reduces metrics overhead.""" + + # Full sampling + @cachier(backend="memory", enable_metrics=True, metrics_sampling_rate=1.0) + def func_full_sampling(x): + return x * 2 + + # Partial sampling + @cachier(backend="memory", enable_metrics=True, metrics_sampling_rate=0.5) + def func_partial_sampling(x): + return x * 2 + + func_full_sampling.clear_cache() + func_partial_sampling.clear_cache() + + # Call many times + for i in range(100): + func_full_sampling(i % 10) + func_partial_sampling(i % 10) + + stats_full = func_full_sampling.metrics.get_stats() + stats_partial = func_partial_sampling.metrics.get_stats() + + # Full sampling should have all calls tracked + assert stats_full.total_calls >= 90 # Allow some variance + + # Partial sampling should have roughly half + assert stats_partial.total_calls < stats_full.total_calls + + func_full_sampling.clear_cache() + func_partial_sampling.clear_cache() + + +@pytest.mark.memory +def test_metrics_thread_safety(): + """Test that metrics collection is thread-safe.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + time.sleep(0.001) # Small delay + return x * 2 + + test_func.clear_cache() + + def worker(): + for i in range(10): + test_func(i % 5) + + # Run multiple threads + threads = [Thread(target=worker) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + stats = test_func.metrics.get_stats() + # Should have tracked calls from all threads + assert stats.total_calls > 0 + assert stats.hits + stats.misses == stats.total_calls + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_reset(): + """Test that metrics can be reset.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # Generate some metrics + test_func(5) + test_func(5) + + stats_before = test_func.metrics.get_stats() + assert stats_before.total_calls > 0 + + # Reset metrics + test_func.metrics.reset() + + stats_after = test_func.metrics.get_stats() + assert stats_after.total_calls == 0 + assert stats_after.hits == 0 + assert stats_after.misses == 0 + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_get_stats_snapshot(): + """Test that get_stats returns a proper snapshot.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + test_func(5) + test_func(5) + + stats = test_func.metrics.get_stats() + + # Check all expected fields are present + assert isinstance(stats, MetricSnapshot) + assert hasattr(stats, "hits") + assert hasattr(stats, "misses") + assert hasattr(stats, "hit_rate") + assert hasattr(stats, "total_calls") + assert hasattr(stats, "avg_latency_ms") + assert hasattr(stats, "stale_hits") + assert hasattr(stats, "recalculations") + assert hasattr(stats, "wait_timeouts") + assert hasattr(stats, "entry_count") + assert hasattr(stats, "total_size_bytes") + assert hasattr(stats, "size_limit_rejections") + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_with_different_backends(): + """Test that metrics work with different cache backends.""" + + @cachier(backend="memory", enable_metrics=True) + def memory_func(x): + return x * 2 + + @cachier(backend="pickle", enable_metrics=True) + def pickle_func(x): + return x * 3 + + memory_func.clear_cache() + pickle_func.clear_cache() + + # Test both functions + memory_func(5) + memory_func(5) + + pickle_func(5) + pickle_func(5) + + memory_stats = memory_func.metrics.get_stats() + pickle_stats = pickle_func.metrics.get_stats() + + # Both should have tracked metrics independently + assert memory_stats.total_calls == 2 + assert pickle_stats.total_calls == 2 + assert memory_stats.hits == 1 + assert pickle_stats.hits == 1 + + memory_func.clear_cache() + pickle_func.clear_cache() + + +def test_cache_metrics_invalid_sampling_rate(): + """Test that invalid sampling rates raise errors.""" + with pytest.raises(ValueError, match="sampling_rate must be between"): + CacheMetrics(sampling_rate=1.5) + + with pytest.raises(ValueError, match="sampling_rate must be between"): + CacheMetrics(sampling_rate=-0.1) + + +@pytest.mark.memory +def test_metrics_size_limit_rejection(): + """Test that size limit rejections are tracked.""" + + @cachier(backend="memory", enable_metrics=True, entry_size_limit="1KB") + def test_func(n): + # Return large data that exceeds 1KB + return "x" * (n * 1000) + + test_func.clear_cache() + + # Call with large data that should be rejected + result = test_func(10) + assert len(result) == 10000 + + stats = test_func.metrics.get_stats() + # Should have recorded a size limit rejection + assert stats.size_limit_rejections >= 1 + + test_func.clear_cache() + + +@pytest.mark.memory +def test_metrics_with_max_age(): + """Test metrics tracking with per-call max_age parameter.""" + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # First call + test_func(5) + + # Second call with negative max_age (force stale) + test_func(5, max_age=timedelta(seconds=-1)) + + stats = test_func.metrics.get_stats() + # Should have at least one stale hit and recalculation + assert stats.stale_hits >= 1 + assert stats.recalculations >= 2 + + test_func.clear_cache() + + +@pytest.mark.memory +@pytest.mark.asyncio +async def test_metrics_async_hit_miss(): + """Test that metrics are correctly tracked for async cached functions.""" + + @cachier(backend="memory", enable_metrics=True) + async def async_func(x): + await asyncio.sleep(0) + return x * 2 + + await async_func.clear_cache() + + result1 = await async_func(5) + assert result1 == 10 + + stats = async_func.metrics.get_stats() + assert stats.misses == 1 + assert stats.hits == 0 + + result2 = await async_func(5) + assert result2 == 10 + + stats = async_func.metrics.get_stats() + assert stats.hits == 1 + assert stats.misses == 1 + assert stats.total_calls == 2 + assert stats.hit_rate == 50.0 + + await async_func.clear_cache() + + +@pytest.mark.memory +@pytest.mark.asyncio +async def test_metrics_async_stale(): + """Test stale hit tracking for async cached functions.""" + + @cachier( + backend="memory", + enable_metrics=True, + stale_after=timedelta(milliseconds=100), + ) + async def async_func(x): + await asyncio.sleep(0) + return x * 2 + + await async_func.clear_cache() + + await async_func(5) + + await asyncio.sleep(0.15) # Let cache go stale + + await async_func(5) + + stats = async_func.metrics.get_stats() + assert stats.stale_hits >= 1 + assert stats.recalculations >= 2 + + await async_func.clear_cache() + + +def test_metrics_zero_sampling_rate(): + """Test that sampling_rate=0.0 records nothing for all record_* methods.""" + metrics = CacheMetrics(sampling_rate=0.0) + metrics.record_hit() + metrics.record_miss() + metrics.record_stale_hit() + metrics.record_wait_timeout() + metrics.record_size_limit_rejection() + metrics.record_latency(0.1) + stats = metrics.get_stats() + assert stats.total_calls == 0 + assert stats.stale_hits == 0 + assert stats.wait_timeouts == 0 + assert stats.size_limit_rejections == 0 + assert stats.avg_latency_ms == 0.0 + + +def test_metrics_get_stats_zero_window(): + """Test get_stats with zero-second window behaves like no window. + + timedelta(seconds=0) is falsy in Python, so the implementation treats it the same as None (all-time statistics), + including all recorded data. + + """ + metrics = CacheMetrics() + metrics.record_latency(0.05) + stats = metrics.get_stats(window=timedelta(seconds=0)) + # timedelta(0) is falsy, so cutoff falls back to 0 (all data included) + assert stats.avg_latency_ms == pytest.approx(50.0, rel=0.1) + + +def test_metrics_empty_window_sizes(): + """Test CacheMetrics with empty window_sizes list.""" + metrics = CacheMetrics(window_sizes=[]) + metrics.record_hit() + stats = metrics.get_stats() + assert stats.hits == 1 + + +def test_metrics_wait_timeout_direct(): + """Test record_wait_timeout directly.""" + metrics = CacheMetrics() + metrics.record_wait_timeout() + stats = metrics.get_stats() + assert stats.wait_timeouts == 1 + + +def test_should_sample_deterministic(): + """Test _should_sample returns True/False deterministically via mocking.""" + from unittest.mock import patch + + metrics = CacheMetrics(sampling_rate=0.5) + with patch.object(metrics._random, "random", return_value=0.1): + assert metrics._should_sample() is True + with patch.object(metrics._random, "random", return_value=0.9): + assert metrics._should_sample() is False + + +def test_metrics_context_manager(): + """Test MetricsContext records latency when used as a context manager.""" + metrics = CacheMetrics() + with MetricsContext(metrics): + time.sleep(0.01) + stats = metrics.get_stats() + assert stats.avg_latency_ms > 0 + + +def test_metrics_context_manager_none(): + """Test MetricsContext with metrics=None does not raise.""" + with MetricsContext(None): + pass # should not raise + + +def test_metrics_context_record_wait_timeout(): + """Test MetricsContext.record_wait_timeout records when metrics is set.""" + metrics = CacheMetrics() + ctx = MetricsContext(metrics) + ctx.record_wait_timeout() + assert metrics.get_stats().wait_timeouts == 1 + + +def test_metrics_context_record_size_limit_rejection(): + """Test MetricsContext.record_size_limit_rejection for both truthy and None metrics.""" + metrics = CacheMetrics() + ctx = MetricsContext(metrics) + ctx.record_size_limit_rejection() + assert metrics.get_stats().size_limit_rejections == 1 + + ctx_none = MetricsContext(None) + ctx_none.record_size_limit_rejection() # should be a no-op + + +@pytest.mark.memory +def test_metrics_entry_count_and_size_memory(): + """Test that entry_count and total_size_bytes reflect cache state for memory backend. + + _MemoryCore overrides _get_entry_count and _get_total_size; both should return real values after entries are + written. + + """ + + @cachier(backend="memory", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + # No entries yet + stats = test_func.metrics.get_stats() + assert stats.entry_count == 0 + assert stats.total_size_bytes == 0 + + # Cache two distinct entries + test_func(1) + test_func(2) + + stats = test_func.metrics.get_stats() + assert stats.entry_count == 2 + assert stats.total_size_bytes > 0 + + test_func.clear_cache() + + +@pytest.mark.pickle +def test_metrics_entry_count_and_size_base_default(): + """Test that entry_count and total_size_bytes are 0 for backends without override. + + The base-class _get_entry_count and _get_total_size return 0. Pickle does not override them, so the snapshot values + must stay at the default. + + """ + + @cachier(backend="pickle", enable_metrics=True) + def test_func(x): + return x * 2 + + test_func.clear_cache() + + test_func(1) + test_func(2) + + stats = test_func.metrics.get_stats() + assert stats.entry_count == 0 + assert stats.total_size_bytes == 0 + + test_func.clear_cache()