Skip to content

Commit add2daf

Browse files
authored
feat: add regression (#333)
1 parent c901407 commit add2daf

9 files changed

Lines changed: 84 additions & 23 deletions

File tree

model2vec/train/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
importable(extra_dependency, _REQUIRED_EXTRA)
99

1010
from model2vec.train.classifier import StaticModelForClassification
11+
from model2vec.train.regression import StaticModelForRegression
1112
from model2vec.train.similarity import StaticModelForSimilarity
1213
from model2vec.train.utils import TipFilter
1314

14-
__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"]
15+
__all__ = ["StaticModelForClassification", "StaticModelForSimilarity", "StaticModelForRegression"]
1516

1617

1718
logging.getLogger("lightning.pytorch.utilities.rank_zero").addFilter(TipFilter())

model2vec/train/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from model2vec.inference import evaluate_single_or_multi_label
1515
from model2vec.train.base import BaseFinetuneable
1616
from model2vec.train.lightning_modules import ClassifierLightningModule, MultiLabelClassifierLightningModule
17-
from model2vec.train.utils import _DEFAULT_RANDOM_SEED
17+
from model2vec.train.utils import DEFAULT_RANDOM_SEED
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -123,7 +123,7 @@ def fit(
123123
y_val: LabelType | None = None,
124124
class_weight: Literal["balanced"] | dict[str, float] | torch.Tensor | None = None,
125125
validation_steps: int | None = None,
126-
random_seed: int = _DEFAULT_RANDOM_SEED,
126+
random_seed: int = DEFAULT_RANDOM_SEED,
127127
) -> StaticModelForClassification:
128128
"""Fit a model.
129129

model2vec/train/lightning_modules.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,13 @@
77
from torch import nn
88

99

10-
class StaticLightningModule(pl.LightningModule):
10+
class RegressionLightningModule(pl.LightningModule):
1111
def __init__(self, model: nn.Module, learning_rate: float) -> None:
1212
"""Initialize the LightningModule."""
1313
super().__init__()
1414
self.model = model
1515
self.learning_rate = learning_rate
16-
self.loss_function = self.cosine_distance
17-
18-
def cosine_distance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
19-
"""Returns the cosine distance loss function."""
20-
x = torch.nn.functional.normalize(x, dim=1)
21-
y = torch.nn.functional.normalize(y, dim=1)
22-
return (1 - torch.sum(x * y, dim=1)).mean()
16+
self.loss_function = nn.MSELoss()
2317

2418
def forward(self, x: torch.Tensor) -> torch.Tensor:
2519
"""Simple forward pass."""
@@ -57,7 +51,24 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
5751
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
5852

5953

60-
class ClassifierLightningModule(StaticLightningModule):
54+
class SimilarityLightningModule(RegressionLightningModule):
55+
def __init__(self, model: nn.Module, learning_rate: float) -> None:
56+
"""Initialize the LightningModule."""
57+
super().__init__(model, learning_rate)
58+
self.model = model
59+
self.learning_rate = learning_rate
60+
self.loss_function = CosineLoss()
61+
62+
63+
class CosineLoss(nn.Module):
64+
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
65+
"""Returns the cosine distance loss function."""
66+
x = torch.nn.functional.normalize(x, dim=1)
67+
y = torch.nn.functional.normalize(y, dim=1)
68+
return (1 - torch.sum(x * y, dim=1)).mean()
69+
70+
71+
class ClassifierLightningModule(RegressionLightningModule):
6172
def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
6273
"""Initialize the LightningModule."""
6374
super().__init__(model, learning_rate)
@@ -77,7 +88,7 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
7788
return loss
7889

7990

80-
class MultiLabelClassifierLightningModule(StaticLightningModule):
91+
class MultiLabelClassifierLightningModule(RegressionLightningModule):
8192
def __init__(self, model: nn.Module, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
8293
"""Initialize the LightningModule."""
8394
super().__init__(model, learning_rate)

model2vec/train/regression.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from model2vec.train.lightning_modules import RegressionLightningModule
6+
from model2vec.train.similarity import StaticModelForSimilarity
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class StaticModelForRegression(StaticModelForSimilarity):
12+
_lightning_class = RegressionLightningModule

model2vec/train/similarity.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from __future__ import annotations
22

33
import logging
4+
from typing import TypeVar
45

56
import lightning as pl
67
import torch
78
from tokenizers import Tokenizer
89

910
from model2vec.train.base import BaseFinetuneable
10-
from model2vec.train.lightning_modules import StaticLightningModule
11-
from model2vec.train.utils import _DEFAULT_RANDOM_SEED
11+
from model2vec.train.lightning_modules import SimilarityLightningModule
12+
from model2vec.train.utils import DEFAULT_RANDOM_SEED
1213

1314
logger = logging.getLogger(__name__)
1415

1516

1617
class StaticModelForSimilarity(BaseFinetuneable):
1718
val_metric = "val_loss"
1819
early_stopping_direction = "min"
20+
_lightning_class = SimilarityLightningModule
1921

2022
def __init__(
2123
self,
@@ -48,7 +50,7 @@ def __init__(
4850
)
4951

5052
def fit(
51-
self,
53+
self: _T,
5254
X: list[str],
5355
y: torch.Tensor,
5456
learning_rate: float = 1e-3,
@@ -61,8 +63,8 @@ def fit(
6163
X_val: list[str] | None = None,
6264
y_val: torch.Tensor | None = None,
6365
validation_steps: int | None = None,
64-
random_seed: int = _DEFAULT_RANDOM_SEED,
65-
) -> StaticModelForSimilarity:
66+
random_seed: int = DEFAULT_RANDOM_SEED,
67+
) -> _T:
6668
"""Fit a model.
6769
6870
This function creates a Lightning Trainer object and fits the model to the data.
@@ -100,7 +102,7 @@ def fit(
100102
self.out_dim = train_dataset.targets.shape[1]
101103
self._initialize()
102104

103-
c = StaticLightningModule(self, learning_rate=learning_rate)
105+
c = self._lightning_class(self, learning_rate=learning_rate)
104106

105107
self._train(
106108
module=c,
@@ -115,3 +117,6 @@ def fit(
115117
)
116118

117119
return self
120+
121+
122+
_T = TypeVar("_T", bound=StaticModelForSimilarity)

model2vec/train/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26-
_DEFAULT_RANDOM_SEED = 42
26+
DEFAULT_RANDOM_SEED = 42
2727
_KNOWN_PAD_TOKENS = ("[PAD]", "<pad>")
2828

2929

tests/conftest.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from model2vec.inference import StaticModelPipeline
1717
from model2vec.model import StaticModel
18-
from model2vec.train import StaticModelForClassification, StaticModelForSimilarity
18+
from model2vec.train import StaticModelForClassification, StaticModelForRegression, StaticModelForSimilarity
1919

2020
_TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"]
2121

@@ -199,6 +199,21 @@ def mock_trained_similarity_pipeline() -> StaticModelForSimilarity:
199199
return model
200200

201201

202+
@pytest.fixture(scope="session")
203+
def mock_trained_regression_pipeline() -> StaticModelForRegression:
204+
"""Mock StaticModelForSimilarity."""
205+
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
206+
torch.random.manual_seed(42)
207+
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
208+
model = StaticModelForRegression(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
209+
210+
X = ["dog", "cat"]
211+
y = torch.randn(2, 32)
212+
model.fit(X, y)
213+
214+
return model
215+
216+
202217
@pytest.fixture(scope="session")
203218
def mock_inference_pipeline_projector(
204219
mock_trained_similarity_pipeline: StaticModelForSimilarity,

tests/test_trainable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from model2vec.train import StaticModelForClassification
1313
from model2vec.train.base import BaseFinetuneable
1414
from model2vec.train.dataset import TextDataset
15+
from model2vec.train.regression import StaticModelForRegression
1516
from model2vec.train.similarity import StaticModelForSimilarity
1617
from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split
1718

@@ -195,6 +196,22 @@ def test_convert_to_pipeline_similarity(mock_trained_similarity_pipeline: Static
195196
assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4)
196197

197198

199+
def test_convert_to_pipeline_regression(mock_trained_similarity_pipeline: StaticModelForRegression) -> None:
200+
"""Convert a model to a pipeline."""
201+
mock_trained_similarity_pipeline.eval()
202+
pipeline = mock_trained_similarity_pipeline.to_pipeline()
203+
encoded_pipeline = pipeline.model.encode(["dog cat", "dog"])
204+
encoded_model = (
205+
mock_trained_similarity_pipeline(mock_trained_similarity_pipeline.tokenize(["dog cat", "dog"]))[1]
206+
.detach()
207+
.numpy()
208+
)
209+
assert np.allclose(encoded_pipeline, encoded_model)
210+
p1 = pipeline.predict(["dog cat", "dog"])
211+
p2 = mock_trained_similarity_pipeline.encode(["dog cat", "dog"])
212+
assert np.allclose(p1, p2, rtol=1e-5, atol=1e-4)
213+
214+
198215
def test_train_test_split() -> None:
199216
"""Test the train test split function."""
200217
a, b, c, d = train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5)

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)