Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion model2vec/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.train.classifier import StaticModelForClassification
from model2vec.train.regression import StaticModelForRegression
from model2vec.train.similarity import StaticModelForSimilarity
from model2vec.train.utils import TipFilter

__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"]
__all__ = ["StaticModelForClassification", "StaticModelForSimilarity", "StaticModelForRegression"]


logging.getLogger("lightning.pytorch.utilities.rank_zero").addFilter(TipFilter())
4 changes: 2 additions & 2 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from model2vec.inference import evaluate_single_or_multi_label
from model2vec.train.base import BaseFinetuneable
from model2vec.train.lightning_modules import ClassifierLightningModule, MultiLabelClassifierLightningModule
from model2vec.train.utils import _DEFAULT_RANDOM_SEED
from model2vec.train.utils import DEFAULT_RANDOM_SEED

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,7 +123,7 @@ def fit(
y_val: LabelType | None = None,
class_weight: Literal["balanced"] | dict[str, float] | torch.Tensor | None = None,
validation_steps: int | None = None,
random_seed: int = _DEFAULT_RANDOM_SEED,
random_seed: int = DEFAULT_RANDOM_SEED,
) -> StaticModelForClassification:
"""Fit a model.

Expand Down
31 changes: 21 additions & 10 deletions model2vec/train/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
from torch import nn


class StaticLightningModule(pl.LightningModule):
class RegressionLightningModule(pl.LightningModule):
def __init__(self, model: nn.Module, learning_rate: float) -> None:
"""Initialize the LightningModule."""
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.loss_function = self.cosine_distance

def cosine_distance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Returns the cosine distance loss function."""
x = torch.nn.functional.normalize(x, dim=1)
y = torch.nn.functional.normalize(y, dim=1)
return (1 - torch.sum(x * y, dim=1)).mean()
self.loss_function = nn.MSELoss()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Simple forward pass."""
Expand Down Expand Up @@ -57,7 +51,24 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}


class ClassifierLightningModule(StaticLightningModule):
class SimilarityLightningModule(RegressionLightningModule):
def __init__(self, model: nn.Module, learning_rate: float) -> None:
"""Initialize the LightningModule."""
super().__init__(model, learning_rate)
self.model = model
self.learning_rate = learning_rate
self.loss_function = CosineLoss()


class CosineLoss(nn.Module):
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Returns the cosine distance loss function."""
x = torch.nn.functional.normalize(x, dim=1)
y = torch.nn.functional.normalize(y, dim=1)
return (1 - torch.sum(x * y, dim=1)).mean()


class ClassifierLightningModule(RegressionLightningModule):
def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
"""Initialize the LightningModule."""
super().__init__(model, learning_rate)
Expand All @@ -77,7 +88,7 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
return loss


class MultiLabelClassifierLightningModule(StaticLightningModule):
class MultiLabelClassifierLightningModule(RegressionLightningModule):
def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
"""Initialize the LightningModule."""
super().__init__(model, learning_rate)
Expand Down
12 changes: 12 additions & 0 deletions model2vec/train/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import logging

from model2vec.train.lightning_modules import RegressionLightningModule
from model2vec.train.similarity import StaticModelForSimilarity

logger = logging.getLogger(__name__)


class StaticModelForRegression(StaticModelForSimilarity):
_lightning_class = RegressionLightningModule
17 changes: 11 additions & 6 deletions model2vec/train/similarity.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from __future__ import annotations

import logging
from typing import TypeVar

import lightning as pl
import torch
from tokenizers import Tokenizer

from model2vec.train.base import BaseFinetuneable
from model2vec.train.lightning_modules import StaticLightningModule
from model2vec.train.utils import _DEFAULT_RANDOM_SEED
from model2vec.train.lightning_modules import SimilarityLightningModule
from model2vec.train.utils import DEFAULT_RANDOM_SEED

logger = logging.getLogger(__name__)


class StaticModelForSimilarity(BaseFinetuneable):
val_metric = "val_loss"
early_stopping_direction = "min"
_lightning_class = SimilarityLightningModule

def __init__(
self,
Expand Down Expand Up @@ -48,7 +50,7 @@ def __init__(
)

def fit(
self,
self: _T,
X: list[str],
y: torch.Tensor,
learning_rate: float = 1e-3,
Expand All @@ -61,8 +63,8 @@ def fit(
X_val: list[str] | None = None,
y_val: torch.Tensor | None = None,
validation_steps: int | None = None,
random_seed: int = _DEFAULT_RANDOM_SEED,
) -> StaticModelForSimilarity:
random_seed: int = DEFAULT_RANDOM_SEED,
) -> _T:
"""Fit a model.

This function creates a Lightning Trainer object and fits the model to the data.
Expand Down Expand Up @@ -100,7 +102,7 @@ def fit(
self.out_dim = train_dataset.targets.shape[1]
self._initialize()

c = StaticLightningModule(self, learning_rate=learning_rate)
c = self._lightning_class(self, learning_rate=learning_rate)

self._train(
module=c,
Expand All @@ -115,3 +117,6 @@ def fit(
)

return self


_T = TypeVar("_T", bound=StaticModelForSimilarity)
2 changes: 1 addition & 1 deletion model2vec/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

logger = logging.getLogger(__name__)

_DEFAULT_RANDOM_SEED = 42
DEFAULT_RANDOM_SEED = 42
_KNOWN_PAD_TOKENS = ("[PAD]", "<pad>")


Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from model2vec.inference import StaticModelPipeline
from model2vec.model import StaticModel
from model2vec.train import StaticModelForClassification, StaticModelForSimilarity
from model2vec.train import StaticModelForClassification, StaticModelForRegression, StaticModelForSimilarity

_TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"]

Expand Down Expand Up @@ -199,6 +199,21 @@ def mock_trained_similarity_pipeline() -> StaticModelForSimilarity:
return model


@pytest.fixture(scope="session")
def mock_trained_regression_pipeline() -> StaticModelForRegression:
"""Mock StaticModelForSimilarity."""
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
torch.random.manual_seed(42)
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
model = StaticModelForRegression(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")

X = ["dog", "cat"]
y = torch.randn(2, 32)
model.fit(X, y)

return model


@pytest.fixture(scope="session")
def mock_inference_pipeline_projector(
mock_trained_similarity_pipeline: StaticModelForSimilarity,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from model2vec.train import StaticModelForClassification
from model2vec.train.base import BaseFinetuneable
from model2vec.train.dataset import TextDataset
from model2vec.train.regression import StaticModelForRegression
from model2vec.train.similarity import StaticModelForSimilarity
from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split

Expand Down Expand Up @@ -195,6 +196,22 @@ def test_convert_to_pipeline_similarity(mock_trained_similarity_pipeline: Static
assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4)


def test_convert_to_pipeline_regression(mock_trained_similarity_pipeline: StaticModelForRegression) -> None:
"""Convert a model to a pipeline."""
mock_trained_similarity_pipeline.eval()
pipeline = mock_trained_similarity_pipeline.to_pipeline()
encoded_pipeline = pipeline.model.encode(["dog cat", "dog"])
encoded_model = (
mock_trained_similarity_pipeline(mock_trained_similarity_pipeline.tokenize(["dog cat", "dog"]))[1]
.detach()
.numpy()
)
assert np.allclose(encoded_pipeline, encoded_model)
p1 = pipeline.predict(["dog cat", "dog"])
p2 = mock_trained_similarity_pipeline.encode(["dog cat", "dog"])
assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4)


def test_train_test_split() -> None:
"""Test the train test split function."""
a, b, c, d = train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5)
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading