Skip to content

Instantly share code, notes, and snippets.

@RuolinZheng08
Created January 31, 2021 20:45
Show Gist options
  • Select an option

  • Save RuolinZheng08/86518fbf0d7ef2cc7c7cb63cb49fd32b to your computer and use it in GitHub Desktop.

Select an option

Save RuolinZheng08/86518fbf0d7ef2cc7c7cb63cb49fd32b to your computer and use it in GitHub Desktop.

Revisions

  1. RuolinZheng08 created this gist Jan 31, 2021.
    433 changes: 433 additions & 0 deletions permutation_test.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,433 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
    "from itertools import combinations\n",
    "from collections import Counter"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
    "def two_sample_permutation_test(arr, start, size):\n",
    " target = arr[start : start + size]\n",
    " arr_counter = Counter(arr)\n",
    " target_diff = sum(target) - sum((arr_counter - Counter(target)).elements())\n",
    " print('target: ', target_diff)\n",
    " count = 0\n",
    " for curr in combinations(arr, size):\n",
    " curr_counter = Counter(curr)\n",
    " complement = list((arr_counter - curr_counter).elements())\n",
    " diff = sum(curr) - sum(complement)\n",
    " if diff >= target_diff:\n",
    " print(curr, diff)\n",
    " count += 1\n",
    " return count"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 60,
    "metadata": {},
    "outputs": [],
    "source": [
    "arr = [111, 56, 86, 92, 104, 118, 117, 111]\n",
    "two_sample_permutation_test(arr, 4, 4)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 72,
    "metadata": {
    "scrolled": false
    },
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "(4.5, 3, 7, 6) 13.0\n",
    "(4.5, 7, 6, 4.5) 16.0\n",
    "(3, 7, 6, 4.5) 13.0\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "3"
    ]
    },
    "execution_count": 72,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "two_sample_permutation_test([4.5, 0, 1, 2, 3, 7, 6, 4.5], 4, 4)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "target: 42\n",
    "(25, 31, 46) 42\n",
    "(25, 46, 31) 42\n",
    "(31, 46, 31) 54\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "3"
    ]
    },
    "execution_count": 4,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "two_sample_permutation_test([25, 31, 46, 10, 19, 31], 0, 3)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 76,
    "metadata": {},
    "outputs": [],
    "source": [
    "def two_sample_permutation_test_two_sided(arr, start, size):\n",
    " target = arr[start : start + size]\n",
    " arr_counter = Counter(arr)\n",
    " target_diff = abs(sum(target) - sum((arr_counter - Counter(target)).elements()))\n",
    " print('target: ', target_diff)\n",
    " count = 0\n",
    " for curr in combinations(arr, size):\n",
    " curr_counter = Counter(curr)\n",
    " complement = list((arr_counter - curr_counter).elements())\n",
    " diff = abs(sum(curr) - sum(complement))\n",
    " if diff >= target_diff:\n",
    " print(curr, diff)\n",
    " count += 1\n",
    " return count"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 79,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "target: 3.299999999999999\n",
    "(3.4, 2.8, 1.9, 2.6) 3.299999999999999\n",
    "(3.4, 2.8, 2.6, 2.4) 4.299999999999999\n",
    "(3.4, 2.8, 2.6, 2.1) 3.6999999999999993\n",
    "(3.4, 2.8, 2.4, 2.1) 3.299999999999999\n",
    "(1.9, 2.6, 1.4, 1.5) 3.299999999999999\n",
    "(1.9, 1.4, 2.4, 1.5) 3.6999999999999993\n",
    "(1.9, 1.4, 2.1, 1.5) 4.299999999999999\n",
    "(1.4, 2.4, 2.1, 1.5) 3.299999999999999\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "8"
    ]
    },
    "execution_count": 79,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "two_sample_permutation_test_two_sided([3.4, 2.8, 1.9, 2.6, 1.4, 2.4, 2.1, 1.5], 4, 4)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 80,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "target: 12\n",
    "(8, 7, 6, 3) 12\n",
    "(8, 7, 6, 5) 16\n",
    "(8, 7, 6, 4) 14\n",
    "(8, 7, 5, 4) 12\n",
    "(6, 3, 2, 1) 12\n",
    "(3, 5, 2, 1) 14\n",
    "(3, 4, 2, 1) 16\n",
    "(5, 4, 2, 1) 12\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "8"
    ]
    },
    "execution_count": 80,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "two_sample_permutation_test_two_sided([8, 7, 6, 3, 5, 4, 2, 1], 4, 4)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
    "def matched_pair_permutation_test(arr):\n",
    " count = 0\n",
    " neg_sum = sum([elm for elm in arr if elm < 0])\n",
    " if neg_sum < 0: \n",
    " # at least one entry negative, should count the all positive array\n",
    " print('[], 0')\n",
    " count += 1\n",
    " arr = sorted([-abs(elm) for elm in arr], reverse=True)\n",
    " \n",
    " for i in range(1, len(arr)):\n",
    " for nums in combinations(arr, i):\n",
    " curr_sum = sum(nums)\n",
    " if curr_sum >= neg_sum:\n",
    " print(nums, curr_sum)\n",
    " count += 1\n",
    " return count"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 91,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "(-1,) -1\n",
    "(-2,) -2\n",
    "(-3,) -3\n",
    "(-4,) -4\n",
    "(-1, -2) -3\n",
    "(-1, -3) -4\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "7"
    ]
    },
    "execution_count": 91,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "arr = [-1, 6, 4, 6, 2, -3, 5]\n",
    "matched_pair_permutation_test(arr)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 83,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "[-1.0] -1.0\n",
    "[-2.0] -2.0\n",
    "[-3.0] -3.0\n",
    "[-4.0] -4.0\n",
    "[-1.0, -2.0] -3.0\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "6"
    ]
    },
    "execution_count": 83,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "matched_pair_permutation_test([-1.0, 6.5, 4.0, 6.5, 2.0, -3.0, 5.0])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 41,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "[-0.25] -0.25\n",
    "[-0.33] -0.33\n",
    "[-0.25, -0.33] -0.5800000000000001\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "4"
    ]
    },
    "execution_count": 41,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "matched_pair_permutation_test([1.85, -0.25, 0.88, 1.46, 1.05, 1.67, 1.74, -0.33])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 42,
    "metadata": {
    "scrolled": false
    },
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "[-1] -1\n",
    "[-2] -2\n",
    "[-3] -3\n",
    "[-1, -2] -3\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "5"
    ]
    },
    "execution_count": 42,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "matched_pair_permutation_test([-1, -2, 3, 4, 5, 6, 7, 8])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 6,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "(-1,) -1\n",
    "(-1,) -1\n",
    "(-1,) -1\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "4"
    ]
    },
    "execution_count": 6,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "matched_pair_permutation_test([1, 6, 1, 7, -1, 2, 8])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[], 0\n",
    "(-2,) -2\n",
    "(-2,) -2\n",
    "(-2,) -2\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "4"
    ]
    },
    "execution_count": 7,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "matched_pair_permutation_test([2, 5, 2, 6, -2, 4, 7])"
    ]
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.8.0"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 2
    }