Skip to content

Commit 7ce0bdd

Browse files
committed
Lazy scipy imports
For now, keep scipy optional, unless probability distributions are really needed. Fixes #483.
1 parent f32149d commit 7ce0bdd

1 file changed

Lines changed: 82 additions & 41 deletions

File tree

petab/v1/distributions.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,10 @@
1515
from typing import Any
1616

1717
import numpy as np
18-
from scipy.stats import (
19-
cauchy,
20-
chi2,
21-
expon,
22-
gamma,
23-
laplace,
24-
norm,
25-
rayleigh,
26-
uniform,
18+
19+
_SCIPY_IMPORT_ERROR = (
20+
"scipy is required for this functionality. "
21+
"Install it with: pip install scipy"
2722
)
2823

2924
__all__ = [
@@ -342,6 +337,11 @@ def __init__(
342337
trunc: tuple[float, float] | None = None,
343338
log: bool | float = False,
344339
):
340+
try:
341+
from scipy.stats import norm
342+
except ImportError as e:
343+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
344+
self._dist = norm
345345
self._loc = loc
346346
self._scale = scale
347347
super().__init__(log=log, trunc=trunc)
@@ -353,13 +353,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
353353
return np.random.normal(loc=self._loc, scale=self._scale, size=shape)
354354

355355
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
356-
return norm.pdf(x, loc=self._loc, scale=self._scale)
356+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
357357

358358
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
359-
return norm.cdf(x, loc=self._loc, scale=self._scale)
359+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
360360

361361
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
362-
return norm.ppf(q, loc=self._loc, scale=self._scale)
362+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
363363

364364
@property
365365
def loc(self) -> float:
@@ -396,6 +396,11 @@ def __init__(
396396
*,
397397
log: bool | float = False,
398398
):
399+
try:
400+
from scipy.stats import uniform
401+
except ImportError as e:
402+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
403+
self._dist = uniform
399404
self._low = low
400405
self._high = high
401406
super().__init__(log=log)
@@ -407,13 +412,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
407412
return np.random.uniform(low=self._low, high=self._high, size=shape)
408413

409414
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
410-
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
415+
return self._dist.pdf(x, loc=self._low, scale=self._high - self._low)
411416

412417
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
413-
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
418+
return self._dist.cdf(x, loc=self._low, scale=self._high - self._low)
414419

415420
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
416-
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
421+
return self._dist.ppf(q, loc=self._low, scale=self._high - self._low)
417422

418423

419424
class LogUniform(Distribution):
@@ -434,6 +439,11 @@ def __init__(
434439
high: float,
435440
trunc: tuple[float, float] | None = None,
436441
):
442+
try:
443+
from scipy.stats import uniform
444+
except ImportError as e:
445+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
446+
self._dist = uniform
437447
self._logbase = np.exp(1)
438448
self._low = self._log(low)
439449
self._high = self._log(high)
@@ -446,13 +456,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
446456
return np.random.uniform(low=self._low, high=self._high, size=shape)
447457

448458
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
449-
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
459+
return self._dist.pdf(x, loc=self._low, scale=self._high - self._low)
450460

451461
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
452-
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
462+
return self._dist.cdf(x, loc=self._low, scale=self._high - self._low)
453463

454464
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
455-
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
465+
return self._dist.ppf(q, loc=self._low, scale=self._high - self._low)
456466

457467

458468
class Laplace(Distribution):
@@ -479,6 +489,11 @@ def __init__(
479489
trunc: tuple[float, float] | None = None,
480490
log: bool | float = False,
481491
):
492+
try:
493+
from scipy.stats import laplace
494+
except ImportError as e:
495+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
496+
self._dist = laplace
482497
self._loc = loc
483498
self._scale = scale
484499
super().__init__(log=log, trunc=trunc)
@@ -490,13 +505,13 @@ def _sample(self, shape=None) -> np.ndarray | float:
490505
return np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
491506

492507
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
493-
return laplace.pdf(x, loc=self._loc, scale=self._scale)
508+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
494509

495510
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
496-
return laplace.cdf(x, loc=self._loc, scale=self._scale)
511+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
497512

498513
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
499-
return laplace.ppf(q, loc=self._loc, scale=self._scale)
514+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
500515

501516
@property
502517
def loc(self) -> float:
@@ -536,6 +551,11 @@ def __init__(
536551
trunc: tuple[float, float] | None = None,
537552
log: bool | float = False,
538553
):
554+
try:
555+
from scipy.stats import cauchy
556+
except ImportError as e:
557+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
558+
self._dist = cauchy
539559
self._loc = loc
540560
self._scale = scale
541561
super().__init__(log=log, trunc=trunc)
@@ -544,16 +564,16 @@ def __repr__(self):
544564
return self._repr({"loc": self._loc, "scale": self._scale})
545565

546566
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
547-
return cauchy.pdf(x, loc=self._loc, scale=self._scale)
567+
return self._dist.pdf(x, loc=self._loc, scale=self._scale)
548568

549569
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
550-
return cauchy.cdf(x, loc=self._loc, scale=self._scale)
570+
return self._dist.cdf(x, loc=self._loc, scale=self._scale)
551571

552572
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
553-
return cauchy.ppf(q, loc=self._loc, scale=self._scale)
573+
return self._dist.ppf(q, loc=self._loc, scale=self._scale)
554574

555575
def _sample(self, shape=None) -> np.ndarray | float:
556-
return cauchy.rvs(loc=self._loc, scale=self._scale, size=shape)
576+
return self._dist.rvs(loc=self._loc, scale=self._scale, size=shape)
557577

558578
@property
559579
def loc(self) -> float:
@@ -592,6 +612,12 @@ def __init__(
592612
trunc: tuple[float, float] | None = None,
593613
log: bool | float = False,
594614
):
615+
try:
616+
from scipy.stats import chi2
617+
except ImportError as e:
618+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
619+
self._dist = chi2
620+
595621
if isinstance(dof, float):
596622
if not dof.is_integer() or dof < 1:
597623
raise ValueError(
@@ -606,16 +632,16 @@ def __repr__(self):
606632
return self._repr({"dof": self._dof})
607633

608634
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
609-
return chi2.pdf(x, df=self._dof)
635+
return self._dist.pdf(x, df=self._dof)
610636

611637
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
612-
return chi2.cdf(x, df=self._dof)
638+
return self._dist.cdf(x, df=self._dof)
613639

614640
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
615-
return chi2.ppf(q, df=self._dof)
641+
return self._dist.ppf(q, df=self._dof)
616642

617643
def _sample(self, shape=None) -> np.ndarray | float:
618-
return chi2.rvs(df=self._dof, size=shape)
644+
return self._dist.rvs(df=self._dof, size=shape)
619645

620646
@property
621647
def dof(self) -> int:
@@ -639,23 +665,28 @@ def __init__(
639665
scale: float,
640666
trunc: tuple[float, float] | None = None,
641667
):
668+
try:
669+
from scipy.stats import expon
670+
except ImportError as e:
671+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
672+
self._dist = expon
642673
self._scale = scale
643674
super().__init__(log=False, trunc=trunc)
644675

645676
def __repr__(self):
646677
return self._repr({"scale": self._scale})
647678

648679
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
649-
return expon.pdf(x, scale=self._scale)
680+
return self._dist.pdf(x, scale=self._scale)
650681

651682
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
652-
return expon.cdf(x, scale=self._scale)
683+
return self._dist.cdf(x, scale=self._scale)
653684

654685
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
655-
return expon.ppf(q, scale=self._scale)
686+
return self._dist.ppf(q, scale=self._scale)
656687

657688
def _sample(self, shape=None) -> np.ndarray | float:
658-
return expon.rvs(scale=self._scale, size=shape)
689+
return self._dist.rvs(scale=self._scale, size=shape)
659690

660691
@property
661692
def scale(self) -> float:
@@ -689,6 +720,11 @@ def __init__(
689720
trunc: tuple[float, float] | None = None,
690721
log: bool | float = False,
691722
):
723+
try:
724+
from scipy.stats import gamma
725+
except ImportError as e:
726+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
727+
self._dist = gamma
692728
self._shape = shape
693729
self._scale = scale
694730
super().__init__(log=log, trunc=trunc)
@@ -697,16 +733,16 @@ def __repr__(self):
697733
return self._repr({"shape": self._shape, "scale": self._scale})
698734

699735
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
700-
return gamma.pdf(x, a=self._shape, scale=self._scale)
736+
return self._dist.pdf(x, a=self._shape, scale=self._scale)
701737

702738
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
703-
return gamma.cdf(x, a=self._shape, scale=self._scale)
739+
return self._dist.cdf(x, a=self._shape, scale=self._scale)
704740

705741
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
706-
return gamma.ppf(q, a=self._shape, scale=self._scale)
742+
return self._dist.ppf(q, a=self._shape, scale=self._scale)
707743

708744
def _sample(self, shape=None) -> np.ndarray | float:
709-
return gamma.rvs(a=self._shape, scale=self._scale, size=shape)
745+
return self._dist.rvs(a=self._shape, scale=self._scale, size=shape)
710746

711747
@property
712748
def shape(self) -> float:
@@ -743,23 +779,28 @@ def __init__(
743779
trunc: tuple[float, float] | None = None,
744780
log: bool | float = False,
745781
):
782+
try:
783+
from scipy.stats import rayleigh
784+
except ImportError as e:
785+
raise ImportError(_SCIPY_IMPORT_ERROR) from e
786+
self._dist = rayleigh
746787
self._scale = scale
747788
super().__init__(log=log, trunc=trunc)
748789

749790
def __repr__(self):
750791
return self._repr({"scale": self._scale})
751792

752793
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
753-
return rayleigh.pdf(x, scale=self._scale)
794+
return self._dist.pdf(x, scale=self._scale)
754795

755796
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
756-
return rayleigh.cdf(x, scale=self._scale)
797+
return self._dist.cdf(x, scale=self._scale)
757798

758799
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
759-
return rayleigh.ppf(q, scale=self._scale)
800+
return self._dist.ppf(q, scale=self._scale)
760801

761802
def _sample(self, shape=None) -> np.ndarray | float:
762-
return rayleigh.rvs(scale=self._scale, size=shape)
803+
return self._dist.rvs(scale=self._scale, size=shape)
763804

764805
@property
765806
def scale(self) -> float:

0 commit comments

Comments
 (0)