diff --git a/stdlib/ast_nodes.krak b/stdlib/ast_nodes.krak index 2b620e7..3c0543c 100644 --- a/stdlib/ast_nodes.krak +++ b/stdlib/ast_nodes.krak @@ -139,11 +139,11 @@ obj import (Object) { return imported == other.imported && name == other.name && translation_unit == other.translation_unit && starred == other.starred } } -fun ast_identifier_ptr(name: *char, type: *type): *ast_node { - return ast_identifier_ptr(string(name), type) +fun ast_identifier_ptr(name: *char, type: *type, enclosing_scope: *ast_node): *ast_node { + return ast_identifier_ptr(string(name), type, enclosing_scope) } -fun ast_identifier_ptr(name: string, type: *type): *ast_node { - var to_ret.construct(name, type): identifier +fun ast_identifier_ptr(name: string, type: *type, enclosing_scope: *ast_node): *ast_node { + var to_ret.construct(name, type, enclosing_scope): identifier var ptr = new() ptr->copy_construct(&ast_node::identifier(to_ret)) return ptr @@ -158,16 +158,19 @@ obj identifier (Object) { var name: string var scope: map> var type: *type - fun construct(name_in: string, type_in: *type): *identifier { + var enclosing_scope: *ast_node + fun construct(name_in: string, type_in: *type, enclosing_scope: *ast_node): *identifier { name.copy_construct(&name_in) scope.construct() type = type_in + identifier::enclosing_scope = enclosing_scope return this } fun copy_construct(old: *identifier) { name.copy_construct(&old->name) scope.copy_construct(&old->scope) type = old->type + enclosing_scope = old->enclosing_scope } fun destruct() { name.destruct() @@ -178,7 +181,7 @@ obj identifier (Object) { copy_construct(&other) } fun operator==(other: ref identifier): bool { - return name == other.name && type == other.type + return name == other.name && type == other.type && enclosing_scope == other.enclosing_scope } } fun ast_type_def_ptr(name: string): *ast_node { @@ -280,11 +283,13 @@ obj function (Object) { var name: string var type: *type var parameters: vector<*ast_node> + var closed_variables: set<*ast_node> var body_statement: *ast_node var scope: map> fun construct(name_in: string, type_in: *type, parameters_in: vector<*ast_node>): *function { name.copy_construct(&name_in) parameters.copy_construct(¶meters_in) + closed_variables.construct() scope.construct() type = type_in body_statement = null() @@ -295,11 +300,13 @@ obj function (Object) { type = old->type body_statement = old->body_statement parameters.copy_construct(&old->parameters) + closed_variables.copy_construct(&old->closed_variables) scope.copy_construct(&old->scope) } fun destruct() { name.destruct() parameters.destruct() + closed_variables.destruct() scope.destruct() } fun operator=(other: ref function) { @@ -307,7 +314,7 @@ obj function (Object) { copy_construct(&other) } fun operator==(other: ref function): bool { - return name == name && type == other.type && parameters == other.parameters && body_statement == other.body_statement + return name == name && type == other.type && parameters == other.parameters && body_statement == other.body_statement && closed_variables == other.closed_variables } } fun ast_template_ptr(name: string, syntax_node: *tree, template_types: vector, template_type_replacements: map, is_function: bool): *ast_node { diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index 004c4e2..93b591c 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -169,10 +169,13 @@ obj ast_transformation (Object) { // transform parameters var parameters = vector<*ast_node>() get_nodes("typed_parameter", node).for_each(fun(child: *tree) { - parameters.add(ast_identifier_ptr(concat_symbol_tree(get_node("identifier", child)), transform_type(get_node("type", child), scope, template_replacements))) + // note the temporary null() which gets replaced below, as the dependency is circular + parameters.add(ast_identifier_ptr(concat_symbol_tree(get_node("identifier", child)), transform_type(get_node("type", child), scope, template_replacements), null())) }) // figure out function type and make function_node var function_node = ast_function_ptr(function_name, type_ptr(parameters.map(fun(parameter: *ast_node): *type return parameter->identifier.type;), return_type), parameters) + // fix up the enclosing_scope's + parameters.for_each(fun(n: *ast_node) n->identifier.enclosing_scope = function_node;) // add to scope add_to_scope(function_name, function_node, scope) add_to_scope("~enclosing_scope", scope, function_node) @@ -372,7 +375,7 @@ obj ast_transformation (Object) { var name = concat_symbol_tree(node) if (name == "this") { while (!is_type_def(scope)) scope = get_ast_scope(scope)->get(string("~enclosing_scope"))[0] - return ast_identifier_ptr("this", scope->type_def.self_type->clone_with_indirection(1)) + return ast_identifier_ptr("this", scope->type_def.self_type->clone_with_indirection(1), scope) } match (searching_for) { search_type::none() return identifier_lookup(name, scope) @@ -465,7 +468,7 @@ obj ast_transformation (Object) { ident_type = get_ast_type(expression) } if (!ident_type) error("declaration statement with no type or expression from which to inference type") - var identifier = ast_identifier_ptr(name, ident_type) + var identifier = ast_identifier_ptr(name, ident_type, scope) var declaration = ast_declaration_statement_ptr(identifier, expression) // ok, deal with the possible init position method call if (identifiers.size == 2) { @@ -541,10 +544,50 @@ obj ast_transformation (Object) { fun transform_lambda(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { var function_node = second_pass_function(node, scope, template_replacements, false) function_node->function.body_statement = transform_statement(get_node("statement", node), function_node, template_replacements) + function_node->function.closed_variables = find_closed_variables(function_node, function_node->function.body_statement) + println(string("Found ") + function_node->function.closed_variables.size() + " closed variables!") while (!is_translation_unit(scope)) scope = get_ast_scope(scope)->get(string("~enclosing_scope"))[0] scope->translation_unit.lambdas.add(function_node) return function_node } + fun find_closed_variables(func: *ast_node, node: *ast_node): set<*ast_node> { + match (*node) { + ast_node::identifier(backing) { + println("found an identifier") + println(backing.name) + if (!in_scope_chain(backing.enclosing_scope, func)) + return set(node); + } + ast_node::statement(backing) { + println("found an statement") + return find_closed_variables(func, backing.child) + } + ast_node::code_block(backing) { + println("found an code_block") + var to_ret = set<*ast_node>() + backing.children.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);) + return to_ret + } + ast_node::function_call(backing) { + println("found an function_call") + var to_ret = find_closed_variables(func, backing.func) + backing.parameters.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);) + return to_ret + } + ast_node::return_statement(backing) { + println("found an return_statement") + return find_closed_variables(func, backing.return_value) + } + } + return set<*ast_node>() + } + fun in_scope_chain(node: *ast_node, high_scope: *ast_node): bool { + if (node == high_scope) + return true + if (get_ast_scope(node)->contains_key(string("~enclosing_scope"))) + return in_scope_chain(get_ast_scope(node)->get(string("~enclosing_scope"))[0], high_scope) + return false + } fun transform_expression(node: *tree, scope: *ast_node, template_replacements: map): *ast_node return transform_expression(node, scope, search_type::none(), template_replacements) fun transform_expression(node: *tree, scope: *ast_node, searching_for: search_type, template_replacements: map): *ast_node { var func_name = string() diff --git a/stdlib/c_generator.krak b/stdlib/c_generator.krak index ee193f5..4f20c73 100644 --- a/stdlib/c_generator.krak +++ b/stdlib/c_generator.krak @@ -76,10 +76,15 @@ obj code_triple (Object) { obj c_generator (Object) { var id_counter: int var ast_name_map: map<*ast_node, string> + var closure_struct_map: map, string> var function_typedef_string: string + var closure_struct_definitions: string fun construct(): *c_generator { id_counter = 0 ast_name_map.construct() + closure_struct_map.construct() + function_typedef_string.construct() + closure_struct_definitions.construct() return this } fun copy_construct(old: *c_generator) { @@ -101,7 +106,8 @@ obj c_generator (Object) { var top_level_c_passthrough: string = "" var variable_extern_declarations: string = "" var structs: string = "\n/**Type Structs**/\n" - function_typedef_string.construct() + function_typedef_string = "\n/**Typedefs**/\n" + closure_struct_definitions = "\n/**Closure Struct Definitions**/\n" var function_prototypes: string = "\n/**Function Prototypes**/\n" var function_definitions: string = "\n/**Function Definitions**/\n" var variable_declarations: string = "\n/**Variable Declarations**/\n" @@ -116,6 +122,13 @@ obj c_generator (Object) { parameter_types = type_to_c(enclosing_object->type_def.self_type) + "*" parameters = type_to_c(enclosing_object->type_def.self_type) + "* this" } + if (backing.closed_variables.size()) { + println("HAS CLOSED VARIABLES") + if (parameter_types != "") { parameter_types += ", "; parameters += ", ";} + var closed_type_name = get_closure_struct_type(backing.closed_variables) + parameter_types += closed_type_name + "*" + parameters += closed_type_name + "* closure_data" + } // stack-stack thing // this could be a stack of strings too, maybe // start out with one stack on the stack @@ -199,7 +212,14 @@ obj c_generator (Object) { }) }) - return make_pair(prequal+plain_typedefs+top_level_c_passthrough+variable_extern_declarations+structs+function_typedef_string+function_prototypes+variable_declarations+function_definitions + "\n", linker_string) + return make_pair(prequal+plain_typedefs+top_level_c_passthrough+variable_extern_declarations+structs+closure_struct_definitions+function_typedef_string+function_prototypes+variable_declarations+function_definitions + "\n", linker_string) + } + fun get_closure_struct_type(closed_variables: set<*ast_node>): string { + if (!closure_struct_map.contains_key(closed_variables)) { + closure_struct_definitions += "typedef struct {} random_closure_type;\n" + closure_struct_map[closed_variables] = string("random_closure_type") + } + return closure_struct_map[closed_variables] } fun generate_if_comp(node: *ast_node, enclosing_object: *ast_node, defer_stack: *stack>>): code_triple { if (node->if_comp.wanted_generator == "__C__") @@ -290,7 +310,7 @@ obj c_generator (Object) { var to_ret = code_triple() // if we're returning an object, copy_construct a new one to return if (return_value_type->is_object() && return_value_type->indirection == 0 && has_method(return_value_type->type_def, "copy_construct", vector(return_value_type->clone_with_indirection(1)))) { - var temp_ident = ast_identifier_ptr(string("temporary_return")+get_id(), return_value_type) + var temp_ident = ast_identifier_ptr(string("temporary_return")+get_id(), return_value_type, null()) var declaration = ast_declaration_statement_ptr(temp_ident, null()) // have to pass false to the declaration generator, so can't do it through generate_statement to_ret.pre = generate_declaration_statement(declaration, enclosing_object, defer_stack, false).one_string() + ";\n" @@ -416,7 +436,7 @@ obj c_generator (Object) { var param_type = get_ast_type(param) if (param_type->is_object() && param_type->indirection == 0 && has_method(param_type->type_def, "copy_construct", vector(param_type->clone_with_indirection(1)))) { - var temp_ident = ast_identifier_ptr(string("temporary_param")+get_id(), param_type) + var temp_ident = ast_identifier_ptr(string("temporary_param")+get_id(), param_type, null()) var declaration = ast_declaration_statement_ptr(temp_ident, null()) // have to pass false to the declaration generator, so can't do it through generate_statement call_string.pre += generate_declaration_statement(declaration, enclosing_object, null>>>(), false).one_string() + ";\n" @@ -431,7 +451,7 @@ obj c_generator (Object) { }) if (func_return_type->is_object() && func_return_type->indirection == 0 && has_method(func_return_type->type_def, "destruct", vector<*type>())) { // kind of ugly combo here of - var temp_ident = ast_identifier_ptr(string("temporary_return")+get_id(), func_return_type) + var temp_ident = ast_identifier_ptr(string("temporary_return")+get_id(), func_return_type, null()) var declaration = ast_declaration_statement_ptr(temp_ident, null()) // have to pass false to the declaration generator, so can't do it through generate_statement call_string.pre += generate_declaration_statement(declaration, enclosing_object, null>>>(), false).one_string() + ";\n" diff --git a/stdlib/set.krak b/stdlib/set.krak index a73f74c..3d04ea9 100644 --- a/stdlib/set.krak +++ b/stdlib/set.krak @@ -63,10 +63,10 @@ obj set (Object, Serializable) { fun contains(item: T): bool { return data.find(item) != -1 } - fun operator+=(item: T) { + fun operator+=(item: ref T) { add(item) } - fun operator+=(items: set) { + fun operator+=(items: ref set) { add(items) } fun add(item: ref T) { diff --git a/tests/to_parse.krak b/tests/to_parse.krak index 27f1fa6..30f5684 100644 --- a/tests/to_parse.krak +++ b/tests/to_parse.krak @@ -1,11 +1,12 @@ import simple_print: * fun main(): int { - var v: fun(int):int - v = fun(data: int): int { + var data = 7 + var v: fun():int + v = fun(): int { println(data) return data } - println(v(7)) + println(v()) // println(print_and_return(7)) return 0 }