package square;
 
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.Random;
 
/**
 * A suggested solution to the puzzle by Wouter Coekaerts from Square. 
 * The puzzle can be found here
 * 
 * @author Eran Medan  
 */
public class Solution {
 
  private static long MIN_BENCH_DURATION = 5000000000L; // in nanoseconds
 
  // a copy of the test class, for benchmarking purposes
  public static class SquareRoot2 {
    public static final int  BITS = SquareRoot.BITS;
    // this is our copy, so we can make it public :)
    public static BigInteger n    = new BigInteger(BITS, new SecureRandom()).pow(2);
 
    public static void answer(BigInteger root) {
      if (n.divide(root).equals(root)) {
        System.out.println("Square root!");
      }
    }
  }
 
  public static void main(String[] args) {
 
    //take a number that was generated the same way the secret one was for benchmarking 
    BigInteger ourNumber = SquareRoot2.n;
 
    // a number slightly below our "lab test" number
    BigInteger ourNumberMinusSome = ourNumber.subtract(BigInteger.valueOf(new Random().nextInt(100)));
    // a number slightly above our "lab test" number (benchmark is actually
    // faster, might be surprising, but it's because BigInteger first checks if the dividend less than divisor) 
    BigInteger ourNumberPlusSome = ourNumber.add(BigInteger.valueOf(new Random().nextInt(100)));
    System.out.println("Benchmarking");
    // do a benchmark of the lower number, this should take more time as it has
    // to really do the division + equals without shortcuts
    double minusSomeTime = benchmarkSample(ourNumberMinusSome);
    // do a benchmark of the higher number, this should take less time as as if first does a compare and skips the division)
    double plusSomeTime = benchmarkSample(ourNumberPlusSome);
 
    // the maximum number for n: (2^10000-1)^2, used as the upper bound for the
    // "guess the number" search
    BigInteger upperBound = BigInteger.valueOf(2).pow(SquareRoot.BITS).subtract(BigInteger.ONE).pow(2);
    // the starting lower bound is 0
    BigInteger lowerBound = BigInteger.valueOf(0);
 
    long lastGCTime = System.currentTimeMillis();
 
    while (true) {
      
      //if we ran a bit, let's do some garbange collection and let the system do some resting in order to lessen the chance of benchmark glitches 
      long timeSinceLastGC = System.currentTimeMillis() - lastGCTime;
      if (timeSinceLastGC > 10000) {
        System.out
            .println("garbage collecting, and letting the system rest a little");
        System.gc();
        try {
          Thread.sleep(1000);
        } catch (InterruptedException e) {
          e.printStackTrace();
        }
        lastGCTime = System.currentTimeMillis();
      }
      // current guess is the middle point between lower and upper bounds
      BigInteger curGuess = upperBound.add(lowerBound).divide(
          BigInteger.valueOf(2));
      // used as a primitive "progress bar"
      BigInteger upperMinusLower = upperBound.subtract(lowerBound);
      System.out.println("upperMinusLower: " + upperMinusLower);
      if (upperMinusLower.compareTo(BigInteger.ONE) != 1) {
        // when we are converging
        System.out.println("Starting to calculate square root");
        BigInteger sqrt = sqrt(curGuess);
        System.out.println("Calculated square root");
        SquareRoot.answer(sqrt);
 
        // in case we missed a little try to explore up and down a bit
        for (int i = 0; i <= 1000; i++) {
          BigInteger add = curGuess.add(BigInteger.valueOf(i));
          BigInteger sub = curGuess.subtract(BigInteger.valueOf(i));
          sqrt = sqrt(add);
          SquareRoot.answer(sqrt);
          sqrt = sqrt(sub);
          SquareRoot.answer(sqrt);
        }
        break;
      }
      // first benchmark to our current guess
      double time1 = benchmarkBigIntOperation(curGuess, new TestSubject() {
 
        @Override
        public void test(BigInteger guess) {
          SquareRoot.answer(guess);
        }
      }, true);
 
      // second benchmark to our current guess
      double time2 = benchmarkBigIntOperation(curGuess, new TestSubject() {
 
        @Override
        public void test(BigInteger guess) {
          SquareRoot.answer(guess);
        }
      }, true);
 
      // take the average of both benchmarks (not sure why it's better than
      // simply doing a longer benchmark, but it seems to have less errors)
      double averageTime = (time1 + time2) / 2;
      // find how far is our benchmark from each of the initial ones
      double diffLess = Math.abs(averageTime - minusSomeTime);
      double diffMore = Math.abs(averageTime - plusSomeTime);
 
      // filter out noisy benchmarks
 
      if (averageTime < plusSomeTime * 1.5) {
        // if our benchmark is about or less than the time it took to divide a
        // larger number (as dividing by a larger number is actually quick as it
        // only compares)
        // then our guess is larger than the secret number, then the upper bound
        // can be set to our current guess
        upperBound = curGuess;
      } else if (averageTime > minusSomeTime / 1.5) {
        // if our benchmark is larger than the time it took to divide a smaller
        // number (as dividing by a smaller number is actually slow as it needs
        // to actually do some work)
        // then our guess is smaller than the secret number, therefore the lower
        // bound can be set to our current guess
        if (diffLess < minusSomeTime * 0.5) {
          lowerBound = curGuess;
        } else {
          System.out.println("Not close enough, skipping");
        }
      } else {
        System.out.println("too risky");
      }
    }
  }
 
