diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py index 59d72e888..478eba748 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py @@ -1,10 +1,10 @@ -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset -class SampleKGDataset(SampleBaseDataset): +class SampleKGDataset(SampleDataset): """Sample KG dataset class. - This class inherits from `SampleBaseDataset` and is specifically designed + This class inherits from `SampleDataset` and is specifically designed for KG datasets. Args: diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py index 559ecb7c4..58ecff1d5 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py @@ -4,18 +4,18 @@ import numpy as np import torch -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset def split( - dataset: SampleBaseDataset, + dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by its outermost indexed items Args: - dataset: a `SampleBaseDataset` object + dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / test seed: random seed for shuffling the dataset diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py index 8fa2a443a..3b6e85920 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py @@ -1,5 +1,5 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset import torch @@ -13,7 +13,7 @@ class ComplEx(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset: SampleDataset, e_dim: int = 600, r_dim: int = 600, ns: str = "adv", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py index e7563137c..7d0b7b325 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py @@ -1,5 +1,5 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset import torch @@ -12,7 +12,7 @@ class DistMult(KGEBaseModel): """ def __init__( self, - dataset: SampleBaseDataset, + dataset: SampleDataset, e_dim: int = 300, r_dim: int = 300, ns: str = "adv", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py index 2de13afe2..1be30c23d 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py @@ -1,5 +1,5 @@ from abc import ABC -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset import torch import time @@ -32,7 +32,7 @@ def device(self): def __init__( self, - dataset: SampleBaseDataset, + dataset: SampleDataset, e_dim: int = 500, r_dim: int = 500, ns: str = "uniform", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py index df7143a6e..0a7f0bad9 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py @@ -1,5 +1,5 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset import torch @@ -13,7 +13,7 @@ class RotatE(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset: SampleDataset, e_dim: int = 600, r_dim: int = 300, ns='adv', diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py index fbb6e68f6..6cf9a4f89 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py @@ -1,5 +1,5 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset +from pyhealth.datasets import SampleDataset import torch @@ -13,7 +13,7 @@ class TransE(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset: SampleDataset, e_dim: int = 300, r_dim: int = 300, ns: str = "adv",