Forked from sorrge/Vector whose size in an expression in C++
Created
May 2, 2025 15:19
-
-
Save leoossa/c255e21c7ed0e78c7175deb373050f2a to your computer and use it in GitHub Desktop.
Revisions
-
sorrge renamed this gist
May 30, 2015 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
sorrge revised this gist
May 30, 2015 . 1 changed file with 0 additions and 42 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,42 +0,0 @@ -
sorrge created this gist
May 30, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,42 @@ data Vec : (n : Nat) -> (t : Type) -> Type where Nil : Vec Z t (::) : t -> Vec n t -> Vec (S n) t infixr 6 .++ (.++) : Vec n t -> Vec m t -> Vec (n + m) t [] .++ b = b (.++) {n = S n} {m = m} (x :: xs) y = rewrite plusSuccRightSucc n m in xs .++ (x :: y) infixl 7 .* (.*) : Num t => Vec n t -> Vec n t -> t (.*) [] [] = 0 (.*) (x :: xs) (y :: ys) = (x * y) + (xs .* ys) natVec : (n : Nat) -> Vec n Int natVec Z = Nil natVec (S n) = (cast n) :: natVec n main : IO () main = do n <- fromIntegerNat . cast `map` getLine m <- fromIntegerNat . cast `map` getLine let a = natVec n let b = natVec m printLn $ a .* a --printLn $ a .* b --Error: cannot unify n with m printLn $ (a .++ b) .* (a .++ b) --printLn $ (a .++ b) .* (b .++ a) --Error: cannot unify (n + m) with (m + n) -- printLn $ (a .++ b) .* (rewrite plusCommutative n m in b .++ a) -- works, but gives a wrong result --printLn $ (a .++ b .++ a .++ a .++ b) .* (rewrite plusCommutative n m in (b .++ a .++ a .++ a .++ b)) case n `decEq` m of Yes eq => printLn $ a .* (rewrite eq in b) _ => return () This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,233 @@ #include <iostream> using namespace std; // TypeValue is a type and a value at the same time. // Each ID is a type, and there can be only one value of this type // This prevents (only in runtime) construction of two vectors with the same Length variable, but of different actual lengths template<int ID> class TypeValue { static int value; static bool initialized; public: TypeValue(int v) { if (initialized) throw "Tried to reinitialize a TypeValue"; initialized = true; value = v; } static int Eval() { return value; } }; template<int ID> bool TypeValue<ID>::initialized = false; template<int ID> int TypeValue<ID>::value; // the presense of this class in the context means that TypeExpr1 is proven to be equal to TypeExpr2 template<class TypeExpr1, class TypeExpr2> class TypesAreEqual; // static TypeExpr equality check using a context template<class TV1, class TV2, class Context> struct TypeExprEq : false_type{}; template<int ID, class Context> struct TypeExprEq<TypeValue<ID>, TypeValue<ID>, Context> : true_type{}; template<class TV1, class TV2> struct TypeExprEq<TV1, TV2, TypesAreEqual<TV1, TV2>> : true_type{}; template<class TV1, class TV2> struct TypeExprEq<TV1, TV2, TypesAreEqual<TV2, TV1>> : true_type{}; template<class E1, class E2> class Plus { public: Plus(E1, E2) {} // The constructor takes values as the proof that they have been created and initialized static int Eval() { return E1::Eval() + E2::Eval(); } }; // plusCommutative template<class TV1, class TV2, class TV3, class TV4, class Context> struct TypeExprEq<Plus<TV1, TV2>, Plus<TV3, TV4>, Context> { static const bool value = TypeExprEq<TV1, TV3, Context>::value && TypeExprEq<TV2, TV4, Context>::value || TypeExprEq<TV1, TV4, Context>::value && TypeExprEq<TV2, TV3, Context>::value; }; // runtime TypeExpr equality check template<class TypeExpr1, class TypeExpr2> class TypeExprEqRT { public: TypeExprEqRT(TypeExpr1, TypeExpr2) {} static bool Eval() { return TypeExpr1::Eval() == TypeExpr2::Eval(); } typedef TypesAreEqual<TypeExpr1, TypeExpr2> ContextIfTrue; }; template<int ID1, int ID2> TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>> operator==(TypeValue<ID1> tv1, TypeValue<ID2> tv2) { return TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>>(tv1, tv2); } template<class Condition, class Operation> void If(Condition, Operation oper) { if (Condition::Eval()) oper.Eval<Condition::ContextIfTrue>(); } // Utility function which passes the context template<class Func> class PrintLN { Func f; public: PrintLN(Func _f) : f(_f) {} template<class Context> void Eval() { cout << f.Eval<Context>() << endl; } }; template<class Func> PrintLN<Func> printLn(Func f) { return PrintLN<Func>(f); } // The vector class. Length is a TypeExpr template<class ElemType, class Length> class Vec { Length length; ElemType *data; public: Vec(Length l) : length(l) { data = new ElemType[Length::Eval()]; } Vec(const Vec<ElemType, Length>& v) : length(v.length) { data = new ElemType[Length::Eval()]; for (int i = 0; i < Length::Eval(); ++i) data[i] = v.data[i]; } virtual ~Vec() { delete[] data; } ElemType& operator[](int idx) { return data[idx]; } Length Len() { return length; } }; // Vector operations template<class Length> Vec<unsigned, Length> NatVec(Length l) { Vec<unsigned, Length> res(l); for (int i = 0; i < Length::Eval(); ++i) res[i] = i; return res; } template<class L1, class L2, class T> class DotProduct { Vec<T, L1> v1; Vec<T, L2> v2; public: DotProduct(Vec<T, L1> _v1, Vec<T, L2> _v2) : v1(_v1), v2(_v2) {} template<class Context> T Eval() { static_assert(TypeExprEq<L1, L2, Context>::value, "Can't prove that vectors have the same length"); T acc = {}; for (int i = 0; i < L1::Eval(); ++i) acc += v1[i] * v2[i]; return acc; } operator T() { return Eval<void>(); } }; template<class L1, class L2, class T> DotProduct<L1, L2, T> operator*(Vec<T, L1> v1, Vec<T, L2> v2) { return DotProduct<L1, L2, T>(v1, v2); } template<class L1, class L2, class T> Vec<T, Plus<L1, L2>> operator+(Vec<T, L1> v1, Vec<T, L2> v2) { Vec<T, Plus<L1, L2>> res(Plus<L1, L2>(v1.Len(), v2.Len())); for (int i = 0; i < L1::Eval(); ++i) res[i] = v1[i]; for (int i = 0; i < L2::Eval(); ++i) res[i + L1::Eval()] = v2[i]; return res; } int main(int argc, char* argv[]) { int nn, mm; cin >> nn >> mm; TypeValue<1> n(nn); TypeValue<2> m(mm); auto a = NatVec(n); auto b = NatVec(m); cout << a * a << endl; // cout << a * b << endl; // Error cout << (a + b) * (a + b) << endl; cout << (a + b) * (b + a) << endl; // OK if plusCommutative is present cout << (a + b + a + a + b) * (b + a + a + a + b) << endl; // OK if plusCommutative is present // cout << (a + b + a + b + a) * (b + a + a + a + b) << endl; // Error: associativity of + is not implemented If(n == m, printLn(a * b)); If(n == m, printLn((a + a) * (b + a))); // If(n == m, printLn((a + a) * (b + a + a))); // Error return 0; }