Last active
September 6, 2024 08:31
-
-
Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Revisions
-
vankesteren revised this gist
Sep 6, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,7 +5,7 @@ using Plots, Random """ permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true) Permute y values to approximate a functional constraint (rule) between x and y. # Arguments - `x::Vector`: The vector of x values -
vankesteren revised this gist
May 30, 2024 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -55,6 +55,8 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num # increment counter search_iter += 1 end # stopping conditions if search_iter >= max_search if verbose println("\nNo improvement found after $(iter-max_search) iterations.") -
vankesteren revised this gist
May 30, 2024 . 1 changed file with 4 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -55,10 +55,9 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num # increment counter search_iter += 1 end if search_iter >= max_search if verbose println("\nNo improvement found after $(iter-max_search) iterations.") end break end @@ -68,6 +67,9 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num end break end # increment iterations iter += 1 end return nothing end -
vankesteren revised this gist
May 30, 2024 . No changes.There are no files selected for viewing
-
vankesteren created this gist
May 30, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,120 @@ using StatsBase: sample, mean, cor using LinearAlgebra: norm using Plots, Random """ permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true) Permute y values to approximate a correlation between x and y of ρ. # Arguments - `x::Vector`: The vector of x values - `y::Vector`: The vector of y values - `rule::Function`: Function taking two numbers and outputting a single value - `score::Real`: The target score, i.e., the target value of `sum(rule.(x, y))` - `tol::Real`: The tolerance. If the loss `abs(current_score - score)` is below this value, stop the algorithm. - `max_iter::Int`: Maximum number of iterations. For large datasets, this may need to be increased. - `max_search::Int`: The number of iterations to search for improvements - `verbose::Bool`: Whether to print debug information """ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number=1e-3, max_iter::Int=10_000, max_search::Int=100, verbose::Bool=true) N = length(x) if N != length(y) throw(ArgumentError("x and y should be the same length!")) end # current objective value current_rule = rule.(x, y) current_score = sum(current_rule) current_loss = abs(score - current_score) iter::Int = 0 search_iter::Int = 0 while iter < max_iter # get random index i = sample(1:N) # compute change in score delta_score = rule.(x, y[i]) .+ rule.(x[i], y) .- current_rule .- current_rule[i] # only change if loss improves new_loss, j = findmin(abs.(score .- (current_score .+ delta_score))) if new_loss < current_loss # Found option! make change y[i], y[j] = y[j], y[i] current_rule[i], current_rule[j] = rule(x[i], y[i]), rule(x[j], y[j]) current_score = sum(current_rule) current_loss = abs(score - current_score) if verbose println("Iter $iter | loss $current_loss | $i ↔ $j | score $current_score | search $search_iter") end search_iter = 0 else # increment counter search_iter += 1 end iter += 1 if search_iter >= max_search if verbose println("\nNo improvement found after $(iter-max_search-1) iterations.") end break end if current_loss < tol if verbose println("\nAchieved tolerance!") end break end end return nothing end function permutefun!(x::Vector, y::Vector, rule::Function; tol::Number=1e-3, max_iter::Int=10_000, max_search::Number=100, verbose::Bool=true) permutefun!(x, y, rule, length(x); tol, max_iter, max_search, verbose) end function marginplot(x, y, title) layout = @layout [ a _ b{0.8w,0.8h} c ] xlim = (minimum(x), maximum(x)) ylim = (minimum(y), maximum(y)) default(fillcolor=:grey, markercolor=:grey, legend=false) plt = plot(layout=layout, link=:none, size=(500, 500), margin=-10Plots.px, plot_title=title) scatter!(x, y, subplot=2, xlim=xlim, ylim=ylim) histogram!(x, nbins=30, subplot=1, orientation=:v, framestyle=:none, bottommargin=-20Plots.px, xlim=xlim) histogram!(y, nbins=30, subplot=3, orientation=:h, framestyle=:none, leftmargin=-40Plots.px, ylim=ylim) return plt end # Generate some data N = 300 Random.seed!(45) x = rand(N) .- 0.5 y = vcat(randn(Int(N / 2)) ./ 6 .- 0.25, randn(Int(N / 2)) ./ 8 .+ 0.35) p1 = marginplot(x, y, "Original data") # Enforce complex constraint x1, y1 = copy(x), copy(y) permutefun!(x1, y1, (xi, yi) -> (xi^4 < yi^2)) p2 = marginplot(x1, y1, "x⁴ < y²") # induce a certain correlation x2, y2 = copy(x), copy(y) xm, ym = mean(x), mean(y) permutefun!(x2, y2, (xi, yi) -> (xi - xm) * (yi - ym), 0.7 * norm(x .- xm) * norm(y .- ym)) cor(x2, y2) p3 = marginplot(x2, y2, "Correlation = .7") # make a hole in the data x3, y3 = copy(x), copy(y) permutefun!(x3, y3, (xi, yi) -> sqrt(xi^2 + yi^2) > 0.3) p4 = marginplot(x3, y3, "Hole of radius 0.3") plot(p1, p2, p3, p4, size=(2000, 500), layout=(1, 4), margin=10Plots.px) savefig("permutefun.pdf")