Skip to content

Instantly share code, notes, and snippets.

@ml-edu
Forked from lmc2179/multinomial_sample.py
Created October 13, 2022 23:51
Show Gist options
  • Save ml-edu/522d0ee4b1d50fb772f5a0d6264a8118 to your computer and use it in GitHub Desktop.
Save ml-edu/522d0ee4b1d50fb772f5a0d6264a8118 to your computer and use it in GitHub Desktop.
Class which samples from a multinomial distribution in python
import random
import partition_tree
class Multinomial_Sampler(object):
def __init__(self, probabilities, event_names):
intervals = self._build_intervals_from_probabilities(probabilities)
self.tree = partition_tree.PartitionTree(intervals, event_names)
def _build_intervals_from_probabilities(self, probabilities):
if sum(probabilities) != 1.0:
raise Exception
intervals = []
left_side = 0.0
for p in probabilities:
intervals.append((left_side, left_side+p))
left_side += p
return intervals
def sample(self):
random_0_1 = random.random()
return self.tree.get_label(random_0_1)
class PartitionTreeNode(object):
def __init__(self, left=None, right=None, interval=None):
self.left = left
self.right = right
self.interval = interval
class PartitionTree(object):
def __init__(self, intervals, labels):
self.mapping = {}
self.root = PartitionTreeNode()
for interval, label in zip(intervals, labels):
self._add_interval(interval, self.root)
self.mapping[interval] = label
def _add_interval(self, interval, node):
if not node.interval:
node.interval = interval
node.left = PartitionTreeNode()
node.right = PartitionTreeNode()
elif interval[1] <= node.interval[0]:
self._add_interval(interval, node.left)
elif interval[0] >= node.interval[1]:
self._add_interval(interval, node.right)
else:
raise Exception
def get_label(self, number):
interval = self._get_interval(number, self.root)
return self.mapping[interval]
def _get_interval(self, number, node):
left_bound, right_bound = node.interval
if number < left_bound:
return self._get_interval(number, node.left)
elif number > right_bound:
return self._get_interval(number, node.right)
else:
return node.interval
import unittest
import partition_tree
import sampler
class PartitionTreeTrest(unittest.TestCase):
def test_partition_tree(self):
tree = partition_tree.PartitionTree([(0.0,0.5),(0.5,1.0)],['A', 'B'])
values = [0.0, 0.3, 0.5, 0.7, 1.0]
labels = [tree.get_label(v) for v in values]
correct_labels = ['A', 'A', 'A', 'B', 'B']
assert labels == correct_labels
class MultinomialSampleTest(unittest.TestCase):
def test_biased_coin_flip(self):
true_heads, true_tails = 0.3, 0.7
P = [true_heads, true_tails]
event_names = ['Heads', 'Tails']
s = sampler.Multinomial_Sampler(P, event_names)
from collections import Counter
total_samples = 400000
sample_counter = Counter([s.sample() for i in range(total_samples)])
allowed_error = 0.001
head_frequency = 1.0*sample_counter['Heads']/total_samples
tail_frequency = 1.0*sample_counter['Tails']/total_samples
print(head_frequency, tail_frequency)
assert true_heads - allowed_error <= head_frequency <= true_heads + allowed_error
assert true_tails - allowed_error <= tail_frequency <= true_tails + allowed_error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment