diff --git a/stdlib/ast_nodes.krak b/stdlib/ast_nodes.krak index 35ce1c0..4d253c8 100644 --- a/stdlib/ast_nodes.krak +++ b/stdlib/ast_nodes.krak @@ -1041,9 +1041,11 @@ fun ast_to_dot(root: *ast_node): string { node_name_map.set(node, escaped) return escaped } + var done_set = set<*ast_node>() var helper: fun(*ast_node):void = fun(node: *ast_node) { + done_set.add(node) get_ast_children(node).for_each(fun(child: *ast_node) { - if (!child) + if (!child || done_set.contains(child)) return; // where on earth does the null come from ret += string("\"") + get_name(node) + "\" -> \"" + get_name(child) + "\"\n"; helper(child) diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index 20acda5..d7d2e5d 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -89,7 +89,7 @@ obj ast_transformation (Object) { ast_to_syntax.set(function_node, child) } else if (child->data.name == "declaration_statement") { // second pass declaration can actually just call a normal transform (but maybe should be it's own method to do so because typedef has to do it too?)... - translation_unit->translation_unit.children.add(transform_declaration_statement(child, translation_unit)) + translation_unit->translation_unit.children.add(transform_declaration_statement(child, translation_unit, map())) } }) // work on the ones already started @@ -99,7 +99,7 @@ obj ast_transformation (Object) { var type_def_syntax = ast_to_syntax[node] type_def_syntax->children.for_each(fun(child: *tree) { if (child->data.name == "declaration_statement") { - var declaration_node = transform_declaration_statement(child, node) + var declaration_node = transform_declaration_statement(child, node, map()) node->type_def.variables.add(declaration_node) ast_to_syntax.set(declaration_node, child) } else if (child->data.name == "function") { @@ -123,14 +123,14 @@ obj ast_transformation (Object) { // make sure not a template? or the method not a template? // also same body problem as below node->type_def.methods.for_each(fun(method: *ast_node) { - method->function.body_statement = transform_statement(get_node("statement", ast_to_syntax[method]), method) + method->function.body_statement = transform_statement(get_node("statement", ast_to_syntax[method]), method, map()) }) } ast_node::function(backing) { // make sure not a template // huh, I guess I can't actually assign to the backing. // This is actually a little bit of a problem, maybe these should be pointers also. All the pointers! - node->function.body_statement = transform_statement(get_node("statement", ast_to_syntax[node]), node) + node->function.body_statement = transform_statement(get_node("statement", ast_to_syntax[node]), node, map()) } } }) @@ -225,37 +225,37 @@ fun transform_type(node: *tree, scope: *ast_node, template_replacements: return type_ptr(base_type::none(), indirection) } } -fun transform(node: *tree, scope: *ast_node): *ast_node return transform(node, scope, search_type::none()) -fun transform(node: *tree, scope: *ast_node, searching_for: search_type): *ast_node { +fun transform(node: *tree, scope: *ast_node, template_replacements: map): *ast_node return transform(node, scope, search_type::none(), template_replacements) +fun transform(node: *tree, scope: *ast_node, searching_for: search_type, template_replacements: map): *ast_node { var name = node->data.name if (name == "identifier" || name == "scoped_identifier") { return transform_identifier(node, scope, searching_for) } else if (name == "code_block") { - return transform_code_block(node, scope) + return transform_code_block(node, scope, template_replacements) } else if (name == "if_comp") { return transform_if_comp(node, scope) } else if (name == "simple_passthrough") { return transform_simple_passthrough(node, scope) } else if (name == "statement") { - return transform_statement(node, scope) + return transform_statement(node, scope, template_replacements) } else if (name == "declaration_statement") { - return transform_declaration_statement(node, scope) + return transform_declaration_statement(node, scope, template_replacements) } else if (name == "assignment_statement") { - return transform_assignment_statement(node, scope) + return transform_assignment_statement(node, scope, template_replacements) } else if (name == "if_statement") { - return transform_if_statement(node, scope) + return transform_if_statement(node, scope, template_replacements) } else if (name == "while_loop") { - return transform_while_loop(node, scope) + return transform_while_loop(node, scope, template_replacements) } else if (name == "for_loop") { - return transform_for_loop(node, scope) + return transform_for_loop(node, scope, template_replacements) } else if (name == "return_statement") { - return transform_return_statement(node, scope) + return transform_return_statement(node, scope, template_replacements) } else if (name == "continue_statement" || name == "break_statement") { return transform_branching_statement(node, scope) } else if (name == "defer_statement") { - return transform_defer_statement(node, scope) + return transform_defer_statement(node, scope, template_replacements) } else if (name == "function_call") { - return transform_function_call(node, scope) + return transform_function_call(node, scope, template_replacements) } else if (name == "boolean_expression" || name == "and_boolean_expression" || name == "bool_exp" || name == "expression" || name == "shiftand" || name == "term" @@ -263,7 +263,7 @@ fun transform(node: *tree, scope: *ast_node, searching_for: search_type) || name == "access_operation" ) { // for now, assume passthrough and just transform underneath - return transform_expression(node, scope, searching_for) + return transform_expression(node, scope, searching_for, template_replacements) } else if (name == "bool" || name == "string" || name == "character" || name == "number" ) { @@ -273,8 +273,8 @@ fun transform(node: *tree, scope: *ast_node, searching_for: search_type) print("FAILED TO TRANSFORM: "); print(name + ": "); println(concat_symbol_tree(node)) return null() } -fun transform_all(nodes: vector<*tree>, scope: *ast_node): vector<*ast_node> { - return nodes.map(fun(node: *tree): *ast_node return transform(node, scope);) +fun transform_all(nodes: vector<*tree>, scope: *ast_node, template_replacements: map): vector<*ast_node> { + return nodes.map(fun(node: *tree): *ast_node return transform(node, scope, template_replacements);) } fun transform_identifier(node: *tree, scope: *ast_node, searching_for: search_type): *ast_node { // first, we check for and generate this @@ -313,16 +313,16 @@ fun transform_value(node: *tree, scope: *ast_node): *ast_node { } return ast_value_ptr(value_str, value_type) } -fun transform_code_block(node: *tree, scope: *ast_node): *ast_node { +fun transform_code_block(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { var new_block = ast_code_block_ptr() add_to_scope("~enclosing_scope", scope, new_block) - new_block->code_block.children = transform_all(node->children, new_block) + new_block->code_block.children = transform_all(node->children, new_block, template_replacements) return new_block } fun transform_if_comp(node: *tree, scope: *ast_node): *ast_node { var new_if_comp = ast_if_comp_ptr() new_if_comp->if_comp.wanted_generator = concat_symbol_tree(get_node("identifier", node)) - new_if_comp->if_comp.statement = transform_statement(get_node("statement", node), scope) + new_if_comp->if_comp.statement = transform_statement(get_node("statement", node), scope, map()) return new_if_comp } fun transform_simple_passthrough(node: *tree, scope: *ast_node): *ast_node { @@ -331,8 +331,8 @@ fun transform_simple_passthrough(node: *tree, scope: *ast_node): *ast_no new_passthrough->simple_passthrough.passthrough_str = concat_symbol_tree(get_node("triple_quoted_string", node)).slice(3,-4) return new_passthrough } -fun transform_statement(node: *tree, scope: *ast_node): *ast_node return ast_statement_ptr(transform(node->children[0], scope)); -fun transform_declaration_statement(node: *tree, scope: *ast_node): *ast_node { +fun transform_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node return ast_statement_ptr(transform(node->children[0], scope, template_replacements)); +fun transform_declaration_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { // this might have an init position method call var identifiers = get_nodes("identifier", node) var name = concat_symbol_tree(identifiers[0]) @@ -341,9 +341,9 @@ fun transform_declaration_statement(node: *tree, scope: *ast_node): *ast var expression_syntax_node = get_node("boolean_expression", node) var ident_type = null() var expression = null() - if (type_syntax_node) ident_type = transform_type(type_syntax_node, scope, map()) + if (type_syntax_node) ident_type = transform_type(type_syntax_node, scope, template_replacements) if (expression_syntax_node) { - expression = transform(expression_syntax_node, scope) + expression = transform(expression_syntax_node, scope, template_replacements) if (!type_syntax_node) ident_type = get_ast_type(expression) } @@ -352,8 +352,8 @@ fun transform_declaration_statement(node: *tree, scope: *ast_node): *ast var declaration = ast_declaration_statement_ptr(identifier, expression) // ok, deal with the possible init position method call if (identifiers.size == 2) { - var method = transform(identifiers[1], ident_type->type_def) - var parameters = get_nodes("parameter", node).map(fun(child: *tree): *ast_node return transform(get_node("boolean_expression", child), scope);) + var method = transform(identifiers[1], ident_type->type_def, template_replacements) + var parameters = get_nodes("parameter", node).map(fun(child: *tree): *ast_node return transform(get_node("boolean_expression", child), scope, template_replacements);) declaration->declaration_statement.init_method_call = make_method_call(identifier, method, parameters) } add_to_scope(name, identifier, scope) @@ -378,9 +378,9 @@ fun make_operator_call(func: *char, params: vector<*ast_node>): *ast_node return fun make_operator_call(func: string, params: vector<*ast_node>): *ast_node { return ast_function_call_ptr(get_builtin_function(func, params.map(fun(p:*ast_node): *type return get_ast_type(p);)), params) } -fun transform_assignment_statement(node: *tree, scope: *ast_node): *ast_node { - var assign_to = transform(get_node("factor", node), scope) - var to_assign = transform(get_node("boolean_expression", node), scope) +fun transform_assignment_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { + var assign_to = transform(get_node("factor", node), scope, template_replacements) + var to_assign = transform(get_node("boolean_expression", node), scope, template_replacements) if (get_node("\"\\+=\"", node)) to_assign = make_operator_call("+", vector(assign_to, to_assign)) else if (get_node("\"-=\"", node)) to_assign = make_operator_call("-", vector(assign_to, to_assign)) else if (get_node("\"\\*=\"", node)) to_assign = make_operator_call("*", vector(assign_to, to_assign)) @@ -388,63 +388,63 @@ fun transform_assignment_statement(node: *tree, scope: *ast_node): *ast_ var assignment = ast_assignment_statement_ptr(assign_to, to_assign) return assignment } -fun transform_if_statement(node: *tree, scope: *ast_node): *ast_node { - var if_statement = ast_if_statement_ptr(transform_expression(get_node("boolean_expression", node), scope)) +fun transform_if_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { + var if_statement = ast_if_statement_ptr(transform_expression(get_node("boolean_expression", node), scope, template_replacements)) // one variable declarations might be in a code_block-less if statement add_to_scope("~enclosing_scope", scope, if_statement) - var statements = transform_all(get_nodes("statement", node), if_statement) + var statements = transform_all(get_nodes("statement", node), if_statement, template_replacements) if_statement->if_statement.then_part = statements[0] // we have an else if (statements.size == 2) if_statement->if_statement.else_part = statements[1] return if_statement } -fun transform_while_loop(node: *tree, scope: *ast_node): *ast_node { - var while_loop = ast_while_loop_ptr(transform_expression(get_node("boolean_expression", node), scope)) +fun transform_while_loop(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { + var while_loop = ast_while_loop_ptr(transform_expression(get_node("boolean_expression", node), scope, template_replacements)) add_to_scope("~enclosing_scope", scope, while_loop) - while_loop->while_loop.statement = transform(get_node("statement", node), while_loop) + while_loop->while_loop.statement = transform(get_node("statement", node), while_loop, template_replacements) return while_loop } -fun transform_for_loop(node: *tree, scope: *ast_node): *ast_node { +fun transform_for_loop(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { var for_loop = ast_for_loop_ptr() add_to_scope("~enclosing_scope", scope, for_loop) var statements = get_nodes("statement", node) - for_loop->for_loop.init = transform(statements[0], for_loop) - for_loop->for_loop.condition = transform(get_node("boolean_expression", node), for_loop) - for_loop->for_loop.update = transform(statements[1], for_loop) - for_loop->for_loop.body = transform(statements[2], for_loop) + for_loop->for_loop.init = transform(statements[0], for_loop, template_replacements) + for_loop->for_loop.condition = transform(get_node("boolean_expression", node), for_loop, template_replacements) + for_loop->for_loop.update = transform(statements[1], for_loop, template_replacements) + for_loop->for_loop.body = transform(statements[2], for_loop, template_replacements) return for_loop } -fun transform_return_statement(node: *tree, scope: *ast_node): *ast_node { - return ast_return_statement_ptr(transform(node->children[0], scope)) +fun transform_return_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { + return ast_return_statement_ptr(transform(node->children[0], scope, template_replacements)) } fun transform_branching_statement(node: *tree, scope: *ast_node): *ast_node { if (node->data.name == "break_statement") return ast_branching_statement_ptr(branching_type::break_stmt()) return ast_branching_statement_ptr(branching_type::continue_stmt()) } -fun transform_defer_statement(node: *tree, scope: *ast_node): *ast_node { - return ast_defer_statement_ptr(transform(node->children[0], scope)) +fun transform_defer_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { + return ast_defer_statement_ptr(transform(node->children[0], scope, template_replacements)) } -fun transform_function_call(node: *tree, scope: *ast_node): *ast_node { +fun transform_function_call(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { // don't bother with a full transform for parameters with their own function, just get the boolean expression and transform it - var parameters = get_nodes("parameter", node).map(fun(child: *tree): *ast_node return transform(get_node("boolean_expression", child), scope);) + var parameters = get_nodes("parameter", node).map(fun(child: *tree): *ast_node return transform(get_node("boolean_expression", child), scope, template_replacements);) var parameter_types = parameters.map(fun(param: *ast_node): *type return get_ast_type(param);) - var f = ast_function_call_ptr(transform(get_node("unarad", node), scope, search_type::function(parameter_types)), parameters) + var f = ast_function_call_ptr(transform(get_node("unarad", node), scope, search_type::function(parameter_types), template_replacements), parameters) print("function call function ") println(f->function_call.func) print("function call parameters ") f->function_call.parameters.for_each(fun(param: *ast_node) print(param);) return f } -fun transform_expression(node: *tree, scope: *ast_node): *ast_node return transform_expression(node, scope, search_type::none()) -fun transform_expression(node: *tree, scope: *ast_node, searching_for: search_type): *ast_node { +fun transform_expression(node: *tree, scope: *ast_node, template_replacements: map): *ast_node return transform_expression(node, scope, search_type::none(), template_replacements) +fun transform_expression(node: *tree, scope: *ast_node, searching_for: search_type, template_replacements: map): *ast_node { // figure out what the expression is, handle overloads, or you know // ignore everything and do a passthrough var func_name = string() var parameters = vector<*ast_node>() if (node->children.size == 1) - return transform(node->children[0], scope, searching_for) + return transform(node->children[0], scope, searching_for, template_replacements) else if (node->children.size == 2) { var template_inst = get_node("template_inst", node) if (template_inst) { @@ -455,7 +455,7 @@ fun transform_expression(node: *tree, scope: *ast_node, searching_for: s println("TE() } - search_type::function(type_vec) return find_or_instantiate_function_template(identifier, template_inst, scope, type_vec) + search_type::function(type_vec) return find_or_instantiate_function_template(identifier, template_inst, scope, type_vec, template_replacements) } println("NEVER EVER HAPPEN") } @@ -463,19 +463,19 @@ fun transform_expression(node: *tree, scope: *ast_node, searching_for: s if (check_if_post == "--" || check_if_post == "++") { // give the post-operators a special suffix so the c_generator knows to emit them post func_name = concat_symbol_tree(node->children[1]) + "p" - parameters = vector(transform(node->children[0], scope)) + parameters = vector(transform(node->children[0], scope, template_replacements)) } else { func_name = concat_symbol_tree(node->children[0]) - parameters = vector(transform(node->children[1], scope)) + parameters = vector(transform(node->children[1], scope, template_replacements)) } } else { func_name = concat_symbol_tree(node->children[1]) - var first_param = transform(node->children[0], scope) + var first_param = transform(node->children[0], scope, template_replacements) var second_param = null() if (func_name == "." || func_name == "->") { - second_param = transform(node->children[2], get_ast_type(first_param)->type_def, searching_for) + second_param = transform(node->children[2], get_ast_type(first_param)->type_def, searching_for, template_replacements) } else { - second_param = transform(node->children[2], scope) + second_param = transform(node->children[2], scope, template_replacements) } parameters = vector(first_param, second_param) } @@ -491,10 +491,10 @@ fun get_builtin_function(name: string, param_types: vector<*type>): *ast_node { return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_decreased_indirection()), vector<*ast_node>()) return ast_function_ptr(name, type_ptr(param_types, param_types[0]), vector<*ast_node>()) } -fun find_or_instantiate_function_template(identifier: *tree, template_inst: *tree, scope: *ast_node, param_types: vector<*type>): *ast_node { +fun find_or_instantiate_function_template(identifier: *tree, template_inst: *tree, scope: *ast_node, param_types: vector<*type>, template_replacements: map): *ast_node { var name = concat_symbol_tree(identifier) var results = scope_lookup(name, scope) - var real_types = get_nodes("type", template_inst).map(fun(t: *tree): *type return transform_type(t, scope, map());) + var real_types = get_nodes("type", template_inst).map(fun(t: *tree): *type return transform_type(t, scope, template_replacements);) var real_types_deref = real_types.map(fun(t:*type):type return *t;) for (var i = 0; i < results.size; i++;) { if (is_function_template(results[i])) { @@ -524,10 +524,11 @@ fun find_or_instantiate_function_template(identifier: *tree, template_in println("MAP DONE") inst_func = second_pass_function(results[i]->function_template.syntax_node, results[i], template_type_replacements, false) - // and fully instantiate it - inst_func->function.body_statement = transform_statement(get_node("statement", results[i]->function_template.syntax_node), inst_func) // add to instantiated_map so we only instantiate with a paticular set of types once + // put in map first for recursive purposes results[i]->function_template.instantiated_map.set(real_types_deref, inst_func) + // and fully instantiate it + inst_func->function.body_statement = transform_statement(get_node("statement", results[i]->function_template.syntax_node), inst_func, template_type_replacements) } if (function_satisfies_params(inst_func, param_types)) diff --git a/tests/to_parse.krak b/tests/to_parse.krak index afcfdad..e1b4f20 100644 --- a/tests/to_parse.krak +++ b/tests/to_parse.krak @@ -39,6 +39,12 @@ fun return_something_p_1(it: Something): Something { */ fun id(in: *T): *T return in; fun id(in: T): T return in; +fun other_id(in: T): T { + var a: T + a = in + simple_println(id(in)) + return in; +} /* fun some_function(): int return 0; fun some_other_function(in: bool): float { @@ -52,6 +58,7 @@ fun main(): int { /*var b = id<*char>("Double down time")*/ /*simple_println(b)*/ simple_println(id("Double down time")) + simple_println(other_id<*char>("Triple down time"))