Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "SourceCodeMcCormick"
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
authors = ["Robert Gottlieb <Robert.x.gottlieb@uconn.edu>"]
version = "0.5.1"
version = "0.5.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -20,6 +21,7 @@ CUDA = "5"
DocStringExtensions = "0.8 - 0.9"
Graphs = "1"
IfElse = "0.1.0 - 0.1.1"
MultiFloats = "3.1"
PrecompileTools = "~1"
Reexport = "~1"
StaticArrays = "~1"
Expand Down
1 change: 1 addition & 0 deletions src/SourceCodeMcCormick.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using DocStringExtensions
using Graphs
using CUDA
using StaticArrays: @MVector
using MultiFloats
import Dates
import SymbolicUtils: BasicSymbolic, exprtype, SYM, TERM, ADD, MUL, POW, DIV

Expand Down
368 changes: 117 additions & 251 deletions src/kernel_writer/kernel_write.jl

Large diffs are not rendered by default.

18 changes: 13 additions & 5 deletions src/kernel_writer/math_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@ end

# Sigmoid function
# max threads: 640
@register_symbolic SCMC_sigmoid(x) # Register as symbolic so that we can use it later
function SCMC_sigmoid_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
stride = blockDim().x * gridDim().x
Expand Down Expand Up @@ -4670,21 +4671,21 @@ function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
kL = Base.ceil(-0.5 - x[idx,3]/(2.0*pi))
xL1 = x[idx,3] + 2.0*pi*kL
xU1 = x[idx,4] + 2.0*pi*kL
if (xL1 < -pi) || (xL1 > pi)
if (xL1 < -pi) || (xL1 > Float64(pi))
eps_min = NaN
eps_max = NaN
elseif xL1 <= 0.0
if xU1 <= 0.0
eps_min = x[idx,3]
eps_max = x[idx,4]
elseif xU1 >= pi
elseif xU1 >= Float64(pi)
eps_min = pi - 2.0*pi*kL
eps_max = -2.0*pi*kL
else
eps_min = (cos(xL1) <= cos(xU1)) ? x[idx,3] : x[idx,4]
eps_max = -2.0*pi*kL
end
elseif xU1 <= pi
elseif xU1 <= Float64(pi)
eps_min = x[idx,4]
eps_max = x[idx,3]
elseif xU1 >= 2.0*pi
Expand Down Expand Up @@ -5449,9 +5450,16 @@ function cos_newton_or_golden_section(x0::Float64, xL::Float64, xU::Float64, env
return xk
end

# Directly from IntervalArithmetic.jl
# Similar to IntervalArithmetic.jl, but not using `rem2pi`
function quadrant(x::Float64)
x_mod2pi = rem2pi(x, RoundNearest)
bigx = MultiFloats.Float64x2(x)
bigpi = MultiFloats._MF{Float64,2}((3.141592653589793, 1.2246467991473532e-16))
rem = Float64(floor(bigx/bigpi))
if iseven(rem)
x_mod2pi = Float64(bigx - rem*bigpi)
else
x_mod2pi = Float64(bigx - (rem+1)*bigpi)
end

x_mod2pi < -(pi/2.0) && return (Int32(2), x_mod2pi)
x_mod2pi < 0 && return (Int32(3), x_mod2pi)
Expand Down
12 changes: 6 additions & 6 deletions src/kernel_writer/string_math_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10291,21 +10291,21 @@ function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, spars
write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n")
write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n")
write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n")
write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n")
write(buffer, " if (xL1 < -pi) || (xL1 > Float64(pi))\n")
write(buffer, " eps_min = NaN\n")
write(buffer, " eps_max = NaN\n")
write(buffer, " elseif xL1 <= 0.0\n")
write(buffer, " if xU1 <= 0.0\n")
write(buffer, " eps_min = $v1_lo\n")
write(buffer, " eps_max = $v1_hi\n")
write(buffer, " elseif xU1 >= pi\n")
write(buffer, " elseif xU1 >= Float64(pi)\n")
write(buffer, " eps_min = pi - 2.0*pi*kL\n")
write(buffer, " eps_max = -2.0*pi*kL\n")
write(buffer, " else\n")
write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n")
write(buffer, " eps_max = -2.0*pi*kL\n")
write(buffer, " end\n")
write(buffer, " elseif xU1 <= pi\n")
write(buffer, " elseif xU1 <= Float64(pi)\n")
write(buffer, " eps_min = $v1_hi\n")
write(buffer, " eps_max = $v1_lo\n")
write(buffer, " elseif xU1 >= 2.0*pi\n")
Expand Down Expand Up @@ -10601,21 +10601,21 @@ function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, spars
write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n")
write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n")
write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n")
write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n")
write(buffer, " if (xL1 < -pi) || (xL1 > Float64(pi))\n")
write(buffer, " eps_min = NaN\n")
write(buffer, " eps_max = NaN\n")
write(buffer, " elseif xL1 <= 0.0\n")
write(buffer, " if xU1 <= 0.0\n")
write(buffer, " eps_min = $v1_lo\n")
write(buffer, " eps_max = $v1_hi\n")
write(buffer, " elseif xU1 >= pi\n")
write(buffer, " elseif xU1 >= Float64(pi)\n")
write(buffer, " eps_min = pi - 2.0*pi*kL\n")
write(buffer, " eps_max = -2.0*pi*kL\n")
write(buffer, " else\n")
write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n")
write(buffer, " eps_max = -2.0*pi*kL\n")
write(buffer, " end\n")
write(buffer, " elseif xU1 <= pi\n")
write(buffer, " elseif xU1 <= Float64(pi)\n")
write(buffer, " eps_min = $v1_hi\n")
write(buffer, " eps_max = $v1_lo\n")
write(buffer, " elseif xU1 >= 2.0*pi\n")
Expand Down
65 changes: 55 additions & 10 deletions src/transform/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,22 @@ julia> pull_vars(func)
z
```
"""
pull_vars(term::BasicSymbolic) = pull_vars(Num(term))
function pull_vars(term::Num)
pull_vars(term::BasicSymbolic; get_names::Bool=false) = pull_vars(Num(term), get_names=get_names)
function pull_vars(term::Num; get_names::Bool=false)
vars = Num[]
strings = String[]
if ~(typeof(term.val) <: Real)
vars, strings = _pull_vars(term.val, vars, strings)
vars = vars[sort_vars(strings)]
end
return vars
if get_names
return get_name.(vars)
else
return vars
end
end

function pull_vars(terms::Vector{Num})
function pull_vars(terms::Vector{Num}; get_names::Bool=false)
vars = Num[]
strings = String[]
for term in terms
Expand All @@ -258,20 +262,28 @@ function pull_vars(terms::Vector{Num})
if ~isempty(vars)
vars = vars[sort_vars(strings)]
end
return vars
if get_names
return get_name.(vars)
else
return vars
end
end

function pull_vars(eqn::Equation)
function pull_vars(eqn::Equation; get_names::Bool=false)
vars = Num[]
strings = String[]
if ~(typeof(eqn.rhs) <: Real)
vars, strings = _pull_vars(eqn.rhs, vars, strings)
vars = vars[sort_vars(strings)]
end
return vars
if get_names
return get_name.(vars)
else
return vars
end
end

function pull_vars(eqns::Vector{Equation})
function pull_vars(eqns::Vector{Equation}; get_names::Bool=false)
vars = Num[]
strings = String[]
for eqn in eqns
Expand All @@ -282,9 +294,13 @@ function pull_vars(eqns::Vector{Equation})
if ~isempty(vars)
vars = vars[sort_vars(strings)]
end
return vars
if get_names
return get_name.(vars)
else
return vars
end
end
function pull_vars(eqn::T) where T<:Real
function pull_vars(eqn::T; get_names::Bool=false) where T<:Real
return Num[]
end

Expand Down Expand Up @@ -536,6 +552,35 @@ function extract(eqs::Vector{Equation}, ID::Int=length(eqs))
return final_expr
end


"""
shorten(::Vector{Equation}, ::Int)

Given a set of symbolic equations, and a specific element index,
return a Vector{Equation} that only contains elements needed to
evaluate the chosen element.
```
"""
function shorten(eqs::Vector{Equation}, ID::Int)
indices = Int[]
function delve!(idx, indices, LHS, RHS)
if idx in indices
return nothing
else
for var in RHS[idx]
var_idx = findfirst(==(var), LHS)
if ~isnothing(var_idx) && (var_idx != idx)
delve!(var_idx, indices, LHS, RHS)
end
end
push!(indices, idx)
return nothing
end
end
delve!(ID, indices, Symbol.(getfield.(eqs, :lhs)), pull_vars.(getfield.(eqs, :rhs), get_names=true))
return eqs[indices]
end

"""
convex_evaluator(::Num)
convex_evaluator(::Equation)
Expand Down
26 changes: 25 additions & 1 deletion src/transform/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,30 @@ function eqn_edges(a::Vector{Equation})
end
return edgelist, vars
end
function eqn_edges(a::Vector{Equation}, vars::Vector{Symbol})
# Create the list of edges
edgelist = Edge{Int}[]

# Create a mapping dictionary
varid = Dict(vars .=> collect(1:length(vars)))

# Identify LHS variables
LHS_id = [varid[x] for x in Symbol.(getfield.(a, :lhs))]

# Identify RHS variables
RHS_id = [[varid[x] for x in pull_vars(RHS, get_names=true)] for RHS in getfield.(a, :rhs)]

# Create edges of RHS -> LHS
for i in eachindex(LHS_id)
for j in eachindex(RHS_id[i])
if RHS_id[i][j] == LHS_id[i]
continue
end
push!(edgelist, Edge(RHS_id[i][j], LHS_id[i]))
end
end
return edgelist
end

# A new topological sort that tries to minimize the number of temporary vectors
# that need to be preallocated
Expand All @@ -128,7 +152,7 @@ function topological_sort(g::SimpleDiGraph; order::Vector{Int64}=Int64[])
for j in g.badjlist[i][sortperm(-lengths)]
recursive_add(g, j, order)
end
if ~in(i, order)
if ~in(i, order) && ~isempty(g.badjlist[i])
push!(order, i)
end
end
Expand Down
Loading