Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save leoossa/c255e21c7ed0e78c7175deb373050f2a to your computer and use it in GitHub Desktop.

Select an option

Save leoossa/c255e21c7ed0e78c7175deb373050f2a to your computer and use it in GitHub Desktop.

Revisions

  1. @sorrge sorrge renamed this gist May 30, 2015. 1 changed file with 0 additions and 0 deletions.
  2. @sorrge sorrge revised this gist May 30, 2015. 1 changed file with 0 additions and 42 deletions.
    42 changes: 0 additions & 42 deletions Vector in Idris
    Original file line number Diff line number Diff line change
    @@ -1,42 +0,0 @@
    data Vec : (n : Nat) -> (t : Type) -> Type where
    Nil : Vec Z t
    (::) : t -> Vec n t -> Vec (S n) t

    infixr 6 .++
    (.++) : Vec n t -> Vec m t -> Vec (n + m) t
    [] .++ b = b
    (.++) {n = S n} {m = m} (x :: xs) y = rewrite plusSuccRightSucc n m in xs .++ (x :: y)

    infixl 7 .*
    (.*) : Num t => Vec n t -> Vec n t -> t
    (.*) [] [] = 0
    (.*) (x :: xs) (y :: ys) = (x * y) + (xs .* ys)

    natVec : (n : Nat) -> Vec n Int
    natVec Z = Nil
    natVec (S n) = (cast n) :: natVec n

    main : IO ()
    main = do
    n <- fromIntegerNat . cast `map` getLine
    m <- fromIntegerNat . cast `map` getLine

    let a = natVec n
    let b = natVec m

    printLn $ a .* a

    --printLn $ a .* b
    --Error: cannot unify n with m

    printLn $ (a .++ b) .* (a .++ b)
    --printLn $ (a .++ b) .* (b .++ a)
    --Error: cannot unify (n + m) with (m + n)

    -- printLn $ (a .++ b) .* (rewrite plusCommutative n m in b .++ a) -- works, but gives a wrong result

    --printLn $ (a .++ b .++ a .++ a .++ b) .* (rewrite plusCommutative n m in (b .++ a .++ a .++ a .++ b))

    case n `decEq` m of
    Yes eq => printLn $ a .* (rewrite eq in b)
    _ => return ()
  3. @sorrge sorrge created this gist May 30, 2015.
    42 changes: 42 additions & 0 deletions Vector in Idris
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,42 @@
    data Vec : (n : Nat) -> (t : Type) -> Type where
    Nil : Vec Z t
    (::) : t -> Vec n t -> Vec (S n) t

    infixr 6 .++
    (.++) : Vec n t -> Vec m t -> Vec (n + m) t
    [] .++ b = b
    (.++) {n = S n} {m = m} (x :: xs) y = rewrite plusSuccRightSucc n m in xs .++ (x :: y)

    infixl 7 .*
    (.*) : Num t => Vec n t -> Vec n t -> t
    (.*) [] [] = 0
    (.*) (x :: xs) (y :: ys) = (x * y) + (xs .* ys)

    natVec : (n : Nat) -> Vec n Int
    natVec Z = Nil
    natVec (S n) = (cast n) :: natVec n

    main : IO ()
    main = do
    n <- fromIntegerNat . cast `map` getLine
    m <- fromIntegerNat . cast `map` getLine

    let a = natVec n
    let b = natVec m

    printLn $ a .* a

    --printLn $ a .* b
    --Error: cannot unify n with m

    printLn $ (a .++ b) .* (a .++ b)
    --printLn $ (a .++ b) .* (b .++ a)
    --Error: cannot unify (n + m) with (m + n)

    -- printLn $ (a .++ b) .* (rewrite plusCommutative n m in b .++ a) -- works, but gives a wrong result

    --printLn $ (a .++ b .++ a .++ a .++ b) .* (rewrite plusCommutative n m in (b .++ a .++ a .++ a .++ b))

    case n `decEq` m of
    Yes eq => printLn $ a .* (rewrite eq in b)
    _ => return ()
    233 changes: 233 additions & 0 deletions Vector whose size in an expression
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,233 @@
    #include <iostream>

    using namespace std;


    // TypeValue is a type and a value at the same time.
    // Each ID is a type, and there can be only one value of this type
    // This prevents (only in runtime) construction of two vectors with the same Length variable, but of different actual lengths
    template<int ID>
    class TypeValue
    {
    static int value;
    static bool initialized;

    public:
    TypeValue(int v)
    {
    if (initialized)
    throw "Tried to reinitialize a TypeValue";

    initialized = true;
    value = v;
    }

    static int Eval() { return value; }
    };

    template<int ID>
    bool TypeValue<ID>::initialized = false;

    template<int ID>
    int TypeValue<ID>::value;

    // the presense of this class in the context means that TypeExpr1 is proven to be equal to TypeExpr2
    template<class TypeExpr1, class TypeExpr2>
    class TypesAreEqual;

    // static TypeExpr equality check using a context
    template<class TV1, class TV2, class Context>
    struct TypeExprEq : false_type{};

    template<int ID, class Context>
    struct TypeExprEq<TypeValue<ID>, TypeValue<ID>, Context> : true_type{};

    template<class TV1, class TV2>
    struct TypeExprEq<TV1, TV2, TypesAreEqual<TV1, TV2>> : true_type{};

    template<class TV1, class TV2>
    struct TypeExprEq<TV1, TV2, TypesAreEqual<TV2, TV1>> : true_type{};

    template<class E1, class E2>
    class Plus
    {
    public:
    Plus(E1, E2) {} // The constructor takes values as the proof that they have been created and initialized

    static int Eval() { return E1::Eval() + E2::Eval(); }
    };

    // plusCommutative
    template<class TV1, class TV2, class TV3, class TV4, class Context>
    struct TypeExprEq<Plus<TV1, TV2>, Plus<TV3, TV4>, Context>
    {
    static const bool value = TypeExprEq<TV1, TV3, Context>::value && TypeExprEq<TV2, TV4, Context>::value ||
    TypeExprEq<TV1, TV4, Context>::value && TypeExprEq<TV2, TV3, Context>::value;
    };


    // runtime TypeExpr equality check
    template<class TypeExpr1, class TypeExpr2>
    class TypeExprEqRT
    {
    public:
    TypeExprEqRT(TypeExpr1, TypeExpr2) {}

    static bool Eval()
    {
    return TypeExpr1::Eval() == TypeExpr2::Eval();
    }

    typedef TypesAreEqual<TypeExpr1, TypeExpr2> ContextIfTrue;
    };


    template<int ID1, int ID2>
    TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>> operator==(TypeValue<ID1> tv1, TypeValue<ID2> tv2)
    {
    return TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>>(tv1, tv2);
    }

    template<class Condition, class Operation>
    void If(Condition, Operation oper)
    {
    if (Condition::Eval())
    oper.Eval<Condition::ContextIfTrue>();
    }


    // Utility function which passes the context
    template<class Func>
    class PrintLN
    {
    Func f;

    public:
    PrintLN(Func _f) : f(_f) {}

    template<class Context>
    void Eval()
    {
    cout << f.Eval<Context>() << endl;
    }
    };

    template<class Func>
    PrintLN<Func> printLn(Func f)
    {
    return PrintLN<Func>(f);
    }


    // The vector class. Length is a TypeExpr
    template<class ElemType, class Length>
    class Vec
    {
    Length length;
    ElemType *data;

    public:
    Vec(Length l) : length(l)
    {
    data = new ElemType[Length::Eval()];
    }

    Vec(const Vec<ElemType, Length>& v) : length(v.length)
    {
    data = new ElemType[Length::Eval()];
    for (int i = 0; i < Length::Eval(); ++i)
    data[i] = v.data[i];
    }

    virtual ~Vec()
    {
    delete[] data;
    }

    ElemType& operator[](int idx) { return data[idx]; }
    Length Len() { return length; }
    };


    // Vector operations
    template<class Length>
    Vec<unsigned, Length> NatVec(Length l)
    {
    Vec<unsigned, Length> res(l);
    for (int i = 0; i < Length::Eval(); ++i)
    res[i] = i;

    return res;
    }

    template<class L1, class L2, class T>
    class DotProduct
    {
    Vec<T, L1> v1;
    Vec<T, L2> v2;

    public:
    DotProduct(Vec<T, L1> _v1, Vec<T, L2> _v2) : v1(_v1), v2(_v2) {}

    template<class Context>
    T Eval()
    {
    static_assert(TypeExprEq<L1, L2, Context>::value, "Can't prove that vectors have the same length");
    T acc = {};
    for (int i = 0; i < L1::Eval(); ++i)
    acc += v1[i] * v2[i];

    return acc;
    }

    operator T()
    {
    return Eval<void>();
    }
    };

    template<class L1, class L2, class T>
    DotProduct<L1, L2, T> operator*(Vec<T, L1> v1, Vec<T, L2> v2)
    {
    return DotProduct<L1, L2, T>(v1, v2);
    }

    template<class L1, class L2, class T>
    Vec<T, Plus<L1, L2>> operator+(Vec<T, L1> v1, Vec<T, L2> v2)
    {
    Vec<T, Plus<L1, L2>> res(Plus<L1, L2>(v1.Len(), v2.Len()));
    for (int i = 0; i < L1::Eval(); ++i)
    res[i] = v1[i];

    for (int i = 0; i < L2::Eval(); ++i)
    res[i + L1::Eval()] = v2[i];

    return res;
    }


    int main(int argc, char* argv[])
    {
    int nn, mm;
    cin >> nn >> mm;
    TypeValue<1> n(nn);
    TypeValue<2> m(mm);

    auto a = NatVec(n);
    auto b = NatVec(m);

    cout << a * a << endl;
    // cout << a * b << endl; // Error

    cout << (a + b) * (a + b) << endl;
    cout << (a + b) * (b + a) << endl; // OK if plusCommutative is present
    cout << (a + b + a + a + b) * (b + a + a + a + b) << endl; // OK if plusCommutative is present
    // cout << (a + b + a + b + a) * (b + a + a + a + b) << endl; // Error: associativity of + is not implemented

    If(n == m, printLn(a * b));
    If(n == m, printLn((a + a) * (b + a)));
    // If(n == m, printLn((a + a) * (b + a + a))); // Error

    return 0;
    }