diff --git a/docs/source/attribute/knitro.md b/docs/source/attribute/knitro.md index 0329cee..b9c42cf 100644 --- a/docs/source/attribute/knitro.md +++ b/docs/source/attribute/knitro.md @@ -83,7 +83,7 @@ - ✅ - ✅ * - Domain - - ❌ + - ✅ - ✅ * - PrimalStart - ❌ diff --git a/include/pyoptinterface/knitro_model.hpp b/include/pyoptinterface/knitro_model.hpp index a723eef..906e52f 100644 --- a/include/pyoptinterface/knitro_model.hpp +++ b/include/pyoptinterface/knitro_model.hpp @@ -22,6 +22,7 @@ B(KN_new); \ B(KN_free); \ B(KN_update); \ + B(KN_solve); \ B(KN_get_param_id); \ B(KN_get_param_type); \ B(KN_set_int_param); \ @@ -29,9 +30,7 @@ B(KN_set_double_param); \ B(KN_get_int_param); \ B(KN_get_double_param); \ - B(KN_add_vars); \ B(KN_add_var); \ - B(KN_add_cons); \ B(KN_add_con); \ B(KN_set_var_lobnd); \ B(KN_set_var_upbnd); \ @@ -43,7 +42,6 @@ B(KN_get_var_name); \ B(KN_set_con_lobnd); \ B(KN_set_con_upbnd); \ - B(KN_set_con_eqbnd); \ B(KN_get_con_lobnd); \ B(KN_get_con_upbnd); \ B(KN_set_con_name); \ @@ -54,10 +52,8 @@ B(KN_add_obj_constant); \ B(KN_del_obj_constant); \ B(KN_add_obj_linear_struct); \ - B(KN_del_obj_linear_struct); \ B(KN_del_obj_linear_struct_all); \ B(KN_add_obj_quadratic_struct); \ - B(KN_del_obj_quadratic_struct); \ B(KN_del_obj_quadratic_struct_all); \ B(KN_chg_obj_linear_term); \ B(KN_add_con_constant); \ @@ -71,7 +67,6 @@ B(KN_set_cb_grad); \ B(KN_set_cb_hess); \ B(KN_del_obj_eval_callback_all); \ - B(KN_solve); \ B(KN_get_var_primal_value); \ B(KN_get_var_dual_value); \ B(KN_get_con_value); \ @@ -324,6 +319,21 @@ inline int knitro_var_type(VariableDomain domain) } } +inline VariableDomain knitro_variable_domain(int var_type) +{ + switch (var_type) + { + case KN_VARTYPE_CONTINUOUS: + return VariableDomain::Continuous; + case KN_VARTYPE_INTEGER: + return VariableDomain::Integer; + case KN_VARTYPE_BINARY: + return VariableDomain::Binary; + default: + throw std::runtime_error("Unknown variable type"); + } +} + inline int knitro_obj_goal(ObjectiveSense sense) { switch (sense) @@ -423,6 +433,7 @@ class KNITROModel : public OnesideLinearConstraintMixin, std::string get_variable_name(const VariableIndex &variable) const; void set_variable_name(const VariableIndex &variable, const std::string &name); void set_variable_domain(const VariableIndex &variable, VariableDomain domain); + VariableDomain get_variable_domain(const VariableIndex &variable) const; double get_variable_rc(const VariableIndex &variable) const; std::string pprint_variable(const VariableIndex &variable) const; diff --git a/lib/knitro_model.cpp b/lib/knitro_model.cpp index cb69f38..57e25eb 100644 --- a/lib/knitro_model.cpp +++ b/lib/knitro_model.cpp @@ -290,6 +290,13 @@ void KNITROModel::set_variable_domain(const VariableIndex &variable, VariableDom _mark_dirty(); } +VariableDomain KNITROModel::get_variable_domain(const VariableIndex &variable) const +{ + KNINT indexVar = _variable_index(variable); + int var_type = _get_value(knitro::KN_get_var_type, indexVar); + return knitro_variable_domain(var_type); +} + double KNITROModel::get_variable_rc(const VariableIndex &variable) const { _check_dirty(); diff --git a/lib/knitro_model_ext.cpp b/lib/knitro_model_ext.cpp index b6bda58..9e2c0f9 100644 --- a/lib/knitro_model_ext.cpp +++ b/lib/knitro_model_ext.cpp @@ -64,6 +64,7 @@ NB_MODULE(knitro_model_ext, m) BIND_F(get_variable_name) BIND_F(set_variable_name) BIND_F(set_variable_domain) + BIND_F(get_variable_domain) BIND_F(get_variable_rc) BIND_F(delete_variable) // clang-format on diff --git a/src/pyoptinterface/_src/knitro.py b/src/pyoptinterface/_src/knitro.py index f43e32f..ae4023a 100644 --- a/src/pyoptinterface/_src/knitro.py +++ b/src/pyoptinterface/_src/knitro.py @@ -106,6 +106,7 @@ def autoload_library(): VariableAttribute.LowerBound: lambda model, v: model.get_variable_lb(v), VariableAttribute.UpperBound: lambda model, v: model.get_variable_ub(v), VariableAttribute.Name: lambda model, v: model.get_variable_name(v), + VariableAttribute.Domain: lambda model, v: model.get_variable_domain(v), VariableAttribute.ReducedCost: lambda model, v: model.get_variable_rc(v), } diff --git a/tests/test_knitro.py b/tests/test_knitro.py index e708010..5d9d24d 100644 --- a/tests/test_knitro.py +++ b/tests/test_knitro.py @@ -484,13 +484,20 @@ def test_variable_attribute_primal_start(): def test_variable_attribute_domain(): - """Test setting variable domain.""" + """Test getting and setting variable domain.""" model = knitro.Model() x = model.add_variable(lb=0.0, ub=10.0) + # Default domain should be Continuous + domain = model.get_variable_attribute(x, poi.VariableAttribute.Domain) + assert domain == poi.VariableDomain.Continuous + + # Set to Integer and verify model.set_variable_attribute( x, poi.VariableAttribute.Domain, poi.VariableDomain.Integer ) + domain = model.get_variable_attribute(x, poi.VariableAttribute.Domain) + assert domain == poi.VariableDomain.Integer model.set_objective(x, poi.ObjectiveSense.Minimize) model.add_linear_constraint(x, poi.ConstraintSense.GreaterEqual, 2.5) @@ -498,6 +505,11 @@ def test_variable_attribute_domain(): assert model.get_value(x) == approx(3.0) + # Test Binary domain + y = model.add_variable(domain=poi.VariableDomain.Binary) + domain = model.get_variable_attribute(y, poi.VariableAttribute.Domain) + assert domain == poi.VariableDomain.Binary + def test_constraint_attribute_name(): """Test getting and setting constraint name."""