setmetatable(_ENV, { __index=lpeg }) VARS = {} function eval_expr(expr) local accum = eval(expr[2]) -- because 1 is "expr" for i = 3, #expr, 2 do local operator = expr[i] local num2 = eval(expr[i+1]) if operator == '+' then accum = accum + num2 elseif operator == '-' then accum = accum - num2 elseif operator == '*' then accum = accum * num2 elseif operator == '/' then accum = accum / num2 end end return accum end function eval_bool(expr) local num1 = eval(expr[2]) local operator = expr[3] local num2 = eval(expr[4]) if operator == '<' then return num1 < num2 elseif operator == '<=' then return num1 <= num2 elseif operator == '>' then return num1 > num2 elseif operator == '>=' then return num1 >= num2 elseif operator == '==' then return num1 == num2 elseif operator == '!=' then return num1 ~= num2 end end function eval(ast) if type(ast) == 'number' then return ast elseif ast[1] == 'expr' or ast[1] == 'term' then return eval_expr(ast) elseif ast[1] == 'array' then local new = {} for _, el in ipairs(ast[2]) do table.insert(new, eval(el)) end return new elseif ast[1] == 'ref' then return lookup(ast) elseif ast[1] == 'assign' then return assign(ast[2], eval(ast[3])) elseif ast[1] == 'list' then for i = 2, #ast do eval(ast[i]) end elseif ast[1] == 'if' then if eval_bool(ast[2]) then return eval(ast[3]) end elseif ast[1] == 'while' then while eval_bool(ast[2]) do eval(ast[3]) end end end function assign(ref, value) local current = VARS for i = 2, #ref do local next_index = ref[i] if type(next_index) == 'table' then next_index = eval(next_index) end if i == #ref then -- last one, set the value current[next_index] = value return value else -- not the last, keep following the chain current = current[next_index] end end end function lookup(ref) local current = VARS for i = 2, #ref do local next_index = ref[i] if type(next_index) == 'table' then next_index = eval(next_index) end current = current[next_index] end return current end spc = S(" \t\n")^0 digit = R('09') number = C( (P("-") + digit) * digit^0 * ( P('.') * digit^0 )^-1 ) / tonumber * spc lparen = "(" * spc rparen = ")" * spc lbrack = "[" * spc rbrack = "]" * spc lcurly = "{" * spc rcurly = "}" * spc comma = "," * spc expr_op = C( S('+-') ) * spc term_op = C( S('*/') ) * spc letter = R('AZ','az') name = C( letter * (digit+letter+"_")^0 ) * spc keywords = (P("if")+P("while")) * spc name = name - keywords boolean = C( S("<>") + "<=" + ">=" + "!=" + "==" ) * spc stmt = spc * P{ "LIST"; LIST = V("STMT") + Ct( Cc("list") * lcurly * V("STMT") * ( ";" * spc * V("STMT") )^0 * rcurly ), STMT = Ct( Cc("assign") * V("REF") * "=" * spc * V("VAL") ) + V("EXPR") + V("IF") + V("WHILE"), EXPR = Ct( Cc("expr") * V("TERM") * ( expr_op * V("TERM") )^0 ), TERM = Ct( Cc("term") * V("FACT") * ( term_op * V("FACT") )^0 ), REF = Ct( Cc("ref") * name * (lbrack * V("EXPR") * rbrack)^0 ), FACT = number + lparen * V("EXPR") * rparen + V("REF"), ARRAY = Ct( Cc("array") * lbrack * Ct( V("VAL_LIST")^-1 ) * rbrack ), VAL_LIST = V("VAL") * (comma * V("VAL"))^0, VAL = V("EXPR") + V("ARRAY"), BOOL = Ct( Cc("bool") * V("EXPR") * boolean * V("EXPR") ), IF = Ct( C("if") * spc * lparen * V("BOOL") * rparen * V("LIST") ), WHILE = Ct( C("while") * spc * lparen * V("BOOL") * rparen * V("LIST") ) } function test(stmt) stmt = stmt / eval assert(stmt:match(" 1 + 2 ") == 3) assert(stmt:match("1+2+3+4+5") == 15) assert(stmt:match("2*3*4 + 5*6*7") == 234) assert(stmt:match(" 1 * 2 + 3") == 5) assert(stmt:match("( 2 +2) *6") == 24) stmt:match("a=3"); assert(VARS.a == 3) assert(stmt:match("a") == 3) assert(stmt:match("a * 5") == 15); VARS.a=nil stmt:match("a = [ 4, 5, 6 ]"); assert(VARS.a[1] == 4) assert(VARS.a[2] == 5) assert(VARS.a[3] == 6) VARS.a=nil stmt:match("b = [ ]"); assert(VARS.b[1] == nil) VARS.b=nil stmt:match("c = [[1,2], [3,4]]") assert(VARS.c[1][1] == 1) assert(VARS.c[1][2] == 2) assert(VARS.c[2][1] == 3) assert(VARS.c[2][2] == 4) assert(stmt:match("c[4/2][1]") == 3) stmt:match("c[3] = 5") assert(VARS.c[3] == 5) VARS.c=nil stmt:match("if(1 < 0) b = 5"); assert(VARS.b ~= 5) VARS.n=0; VARS.x=1 stmt:match("while(n < 8) { x = x * 2; n = n + 1 }") assert(VARS.x == 256) VARS.n=nil; VARS.x=nil end function repl(file) file = file or io.input() parser = stmt for line in file:lines() do print(parser:match(line)) end end