Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ jobs:
# Wait for server to be ready
echo "Waiting for server to start..."
sleep 15

# Test server health
curl -s http://localhost:8000/health || echo "Server health check failed"
env:
OPTILLM_API_KEY: optillm
# Bound generation for the small test model, which does not reliably emit
# an EOS token and would otherwise ramble up to the 4096 default per call.
OPTILLM_MAX_TOKENS: "128"
HF_TOKEN: ${{ secrets.HF_TOKEN }}

- name: Run integration tests (server required)
Expand Down Expand Up @@ -217,6 +220,9 @@ jobs:
curl -s http://localhost:8000/health || echo "Server health check failed"
env:
OPTILLM_API_KEY: optillm
# Bound generation for the small test model, which does not reliably emit
# an EOS token and would otherwise ramble up to the 4096 default per call.
OPTILLM_MAX_TOKENS: "128"
HF_TOKEN: ${{ secrets.HF_TOKEN }}

- name: Run conversation logging tests
Expand Down
2 changes: 1 addition & 1 deletion optillm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Version information
__version__ = "0.3.19"
__version__ = "0.3.20"

import os as _os

Expand Down
69 changes: 60 additions & 9 deletions optillm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass
from dataclasses import dataclass, field
from collections import OrderedDict, defaultdict
import torch.nn.functional as F
import torch.nn as nn
Expand Down Expand Up @@ -89,6 +89,27 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int:
MLX_AVAILABLE = False
logger.debug("MLX framework not available - falling back to PyTorch")


# Hard ceiling of 4096 by default. Can be lowered via OPTILLM_MAX_TOKENS so a
# single local generation is bounded even when the request (or an approach's
# internal calls) sends no max_tokens -- important for small local models that
# do not reliably emit an EOS token (e.g. the dhara test model), which would
# otherwise ramble up to the full default on every call.
DEFAULT_MAX_NEW_TOKENS = 4096


def _default_max_new_tokens() -> int:
"""Default ``max_new_tokens`` for local generation (env-overridable)."""
raw = os.environ.get("OPTILLM_MAX_TOKENS")
if raw is None:
return DEFAULT_MAX_NEW_TOKENS
try:
return max(1, int(raw))
except (TypeError, ValueError):
logger.warning("Ignoring invalid OPTILLM_MAX_TOKENS=%r; using %d", raw, DEFAULT_MAX_NEW_TOKENS)
return DEFAULT_MAX_NEW_TOKENS


@dataclass
class ModelConfig:
base_model_id: str
Expand All @@ -98,7 +119,7 @@ class ModelConfig:
quantization_bits: int = 4
device_preference: Optional[str] = None
# Default generation parameters
max_new_tokens: int = 4096
max_new_tokens: int = field(default_factory=_default_max_new_tokens)
do_sample: bool = True
top_p: float = 0.9
top_k: int = 50
Expand Down Expand Up @@ -292,7 +313,7 @@ def suggest_mlx_alternative(model_id: str) -> str:
class MLXModelConfig:
"""Configuration for MLX models"""
model_id: str
max_new_tokens: int = 4096
max_new_tokens: int = field(default_factory=_default_max_new_tokens)
temperature: float = 0.7
top_p: float = 0.9
repetition_penalty: float = 1.0
Expand Down Expand Up @@ -1268,16 +1289,46 @@ def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:

return tokenizer

def _resolve_eos_token_ids(self):
"""Resolve the effective end-of-sequence token id(s) for generation.

Prefer the model's own ``generation_config.eos_token_id``. Chat models
commonly set it to the chat-turn end token (e.g. ``<|im_end|>``), which
can differ from the tokenizer's ``eos_token_id`` (often the base-model
``<|end_of_text|>``). Passing only the tokenizer eos to ``generate`` there
means the model never stops on its real turn-end token and rambles up to
``max_new_tokens`` -- e.g. dhara-250m's ChatML ends at ``<|im_end|>`` but
its tokenizer eos is ``<|end_of_text|>``.

The tokenizer eos is merged in as a fallback so a model that only emits
the base eos still terminates. Returns an int, a list of ints, or None.
"""
ids: List[int] = []
gen_cfg = getattr(self.current_model, "generation_config", None)
gc_eos = getattr(gen_cfg, "eos_token_id", None) if gen_cfg is not None else None
if isinstance(gc_eos, int):
ids.append(gc_eos)
elif isinstance(gc_eos, (list, tuple)):
ids.extend(int(x) for x in gc_eos if isinstance(x, int))
tok_eos = self.tokenizer.eos_token_id
if isinstance(tok_eos, int):
ids.append(tok_eos)
seen = set()
resolved = [x for x in ids if not (x in seen or seen.add(x))]
if not resolved:
return None
return resolved[0] if len(resolved) == 1 else resolved

