From 6bda747ecc9fdc3857ad8965c2febbb44d0db6bc Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Wed, 3 Jun 2026 16:35:52 -0700 Subject: [PATCH 1/2] moves auto water to an intervention --- schema/aind_behavior_dynamic_foraging.json | 24 ++--- .../AindBehaviorDynamicForaging.Generated.cs | 100 +++++++++--------- .../interventions/auto_water_intervention.py | 71 +++++++++++++ .../interventions/base_intervention.py | 19 ++++ .../interventions/bias_intervention.py | 12 ++- .../block_based_trial_generator.py | 68 ++++-------- .../test_bias_intervention.py | 42 ++++---- 7 files changed, 198 insertions(+), 138 deletions(-) create mode 100644 src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py create mode 100644 src/aind_behavior_dynamic_foraging/task_logic/interventions/base_intervention.py diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 99564c52..1169d41f 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -326,7 +326,7 @@ "title": "AuditorySecondaryReinforcer", "type": "object" }, - "AutoWaterParameters": { + "AutoWaterInterventionParameters": { "properties": { "min_ignored_trials": { "default": 3, @@ -351,7 +351,7 @@ "type": "number" } }, - "title": "AutoWaterParameters", + "title": "AutoWaterInterventionParameters", "type": "object" }, "Axis": { @@ -505,7 +505,7 @@ }, "description": "Distribution describing block length." }, - "autowater_parameters": { + "auto_water_intervention_parameters": { "default": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, @@ -514,7 +514,7 @@ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { - "$ref": "#/$defs/AutoWaterParameters" + "$ref": "#/$defs/AutoWaterInterventionParameters" }, { "type": "null" @@ -890,7 +890,7 @@ }, "description": "Distribution describing block length." }, - "autowater_parameters": { + "auto_water_intervention_parameters": { "default": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, @@ -899,7 +899,7 @@ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { - "$ref": "#/$defs/AutoWaterParameters" + "$ref": "#/$defs/AutoWaterInterventionParameters" }, { "type": "null" @@ -1177,7 +1177,7 @@ }, "description": "Distribution describing block length." }, - "autowater_parameters": { + "auto_water_intervention_parameters": { "default": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, @@ -1186,7 +1186,7 @@ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { - "$ref": "#/$defs/AutoWaterParameters" + "$ref": "#/$defs/AutoWaterInterventionParameters" }, { "type": "null" @@ -1387,7 +1387,7 @@ }, "description": "Distribution describing block length." }, - "autowater_parameters": { + "auto_water_intervention_parameters": { "default": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, @@ -1396,7 +1396,7 @@ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { - "$ref": "#/$defs/AutoWaterParameters" + "$ref": "#/$defs/AutoWaterInterventionParameters" }, { "type": "null" @@ -3748,7 +3748,7 @@ }, "description": "Distribution describing block length." }, - "autowater_parameters": { + "auto_water_intervention_parameters": { "default": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, @@ -3757,7 +3757,7 @@ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { - "$ref": "#/$defs/AutoWaterParameters" + "$ref": "#/$defs/AutoWaterInterventionParameters" }, { "type": "null" diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index 7564a934..82d08f31 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -772,7 +772,7 @@ public override string ToString() [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] - public partial class AutoWaterParameters + public partial class AutoWaterInterventionParameters { private int _minIgnoredTrials; @@ -781,14 +781,14 @@ public partial class AutoWaterParameters private double _rewardFraction; - public AutoWaterParameters() + public AutoWaterInterventionParameters() { _minIgnoredTrials = 3; _minUnrewardedTrials = 3; _rewardFraction = 0.8D; } - protected AutoWaterParameters(AutoWaterParameters other) + protected AutoWaterInterventionParameters(AutoWaterInterventionParameters other) { _minIgnoredTrials = other._minIgnoredTrials; _minUnrewardedTrials = other._minUnrewardedTrials; @@ -846,14 +846,14 @@ public double RewardFraction } } - public System.IObservable Generate() + public System.IObservable Generate() { - return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AutoWaterParameters(this))); + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AutoWaterInterventionParameters(this))); } - public System.IObservable Generate(System.IObservable source) + public System.IObservable Generate(System.IObservable source) { - return System.Reactive.Linq.Observable.Select(source, _ => new AutoWaterParameters(this)); + return System.Reactive.Linq.Observable.Select(source, _ => new AutoWaterInterventionParameters(this)); } protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) @@ -895,7 +895,7 @@ public partial class BaseCoupledTrialGeneratorSpec : TrialGeneratorSpec private AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution _blockLength; - private AutoWaterParameters _autowaterParameters; + private AutoWaterInterventionParameters _autoWaterInterventionParameters; private BiasInterventionParameters _biasInterventionParameters; @@ -910,7 +910,7 @@ public BaseCoupledTrialGeneratorSpec() _rewardConsumptionDuration = 3D; _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); - _autowaterParameters = new AutoWaterParameters(); + _autoWaterInterventionParameters = new AutoWaterInterventionParameters(); _biasInterventionParameters = new BiasInterventionParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); @@ -924,7 +924,7 @@ protected BaseCoupledTrialGeneratorSpec(BaseCoupledTrialGeneratorSpec other) : _rewardConsumptionDuration = other._rewardConsumptionDuration; _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; - _autowaterParameters = other._autowaterParameters; + _autoWaterInterventionParameters = other._autoWaterInterventionParameters; _biasInterventionParameters = other._biasInterventionParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; @@ -1023,18 +1023,18 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [Newtonsoft.Json.JsonPropertyAttribute("auto_water_intervention_parameters")] [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + "gnored or unrewarded trial thresholds.")] - public AutoWaterParameters AutowaterParameters + public AutoWaterInterventionParameters AutoWaterInterventionParameters { get { - return _autowaterParameters; + return _autoWaterInterventionParameters; } set { - _autowaterParameters = value; + _autoWaterInterventionParameters = value; } } @@ -1113,7 +1113,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("RewardConsumptionDuration = " + _rewardConsumptionDuration + ", "); stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); - stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AutoWaterInterventionParameters = " + _autoWaterInterventionParameters + ", "); stringBuilder.Append("BiasInterventionParameters = " + _biasInterventionParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters); @@ -1543,7 +1543,7 @@ public partial class BlockBasedTrialGeneratorSpec : TrialGeneratorSpec private AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution _blockLength; - private AutoWaterParameters _autowaterParameters; + private AutoWaterInterventionParameters _autoWaterInterventionParameters; private BiasInterventionParameters _biasInterventionParameters; @@ -1556,7 +1556,7 @@ public BlockBasedTrialGeneratorSpec() _rewardConsumptionDuration = 3D; _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); - _autowaterParameters = new AutoWaterParameters(); + _autoWaterInterventionParameters = new AutoWaterInterventionParameters(); _biasInterventionParameters = new BiasInterventionParameters(); _isBaiting = false; } @@ -1569,7 +1569,7 @@ protected BlockBasedTrialGeneratorSpec(BlockBasedTrialGeneratorSpec other) : _rewardConsumptionDuration = other._rewardConsumptionDuration; _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; - _autowaterParameters = other._autowaterParameters; + _autoWaterInterventionParameters = other._autoWaterInterventionParameters; _biasInterventionParameters = other._biasInterventionParameters; _isBaiting = other._isBaiting; } @@ -1667,18 +1667,18 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [Newtonsoft.Json.JsonPropertyAttribute("auto_water_intervention_parameters")] [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + "gnored or unrewarded trial thresholds.")] - public AutoWaterParameters AutowaterParameters + public AutoWaterInterventionParameters AutoWaterInterventionParameters { get { - return _autowaterParameters; + return _autoWaterInterventionParameters; } set { - _autowaterParameters = value; + _autoWaterInterventionParameters = value; } } @@ -1739,7 +1739,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("RewardConsumptionDuration = " + _rewardConsumptionDuration + ", "); stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); - stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AutoWaterInterventionParameters = " + _autoWaterInterventionParameters + ", "); stringBuilder.Append("BiasInterventionParameters = " + _biasInterventionParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting); return true; @@ -2266,7 +2266,7 @@ public partial class CoupledTrialGeneratorSpec : TrialGeneratorSpec private AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution _blockLength; - private AutoWaterParameters _autowaterParameters; + private AutoWaterInterventionParameters _autoWaterInterventionParameters; private BiasInterventionParameters _biasInterventionParameters; @@ -2291,7 +2291,7 @@ public CoupledTrialGeneratorSpec() _rewardConsumptionDuration = 3D; _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); - _autowaterParameters = new AutoWaterParameters(); + _autoWaterInterventionParameters = new AutoWaterInterventionParameters(); _biasInterventionParameters = new BiasInterventionParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); @@ -2310,7 +2310,7 @@ protected CoupledTrialGeneratorSpec(CoupledTrialGeneratorSpec other) : _rewardConsumptionDuration = other._rewardConsumptionDuration; _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; - _autowaterParameters = other._autowaterParameters; + _autoWaterInterventionParameters = other._autoWaterInterventionParameters; _biasInterventionParameters = other._biasInterventionParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; @@ -2414,18 +2414,18 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [Newtonsoft.Json.JsonPropertyAttribute("auto_water_intervention_parameters")] [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + "gnored or unrewarded trial thresholds.")] - public AutoWaterParameters AutowaterParameters + public AutoWaterInterventionParameters AutoWaterInterventionParameters { get { - return _autowaterParameters; + return _autoWaterInterventionParameters; } set { - _autowaterParameters = value; + _autoWaterInterventionParameters = value; } } @@ -2589,7 +2589,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("RewardConsumptionDuration = " + _rewardConsumptionDuration + ", "); stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); - stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AutoWaterInterventionParameters = " + _autoWaterInterventionParameters + ", "); stringBuilder.Append("BiasInterventionParameters = " + _biasInterventionParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); @@ -2751,7 +2751,7 @@ public partial class CoupledWarmupTrialGeneratorSpec : TrialGeneratorSpec private AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution _blockLength; - private AutoWaterParameters _autowaterParameters; + private AutoWaterInterventionParameters _autoWaterInterventionParameters; private BiasInterventionParameters _biasInterventionParameters; @@ -2768,7 +2768,7 @@ public CoupledWarmupTrialGeneratorSpec() _rewardConsumptionDuration = 3D; _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); - _autowaterParameters = new AutoWaterParameters(); + _autoWaterInterventionParameters = new AutoWaterInterventionParameters(); _biasInterventionParameters = new BiasInterventionParameters(); _isBaiting = true; _rewardProbabilityParameters = new RewardProbabilityParameters(); @@ -2783,7 +2783,7 @@ protected CoupledWarmupTrialGeneratorSpec(CoupledWarmupTrialGeneratorSpec other) _rewardConsumptionDuration = other._rewardConsumptionDuration; _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; - _autowaterParameters = other._autowaterParameters; + _autoWaterInterventionParameters = other._autoWaterInterventionParameters; _biasInterventionParameters = other._biasInterventionParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; @@ -2883,18 +2883,18 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [Newtonsoft.Json.JsonPropertyAttribute("auto_water_intervention_parameters")] [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + "gnored or unrewarded trial thresholds.")] - public AutoWaterParameters AutowaterParameters + public AutoWaterInterventionParameters AutoWaterInterventionParameters { get { - return _autowaterParameters; + return _autoWaterInterventionParameters; } set { - _autowaterParameters = value; + _autoWaterInterventionParameters = value; } } @@ -2991,7 +2991,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("RewardConsumptionDuration = " + _rewardConsumptionDuration + ", "); stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); - stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AutoWaterInterventionParameters = " + _autoWaterInterventionParameters + ", "); stringBuilder.Append("BiasInterventionParameters = " + _biasInterventionParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); @@ -6885,7 +6885,7 @@ public partial class UncoupledTrialGeneratorSpec : TrialGeneratorSpec private AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistribution _blockLength; - private AutoWaterParameters _autowaterParameters; + private AutoWaterInterventionParameters _autoWaterInterventionParameters; private BiasInterventionParameters _biasInterventionParameters; @@ -6904,7 +6904,7 @@ public UncoupledTrialGeneratorSpec() _rewardConsumptionDuration = 3D; _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistribution(); - _autowaterParameters = new AutoWaterParameters(); + _autoWaterInterventionParameters = new AutoWaterInterventionParameters(); _biasInterventionParameters = new BiasInterventionParameters(); _isBaiting = false; _trialGenerationEndParameters = new UncoupledTrialGenerationEndConditions(); @@ -6920,7 +6920,7 @@ protected UncoupledTrialGeneratorSpec(UncoupledTrialGeneratorSpec other) : _rewardConsumptionDuration = other._rewardConsumptionDuration; _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; - _autowaterParameters = other._autowaterParameters; + _autoWaterInterventionParameters = other._autoWaterInterventionParameters; _biasInterventionParameters = other._biasInterventionParameters; _isBaiting = other._isBaiting; _trialGenerationEndParameters = other._trialGenerationEndParameters; @@ -7021,18 +7021,18 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistributio /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] - [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] + [Newtonsoft.Json.JsonPropertyAttribute("auto_water_intervention_parameters")] [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + "gnored or unrewarded trial thresholds.")] - public AutoWaterParameters AutowaterParameters + public AutoWaterInterventionParameters AutoWaterInterventionParameters { get { - return _autowaterParameters; + return _autoWaterInterventionParameters; } set { - _autowaterParameters = value; + _autoWaterInterventionParameters = value; } } @@ -7146,7 +7146,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("RewardConsumptionDuration = " + _rewardConsumptionDuration + ", "); stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); - stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AutoWaterInterventionParameters = " + _autoWaterInterventionParameters + ", "); stringBuilder.Append("BiasInterventionParameters = " + _biasInterventionParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); @@ -8293,9 +8293,9 @@ public System.IObservable Process(System.IObservable(source); } - public System.IObservable Process(System.IObservable source) + public System.IObservable Process(System.IObservable source) { - return Process(source); + return Process(source); } public System.IObservable Process(System.IObservable source) @@ -8536,7 +8536,7 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] - [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py b/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py new file mode 100644 index 00000000..0da7e77e --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py @@ -0,0 +1,71 @@ +import logging +from typing import Optional + +from pydantic import BaseModel, Field + +from aind_behavior_dynamic_foraging.task_logic.interventions.base_intervention import BaseIntervention + +logger = logging.getLogger(__name__) + + +class AutoWaterInterventionParameters(BaseModel): + min_ignored_trials: int = Field( + default=3, ge=0, description="Minimum consecutive ignored trials before auto water is triggered." + ) + min_unrewarded_trials: int = Field( + default=3, ge=0, description="Minimum consecutive unrewarded trials before auto water is triggered." + ) + reward_fraction: float = Field( + default=0.8, + ge=0, + le=1, + description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", + ) # TODO: Not implemented yet + + +class AutoWaterIntervention(BaseIntervention): + """Manages auto water interventions during a task.""" + + def __init__( + self, + auto_water_intervention_parameters: Optional[AutoWaterInterventionParameters] = None, + ): + + self.parameters = auto_water_intervention_parameters + + def are_intervention_conditions_met( + self, is_right_choice_history: list[bool | None], reward_history: list[bool] + ) -> bool: + """Checks whether autowater should be given. + + Returns: + True if autowater conditions are met, False otherwise. + """ + + if self.parameters is None: + logger.debug("Auto-water not configured.") + return False + + min_ignore = self.parameters.min_ignored_trials + min_unreward = self.parameters.min_unrewarded_trials + + is_ignored = [choice is None for choice in is_right_choice_history] + if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]): + logger.debug("Past %s trials ignored." % min_ignore) + return True + + is_unrewarded = [not reward for reward in reward_history] + if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]): + logger.debug("Past %s trials unrewarded." % min_unreward) + return True + + return False + + def determine_intervention(self, p_reward_right: float, p_reward_left: float) -> bool: + """Determine auto-water interventions to perform: give water on higher probability side + + Returns: + boolean indicating is_auto_response_right. True indicates auto-water given to right; False, left. + """ + + return True if p_reward_right > p_reward_left else False diff --git a/src/aind_behavior_dynamic_foraging/task_logic/interventions/base_intervention.py b/src/aind_behavior_dynamic_foraging/task_logic/interventions/base_intervention.py new file mode 100644 index 00000000..f3b470a4 --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/interventions/base_intervention.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + + +class BaseIntervention(ABC): + @abstractmethod + def are_intervention_conditions_met() -> bool: + """Abstract method to determine if intervention conditions are met. + + Returns: + True if intervention conditions are met, False otherwise. + """ + + pass + + @abstractmethod + def determine_intervention(): + """Abstract method to determine interventions if conditions are met.""" + + pass diff --git a/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py b/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py index 795d56c1..1f1ff68c 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, Field +from aind_behavior_dynamic_foraging.task_logic.interventions.base_intervention import BaseIntervention + logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ class BiasInterventionParameters(BaseModel): ) -class BiasIntervention: +class BiasIntervention(BaseIntervention): """Manages bias correction interventions during a task. Tracks the animal's side bias and applies corrections — either automatic water @@ -60,13 +62,13 @@ def __init__( self.water_corrections = 0 self.total_lickspout_offset = 0 - def are_antibias_conditions_met(self, bias: float) -> bool: + def are_intervention_conditions_met(self, bias: float) -> bool: """Checks whether antibias conditions are met. Intervention is only considered once ``trials_in_bias_intervention`` exceeds ``parameters.intervention_interval``. If the bias is outside the threshold range at that point, returns True and leaves the counter unchanged (the caller - is expected to call ``determine_antibias_intervention``, which resets it). + is expected to call ``determine_intervention``, which resets it). If conditions are not met, increments ``trials_in_bias_intervention`` by 1. Returns: @@ -87,10 +89,10 @@ def are_antibias_conditions_met(self, bias: float) -> bool: self.trials_in_bias_intervention += 1 return False - def determine_antibias_intervention(self, bias: float) -> tuple[Optional[bool], float]: + def determine_intervention(self, bias: float) -> tuple[Optional[bool], float]: """Determine anitbias interventions to perform: give water or move lickspouts - Called after ``are_antibias_conditions_met`` returns True. Resets + Called after ``are_intervention_conditions_met`` returns True. Resets ``trials_in_bias_intervention`` to 0 regardless of which intervention is applied. Water corrections are attempted first, up to ``parameters.maximum_water_corrections`` diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index cbfcae97..2ae6d967 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -12,6 +12,10 @@ from aind_behavior_services.task.distributions_utils import draw_sample from pydantic import BaseModel, Field +from aind_behavior_dynamic_foraging.task_logic.interventions.auto_water_intervention import ( + AutoWaterIntervention, + AutoWaterInterventionParameters, +) from aind_behavior_dynamic_foraging.task_logic.interventions.bias_intervention import ( BiasIntervention, BiasInterventionParameters, @@ -30,21 +34,6 @@ class BlockBasedTrialMetadata(BaseModel): is_autowater: bool = Field(default=False, description="Flag indicating if autowater is given for trial.") -class AutoWaterParameters(BaseModel): - min_ignored_trials: int = Field( - default=3, ge=0, description="Minimum consecutive ignored trials before auto water is triggered." - ) - min_unrewarded_trials: int = Field( - default=3, ge=0, description="Minimum consecutive unrewarded trials before auto water is triggered." - ) - reward_fraction: float = Field( - default=0.8, - ge=0, - le=1, - description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", - ) # TODO: Not implemented yet - - class Block(BaseModel): p_right_reward: float = Field(ge=0, le=1, description="Reward probability for right side during block.") p_left_reward: float = Field(ge=0, le=1, description="Reward probability for left side during block.") @@ -87,8 +76,8 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): description="Distribution describing block length.", ) - autowater_parameters: Optional[AutoWaterParameters] = Field( - default=AutoWaterParameters(), + auto_water_intervention_parameters: Optional[AutoWaterInterventionParameters] = Field( + default=AutoWaterInterventionParameters(), validate_default=True, description="Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", ) @@ -135,9 +124,12 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: self.is_right_baited: bool = False self.block: Block - self.bias: Optional[float] = None + # interventions + self.bias: float = np.nan self.bias_intervention = BiasIntervention(self.spec.bias_intervention_parameters) + self.auto_water_intervention = AutoWaterIntervention(self.spec.auto_water_intervention_parameters) + def update(self, outcome: TrialOutcome | str): """Updates generator state from the previous trial outcome. Records choice and reward history and manages baiting state. Args: @@ -197,16 +189,18 @@ def next(self) -> Trial | None: is_auto_response_right = None # determine autowater - if is_autowater := self._are_autowater_conditions_met(): - is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False + if is_autowater := self.auto_water_intervention.are_intervention_conditions_met( + self.is_right_choice_history, self.reward_history + ): + is_auto_response_right = self.auto_water_intervention.determine_intervention( + self.block.p_right_reward, self.block.p_left_reward + ) logger.debug("Delivering autowater: is_auto_response_right = %s" % is_auto_response_right) # determine bias correction. Overrides autowater lickspout_offset_delta = 0 - if self.bias_intervention.are_antibias_conditions_met(self.bias): - is_auto_response_right, lickspout_offset_delta = self.bias_intervention.determine_antibias_intervention( - self.bias - ) + if self.bias_intervention.are_intervention_conditions_met(self.bias): + is_auto_response_right, lickspout_offset_delta = self.bias_intervention.determine_intervention(self.bias) logger.debug( "Performing bias intervention: is_auto_response_right = %s, lickspout_offset_delta = %s." % (is_auto_response_right, lickspout_offset_delta) @@ -233,32 +227,6 @@ def get_metrics(self) -> TrialMetrics: return TrialMetrics(bias=self.bias) - def _are_autowater_conditions_met(self) -> bool: - """Checks whether autowater should be given. - - Returns: - True if autowater conditions are met, False otherwise. - """ - - if self.spec.autowater_parameters is None: - logger.debug("Autowater not configured.") - return False - - min_ignore = self.spec.autowater_parameters.min_ignored_trials - min_unreward = self.spec.autowater_parameters.min_unrewarded_trials - - is_ignored = [choice is None for choice in self.is_right_choice_history] - if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]): - logger.debug("Past %s trials ignored." % min_ignore) - return True - - is_unrewarded = [not reward for reward in self.reward_history] - if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]): - logger.debug("Past %s trials unrewarded." % min_unreward) - return True - - return False - @abstractmethod def _are_end_conditions_met(self) -> bool: """Checks whether the session should end. diff --git a/tests/test_interventions/test_bias_intervention.py b/tests/test_interventions/test_bias_intervention.py index 5dc9ba92..314222af 100644 --- a/tests/test_interventions/test_bias_intervention.py +++ b/tests/test_interventions/test_bias_intervention.py @@ -13,20 +13,20 @@ class TestBiasIntervention(unittest.TestCase): def test_returns_false_when_antibias_disabled(self): """Antibias should never trigger when bias_intervention_parameters is None.""" bias_intervention = BiasIntervention(bias_intervention_parameters=None) - self.assertFalse(bias_intervention.are_antibias_conditions_met(0.9)) + self.assertFalse(bias_intervention.are_intervention_conditions_met(0.9)) def test_returns_false_before_intervention_interval(self): """Condition should not trigger before the intervention interval is exceeded.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.trials_in_bias_intervention = 5 - self.assertFalse(bias_intervention.are_antibias_conditions_met(0.5)) + self.assertFalse(bias_intervention.are_intervention_conditions_met(0.5)) def test_returns_false_when_bias_within_thresholds(self): """No intervention when bias sits between lower and upper thresholds.""" bias_intervention = BiasIntervention(BiasInterventionParameters(bias_window_length=5)) bias_intervention.trials_in_bias_intervention = 15 - result = bias_intervention.are_antibias_conditions_met(0.4) + result = bias_intervention.are_intervention_conditions_met(0.4) self.assertFalse(result) @@ -35,7 +35,7 @@ def test_returns_true_when_bias_above_upper_threshold(self): bias_intervention = BiasIntervention(BiasInterventionParameters(bias_window_length=5)) bias_intervention.trials_in_bias_intervention = 15 - result = bias_intervention.are_antibias_conditions_met(0.9) + result = bias_intervention.are_intervention_conditions_met(0.9) self.assertTrue(result) @@ -43,7 +43,7 @@ def test_returns_true_when_bias_below_lower_threshold(self): """Intervention when bias is below threshold""" bias_intervention = BiasIntervention(BiasInterventionParameters(bias_window_length=5)) bias_intervention.trials_in_bias_intervention = 15 - result = bias_intervention.are_antibias_conditions_met(0.2) + result = bias_intervention.are_intervention_conditions_met(0.2) self.assertTrue(result) @@ -51,28 +51,28 @@ def test_gives_right_water_on_left_bias(self): """Negative bias (left bias) → give right water.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) - is_right, delta = bias_intervention.determine_antibias_intervention(-0.9) + is_right, delta = bias_intervention.determine_intervention(-0.9) self.assertTrue(is_right) self.assertEqual(delta, 0.0) def test_gives_left_water_on_right_bias(self): """Positive bias (right bias) → give left water.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) - is_right, delta = bias_intervention.determine_antibias_intervention(0.9) + is_right, delta = bias_intervention.determine_intervention(0.9) self.assertFalse(is_right) self.assertEqual(delta, 0.0) def test_water_corrections_counter_increments(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.water_corrections = 2 - bias_intervention.determine_antibias_intervention(-0.9) + bias_intervention.determine_intervention(-0.9) self.assertEqual(bias_intervention.water_corrections, 3) def test_switches_to_lickspout_after_max_corrections_left_bias(self): """After exhausting water corrections, move lickspout right (combat left bias).""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.water_corrections = 5 - is_right, delta = bias_intervention.determine_antibias_intervention(-0.9) + is_right, delta = bias_intervention.determine_intervention(-0.9) self.assertIsNone(is_right) self.assertGreater(delta, 0) @@ -80,14 +80,14 @@ def test_switches_to_lickspout_after_max_corrections_right_bias(self): """After exhausting water corrections, move lickspout left (combat right bias).""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.water_corrections = 5 - is_right, delta = bias_intervention.determine_antibias_intervention(0.9) + is_right, delta = bias_intervention.determine_intervention(0.9) self.assertIsNone(is_right) self.assertLess(delta, 0) def test_water_corrections_reset_after_lickspout_move(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.water_corrections = 5 - bias_intervention.determine_antibias_intervention(0.9) + bias_intervention.determine_intervention(0.9) self.assertEqual(bias_intervention.water_corrections, 0) # #### Test lickspout centering #### @@ -98,7 +98,7 @@ def test_no_centering_when_offset_is_zero(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.total_lickspout_offset = 0 bias_intervention.water_corrections = 5 - _, delta = bias_intervention.determine_antibias_intervention(0.2) + _, delta = bias_intervention.determine_intervention(0.2) self.assertEqual(delta, 0.0) def test_centering_moves_toward_zero_from_positive_offset(self): @@ -106,7 +106,7 @@ def test_centering_moves_toward_zero_from_positive_offset(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.total_lickspout_offset = 1 bias_intervention.water_corrections = 5 - _, delta = bias_intervention.determine_antibias_intervention(0.2) + _, delta = bias_intervention.determine_intervention(0.2) self.assertLess(delta, 0) def test_centering_moves_toward_zero_from_negative_offset(self): @@ -114,7 +114,7 @@ def test_centering_moves_toward_zero_from_negative_offset(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.total_lickspout_offset = -1 bias_intervention.water_corrections = 5 - _, delta = bias_intervention.determine_antibias_intervention(0.2) + _, delta = bias_intervention.determine_intervention(0.2) self.assertGreater(delta, 0) def test_centering_step_capped_at_offset_magnitude(self): @@ -123,7 +123,7 @@ def test_centering_step_capped_at_offset_magnitude(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.total_lickspout_offset = 0.01 bias_intervention.water_corrections = 5 - _, delta = bias_intervention.determine_antibias_intervention(0.2) + _, delta = bias_intervention.determine_intervention(0.2) self.assertLessEqual(abs(delta), 0.01) def test_total_lickspout_offset_updated_after_move(self): @@ -131,33 +131,33 @@ def test_total_lickspout_offset_updated_after_move(self): bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.total_lickspout_offset = 0 bias_intervention.water_corrections = 5 - _, delta = bias_intervention.determine_antibias_intervention(0.9) + _, delta = bias_intervention.determine_intervention(0.9) self.assertAlmostEqual(bias_intervention.total_lickspout_offset, delta) def test_trials_in_bias_intervention_increments_when_no_intervention(self): """Counter should increment each time conditions are checked but not met.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.trials_in_bias_intervention = 0 - bias_intervention.are_antibias_conditions_met(0.5) + bias_intervention.are_intervention_conditions_met(0.5) self.assertEqual(bias_intervention.trials_in_bias_intervention, 1) def test_trials_in_bias_intervention_does_not_increment_when_triggered(self): """Counter should not increment when intervention conditions are met.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.trials_in_bias_intervention = 15 - bias_intervention.are_antibias_conditions_met(0.9) + bias_intervention.are_intervention_conditions_met(0.9) self.assertNotEqual(bias_intervention.trials_in_bias_intervention, 16) def test_trials_in_bias_intervention_resets_after_determine_intervention(self): - """Counter should reset to 0 after determine_antibias_intervention is called.""" + """Counter should reset to 0 after determine_intervention is called.""" bias_intervention = BiasIntervention(BiasInterventionParameters()) bias_intervention.trials_in_bias_intervention = 15 - bias_intervention.determine_antibias_intervention(0.9) + bias_intervention.determine_intervention(0.9) self.assertEqual(bias_intervention.trials_in_bias_intervention, 0) def test_trials_in_bias_intervention_does_not_increment_when_disabled(self): """Counter should not change when bias intervention is not configured.""" bias_intervention = BiasIntervention(bias_intervention_parameters=None) bias_intervention.trials_in_bias_intervention = 0 - bias_intervention.are_antibias_conditions_met(0.9) + bias_intervention.are_intervention_conditions_met(0.9) self.assertEqual(bias_intervention.trials_in_bias_intervention, 0) From 4f04bc41ad31e2d98cb43e22fc8991b4f8576fc3 Mon Sep 17 00:00:00 2001 From: Micah Woodard Date: Wed, 3 Jun 2026 17:09:04 -0700 Subject: [PATCH 2/2] updates tests --- schema/coupled_baiting.json | 14 ++++----- schema/uncoupled.json | 16 +++++----- schema/uncoupled_baiting.json | 14 ++++----- .../interventions/auto_water_intervention.py | 4 +-- .../interventions/bias_intervention.py | 10 +++---- .../block_based_trial_generator.py | 12 ++++---- .../test_bias_intervention.py | 2 +- .../test_block_based_trial_generator.py | 30 +++++++++++-------- .../uncoupled/stages.py | 14 ++++----- 9 files changed, 60 insertions(+), 56 deletions(-) diff --git a/schema/coupled_baiting.json b/schema/coupled_baiting.json index cf4dc5fb..fb0f6d9c 100644 --- a/schema/coupled_baiting.json +++ b/schema/coupled_baiting.json @@ -54,7 +54,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -115,7 +115,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -210,7 +210,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -304,7 +304,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -398,7 +398,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -492,7 +492,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -594,7 +594,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 diff --git a/schema/uncoupled.json b/schema/uncoupled.json index 7c2a23fc..f5a9204d 100644 --- a/schema/uncoupled.json +++ b/schema/uncoupled.json @@ -54,7 +54,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.5 @@ -115,7 +115,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.5 @@ -210,7 +210,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 5, "min_unrewarded_trials": 5, "reward_fraction": 0.5 @@ -304,7 +304,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 7, "min_unrewarded_trials": 7, "reward_fraction": 0.5 @@ -395,7 +395,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 10, "min_unrewarded_trials": 10, "reward_fraction": 0.5 @@ -475,7 +475,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -496,7 +496,7 @@ "maximum_dominance_streak": 3.0 }, "lick_spout_retraction": false, - "autowater_parameters": null + "auto_water_intervention_parameters_parameters": null }, "version": "0.0.2-rc33", "stage_name": null @@ -556,7 +556,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": null, + "auto_water_intervention_parameters_parameters": null, "is_baiting": false, "trial_generation_end_parameters": { "ignore_window_length": 30, diff --git a/schema/uncoupled_baiting.json b/schema/uncoupled_baiting.json index 342fdfa4..3cb30619 100644 --- a/schema/uncoupled_baiting.json +++ b/schema/uncoupled_baiting.json @@ -54,7 +54,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -115,7 +115,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -210,7 +210,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -304,7 +304,7 @@ }, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -395,7 +395,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -475,7 +475,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 @@ -555,7 +555,7 @@ "truncation_parameters": null, "scaling_parameters": null }, - "autowater_parameters": { + "auto_water_intervention_parameters_parameters": { "min_ignored_trials": 3, "min_unrewarded_trials": 3, "reward_fraction": 0.8 diff --git a/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py b/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py index 0da7e77e..bf29f7db 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/interventions/auto_water_intervention.py @@ -36,10 +36,10 @@ def __init__( def are_intervention_conditions_met( self, is_right_choice_history: list[bool | None], reward_history: list[bool] ) -> bool: - """Checks whether autowater should be given. + """Checks whether auto-water should be given. Returns: - True if autowater conditions are met, False otherwise. + True if auto-water conditions are met, False otherwise. """ if self.parameters is None: diff --git a/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py b/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py index 1f1ff68c..eaab1fa1 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/interventions/bias_intervention.py @@ -63,7 +63,7 @@ def __init__( self.total_lickspout_offset = 0 def are_intervention_conditions_met(self, bias: float) -> bool: - """Checks whether antibias conditions are met. + """Checks whether bias interventions conditions are met. Intervention is only considered once ``trials_in_bias_intervention`` exceeds ``parameters.intervention_interval``. If the bias is outside the threshold @@ -72,7 +72,7 @@ def are_intervention_conditions_met(self, bias: float) -> bool: If conditions are not met, increments ``trials_in_bias_intervention`` by 1. Returns: - True if antibias conditions are met, False otherwise. + True if bias intervention conditions are met, False otherwise. """ if self.parameters is None: logger.debug("Bias intervention not configured.") @@ -109,14 +109,14 @@ def determine_intervention(self, bias: float) -> tuple[Optional[bool], float]: logger.debug("Bias intervention not configured.") return None, 0 - is_right_autowater = None + is_right_auto_water = None lickspout_offset_delta = 0 ab_delta = self.parameters.lickspout_offset_delta if abs(bias) >= self.parameters.threshold.upper: if self.water_corrections < self.parameters.maximum_water_corrections: logger.debug("Correcting bias with water.") # - bias values corresponds to left, so give right and vice versa - is_right_autowater = True if bias < 0 else False + is_right_auto_water = True if bias < 0 else False self.water_corrections += 1 else: logger.debug("Correcting bias with lickspout offset.") @@ -134,4 +134,4 @@ def determine_intervention(self, bias: float) -> tuple[Optional[bool], float]: self.total_lickspout_offset += lickspout_offset_delta self.trials_in_bias_intervention = 0 - return is_right_autowater, lickspout_offset_delta + return is_right_auto_water, lickspout_offset_delta diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 2ae6d967..e1cee0fd 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -31,7 +31,7 @@ class BlockBasedTrialMetadata(BaseModel): """Metadata for block based trial. These fields will NOT be used by the task engine.""" - is_autowater: bool = Field(default=False, description="Flag indicating if autowater is given for trial.") + is_auto_water: bool = Field(default=False, description="Flag indicating if auto-water is given for trial.") class Block(BaseModel): @@ -188,16 +188,16 @@ def next(self) -> Trial | None: is_auto_response_right = None - # determine autowater - if is_autowater := self.auto_water_intervention.are_intervention_conditions_met( + # determine auto-water + if is_auto_water := self.auto_water_intervention.are_intervention_conditions_met( self.is_right_choice_history, self.reward_history ): is_auto_response_right = self.auto_water_intervention.determine_intervention( self.block.p_right_reward, self.block.p_left_reward ) - logger.debug("Delivering autowater: is_auto_response_right = %s" % is_auto_response_right) + logger.debug("Delivering auto-water: is_auto_response_right = %s" % is_auto_response_right) - # determine bias correction. Overrides autowater + # determine bias correction. Overrides auto-water lickspout_offset_delta = 0 if self.bias_intervention.are_intervention_conditions_met(self.bias): is_auto_response_right, lickspout_offset_delta = self.bias_intervention.determine_intervention(self.bias) @@ -218,7 +218,7 @@ def next(self) -> Trial | None: metadata=Metadata( p_reward_left=self.block.p_left_reward, p_reward_right=self.block.p_right_reward, - extra=BlockBasedTrialMetadata(is_autowater=is_autowater), + extra=BlockBasedTrialMetadata(is_auto_water=is_auto_water), ), ) diff --git a/tests/test_interventions/test_bias_intervention.py b/tests/test_interventions/test_bias_intervention.py index 314222af..ef5d4862 100644 --- a/tests/test_interventions/test_bias_intervention.py +++ b/tests/test_interventions/test_bias_intervention.py @@ -10,7 +10,7 @@ class TestBiasIntervention(unittest.TestCase): - def test_returns_false_when_antibias_disabled(self): + def test_returns_false_when_bias_intervention_disabled(self): """Antibias should never trigger when bias_intervention_parameters is None.""" bias_intervention = BiasIntervention(bias_intervention_parameters=None) self.assertFalse(bias_intervention.are_intervention_conditions_met(0.9)) diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 77b2db61..2f3eb8e1 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -5,12 +5,14 @@ import numpy as np +from aind_behavior_dynamic_foraging.task_logic.interventions.auto_water_intervention import ( + AutoWaterInterventionParameters, +) from aind_behavior_dynamic_foraging.task_logic.interventions.bias_intervention import ( BiasInterventionParameters, BiasThreshold, ) from aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator import ( - AutoWaterParameters, Block, BlockBasedTrialGenerator, BlockBasedTrialGeneratorSpec, @@ -56,7 +58,7 @@ def test_next_returns_correct_reward_probs(self): self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward) -class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): +class TestBiasInterventionBlockBasedTrialGenerator(unittest.TestCase): def _patch_bias(self, bias_value: float) -> Any: return patch( @@ -107,44 +109,46 @@ def test_bias_stored_on_generator_after_check(self): #### Test next #### - def test_next_gives_right_autowater_on_left_bias(self): + def test_next_gives_right_auto_water_on_left_bias(self): gen = self._make_generator(bias=-0.9) trial = gen.next() assert trial is not None self.assertTrue(trial.is_auto_response_right) - def test_next_gives_left_autowater_on_right_bias(self): + def test_next_gives_left_auto_water_on_right_bias(self): gen = self._make_generator(bias=0.9) trial = gen.next() assert trial is not None self.assertFalse(trial.is_auto_response_right) - def test_next_no_antibias_when_below_interval(self): - """No antibias effect when trials_in_bias_intervention has not exceeded interval.""" + def test_next_no_bias_intervention_when_below_interval(self): + """No bias intervention when trials_in_bias_intervention has not exceeded interval.""" gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) trial = gen.next() assert trial is not None self.assertIsNone(trial.is_auto_response_right) - def test_next_antibias_overrides_autowater(self): - """When both autowater and antibias conditions are met, antibias takes precedence.""" + def test_next_bias_intervention_overrides_auto_water(self): + """When both auto-water and bias intervention conditions are met, bias intervention takes precedence.""" bip = BiasInterventionParameters( intervention_interval=10, threshold=BiasThreshold(upper=0.7, lower=0.3), maximum_water_corrections=5, bias_window_length=5, ) - aw = AutoWaterParameters(min_ignored_trials=1, min_unrewarded_trials=1, reward_fraction=0.8) - spec = ConcreteBlockBasedTrialGeneratorSpec(bias_intervention_parameters=bip, autowater_parameters=aw) + aw = AutoWaterInterventionParameters(min_ignored_trials=1, min_unrewarded_trials=1, reward_fraction=0.8) + spec = ConcreteBlockBasedTrialGeneratorSpec( + bias_intervention_parameters=bip, auto_water_intervention_parameters=aw + ) gen = spec.create_generator() gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10) gen.bias = -0.9 gen.bias_intervention.trials_in_bias_intervention = 15 - gen.is_right_choice_history = [None] # ignored trial → autowater would also fire + gen.is_right_choice_history = [None] # ignored trial → auto_water would also fire gen.reward_history = [False] trial = gen.next() - # Antibias (left bias → give right water) should win + # bias intervention (left bias → give right water) should win assert trial is not None self.assertTrue(trial.is_auto_response_right) @@ -155,7 +159,7 @@ def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self): assert trial is not None self.assertEqual(trial.lickspout_offset_delta, 0.05) - def test_next_no_lickspout_delta_when_antibias_not_triggered(self): + def test_next_no_lickspout_delta_when_bias_intervention_not_triggered(self): gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) trial = gen.next() assert trial is not None diff --git a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/uncoupled/stages.py b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/uncoupled/stages.py index 32b247ca..c2764a69 100644 --- a/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/uncoupled/stages.py +++ b/workspace/aind_behavior_dynamic_foraging_curricula/src/aind_behavior_dynamic_foraging_curricula/uncoupled/stages.py @@ -50,7 +50,7 @@ def make_s_stage_1_warmup(): trial_generator=TrialGeneratorCompositeSpec( generators=[ CoupledWarmupTrialGeneratorSpec( - autowater_parameters=AutoWaterParameters( + auto_water_intervention_parameters_parameters=AutoWaterParameters( reward_fraction=0.5, min_ignored_trials=3, min_unrewarded_trials=3 ), trial_generation_end_parameters=CoupledWarmupTrialGenerationEndConditions( @@ -76,7 +76,7 @@ def make_s_stage_1_warmup(): extend_block_on_no_response=True, ), CoupledTrialGeneratorSpec( - autowater_parameters=AutoWaterParameters( + auto_water_intervention_parameters_parameters=AutoWaterParameters( reward_fraction=0.5, min_ignored_trials=3, min_unrewarded_trials=3 ), trial_generation_end_parameters=CoupledTrialGenerationEndConditions( @@ -126,7 +126,7 @@ def make_s_stage_1(): reward_size=RewardSize(right_value_volume=2.0, left_value_volume=2.0), lick_spout_retraction=False, trial_generator=CoupledTrialGeneratorSpec( - autowater_parameters=AutoWaterParameters( + auto_water_intervention_parameters_parameters=AutoWaterParameters( reward_fraction=0.5, min_ignored_trials=5, min_unrewarded_trials=5 ), trial_generation_end_parameters=CoupledTrialGenerationEndConditions( @@ -174,7 +174,7 @@ def make_s_stage_2(): reward_size=RewardSize(right_value_volume=2.0, left_value_volume=2.0), lick_spout_retraction=False, trial_generator=CoupledTrialGeneratorSpec( - autowater_parameters=AutoWaterParameters( + auto_water_intervention_parameters_parameters=AutoWaterParameters( reward_fraction=0.5, min_ignored_trials=7, min_unrewarded_trials=7 ), trial_generation_end_parameters=CoupledTrialGenerationEndConditions( @@ -222,7 +222,7 @@ def make_s_stage_3(): reward_size=RewardSize(right_value_volume=2.0, left_value_volume=2.0), lick_spout_retraction=False, trial_generator=UncoupledTrialGeneratorSpec( - autowater_parameters=AutoWaterParameters( + auto_water_intervention_parameters_parameters=AutoWaterParameters( reward_fraction=0.5, min_ignored_trials=10, min_unrewarded_trials=10 ), trial_generation_end_parameters=UncoupledTrialGenerationEndConditions( @@ -259,7 +259,7 @@ def make_s_stage_final(): task_parameters=AindDynamicForagingTaskParameters( reward_size=RewardSize(right_value_volume=2.0, left_value_volume=2.0), lick_spout_retraction=False, - autowater_parameters=None, + auto_water_intervention_parameters_parameters=None, trial_generator=UncoupledTrialGeneratorSpec( trial_generation_end_parameters=UncoupledTrialGenerationEndConditions( max_trial=1000, @@ -296,7 +296,7 @@ def make_s_stage_graduated(): reward_size=RewardSize(right_value_volume=2.0, left_value_volume=2.0), lick_spout_retraction=False, trial_generator=UncoupledTrialGeneratorSpec( - autowater_parameters=None, + auto_water_intervention_parameters_parameters=None, trial_generation_end_parameters=UncoupledTrialGenerationEndConditions( max_trial=1000, max_time=4500,