Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
@info "Testing ModelingToolkitTearing"
Pkg.activate("lib/ModelingToolkitTearing")
Pkg.develop(; path = ".")
Pkg.add(; name = "ModelingToolkitBase", rev = "as/reversible-tfs")
Pkg.test()
else
@error "Unknown package" PKG
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/MTKDownstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
with:
repository: SciML/ModelingToolkit.jl
path: downstream
ref: "as/reversible-tfs"
- name: "Test ModelingToolkit/${{ matrix.group }}"
env:
GROUP: ${{ matrix.group }}
Expand Down
36 changes: 9 additions & 27 deletions lib/ModelingToolkitTearing/src/ModelingToolkitTearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,22 @@ include("diagnostics.jl")

include("reassemble.jl")

struct UnhackSystemCacheKey end
abstract type InlineLinsolveTransformation <: MTKBase.ReversibleTransformations end

function MTKBase.should_invalidate_mutable_cache_entry(::Type{UnhackSystemCacheKey}, patch::NamedTuple)
function MTKBase.should_invalidate_mutable_cache_entry(
::Type{InlineLinsolveTransformation}, patch::NamedTuple
)
return true
end

function MTKBase.unhack_system(sys::System)
cached_sys = MTKBase.check_mutable_cache(sys, UnhackSystemCacheKey, System, nothing)
function MTKBase.reverse_transformation(sys, ::Type{InlineLinsolveTransformation})
cached_sys = MTKBase.check_mutable_cache(sys, InlineLinsolveTransformation, System, nothing)
if cached_sys isa System
return cached_sys
end
# Observed are copied by the masking operation

obseqs = observed(sys)
eqs = copy(equations(sys))
obs_mask = trues(length(obseqs))
for (i, eq) in enumerate(obseqs)
obs_mask[i] = @match eq.rhs begin
BSImpl.Term(; f, args) => if f === change_origin
false
elseif f === SU.array_literal
result = true
for (si, ai) in zip(SU.stable_eachindex(eq.lhs), Iterators.drop(eachindex(args), 1))
result &= isequal(eq.lhs[si], args[ai])
result || break
end
!result
else
true
end
_ => true
end
end
obseqs = obseqs[obs_mask]

# Map from ldiv operation to index of the equations where it is the RHS. A
# positive index is into `obseqs`, a negative index is into `eqs`. The variable
Expand All @@ -107,8 +90,7 @@ function MTKBase.unhack_system(sys::System)
# Now, we want to turn all inlined linear SCCs into algebraic equations. If an element
# of the SCC is a differential variable, we'll introduce the `toterm` as a new algebraic.
# Otherwise, the observed equation is removed.
resize!(obs_mask, length(obseqs))
fill!(obs_mask, true)
obs_mask = trues(length(obseqs))
additional_eqs = Equation[]
additional_vars = Set{SymbolicT}()
additional_subs = Dict{SymbolicT, SymbolicT}()
Expand Down Expand Up @@ -171,7 +153,7 @@ function MTKBase.unhack_system(sys::System)
@set! newsys.unknowns = dvs
@set! newsys.schedule = sched

MTKBase.store_to_mutable_cache!(sys, UnhackSystemCacheKey, newsys)
MTKBase.store_to_mutable_cache!(sys, InlineLinsolveTransformation, newsys)

return newsys
end
Expand Down
17 changes: 13 additions & 4 deletions lib/ModelingToolkitTearing/src/reassemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1177,8 +1177,11 @@ function update_simplified_system!(
dummy_sub::Dict{SymbolicT, SymbolicT}, var_sccs::Vector{Vector{Int}},
extra_unknowns::Vector{SymbolicT}, iv::Union{SymbolicT, Nothing},
D::Union{Differential, Shift, Nothing}; array_hack = true)
(; fullvars, structure) = state
(; fullvars, structure, sys) = state
(; solvable_graph, var_to_diff, eq_to_diff, graph) = structure

sys = MTKBase.remove_unhack_system_transformation(sys)

diff_to_var = invview(var_to_diff)
# Since we solved the highest order derivative variable in discrete systems,
# we make a list of the solved variables and avoid including them in the
Expand All @@ -1202,7 +1205,6 @@ function update_simplified_system!(
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
end

sys = state.sys
obs_sub = dummy_sub
for eq in neweqs
MTKBase.isdiffeq(eq) || continue
Expand Down Expand Up @@ -1247,7 +1249,10 @@ function update_simplified_system!(
end
@set! sys.unknowns = unknowns

obs = (@invokelatest tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack))::Vector{Equation}
if array_hack
tf = MTKBase.add_array_observed!(obs, unknowns)
sys = MTKBase.with_reversible_transformation(sys, tf)
end

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand Down Expand Up @@ -1337,7 +1342,6 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
(; simplify, array_hack, inline_linear_sccs, analytical_linear_scc_limit) = alg
(; var_eq_matching, full_var_eq_matching, var_sccs) = tearing_result


extra_eqs_vars = get_extra_eqs_vars(
state, var_eq_matching, full_var_eq_matching, fully_determined)
neweqs = collect(equations(state))
Expand Down Expand Up @@ -1410,6 +1414,11 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
extra_unknowns, iv, D; array_hack)
end

if inline_linear_sccs
# We add this at the end so it is the first reversed transformation. This enables better
# caching.
sys = MTKBase.with_reversible_transformation(sys, InlineLinsolveTransformation)
end
sys = SU.setmetadata(sys, InlineLinearSystemsMetadata, inline_blocks)
@set! state.sys = sys
@set! sys.tearing_state = state
Expand Down
Loading