def get_optimized_generation_config(self, generation_params: Optional[Dict[str, Any]] = None) -> Dict:
"""Get optimized generation config"""
config = {
"max_new_tokens": generation_params.get("max_new_tokens", 4096),
"max_new_tokens": generation_params.get("max_new_tokens", _default_max_new_tokens()),
"do_sample": generation_params.get("temperature", 1.0) > 0,
"temperature": generation_params.get("temperature", 1.0),
"top_p": generation_params.get("top_p", 0.95),
"num_return_sequences": generation_params.get("num_return_sequences", 1),
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self._resolve_eos_token_ids(),
"return_dict_in_generate": True,
"output_scores": generation_params.get("logprobs", False),
"use_cache": True
Expand Down Expand Up @@ -1571,13 +1622,13 @@ def process_batch(
if batch_prompts: # If there are any uncached prompts
# Configure generation parameters
base_params = {
"max_new_tokens": generation_params.get("max_new_tokens", 4096) if generation_params else self.model_config.max_new_tokens,
"max_new_tokens": generation_params.get("max_new_tokens", _default_max_new_tokens()) if generation_params else self.model_config.max_new_tokens,
"do_sample": generation_params.get("temperature", 1.0) > 0 if generation_params else self.model_config.do_sample,
"temperature": generation_params.get("temperature", 1.0) if generation_params else self.model_config.temperature,
"top_p": generation_params.get("top_p", 1.0) if generation_params else self.model_config.top_p,
"num_return_sequences": n,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self._resolve_eos_token_ids(),
}

# Add optional parameters if specified
Expand Down Expand Up @@ -1900,7 +1951,7 @@ def create(

# Use directly available parameters for entropy decoding
entropy_params = {
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
"max_new_tokens": max_tokens if max_tokens is not None else _default_max_new_tokens(),
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
Expand Down Expand Up @@ -2046,7 +2097,7 @@ def create(
"temperature": temperature,
"top_p": top_p,
"num_return_sequences": n,
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
"max_new_tokens": max_tokens if max_tokens is not None else _default_max_new_tokens(),
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop_sequences": [stop] if isinstance(stop, str) else stop,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "optillm"
version = "0.3.19"
version = "0.3.20"
description = "An optimizing inference proxy for LLMs."
readme = "README.md"
license = "Apache-2.0"
Expand Down
64 changes: 64 additions & 0 deletions tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,70 @@ def run_performance_comparison():
}


class TestGenerationConfigDefaults(unittest.TestCase):
"""Unit tests for env-configurable max_new_tokens and EOS resolution.

These exercise the guards that keep a small local model from rambling up to
the 4096-token default when it does not reliably emit an EOS token (e.g. a
ChatML model whose tokenizer EOS differs from the chat-turn end token). No
model is loaded.
"""

def tearDown(self):
os.environ.pop("OPTILLM_MAX_TOKENS", None)

def test_default_max_new_tokens_env_override(self):
from optillm.inference import _default_max_new_tokens, DEFAULT_MAX_NEW_TOKENS

os.environ.pop("OPTILLM_MAX_TOKENS", None)
self.assertEqual(_default_max_new_tokens(), DEFAULT_MAX_NEW_TOKENS)

os.environ["OPTILLM_MAX_TOKENS"] = "128"
self.assertEqual(_default_max_new_tokens(), 128)

# Invalid value falls back to the default rather than raising.
os.environ["OPTILLM_MAX_TOKENS"] = "not-a-number"
self.assertEqual(_default_max_new_tokens(), DEFAULT_MAX_NEW_TOKENS)

# Non-positive is clamped to a usable minimum.
os.environ["OPTILLM_MAX_TOKENS"] = "0"
self.assertEqual(_default_max_new_tokens(), 1)

def test_resolve_eos_prefers_generation_config(self):
from types import SimpleNamespace
from optillm.inference import InferencePipeline

# tokenizer EOS (<|end_of_text|>=1) differs from the chat end token
# (<|im_end|>=49154); both must be honoured, generation_config first.
fake = SimpleNamespace(
current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=49154)),
tokenizer=SimpleNamespace(eos_token_id=1),
)
eos = InferencePipeline._resolve_eos_token_ids(fake)
self.assertEqual(eos, [49154, 1])

def test_resolve_eos_dedupes_list(self):
from types import SimpleNamespace
from optillm.inference import InferencePipeline

fake = SimpleNamespace(
current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=[100, 200])),
tokenizer=SimpleNamespace(eos_token_id=200),
)
self.assertEqual(InferencePipeline._resolve_eos_token_ids(fake), [100, 200])

def test_resolve_eos_falls_back_to_tokenizer(self):
from types import SimpleNamespace
from optillm.inference import InferencePipeline

fake = SimpleNamespace(
current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=None)),
tokenizer=SimpleNamespace(eos_token_id=7),
)
# A single id is returned as a plain int, not a list.
self.assertEqual(InferencePipeline._resolve_eos_token_ids(fake), 7)


if __name__ == "__main__":
# Run tests
unittest.main(verbosity=2)
Loading