From b505c021f29f6169973120162631a720d4a94659 Mon Sep 17 00:00:00 2001 From: Nathan Braswell Date: Sun, 19 Apr 2020 21:52:21 -0400 Subject: [PATCH] Implemented closures --- bf.kp | 9 +-- k_prime.krak | 152 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 148 insertions(+), 13 deletions(-) diff --git a/bf.kp b/bf.kp index 8dd5b3b..315ddb5 100644 --- a/bf.kp +++ b/bf.kp @@ -69,7 +69,8 @@ ;(def! other 12) ;(def! main (fn* [] (+ other 4))) -(def! fact (fn* [n] (if (<= n 1) 1 (* (fact (- n 1)) n)))) -(def! main (fn* [] (let* (to_ret (fact 5)) - (do (println to_ret) - to_ret)))) +;(def! fact (fn* [n] (if (<= n 1) 1 (* (fact (- n 1)) n)))) +;(def! main (fn* [] (let* (to_ret (fact 5)) (do (println to_ret) to_ret)))) + +(def! ret_with_call (fn* [n] (fn* [x] (+ n x)))) +(def! main (fn* [] ((ret_with_call 3) 5))) diff --git a/k_prime.krak b/k_prime.krak index 2ae5194..912e137 100644 --- a/k_prime.krak +++ b/k_prime.krak @@ -374,6 +374,20 @@ obj Env (Object) { return MalResult::Err(MalValue::String(str("'") + key + "' not found")) } } + fun to_string(): str { + var to_ret = str() + to_string(str("\t"), to_ret) + return to_ret + } + fun to_string(tabs: ref str, s: ref str) { + for (var i = 0; i < data.keys.size; i++;) { + /*s += tabs + data.keys[i] + ": " + data.values[i] + "\n"*/ + s += tabs + data.keys[i] + "\n" + } + if outer != null() { + outer->to_string(tabs + "\t", s) + } + } } obj MalBuiltinFunction (Object) { var fp: fun(vec): MalResult @@ -635,6 +649,9 @@ fun EVAL(env: *Env, ast: MalValue): MalResult { if (!is_symbol(l[1])) { return MalResult::Err(MalValue::String(str("def! not on symbol"))) } + if env->outer != null() { + return MalResult::Err(MalValue::String(str("def! not at top level"))) + } var value = EVAL(env, l[2]) if (is_err(value)) { return value @@ -1474,7 +1491,8 @@ fun main(argc: int, argv: **char): int { printf("\n"); return 0x2F; } - closure _println_closure = (closure){ _println_impl, NULL};""" //' + closure _println_closure = (closure){ _println_impl, NULL}; + """ //" var main_s = str("int main(int argc, char** argv) {\n") var main_body = str() var inner_main = compile(&top_decs, &top_defs, &main_s, &main_body, f.env, *f.body) @@ -1507,6 +1525,104 @@ fun new_tmp(): str { tmp_idx += 1 return str("x") + tmp_idx } +fun find_closed_vars(defined: set, env: *Env, ast: MalValue): set { + match (ast) { + MalValue::List(l) { + println("Find closed vars list") + if (l.size == 0) { + return set() + } else if (is_symbol(l[0], "def!")) { + println("Find closed vars in def!") + defined.add(get_symbol_text(l[1])) + /*return find_closed_vars(defined, env, l[2])*/ + var to_ret = find_closed_vars(defined, env, l[2]) + println("end Find closed vars in def!") + return to_ret + } else if (is_symbol(l[0], "let*")) { + var bindings = get_list_or_vec(l[1]) + var to_ret = set() + var new_env = new()->construct(env) + for (var i = 0; i < bindings.size; i+=2;) { + defined.add(get_symbol_text(bindings[i])) + new_env->set(get_symbol_text(bindings[i]), MalValue::Nil()) + to_ret += find_closed_vars(defined, new_env, bindings[i+1]) + } + return to_ret + find_closed_vars(defined, new_env, l[2]) + } else if is_symbol(l[0], "do") || is_symbol(l[0], "if") { + var to_ret = set() + for (var i = 1; i < l.size; i++;) { + to_ret += find_closed_vars(defined, env, l[i]) + } + return to_ret + } else if (is_symbol(l[0], "fn*")) { + println("Find closed vars fn*") + var f = EVAL(env, ast) + /*return find_closed_vars(defined, env, get_value(f))*/ + var to_ret = find_closed_vars(defined, env, get_value(f)) + println("end find closed vars fn*") + return to_ret + } else if (is_symbol(l[0], "quote")) { + return set() + } else if (is_symbol(l[0], "quasiquote")) { + return find_closed_vars(defined, env, quasiquote(l[1])) + } else if (is_symbol(l[0], "macroexpand")) { + error("macroexpand doesn't make sense while finding closed vars") + } else if (is_symbol(l[0], "try*")) { + error("finding closed vars in try* unimplemented") + } else { + var to_ret = set() + for (var i = 0; i < l.size; i++;) { + to_ret += find_closed_vars(defined, env, l[i]) + } + return to_ret + } + println("end list") + } + MalValue::List(l) { + error("Can't get clsoure_vars for " + pr_str(ast, true)) + } + MalValue::Vector(l) { + error("Can't get clsoure_vars for " + pr_str(ast, true)) + } + MalValue::Symbol(s) { + if !defined.contains(s) { + var scope = env->find(s) + if scope == null() { + error("Can't find " + s + " in env when trying to find closed_vars\n" + env->to_string()) + } + // don't do for top level vars + if scope->outer != null() { + return set(s) + } + } + return set() + } + MalValue::Int(i) { + return set() + } + MalValue::Nil() { + return set() + } + MalValue::True() { + return set() + } + MalValue::False() { + return set() + } + MalValue::Function(f) { + var new_env = new()->construct(env) + for (var i = 0; i < f.parameters.size; i++;) { + new_env->set(f.parameters[i], MalValue::Nil()) + } + println("Find closed vars going inside function:\n" + new_env->to_string()) + /*return find_closed_vars(defined.union(from_vector(f.parameters)), new_env, *f.body)*/ + var to_ret = find_closed_vars(defined.union(from_vector(f.parameters)), new_env, *f.body) + println("coming out of function") + return to_ret + } + } + error("Can't get clsoure_vars for " + pr_str(ast, true)) +} fun compile_value(top_decs: *str, top_defs: *str, main_init: *str, defs: *str, env: *Env, ast: MalValue): str { match (ast) { MalValue::List(l) { @@ -1585,22 +1701,36 @@ fun compile_value(top_decs: *str, top_defs: *str, main_init: *str, defs: *str, e return str("0x1F") } MalValue::Function(f) { - var parameters_str = str() - for (var i = 0; i < f.parameters.size; i++;) { - parameters_str += "size_t " + f.parameters[i] + " = args[" + i + "];\n" - } var fun_name = "fun_" + new_tmp() *top_decs += "size_t " + fun_name + "(size_t*, size_t, size_t*);\n" - var function = "size_t " + fun_name + "(size_t* _, size_t num, size_t* args) {\n" + var function = "size_t " + fun_name + "(size_t* closed_vars, size_t num, size_t* args) {\n" function += str("check_num_params(num, ") + f.parameters.size + ", \"lambda\");\n" - function += parameters_str - var inner_value = compile(top_decs, top_defs, main_init, &function, env, *f.body) + var new_env = new()->construct(env) + for (var i = 0; i < f.parameters.size; i++;) { + function += "size_t " + f.parameters[i] + " = args[" + i + "];\n" + new_env->set(f.parameters[i], MalValue::Nil()) + } + /*var closed_vars = find_closed_vars(from_vector(f.parameters), env, *f.body)*/ + var closed_vars = find_closed_vars(set(), new_env, ast) + for (var i = 0; i < closed_vars.data.size; i++;) { + function += "size_t " + closed_vars.data[i] + " = closed_vars[" + i + "];\n" + } + var inner_value = compile(top_decs, top_defs, main_init, &function, new_env, *f.body) function += "return " + inner_value + ";\n}\n" *top_defs += function *defs += "closure* " + fun_name + "_closure = malloc(sizeof(closure));\n" *defs += fun_name + "_closure->func = " + fun_name + ";\n" - *defs += fun_name + "_closure->data = NULL;\n" + + + if closed_vars.data.size > 0 { + *defs += fun_name + "_closure->data = malloc(sizeof(size_t)*" + closed_vars.data.size + ");\n" + for (var i = 0; i < closed_vars.data.size; i++;) { + *defs += fun_name + "_closure->data[" + i + "] = " + closed_vars.data[i] + ";\n" + } + } else { + *defs += fun_name + "_closure->data = NULL;\n" + } return "((((size_t)"+fun_name+"_closure)<<3)|0x6)" } } @@ -1626,6 +1756,9 @@ fun compile(top_decs: *str, top_defs: *str, main_init: *str, defs: *str, env: *E if (!is_symbol(l[1])) { error("def! not on symbol") } + if env->outer != null() { + error("def! not at top level") + } var to_set_name = get_symbol_text(l[1]) var to_set_value = compile(top_decs, top_defs, main_init, defs, env, l[2]) *defs += "size_t " + to_set_name + " = " + to_set_value + ";\n" @@ -1652,6 +1785,7 @@ fun compile(top_decs: *str, top_defs: *str, main_init: *str, defs: *str, env: *E } var to_set_value = compile(top_decs, top_defs, main_init, defs, new_env, bindings[i+1]) *defs += "size_t " + get_symbol_text(bindings[i]) + " = " + to_set_value + ";\n" + new_env->set(get_symbol_text(bindings[i]), MalValue::Nil()) } *defs += let_val + " = " + compile(top_decs, top_defs, main_init, defs, new_env, l[2]) + ";\n}\n" return let_val