Skip to content

Instantly share code, notes, and snippets.

@malte-j
Created December 4, 2023 10:35
Show Gist options
  • Save malte-j/5d846a92159f00f83a1d7db69adaf68a to your computer and use it in GitHub Desktop.
Save malte-j/5d846a92159f00f83a1d7db69adaf68a to your computer and use it in GitHub Desktop.

Revisions

  1. malte-j created this gist Dec 4, 2023.
    71 changes: 71 additions & 0 deletions ts.dart
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,71 @@
    import 'dart:math';

    class ThompsonSampling {
    final List<double> means;
    final List<double> variances;

    ThompsonSampling(List<double> initialMeans, List<double> initialVariances)
    : means = List.from(initialMeans),
    variances = List.from(initialVariances);

    void updateObservations(int armIndex, double newObservation) {
    // Update mean and variance based on new observation
    final double oldMean = means[armIndex];
    final double oldVariance = variances[armIndex];

    // Update mean and variance using online update formulas
    final double newMean = (oldMean + newObservation) / 2;
    final double newVariance =
    (oldVariance + pow(newObservation - oldMean, 2)) / 2;

    means[armIndex] = newMean;
    variances[armIndex] = newVariance;
    }

    int selectArm() {
    // Number of arms (options)
    final int numArms = means.length;

    // Perform Thompson Sampling for each arm
    final List<double> samples = List.generate(numArms, (index) {
    // Generate a random sample for each arm using the Normal distribution
    final double sample = Random().nextDouble();

    // Calculate the sampled value from the Normal distribution
    return means[index] + sqrt(variances[index]) * cos(2 * pi * sample);
    });

    // Choose the arm with the highest sampled value
    final int selectedArm = samples.indexOf(samples.reduce(max));

    return selectedArm;
    }
    }

    void main() {
    // Example usage
    final List<double> initialMeans = [
    1.0,
    1.0,
    1.0,
    ]; // Initial mean for each arm
    final List<double> initialVariances = [
    2.0,
    2.0,
    2.0,
    ]; // Initial variance for each arm

    // Create Thompson Sampling instance
    final ThompsonSampling thompsonSampling =
    ThompsonSampling(initialMeans, initialVariances);

    // // Simulate new observations (adjust to new data)
    thompsonSampling.updateObservations(0, 11.0);
    // thompsonSampling.updateObservations(1, 10.0);
    // thompsonSampling.updateObservations(2, 3.0);

    // Get the index of the selected arm using Thompson Sampling
    final int selectedArm = thompsonSampling.selectArm();

    print("Selected Arm: $selectedArm");
    }