diff --git a/model2vec/train/__init__.py b/model2vec/train/__init__.py index 50361cd..4d265e3 100644 --- a/model2vec/train/__init__.py +++ b/model2vec/train/__init__.py @@ -8,10 +8,11 @@ importable(extra_dependency, _REQUIRED_EXTRA) from model2vec.train.classifier import StaticModelForClassification +from model2vec.train.regression import StaticModelForRegression from model2vec.train.similarity import StaticModelForSimilarity from model2vec.train.utils import TipFilter -__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"] +__all__ = ["StaticModelForClassification", "StaticModelForSimilarity", "StaticModelForRegression"] logging.getLogger("lightning.pytorch.utilities.rank_zero").addFilter(TipFilter()) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 12c74d8..83f0889 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -14,7 +14,7 @@ from model2vec.inference import evaluate_single_or_multi_label from model2vec.train.base import BaseFinetuneable from model2vec.train.lightning_modules import ClassifierLightningModule, MultiLabelClassifierLightningModule -from model2vec.train.utils import _DEFAULT_RANDOM_SEED +from model2vec.train.utils import DEFAULT_RANDOM_SEED logger = logging.getLogger(__name__) @@ -123,7 +123,7 @@ def fit( y_val: LabelType | None = None, class_weight: Literal["balanced"] | dict[str, float] | torch.Tensor | None = None, validation_steps: int | None = None, - random_seed: int = _DEFAULT_RANDOM_SEED, + random_seed: int = DEFAULT_RANDOM_SEED, ) -> StaticModelForClassification: """Fit a model. diff --git a/model2vec/train/lightning_modules.py b/model2vec/train/lightning_modules.py index cd0db93..b627193 100644 --- a/model2vec/train/lightning_modules.py +++ b/model2vec/train/lightning_modules.py @@ -7,19 +7,13 @@ from torch import nn -class StaticLightningModule(pl.LightningModule): +class RegressionLightningModule(pl.LightningModule): def __init__(self, model: nn.Module, learning_rate: float) -> None: """Initialize the LightningModule.""" super().__init__() self.model = model self.learning_rate = learning_rate - self.loss_function = self.cosine_distance - - def cosine_distance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Returns the cosine distance loss function.""" - x = torch.nn.functional.normalize(x, dim=1) - y = torch.nn.functional.normalize(y, dim=1) - return (1 - torch.sum(x * y, dim=1)).mean() + self.loss_function = nn.MSELoss() def forward(self, x: torch.Tensor) -> torch.Tensor: """Simple forward pass.""" @@ -57,7 +51,24 @@ def configure_optimizers(self) -> OptimizerLRScheduler: return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} -class ClassifierLightningModule(StaticLightningModule): +class SimilarityLightningModule(RegressionLightningModule): + def __init__(self, model: nn.Module, learning_rate: float) -> None: + """Initialize the LightningModule.""" + super().__init__(model, learning_rate) + self.model = model + self.learning_rate = learning_rate + self.loss_function = CosineLoss() + + +class CosineLoss(nn.Module): + def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Returns the cosine distance loss function.""" + x = torch.nn.functional.normalize(x, dim=1) + y = torch.nn.functional.normalize(y, dim=1) + return (1 - torch.sum(x * y, dim=1)).mean() + + +class ClassifierLightningModule(RegressionLightningModule): def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None: """Initialize the LightningModule.""" super().__init__(model, learning_rate) @@ -77,7 +88,7 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i return loss -class MultiLabelClassifierLightningModule(StaticLightningModule): +class MultiLabelClassifierLightningModule(RegressionLightningModule): def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None: """Initialize the LightningModule.""" super().__init__(model, learning_rate) diff --git a/model2vec/train/regression.py b/model2vec/train/regression.py new file mode 100644 index 0000000..5eef7eb --- /dev/null +++ b/model2vec/train/regression.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import logging + +from model2vec.train.lightning_modules import RegressionLightningModule +from model2vec.train.similarity import StaticModelForSimilarity + +logger = logging.getLogger(__name__) + + +class StaticModelForRegression(StaticModelForSimilarity): + _lightning_class = RegressionLightningModule diff --git a/model2vec/train/similarity.py b/model2vec/train/similarity.py index d61f55b..87b97d1 100644 --- a/model2vec/train/similarity.py +++ b/model2vec/train/similarity.py @@ -1,14 +1,15 @@ from __future__ import annotations import logging +from typing import TypeVar import lightning as pl import torch from tokenizers import Tokenizer from model2vec.train.base import BaseFinetuneable -from model2vec.train.lightning_modules import StaticLightningModule -from model2vec.train.utils import _DEFAULT_RANDOM_SEED +from model2vec.train.lightning_modules import SimilarityLightningModule +from model2vec.train.utils import DEFAULT_RANDOM_SEED logger = logging.getLogger(__name__) @@ -16,6 +17,7 @@ class StaticModelForSimilarity(BaseFinetuneable): val_metric = "val_loss" early_stopping_direction = "min" + _lightning_class = SimilarityLightningModule def __init__( self, @@ -48,7 +50,7 @@ def __init__( ) def fit( - self, + self: _T, X: list[str], y: torch.Tensor, learning_rate: float = 1e-3, @@ -61,8 +63,8 @@ def fit( X_val: list[str] | None = None, y_val: torch.Tensor | None = None, validation_steps: int | None = None, - random_seed: int = _DEFAULT_RANDOM_SEED, - ) -> StaticModelForSimilarity: + random_seed: int = DEFAULT_RANDOM_SEED, + ) -> _T: """Fit a model. This function creates a Lightning Trainer object and fits the model to the data. @@ -100,7 +102,7 @@ def fit( self.out_dim = train_dataset.targets.shape[1] self._initialize() - c = StaticLightningModule(self, learning_rate=learning_rate) + c = self._lightning_class(self, learning_rate=learning_rate) self._train( module=c, @@ -115,3 +117,6 @@ def fit( ) return self + + +_T = TypeVar("_T", bound=StaticModelForSimilarity) diff --git a/model2vec/train/utils.py b/model2vec/train/utils.py index 91f5a85..450a2fe 100644 --- a/model2vec/train/utils.py +++ b/model2vec/train/utils.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -_DEFAULT_RANDOM_SEED = 42 +DEFAULT_RANDOM_SEED = 42 _KNOWN_PAD_TOKENS = ("[PAD]", "") diff --git a/tests/conftest.py b/tests/conftest.py index de60201..2f45eef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from model2vec.inference import StaticModelPipeline from model2vec.model import StaticModel -from model2vec.train import StaticModelForClassification, StaticModelForSimilarity +from model2vec.train import StaticModelForClassification, StaticModelForRegression, StaticModelForSimilarity _TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"] @@ -199,6 +199,21 @@ def mock_trained_similarity_pipeline() -> StaticModelForSimilarity: return model +@pytest.fixture(scope="session") +def mock_trained_regression_pipeline() -> StaticModelForRegression: + """Mock StaticModelForSimilarity.""" + tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer + torch.random.manual_seed(42) + vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) + model = StaticModelForRegression(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + + X = ["dog", "cat"] + y = torch.randn(2, 32) + model.fit(X, y) + + return model + + @pytest.fixture(scope="session") def mock_inference_pipeline_projector( mock_trained_similarity_pipeline: StaticModelForSimilarity, diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 3541ef8..5e700ea 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -12,6 +12,7 @@ from model2vec.train import StaticModelForClassification from model2vec.train.base import BaseFinetuneable from model2vec.train.dataset import TextDataset +from model2vec.train.regression import StaticModelForRegression from model2vec.train.similarity import StaticModelForSimilarity from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split @@ -195,6 +196,22 @@ def test_convert_to_pipeline_similarity(mock_trained_similarity_pipeline: Static assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4) +def test_convert_to_pipeline_regression(mock_trained_similarity_pipeline: StaticModelForRegression) -> None: + """Convert a model to a pipeline.""" + mock_trained_similarity_pipeline.eval() + pipeline = mock_trained_similarity_pipeline.to_pipeline() + encoded_pipeline = pipeline.model.encode(["dog cat", "dog"]) + encoded_model = ( + mock_trained_similarity_pipeline(mock_trained_similarity_pipeline.tokenize(["dog cat", "dog"]))[1] + .detach() + .numpy() + ) + assert np.allclose(encoded_pipeline, encoded_model) + p1 = pipeline.predict(["dog cat", "dog"]) + p2 = mock_trained_similarity_pipeline.encode(["dog cat", "dog"]) + assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4) + + def test_train_test_split() -> None: """Test the train test split function.""" a, b, c, d = train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) diff --git a/uv.lock b/uv.lock index f5e06e4..000d536 100644 --- a/uv.lock +++ b/uv.lock @@ -9,8 +9,8 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-26T18:35:34.031819Z" -exclude-newer-span = "P1W" +exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer-span = "P3D" [[package]] name = "aiohappyeyeballs"