diff --git a/stdlib/function_value_lower.krak b/stdlib/function_value_lower.krak index fc8b8ae..95bf43f 100644 --- a/stdlib/function_value_lower.krak +++ b/stdlib/function_value_lower.krak @@ -16,11 +16,13 @@ import pass_common:* obj function_parent_block { var function: *ast_node var parent: *ast_node + var parent_block: *ast_node } -fun make_function_parent_block(function: *ast_node, parent: *ast_node): function_parent_block { +fun make_function_parent_block(function: *ast_node, parent: *ast_node, parent_block: *ast_node): function_parent_block { var result: function_parent_block result.function = function result.parent = parent + result.parent_block = parent_block return result } @@ -56,13 +58,14 @@ fun function_value_lower(name_ast_map: *map,*ast_node // as I'm not sure you can pass member functions anyway /*|| parent_chain->size() < 2 || !is_function_call(parent_chain->from_top(1)) || parent_chain->from_top(1)->function_call.func != parent))))*/ if (need_done) { - function_value_creation_points.add(make_function_parent_block(node, parent_chain->top())) + function_value_creation_points.add(make_function_parent_block(node, parent_chain->top(), + parent_chain->item_from_top_satisfying(fun(i: *ast_node): bool return is_code_block(i);))) } } ast_node::function_call(backing) { if (!get_ast_type(backing.func)->is_raw) - function_value_call_points.add(make_function_parent_block(backing.func, node)) + function_value_call_points.add(make_function_parent_block(backing.func, node, null())) } } } @@ -72,7 +75,7 @@ fun function_value_lower(name_ast_map: *map,*ast_node println(string("there are ") + all_types.size() + " all types in the program.") var void_ptr = type_ptr(base_type::void_return(), 1); // this most vexing parse actually causes a compiler segfault as it tries to call the result of type_ptr as a function.... - // AND IT STILL DOES EVEN WITH ALL MY CHEKCS + // AND IT STILL DOES EVEN WITH ALL MY CHECKS var lambda_type_to_struct_type_and_call_func = map>(); //freaking vexing parse moved var all_type_values = all_types.map(fun(t: *type): type return *t;) all_type_values.for_each(fun(t: type) { @@ -82,12 +85,22 @@ fun function_value_lower(name_ast_map: *map,*ast_node var new_type_def_name = t.to_string() + "_function_value_struct" var new_type_def = ast_type_def_ptr(new_type_def_name) + var func_ident = ast_identifier_ptr("func", cleaned, new_type_def) add_to_scope("func", func_ident, new_type_def) + + var func_closure_type = cleaned->clone() + func_closure_type->parameter_types.add(0, type_ptr(base_type::void_return(), 1)) + var func_closure_ident = ast_identifier_ptr("func_closure", func_closure_type, new_type_def) + add_to_scope("func_closure", func_closure_ident, new_type_def) + var data_ident = ast_identifier_ptr("data", void_ptr, new_type_def) add_to_scope("data", data_ident, new_type_def) + new_type_def->type_def.variables.add(ast_declaration_statement_ptr(func_ident, null())) + new_type_def->type_def.variables.add(ast_declaration_statement_ptr(func_closure_ident, null())) new_type_def->type_def.variables.add(ast_declaration_statement_ptr(data_ident, null())) + add_to_scope("~enclosing_scope", name_ast_map->values.first().second, new_type_def) add_to_scope(new_type_def_name, new_type_def, name_ast_map->values.first().second) name_ast_map->values.first().second->translation_unit.children.add(new_type_def) @@ -101,9 +114,13 @@ fun function_value_lower(name_ast_map: *map,*ast_node return ast_identifier_ptr("pass_through_param", t, null()) }) var lambda_call_function = ast_function_ptr(string("lambda_call"), lambda_call_type, lambda_call_parameters, false) - lambda_call_function->function.body_statement = ast_code_block_ptr() - lambda_call_function->function.body_statement->code_block.children.add(ast_return_statement_ptr(ast_function_call_ptr(access_expression(lambda_call_func_param, "func"), lambda_call_parameters.slice(1,-1)))) // create call body with if, etc + var if_statement = ast_if_statement_ptr(access_expression(lambda_call_func_param, "data")) + lambda_call_function->function.body_statement = ast_code_block_ptr(if_statement) + if_statement->if_statement.then_part = ast_code_block_ptr(ast_return_statement_ptr(ast_function_call_ptr(access_expression(lambda_call_func_param, "func_closure"), + vector(access_expression(lambda_call_func_param, "data")) + lambda_call_parameters.slice(1,-1)))) + if_statement->if_statement.else_part = ast_code_block_ptr(ast_return_statement_ptr(ast_function_call_ptr(access_expression(lambda_call_func_param, "func"), + lambda_call_parameters.slice(1,-1)))) lambda_type_to_struct_type_and_call_func[t] = make_pair(lambda_struct_type, lambda_call_function) name_ast_map->values.first().second->translation_unit.children.add(new_type_def) @@ -115,16 +132,20 @@ fun function_value_lower(name_ast_map: *map,*ast_node // create the closure type for each lambda var closure_id = 0 lambdas.for_each(fun(l: *ast_node) { + var closure_struct_type: *type if (l->function.closed_variables.size()) { var new_type_def_name = string("closure_struct_") + closure_id++ var new_type_def = ast_type_def_ptr(new_type_def_name) l->function.closed_variables.for_each(fun(v: *ast_node) { // TODO: need to clean this type if it's a lambda type or contains it - new_type_def->type_def.variables.add(ast_declaration_statement_ptr(ast_identifier_ptr(v->identifier.name, v->identifier.type->clone_with_ref(), new_type_def), null())) + var closed_ident = ast_identifier_ptr(v->identifier.name, v->identifier.type->clone_with_ref(), new_type_def) + new_type_def->type_def.variables.add(ast_declaration_statement_ptr(closed_ident, null())) + add_to_scope(v->identifier.name, closed_ident, new_type_def) }) add_to_scope("~enclosing_scope", name_ast_map->values.first().second, new_type_def) add_to_scope(new_type_def_name, new_type_def, name_ast_map->values.first().second) name_ast_map->values.first().second->translation_unit.children.add(new_type_def) + closure_struct_type = type_ptr(new_type_def)->clone_with_increased_indirection() } var return_type = lambda_type_to_struct_type_and_call_func[*l->function.type].first @@ -134,7 +155,20 @@ fun function_value_lower(name_ast_map: *map,*ast_node var ident = ast_identifier_ptr("to_ret", return_type, body) body->code_block.children.add(ast_declaration_statement_ptr(ident, null())) body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "func"), l)) - body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "data"), ast_value_ptr(string("0"), type_ptr(base_type::void_return(), 1)))) + body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "func_closure"), l)) + if (l->function.closed_variables.size()) { + var closure_param = ast_identifier_ptr("closure", closure_struct_type, body) + lambda_creation_funcs[l]->function.parameters.add(closure_param) + body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "data"), closure_param)) + l->function.closed_variables.for_each(fun(v: *ast_node) { + var closed_param = ast_identifier_ptr("closed_param", v->identifier.type->clone_with_increased_indirection(), l) + lambda_creation_funcs[l]->function.parameters.add(closed_param) + /*body->code_block.children.add(ast_assignment_statement_ptr(access_expression(closure_param, v->identifier.name), closed_param))*/ + body->code_block.children.add(ast_assignment_statement_ptr(closure_param, closed_param)) + }) + } else { + body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "data"), ast_value_ptr(string("0"), type_ptr(base_type::void_return(), 1)))) + } body->code_block.children.add(ast_return_statement_ptr(ident)) lambda_creation_funcs[l]->function.body_statement = body name_ast_map->values.first().second->translation_unit.children.add(lambda_creation_funcs[l]) @@ -148,7 +182,19 @@ fun function_value_lower(name_ast_map: *map,*ast_node p.parent->function_call.parameters.add(0, function_struct) }) function_value_creation_points.for_each(fun(p: function_parent_block) { - var func_call = ast_function_call_ptr(lambda_creation_funcs[p.function], vector<*ast_node>()) + var lambda_creation_params = vector<*ast_node>() + // add the declaration of the closure struct to the enclosing code block + if (p.function->function.closed_variables.size()) { + // pull closure type off lambda creation func parameter + var closure_type = get_ast_type(lambda_creation_funcs[p.function]->function.parameters[0])->clone_with_decreased_indirection() + var closure_struct_ident = ast_identifier_ptr("closure_struct", closure_type, p.parent_block) + p.parent_block->code_block.children.add(0,ast_declaration_statement_ptr(closure_struct_ident, null())) + lambda_creation_params.add(make_operator_call("&", vector(closure_struct_ident))) + p.function->function.closed_variables.for_each(fun(v: *ast_node) { + lambda_creation_params.add(make_operator_call("&", vector(v))) + }) + } + var func_call = ast_function_call_ptr(lambda_creation_funcs[p.function], lambda_creation_params) replace_with_in(p.function, func_call, p.parent) }) lambdas.for_each(fun(l: *ast_node) l->function.type = l->function.type->clone();) diff --git a/test_function_value.krak b/test_function_value.krak index e037026..e079ca4 100644 --- a/test_function_value.krak +++ b/test_function_value.krak @@ -1,6 +1,7 @@ fun main(argc: int, argv: **char): int { - var a = fun(i: int, x: int): int { return i+x; } + var y = 20 + var a = fun(i: int, x: int): int { return i+x+y; } return a(12, 11) }