77from torch import nn
88
99
10- class StaticLightningModule (pl .LightningModule ):
10+ class RegressionLightningModule (pl .LightningModule ):
1111 def __init__ (self , model : nn .Module , learning_rate : float ) -> None :
1212 """Initialize the LightningModule."""
1313 super ().__init__ ()
1414 self .model = model
1515 self .learning_rate = learning_rate
16- self .loss_function = self .cosine_distance
17-
18- def cosine_distance (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
19- """Returns the cosine distance loss function."""
20- x = torch .nn .functional .normalize (x , dim = 1 )
21- y = torch .nn .functional .normalize (y , dim = 1 )
22- return (1 - torch .sum (x * y , dim = 1 )).mean ()
16+ self .loss_function = nn .MSELoss ()
2317
2418 def forward (self , x : torch .Tensor ) -> torch .Tensor :
2519 """Simple forward pass."""
@@ -57,7 +51,24 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
5751 return {"optimizer" : optimizer , "lr_scheduler" : {"scheduler" : scheduler , "monitor" : "val_loss" }}
5852
5953
60- class ClassifierLightningModule (StaticLightningModule ):
54+ class SimilarityLightningModule (RegressionLightningModule ):
55+ def __init__ (self , model : nn .Module , learning_rate : float ) -> None :
56+ """Initialize the LightningModule."""
57+ super ().__init__ (model , learning_rate )
58+ self .model = model
59+ self .learning_rate = learning_rate
60+ self .loss_function = CosineLoss ()
61+
62+
63+ class CosineLoss (nn .Module ):
64+ def __call__ (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
65+ """Returns the cosine distance loss function."""
66+ x = torch .nn .functional .normalize (x , dim = 1 )
67+ y = torch .nn .functional .normalize (y , dim = 1 )
68+ return (1 - torch .sum (x * y , dim = 1 )).mean ()
69+
70+
71+ class ClassifierLightningModule (RegressionLightningModule ):
6172 def __init__ (self , model : nn .Module , learning_rate : float , class_weight : torch .Tensor | None = None ) -> None :
6273 """Initialize the LightningModule."""
6374 super ().__init__ (model , learning_rate )
@@ -77,7 +88,7 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i
7788 return loss
7889
7990
80- class MultiLabelClassifierLightningModule (StaticLightningModule ):
91+ class MultiLabelClassifierLightningModule (RegressionLightningModule ):
8192 def __init__ (self , model : nn .Module , learning_rate : float , class_weight : torch .Tensor | None = None ) -> None :
8293 """Initialize the LightningModule."""
8394 super ().__init__ (model , learning_rate )
0 commit comments