diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 48967ef..6f7dd25 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -53,6 +53,7 @@ jobs: @info "Testing ModelingToolkitTearing" Pkg.activate("lib/ModelingToolkitTearing") Pkg.develop(; path = ".") + Pkg.add(; name = "ModelingToolkitBase", rev = "as/reversible-tfs") Pkg.test() else @error "Unknown package" PKG diff --git a/.github/workflows/MTKDownstream.yml b/.github/workflows/MTKDownstream.yml index 5a4f305..9d3d4ad 100644 --- a/.github/workflows/MTKDownstream.yml +++ b/.github/workflows/MTKDownstream.yml @@ -52,6 +52,7 @@ jobs: with: repository: SciML/ModelingToolkit.jl path: downstream + ref: "as/reversible-tfs" - name: "Test ModelingToolkit/${{ matrix.group }}" env: GROUP: ${{ matrix.group }} diff --git a/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl b/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl index a31fa1f..b2577c3 100644 --- a/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl +++ b/lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl @@ -58,39 +58,22 @@ include("diagnostics.jl") include("reassemble.jl") -struct UnhackSystemCacheKey end +abstract type InlineLinsolveTransformation <: MTKBase.ReversibleTransformations end -function MTKBase.should_invalidate_mutable_cache_entry(::Type{UnhackSystemCacheKey}, patch::NamedTuple) +function MTKBase.should_invalidate_mutable_cache_entry( + ::Type{InlineLinsolveTransformation}, patch::NamedTuple + ) return true end -function MTKBase.unhack_system(sys::System) - cached_sys = MTKBase.check_mutable_cache(sys, UnhackSystemCacheKey, System, nothing) +function MTKBase.reverse_transformation(sys, ::Type{InlineLinsolveTransformation}) + cached_sys = MTKBase.check_mutable_cache(sys, InlineLinsolveTransformation, System, nothing) if cached_sys isa System return cached_sys end - # Observed are copied by the masking operation + obseqs = observed(sys) eqs = copy(equations(sys)) - obs_mask = trues(length(obseqs)) - for (i, eq) in enumerate(obseqs) - obs_mask[i] = @match eq.rhs begin - BSImpl.Term(; f, args) => if f === change_origin - false - elseif f === SU.array_literal - result = true - for (si, ai) in zip(SU.stable_eachindex(eq.lhs), Iterators.drop(eachindex(args), 1)) - result &= isequal(eq.lhs[si], args[ai]) - result || break - end - !result - else - true - end - _ => true - end - end - obseqs = obseqs[obs_mask] # Map from ldiv operation to index of the equations where it is the RHS. A # positive index is into `obseqs`, a negative index is into `eqs`. The variable @@ -107,8 +90,7 @@ function MTKBase.unhack_system(sys::System) # Now, we want to turn all inlined linear SCCs into algebraic equations. If an element # of the SCC is a differential variable, we'll introduce the `toterm` as a new algebraic. # Otherwise, the observed equation is removed. - resize!(obs_mask, length(obseqs)) - fill!(obs_mask, true) + obs_mask = trues(length(obseqs)) additional_eqs = Equation[] additional_vars = Set{SymbolicT}() additional_subs = Dict{SymbolicT, SymbolicT}() @@ -171,7 +153,7 @@ function MTKBase.unhack_system(sys::System) @set! newsys.unknowns = dvs @set! newsys.schedule = sched - MTKBase.store_to_mutable_cache!(sys, UnhackSystemCacheKey, newsys) + MTKBase.store_to_mutable_cache!(sys, InlineLinsolveTransformation, newsys) return newsys end diff --git a/lib/ModelingToolkitTearing/src/reassemble.jl b/lib/ModelingToolkitTearing/src/reassemble.jl index e61ab5e..b2a3c12 100644 --- a/lib/ModelingToolkitTearing/src/reassemble.jl +++ b/lib/ModelingToolkitTearing/src/reassemble.jl @@ -1177,8 +1177,11 @@ function update_simplified_system!( dummy_sub::Dict{SymbolicT, SymbolicT}, var_sccs::Vector{Vector{Int}}, extra_unknowns::Vector{SymbolicT}, iv::Union{SymbolicT, Nothing}, D::Union{Differential, Shift, Nothing}; array_hack = true) - (; fullvars, structure) = state + (; fullvars, structure, sys) = state (; solvable_graph, var_to_diff, eq_to_diff, graph) = structure + + sys = MTKBase.remove_unhack_system_transformation(sys) + diff_to_var = invview(var_to_diff) # Since we solved the highest order derivative variable in discrete systems, # we make a list of the solved variables and avoid including them in the @@ -1202,7 +1205,6 @@ function update_simplified_system!( (var_to_diff[i] !== nothing && !isempty(đť‘‘neighbors(graph, var_to_diff[i])))) end - sys = state.sys obs_sub = dummy_sub for eq in neweqs MTKBase.isdiffeq(eq) || continue @@ -1247,7 +1249,10 @@ function update_simplified_system!( end @set! sys.unknowns = unknowns - obs = (@invokelatest tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack))::Vector{Equation} + if array_hack + tf = MTKBase.add_array_observed!(obs, unknowns) + sys = MTKBase.with_reversible_transformation(sys, tf) + end @set! sys.eqs = neweqs @set! sys.observed = obs @@ -1337,7 +1342,6 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState, (; simplify, array_hack, inline_linear_sccs, analytical_linear_scc_limit) = alg (; var_eq_matching, full_var_eq_matching, var_sccs) = tearing_result - extra_eqs_vars = get_extra_eqs_vars( state, var_eq_matching, full_var_eq_matching, fully_determined) neweqs = collect(equations(state)) @@ -1410,6 +1414,11 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState, extra_unknowns, iv, D; array_hack) end + if inline_linear_sccs + # We add this at the end so it is the first reversed transformation. This enables better + # caching. + sys = MTKBase.with_reversible_transformation(sys, InlineLinsolveTransformation) + end sys = SU.setmetadata(sys, InlineLinearSystemsMetadata, inline_blocks) @set! state.sys = sys @set! sys.tearing_state = state