  /**
   * Benchmark a numbrer using our "lab" version of the SquareRoot class 
   * 
   * @param guess the number to benchmark against 
   * @return
   */
  private static double benchmarkSample(BigInteger guess) {
    double minusSomeTime = benchmarkBigIntOperation(guess, new TestSubject() {
 
      @Override
      public void test(BigInteger guess) {
        SquareRoot2.answer(guess);
      }
    }, false);
    return minusSomeTime;
  }
 
  /**
   * finds a square root of a BigInteger
   * 
   * Credit: based on a blog post from here: http://faruk.akgul.org/blog/javas-missing-algorithm-biginteger-sqrt/
   * 
   * @param n the number to find the square root for 
   * @return the square root
   */
  
  public static BigInteger sqrt(BigInteger n) {
    BigInteger a = BigInteger.ONE;
    BigInteger b = new BigInteger(n.shiftRight(5).add(new BigInteger("8"))
        .toString());
    while (b.compareTo(a) >= 0) {
      BigInteger mid = new BigInteger(a.add(b).shiftRight(1).toString());
      if (mid.multiply(mid).compareTo(n) > 0)
        b = mid.subtract(BigInteger.ONE);
      else
        a = mid.add(BigInteger.ONE);
    }
    return a.subtract(BigInteger.ONE);
  }
 
  /**
   * a helper interface for the benchmarkBigIntOperation 
   * used to pass a piece of code to be under benchmark against a BigInteger 
   */
  
  public static interface TestSubject {
    public void test(BigInteger guess);
  }
 
  /**
   * Benchmark an operation (divide in this case) on a BigInteger
   *  
   * Credit: based on https://github.com/tbuktu/bigint/blob/master/src/main/java/DivBenchmark.java
   *  
   * @param b the number used to test on, e.g. to execute testSubject.test(b)
   * @param testSubject the piece of code that needs to benchmark
   * @param fast if true, will perform a faster benchmark (less acurate though)
   * @return
   */
  private static double benchmarkBigIntOperation(BigInteger b, TestSubject testSubject, boolean fast) {
    long minBenchDuration = MIN_BENCH_DURATION;
    if (fast) {
      minBenchDuration = minBenchDuration / 100;
    }
 
    int numIterations = 0;
    long tStart = System.nanoTime();
 
    do {
      testSubject.test(b);
      numIterations++;
    } while (System.nanoTime() - tStart < minBenchDuration);
 
    b = new BigInteger(b.toByteArray());
    tStart = System.nanoTime();
    for (int i = 0; i < numIterations; i++)
      testSubject.test(b);
    long tEnd = System.nanoTime();
    long tNano = (tEnd - tStart + (numIterations + 1) / 2) / numIterations; 
    return tNano;
  }
}