Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active September 6, 2024 08:31
Show Gist options
  • Select an option

  • Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.

Select an option

Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.

Revisions

  1. vankesteren revised this gist Sep 6, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion permutefun.jl
    Original file line number Diff line number Diff line change
    @@ -5,7 +5,7 @@ using Plots, Random
    """
    permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true)
    Permute y values to approximate a correlation between x and y of ρ.
    Permute y values to approximate a functional constraint (rule) between x and y.
    # Arguments
    - `x::Vector`: The vector of x values
  2. vankesteren revised this gist May 30, 2024. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions permutefun.jl
    Original file line number Diff line number Diff line change
    @@ -55,6 +55,8 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num
    # increment counter
    search_iter += 1
    end

    # stopping conditions
    if search_iter >= max_search
    if verbose
    println("\nNo improvement found after $(iter-max_search) iterations.")
  3. vankesteren revised this gist May 30, 2024. 1 changed file with 4 additions and 2 deletions.
    6 changes: 4 additions & 2 deletions permutefun.jl
    Original file line number Diff line number Diff line change
    @@ -55,10 +55,9 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num
    # increment counter
    search_iter += 1
    end
    iter += 1
    if search_iter >= max_search
    if verbose
    println("\nNo improvement found after $(iter-max_search-1) iterations.")
    println("\nNo improvement found after $(iter-max_search) iterations.")
    end
    break
    end
    @@ -68,6 +67,9 @@ function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Num
    end
    break
    end

    # increment iterations
    iter += 1
    end
    return nothing
    end
  4. vankesteren revised this gist May 30, 2024. No changes.
  5. vankesteren created this gist May 30, 2024.
    120 changes: 120 additions & 0 deletions permutefun.jl
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,120 @@
    using StatsBase: sample, mean, cor
    using LinearAlgebra: norm
    using Plots, Random

    """
    permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true)
    Permute y values to approximate a correlation between x and y of ρ.
    # Arguments
    - `x::Vector`: The vector of x values
    - `y::Vector`: The vector of y values
    - `rule::Function`: Function taking two numbers and outputting a single value
    - `score::Real`: The target score, i.e., the target value of `sum(rule.(x, y))`
    - `tol::Real`: The tolerance. If the loss `abs(current_score - score)` is below this value, stop the algorithm.
    - `max_iter::Int`: Maximum number of iterations. For large datasets, this may need to be increased.
    - `max_search::Int`: The number of iterations to search for improvements
    - `verbose::Bool`: Whether to print debug information
    """
    function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number=1e-3, max_iter::Int=10_000, max_search::Int=100, verbose::Bool=true)
    N = length(x)
    if N != length(y)
    throw(ArgumentError("x and y should be the same length!"))
    end

    # current objective value
    current_rule = rule.(x, y)
    current_score = sum(current_rule)
    current_loss = abs(score - current_score)

    iter::Int = 0
    search_iter::Int = 0
    while iter < max_iter
    # get random index
    i = sample(1:N)

    # compute change in score
    delta_score = rule.(x, y[i]) .+ rule.(x[i], y) .- current_rule .- current_rule[i]

    # only change if loss improves
    new_loss, j = findmin(abs.(score .- (current_score .+ delta_score)))

    if new_loss < current_loss
    # Found option! make change
    y[i], y[j] = y[j], y[i]
    current_rule[i], current_rule[j] = rule(x[i], y[i]), rule(x[j], y[j])
    current_score = sum(current_rule)
    current_loss = abs(score - current_score)
    if verbose
    println("Iter $iter | loss $current_loss | $i$j | score $current_score | search $search_iter")
    end
    search_iter = 0
    else
    # increment counter
    search_iter += 1
    end
    iter += 1
    if search_iter >= max_search
    if verbose
    println("\nNo improvement found after $(iter-max_search-1) iterations.")
    end
    break
    end
    if current_loss < tol
    if verbose
    println("\nAchieved tolerance!")
    end
    break
    end
    end
    return nothing
    end

    function permutefun!(x::Vector, y::Vector, rule::Function; tol::Number=1e-3, max_iter::Int=10_000, max_search::Number=100, verbose::Bool=true)
    permutefun!(x, y, rule, length(x); tol, max_iter, max_search, verbose)
    end


    function marginplot(x, y, title)
    layout = @layout [
    a _
    b{0.8w,0.8h} c
    ]
    xlim = (minimum(x), maximum(x))
    ylim = (minimum(y), maximum(y))
    default(fillcolor=:grey, markercolor=:grey, legend=false)
    plt = plot(layout=layout, link=:none, size=(500, 500), margin=-10Plots.px, plot_title=title)
    scatter!(x, y, subplot=2, xlim=xlim, ylim=ylim)
    histogram!(x, nbins=30, subplot=1, orientation=:v, framestyle=:none, bottommargin=-20Plots.px, xlim=xlim)
    histogram!(y, nbins=30, subplot=3, orientation=:h, framestyle=:none, leftmargin=-40Plots.px, ylim=ylim)
    return plt
    end

    # Generate some data
    N = 300
    Random.seed!(45)
    x = rand(N) .- 0.5
    y = vcat(randn(Int(N / 2)) ./ 6 .- 0.25, randn(Int(N / 2)) ./ 8 .+ 0.35)
    p1 = marginplot(x, y, "Original data")

    # Enforce complex constraint
    x1, y1 = copy(x), copy(y)
    permutefun!(x1, y1, (xi, yi) -> (xi^4 < yi^2))
    p2 = marginplot(x1, y1, "x⁴ < y²")

    # induce a certain correlation
    x2, y2 = copy(x), copy(y)
    xm, ym = mean(x), mean(y)
    permutefun!(x2, y2, (xi, yi) -> (xi - xm) * (yi - ym), 0.7 * norm(x .- xm) * norm(y .- ym))
    cor(x2, y2)
    p3 = marginplot(x2, y2, "Correlation = .7")

    # make a hole in the data
    x3, y3 = copy(x), copy(y)
    permutefun!(x3, y3, (xi, yi) -> sqrt(xi^2 + yi^2) > 0.3)
    p4 = marginplot(x3, y3, "Hole of radius 0.3")

    plot(p1, p2, p3, p4, size=(2000, 500), layout=(1, 4), margin=10Plots.px)
    savefig("permutefun.pdf")