Skip to content

Commit 584ceb3

Browse files
committed
Lazy scipy imports
For now, keep scipy optional, unless probability distributions are really needed. Fixes #483.
1 parent f32149d commit 584ceb3

1 file changed

Lines changed: 105 additions & 42 deletions

File tree

petab/v1/distributions.py

Lines changed: 105 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,6 @@
1515
from typing import Any
1616

1717
import numpy as np
18-
from scipy.stats import (
19-
cauchy,
20-
chi2,
21-
expon,
22-
gamma,
23-
laplace,
24-
norm,
25-
rayleigh,
26-
uniform,
27-
)
2818

2919
__all__ = [
3020
"Distribution",
@@ -342,6 +332,14 @@ def __init__(
342332
trunc: tuple[float, float] | None = None,
343333
log: bool | float = False,
344334
):
335+
try:
336+
from scipy.stats import norm
337+
except ImportError as e:
338+
raise ImportError(
339+
"scipy is required for this functionality. "
340+
"Install it with: pip install scipy"
341+
) from e
342+
self._dist = norm
345343
self._loc = loc
346344
self._scale = scale
347345
super().__init__(log=log, trunc=trunc)
@@ -353,13 +351,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
353351
return np.random.normal(loc=self._loc, scale=self._scale, size=shape)
354352

355353
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
356-
return norm.pdf(x, loc=self._loc, scale=self._scale)
354+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
357355

358356
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
359-
return norm.cdf(x, loc=self._loc, scale=self._scale)
357+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
360358

361359
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
362-
return norm.ppf(q, loc=self._loc, scale=self._scale)
360+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
363361

364362
@property
365363
def loc(self) -> float:
@@ -396,6 +394,14 @@ def __init__(
396394
*,
397395
log: bool | float = False,
398396
):
397+
try:
398+
from scipy.stats import uniform
399+
except ImportError as e:
400+
raise ImportError(
401+
"scipy is required for this functionality. "
402+
"Install it with: pip install scipy"
403+
) from e
404+
self._dist = uniform
399405
self._low = low
400406
self._high = high
401407
super().__init__(log=log)
@@ -407,13 +413,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
407413
return np.random.uniform(low=self._low, high=self._high, size=shape)
408414

409415
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
410-
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
416+
return self._dist.pdf(x, loc=self._low, scale=self._high - self._low)
411417

412418
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
413-
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
419+
return self._dist.cdf(x, loc=self._low, scale=self._high - self._low)
414420

415421
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
416-
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
422+
return self._dist.ppf(q, loc=self._low, scale=self._high - self._low)
417423

418424

419425
class LogUniform(Distribution):
@@ -434,6 +440,14 @@ def __init__(
434440
high: float,
435441
trunc: tuple[float, float] | None = None,
436442
):
443+
try:
444+
from scipy.stats import uniform
445+
except ImportError as e:
446+
raise ImportError(
447+
"scipy is required for this functionality. "
448+
"Install it with: pip install scipy"
449+
) from e
450+
self._dist = uniform
437451
self._logbase = np.exp(1)
438452
self._low = self._log(low)
439453
self._high = self._log(high)
@@ -446,13 +460,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
446460
return np.random.uniform(low=self._low, high=self._high, size=shape)
447461

448462
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
449-
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
463+
return self._dist.pdf(x, loc=self._low, scale=self._high - self._low)
450464

451465
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
452-
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
466+
return self._dist.cdf(x, loc=self._low, scale=self._high - self._low)
453467

454468
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
455-
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
469+
return self._dist.ppf(q, loc=self._low, scale=self._high - self._low)
456470

457471

458472
class Laplace(Distribution):
@@ -479,6 +493,14 @@ def __init__(
479493
trunc: tuple[float, float] | None = None,
480494
log: bool | float = False,
481495
):
496+
try:
497+
from scipy.stats import laplace
498+
except ImportError as e:
499+
raise ImportError(
500+
"scipy is required for this functionality. "
501+
"Install it with: pip install scipy"
502+
) from e
503+
self._dist = laplace
482504
self._loc = loc
483505
self._scale = scale
484506
super().__init__(log=log, trunc=trunc)
@@ -490,13 +512,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
490512
return np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
491513

492514
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
493-
return laplace.pdf(x, loc=self._loc, scale=self._scale)
515+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
494516

495517
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
496-
return laplace.cdf(x, loc=self._loc, scale=self._scale)
518+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
497519

498520
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
499-
return laplace.ppf(q, loc=self._loc, scale=self._scale)
521+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
500522

501523
@property
502524
def loc(self) -> float:
@@ -536,6 +558,14 @@ def __init__(
536558
trunc: tuple[float, float] | None = None,
537559
log: bool | float = False,
538560
):
561+
try:
562+
from scipy.stats import cauchy
563+
except ImportError as e:
564+
raise ImportError(
565+
"scipy is required for this functionality. "
566+
"Install it with: pip install scipy"
567+
) from e
568+
self._dist = cauchy
539569
self._loc = loc
540570
self._scale = scale
541571
super().__init__(log=log, trunc=trunc)
@@ -544,16 +574,16 @@ def __repr__(self):
544574
return self._repr({"loc": self._loc, "scale": self._scale})
545575

546576
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
547-
return cauchy.pdf(x, loc=self._loc, scale=self._scale)
577+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
548578

549579
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
550-
return cauchy.cdf(x, loc=self._loc, scale=self._scale)
580+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
551581

552582
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
553-
return cauchy.ppf(q, loc=self._loc, scale=self._scale)
583+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
554584

555585
def _sample(self, shape=None) -> np.ndarray | float:
556-
return cauchy.rvs(loc=self._loc, scale=self._scale, size=shape)
586+
return self._dist.rvs(loc=self._loc, scale=self._scale, size=shape)
557587

558588
@property
559589
def loc(self) -> float:
@@ -592,6 +622,15 @@ def __init__(
592622
trunc: tuple[float, float] | None = None,
593623
log: bool | float = False,
594624
):
625+
try:
626+
from scipy.stats import chi2
627+
except ImportError as e:
628+
raise ImportError(
629+
"scipy is required for this functionality. "
630+
"Install it with: pip install scipy"
631+
) from e
632+
self._dist = chi2
633+
595634
if isinstance(dof, float):
596635
if not dof.is_integer() or dof < 1:
597636
raise ValueError(
@@ -606,16 +645,16 @@ def __repr__(self):
606645
return self._repr({"dof": self._dof})
607646

608647
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
609-
return chi2.pdf(x, df=self._dof)
648+
return self._dist.pdf(x, df=self._dof)
610649

611650
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
612-
return chi2.cdf(x, df=self._dof)
651+
return self._dist.cdf(x, df=self._dof)
613652

614653
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
615-
return chi2.ppf(q, df=self._dof)
654+
return self._dist.ppf(q, df=self._dof)
616655

617656
def _sample(self, shape=None) -> np.ndarray | float:
618-
return chi2.rvs(df=self._dof, size=shape)
657+
return self._dist.rvs(df=self._dof, size=shape)
619658

620659
@property
621660
def dof(self) -> int:
@@ -639,23 +678,31 @@ def __init__(
639678
scale: float,
640679
trunc: tuple[float, float] | None = None,
641680
):
681+
try:
682+
from scipy.stats import expon
683+
except ImportError as e:
684+
raise ImportError(
685+
"scipy is required for this functionality. "
686+
"Install it with: pip install scipy"
687+
) from e
688+
self._dist = expon
642689
self._scale = scale
643690
super().__init__(log=False, trunc=trunc)
644691

645692
def __repr__(self):
646693
return self._repr({"scale": self._scale})
647694

648695
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
649-
return expon.pdf(x, scale=self._scale)
696+
return self._dist.pdf(x, scale=self._scale)
650697

651698
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
652-
return expon.cdf(x, scale=self._scale)
699+
return self._dist.cdf(x, scale=self._scale)
653700

654701
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
655-
return expon.ppf(q, scale=self._scale)
702+
return self._dist.ppf(q, scale=self._scale)
656703

657704
def _sample(self, shape=None) -> np.ndarray | float:
658-
return expon.rvs(scale=self._scale, size=shape)
705+
return self._dist.rvs(scale=self._scale, size=shape)
659706

660707
@property
661708
def scale(self) -> float:
@@ -689,6 +736,14 @@ def __init__(
689736
trunc: tuple[float, float] | None = None,
690737
log: bool | float = False,
691738
):
739+
try:
740+
from scipy.stats import gamma
741+
except ImportError as e:
742+
raise ImportError(
743+
"scipy is required for this functionality. "
744+
"Install it with: pip install scipy"
745+
) from e
746+
self._dist = gamma
692747
self._shape = shape
693748
self._scale = scale
694749
super().__init__(log=log, trunc=trunc)
@@ -697,16 +752,16 @@ def __repr__(self):
697752
return self._repr({"shape": self._shape, "scale": self._scale})
698753

699754
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
700-
return gamma.pdf(x, a=self._shape, scale=self._scale)
755+
return self._dist.pdf(x, a=self._shape, scale=self._scale)
701756

702757
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
703-
return gamma.cdf(x, a=self._shape, scale=self._scale)
758+
return self._dist.cdf(x, a=self._shape, scale=self._scale)
704759

705760
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
706-
return gamma.ppf(q, a=self._shape, scale=self._scale)
761+
return self._dist.ppf(q, a=self._shape, scale=self._scale)
707762

708763
def _sample(self, shape=None) -> np.ndarray | float:
709-
return gamma.rvs(a=self._shape, scale=self._scale, size=shape)
764+
return self._dist.rvs(a=self._shape, scale=self._scale, size=shape)
710765

711766
@property
712767
def shape(self) -> float:
@@ -743,23 +798,31 @@ def __init__(
743798
trunc: tuple[float, float] | None = None,
744799
log: bool | float = False,
745800
):
801+
try:
802+
from scipy.stats import rayleigh
803+
except ImportError as e:
804+
raise ImportError(
805+
"scipy is required for this functionality. "
806+
"Install it with: pip install scipy"
807+
) from e
808+
self._dist = rayleigh
746809
self._scale = scale
747810
super().__init__(log=log, trunc=trunc)
748811

749812
def __repr__(self):
750813
return self._repr({"scale": self._scale})
751814

752815
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
753-
return rayleigh.pdf(x, scale=self._scale)
816+
return self._dist.pdf(x, scale=self._scale)
754817

755818
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
756-
return rayleigh.cdf(x, scale=self._scale)
819+
return self._dist.cdf(x, scale=self._scale)
757820

758821
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
759-
return rayleigh.ppf(q, scale=self._scale)
822+
return self._dist.ppf(q, scale=self._scale)
760823

761824
def _sample(self, shape=None) -> np.ndarray | float:
762-
return rayleigh.rvs(scale=self._scale, size=shape)
825+
return self._dist.rvs(scale=self._scale, size=shape)
763826

764827
@property
765828
def scale(self) -> float:

0 commit comments

Comments
 (0)