-
-
Save ml-edu/522d0ee4b1d50fb772f5a0d6264a8118 to your computer and use it in GitHub Desktop.
Class which samples from a multinomial distribution in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import partition_tree | |
| class Multinomial_Sampler(object): | |
| def __init__(self, probabilities, event_names): | |
| intervals = self._build_intervals_from_probabilities(probabilities) | |
| self.tree = partition_tree.PartitionTree(intervals, event_names) | |
| def _build_intervals_from_probabilities(self, probabilities): | |
| if sum(probabilities) != 1.0: | |
| raise Exception | |
| intervals = [] | |
| left_side = 0.0 | |
| for p in probabilities: | |
| intervals.append((left_side, left_side+p)) | |
| left_side += p | |
| return intervals | |
| def sample(self): | |
| random_0_1 = random.random() | |
| return self.tree.get_label(random_0_1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class PartitionTreeNode(object): | |
| def __init__(self, left=None, right=None, interval=None): | |
| self.left = left | |
| self.right = right | |
| self.interval = interval | |
| class PartitionTree(object): | |
| def __init__(self, intervals, labels): | |
| self.mapping = {} | |
| self.root = PartitionTreeNode() | |
| for interval, label in zip(intervals, labels): | |
| self._add_interval(interval, self.root) | |
| self.mapping[interval] = label | |
| def _add_interval(self, interval, node): | |
| if not node.interval: | |
| node.interval = interval | |
| node.left = PartitionTreeNode() | |
| node.right = PartitionTreeNode() | |
| elif interval[1] <= node.interval[0]: | |
| self._add_interval(interval, node.left) | |
| elif interval[0] >= node.interval[1]: | |
| self._add_interval(interval, node.right) | |
| else: | |
| raise Exception | |
| def get_label(self, number): | |
| interval = self._get_interval(number, self.root) | |
| return self.mapping[interval] | |
| def _get_interval(self, number, node): | |
| left_bound, right_bound = node.interval | |
| if number < left_bound: | |
| return self._get_interval(number, node.left) | |
| elif number > right_bound: | |
| return self._get_interval(number, node.right) | |
| else: | |
| return node.interval |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import unittest | |
| import partition_tree | |
| import sampler | |
| class PartitionTreeTrest(unittest.TestCase): | |
| def test_partition_tree(self): | |
| tree = partition_tree.PartitionTree([(0.0,0.5),(0.5,1.0)],['A', 'B']) | |
| values = [0.0, 0.3, 0.5, 0.7, 1.0] | |
| labels = [tree.get_label(v) for v in values] | |
| correct_labels = ['A', 'A', 'A', 'B', 'B'] | |
| assert labels == correct_labels | |
| class MultinomialSampleTest(unittest.TestCase): | |
| def test_biased_coin_flip(self): | |
| true_heads, true_tails = 0.3, 0.7 | |
| P = [true_heads, true_tails] | |
| event_names = ['Heads', 'Tails'] | |
| s = sampler.Multinomial_Sampler(P, event_names) | |
| from collections import Counter | |
| total_samples = 400000 | |
| sample_counter = Counter([s.sample() for i in range(total_samples)]) | |
| allowed_error = 0.001 | |
| head_frequency = 1.0*sample_counter['Heads']/total_samples | |
| tail_frequency = 1.0*sample_counter['Tails']/total_samples | |
| print(head_frequency, tail_frequency) | |
| assert true_heads - allowed_error <= head_frequency <= true_heads + allowed_error | |
| assert true_tails - allowed_error <= tail_frequency <= true_tails + allowed_error |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment