diff --git a/stdlib/ast_nodes.krak b/stdlib/ast_nodes.krak index b978177..9994bd5 100644 --- a/stdlib/ast_nodes.krak +++ b/stdlib/ast_nodes.krak @@ -18,6 +18,7 @@ adt ast_node { type_def: type_def, adt_def: adt_def, function: function, + function_template: function_template, code_block: code_block, statement: statement, if_statement: if_statement, @@ -302,6 +303,58 @@ obj function (Object) { return name == name && type == other.type && parameters == other.parameters && body_statement == other.body_statement } } +fun ast_function_template_ptr(name: string, syntax_node: *tree, template_types: vector, template_type_replacements: map): *ast_node { + var to_ret.construct(name, syntax_node, template_types, template_type_replacements): function_template + var ptr = new() + ptr->copy_construct(&ast_node::function_template(to_ret)) + return ptr +} +fun is_function_template(node: *ast_node): bool { + match(*node) { + ast_node::function_template(backing) return true + } + return false +} +obj function_template (Object) { + var name: string + var syntax_node: *tree + var instantiated: vector<*ast_node> + var template_types: vector + var template_type_replacements: map + var scope: map> + fun construct(name_in: string, syntax_node_in: *tree, template_types_in: vector, template_type_replacements_in: map): *function_template { + name.copy_construct(&name_in) + syntax_node = syntax_node_in + instantiated.construct() + template_types.copy_construct(&template_types_in) + template_type_replacements.copy_construct(&template_type_replacements_in) + scope.construct() + return this + } + fun copy_construct(old: *function_template) { + name.copy_construct(&old->name) + syntax_node = old->syntax_node + instantiated.copy_construct(&old->instantiated) + template_types.copy_construct(&old->template_types) + template_type_replacements.copy_construct(&old->template_type_replacements) + scope.copy_construct(&old->scope) + } + fun destruct() { + name.destruct() + instantiated.destruct() + template_types.destruct() + template_type_replacements.destruct() + scope.destruct() + } + fun operator=(other: ref function_template) { + destruct() + copy_construct(&other) + } + fun operator==(other: ref function_template): bool { + return name == name && syntax_node == other.syntax_node && instantiated == other.instantiated && + scope == other.scope && template_types == other.template_types && template_type_replacements == other.template_type_replacements + } +} fun ast_code_block_ptr(): *ast_node { var to_ret.construct(): code_block var ptr = new() @@ -888,6 +941,7 @@ fun get_ast_children(node: *ast_node): vector<*ast_node> { ast_node::type_def(backing) return backing.variables + backing.methods ast_node::adt_def(backing) return vector<*ast_node>() ast_node::function(backing) return backing.parameters + backing.body_statement + ast_node::function_template(backing) return backing.instantiated ast_node::code_block(backing) return backing.children ast_node::statement(backing) return vector<*ast_node>(backing.child) ast_node::if_statement(backing) return vector(backing.condition, backing.then_part, backing.else_part) @@ -914,6 +968,7 @@ fun get_ast_name(node: *ast_node): string { ast_node::type_def(backing) return string("type_def: ") + backing.name ast_node::adt_def(backing) return string("adt_def: ") + backing.name ast_node::function(backing) return string("function: ") + backing.name + ": " + backing.type->to_string() + ast_node::function_template(backing) return string("function_template: ") + backing.name ast_node::code_block(backing) return string("code_block") ast_node::statement(backing) return string("statement") ast_node::if_statement(backing) return string("if_statement") @@ -940,6 +995,7 @@ fun get_ast_scope(node: *ast_node): *map> { ast_node::type_def() return &node->type_def.scope ast_node::adt_def() return &node->adt_def.scope ast_node::function() return &node->function.scope + ast_node::function_template() return &node->function_template.scope ast_node::code_block() return &node->code_block.scope ast_node::statement() return &node->statement.scope ast_node::if_statement() return &node->if_statement.scope diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index 28b3a7c..c0ea824 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -83,7 +83,8 @@ obj ast_transformation (Object) { // we go through the parse tree for getting functions, but we're going through the ast for the things we've already set up and using the ast_to_syntax map parse_tree->children.for_each(fun(child: *tree) { if (child->data.name == "function") { - var function_node = second_pass_function(child, translation_unit, map()) + // also handles templated function + var function_node = second_pass_function(child, translation_unit, map(), true) translation_unit->translation_unit.children.add(function_node) ast_to_syntax.set(function_node, child) } else if (child->data.name == "declaration_statement") { @@ -102,7 +103,8 @@ obj ast_transformation (Object) { node->type_def.variables.add(declaration_node) ast_to_syntax.set(declaration_node, child) } else if (child->data.name == "function") { - var function_node = second_pass_function(child, node, map()) + // again, also handles templates + var function_node = second_pass_function(child, node, map(), true) node->type_def.methods.add(function_node) ast_to_syntax.set(function_node, child) } @@ -112,29 +114,6 @@ obj ast_transformation (Object) { } }) } - fun second_pass_function(node: *tree, translation_unit: *ast_node, template_replacements: map): *ast_node { - var function_name = concat_symbol_tree(get_node("func_identifier", node)) - // check to see if it is a template - // figure out return type - var typed_return_node = get_node("typed_return", node) - // darn no ternary yet - var return_type = null() - if (typed_return_node) return_type = transform_type(get_node("type", typed_return_node), translation_unit, template_replacements) - else return_type = type_ptr(base_type::void_return()) - // transform parameters - var parameters = vector<*ast_node>() - get_nodes("typed_parameter", node).for_each(fun(child: *tree) { - parameters.add(ast_identifier_ptr(concat_symbol_tree(get_node("identifier", child)), transform_type(get_node("type", child), translation_unit, template_replacements))) - }) - // figure out function type and make function_node - var function_node = ast_function_ptr(function_name, type_ptr(parameters.map(fun(parameter: *ast_node): *type return parameter->identifier.type;), return_type), parameters) - // add to scope (translation_unit) - add_to_scope(function_name, function_node, translation_unit) - add_to_scope("~enclosing_scope", translation_unit, function_node) - // add parameters to scope of function - parameters.for_each(fun(parameter: *ast_node) add_to_scope(parameter->identifier.name, parameter, function_node);) - return function_node - } // The third pass finishes up by doing all function bodies (top level and methods in objects) fun third_pass(parse_tree: *tree, translation_unit: *ast_node) { println(string("Third Pass for ") + translation_unit->translation_unit.name) @@ -161,6 +140,44 @@ obj ast_transformation (Object) { println(string("Fourth Pass for ") + translation_unit->translation_unit.name) } } +fun second_pass_function(node: *tree, scope: *ast_node, template_replacements: map, do_raw_template: bool): *ast_node { + var function_name = concat_symbol_tree(get_node("func_identifier", node)) + var template_dec = get_node("template_dec", node) + if (do_raw_template && template_dec) { + var template_types = vector() + var template_type_replacements = map() + get_nodes("template_param", template_dec).for_each(fun(template_param: *tree) { + template_types.add(concat_symbol_tree(get_node("identifier", template_param))) + template_type_replacements.set(template_types.last(), type_ptr(vector())) + }) + template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) + println("MAP DONE") + var function_template = ast_function_template_ptr(function_name, node, template_types, template_type_replacements) + add_to_scope(function_name, function_template, scope) + add_to_scope("~enclosing_scope", scope, function_template) + return function_template + } + // check to see if it is a template + // figure out return type + var typed_return_node = get_node("typed_return", node) + // darn no ternary yet + var return_type = null() + if (typed_return_node) return_type = transform_type(get_node("type", typed_return_node), scope, template_replacements) + else return_type = type_ptr(base_type::void_return()) + // transform parameters + var parameters = vector<*ast_node>() + get_nodes("typed_parameter", node).for_each(fun(child: *tree) { + parameters.add(ast_identifier_ptr(concat_symbol_tree(get_node("identifier", child)), transform_type(get_node("type", child), scope, template_replacements))) + }) + // figure out function type and make function_node + var function_node = ast_function_ptr(function_name, type_ptr(parameters.map(fun(parameter: *ast_node): *type return parameter->identifier.type;), return_type), parameters) + // add to scope + add_to_scope(function_name, function_node, scope) + add_to_scope("~enclosing_scope", scope, function_node) + // add parameters to scope of function + parameters.for_each(fun(parameter: *ast_node) add_to_scope(parameter->identifier.name, parameter, function_node);) + return function_node +} fun transform_type(node: *tree, scope: *ast_node, template_replacements: map): *type { // check for references and step down @@ -174,6 +191,12 @@ fun transform_type(node: *tree, scope: *ast_node, template_replacements: } var type_syntax_str = concat_symbol_tree(real_node) println(type_syntax_str + " *************************") + if (template_replacements.contains_key(type_syntax_str)) { + print("Is in template_replacements, returning: ") + var to_ret = template_replacements[type_syntax_str]->clone_with_indirection(indirection) + println(to_ret->to_string()) + return to_ret + } // should take into account indirection and references... if (type_syntax_str == "void") return type_ptr(base_type::void_return(), indirection) @@ -408,10 +431,10 @@ fun transform_function_call(node: *tree, scope: *ast_node): *ast_node { var parameters = get_nodes("parameter", node).map(fun(child: *tree): *ast_node return transform(get_node("boolean_expression", child), scope);) var parameter_types = parameters.map(fun(param: *ast_node): *type return get_ast_type(param);) var f = ast_function_call_ptr(transform(get_node("unarad", node), scope, search_type::function(parameter_types)), parameters) - /*print("function call function ")*/ - /*println(f->function_call.func)*/ - /*print("function call parameters ")*/ - /*f->function_call.parameters.for_each(fun(param: *ast_node) print(param);)*/ + print("function call function ") + println(f->function_call.func) + print("function call parameters ") + f->function_call.parameters.for_each(fun(param: *ast_node) print(param);) return f } fun transform_expression(node: *tree, scope: *ast_node): *ast_node return transform_expression(node, scope, search_type::none()) @@ -423,6 +446,11 @@ fun transform_expression(node: *tree, scope: *ast_node, searching_for: s if (node->children.size == 1) return transform(node->children[0], scope, searching_for) else if (node->children.size == 2) { + var template_inst = get_node("template_inst", node) + if (template_inst) { + var identifier = get_node("scoped_identifier", node) + return find_or_instantiate_function_template(identifier, template_inst, scope, searching_for) + } var check_if_post = concat_symbol_tree(node->children[1]) if (check_if_post == "--" || check_if_post == "++") { // give the post-operators a special suffix so the c_generator knows to emit them post @@ -437,10 +465,8 @@ fun transform_expression(node: *tree, scope: *ast_node, searching_for: s var first_param = transform(node->children[0], scope) var second_param = null() if (func_name == "." || func_name == "->") { - println("Gonna do the internal scope thing") second_param = transform(node->children[2], get_ast_type(first_param)->type_def, searching_for) } else { - println("Gonna do regular scope thing") second_param = transform(node->children[2], scope) } parameters = vector(first_param, second_param) @@ -457,6 +483,40 @@ fun get_builtin_function(name: string, param_types: vector<*type>): *ast_node { return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_decreased_indirection()), vector<*ast_node>()) return ast_function_ptr(name, type_ptr(param_types, param_types[0]), vector<*ast_node>()) } +fun find_or_instantiate_function_template(identifier: *tree, template_inst: *tree, scope: *ast_node, searching_for: search_type): *ast_node { + var name = concat_symbol_tree(identifier) + var results = scope_lookup(name, scope) + var real_types = get_nodes("type", template_inst).map(fun(t: *tree): *type return transform_type(t, scope, map());) + for (var i = 0; i < results.size; i++;) { + if (is_function_template(results[i])) { + var template_types = results[i]->function_template.template_types + var template_type_replacements = results[i]->function_template.template_type_replacements + if (template_types.size != real_types.size) + continue + println("FOR FIND OR INSTATINTATE PREEEE") + template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) + println("MAP DONE") + for (var j = 0; j < template_types.size; j++;) { + template_type_replacements[template_types[j]] = real_types[j] + println("Just made") + println(template_types[j]) + println("equal to") + println(real_types[j]->to_string()) + } + + println("FOR FIND OR INSTATINTATE") + template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) + println("MAP DONE") + + var part_instantiated = second_pass_function(results[i]->function_template.syntax_node, results[i], template_type_replacements, false) + // and fully instantiate it + part_instantiated->function.body_statement = transform_statement(get_node("statement", results[i]->function_template.syntax_node), part_instantiated) + return part_instantiated + } + } + println("FREAK OUT MACHINE") + return null() +} fun function_lookup(name: string, scope: *ast_node, param_types: vector<*type>): *ast_node { println(string("doing function lookup for: ") + name) var param_string = string() diff --git a/stdlib/c_generator.krak b/stdlib/c_generator.krak index 19e76b4..09fa2cf 100644 --- a/stdlib/c_generator.krak +++ b/stdlib/c_generator.krak @@ -151,11 +151,17 @@ obj c_generator (Object) { ast_node::simple_passthrough(backing) top_level_c_passthrough += generate_simple_passthrough(child) ast_node::declaration_statement(backing) variable_declarations += generate_declaration_statement(child, null(), null>>>(), true).one_string() + ";\n" ast_node::function(backing) { - // make sure not a template - // or a passthrough // check for and add to parameters if a closure generate_function_definition(child, null()) } + ast_node::function_template(backing) { + backing.scope.for_each(fun(key: string, value: vector<*ast_node>) { + value.for_each(fun(node: *ast_node) { + if (is_function(node)) + generate_function_definition(node, null()) + }) + }) + } ast_node::type_def(backing) { type_poset.add_vertex(child) } diff --git a/stdlib/type.krak b/stdlib/type.krak index 48034e0..3b9caea 100644 --- a/stdlib/type.krak +++ b/stdlib/type.krak @@ -34,18 +34,33 @@ fun type_ptr(parameters: vector<*type>, return_type: *type, indirection: int): * return new()->construct(parameters, return_type, indirection) } +fun type_ptr(traits: vector): *type { + return new()->construct(traits) +} + obj type (Object) { var base: base_type var parameter_types: vector<*type> var return_type: *type var indirection: int var type_def: *ast_node + var traits: vector fun construct(): *type { base.copy_construct(&base_type::none()) parameter_types.construct() indirection = 0 return_type = null() type_def = null() + traits.construct() + return this + } + fun construct(traits_in: vector): *type { + base.copy_construct(&base_type::template_type()) + parameter_types.construct() + indirection = 0 + return_type = null() + type_def = null() + traits.copy_construct(&traits_in) return this } fun construct(base_in: base_type, indirection_in: int): *type { @@ -54,6 +69,7 @@ obj type (Object) { indirection = indirection_in return_type = null() type_def = null() + traits.construct() return this } fun construct(type_def_in: *ast_node): *type { @@ -62,6 +78,7 @@ obj type (Object) { indirection = 0 return_type = null() type_def = type_def_in + traits.construct() return this } fun construct(parameter_types_in: vector<*type>, return_type_in: *type, indirection_in: int): *type { @@ -70,6 +87,7 @@ obj type (Object) { return_type = return_type_in indirection = indirection_in type_def = null() + traits.construct() return this } fun copy_construct(old: *type) { @@ -78,6 +96,7 @@ obj type (Object) { return_type = old->return_type indirection = old->indirection type_def = old->type_def + traits.copy_construct(&old->traits) } fun operator=(other: ref type) { destruct() @@ -86,29 +105,32 @@ obj type (Object) { fun destruct() { base.destruct() parameter_types.destruct() + traits.destruct() } fun operator!=(other: ref type):bool return !(*this == other); fun operator==(other: ref type):bool { if ( (return_type && other.return_type && *return_type != *other.return_type) || (return_type && !other.return_type) || (!return_type && other.return_type) ) return false - return base == other.base && parameter_types == other.parameter_types && indirection == other.indirection && type_def == other.type_def + return base == other.base && parameter_types == other.parameter_types && indirection == other.indirection && type_def == other.type_def && traits == other.traits } fun to_string(): string { - var indirection_str = string() - for (var i = 0; i < indirection; i++;) indirection_str += "*" + var all_string = string("traits:[") + for (var i = 0; i < traits.size; i++;) all_string += traits[i] + all_string += "] " + for (var i = 0; i < indirection; i++;) all_string += "*" match (base) { - base_type::none() return indirection_str + string("none") - base_type::object() return indirection_str + type_def->type_def.name - base_type::template() return indirection_str + string("template") - base_type::template_type() return indirection_str + string("template_type") - base_type::void_return() return indirection_str + string("void_return") - base_type::boolean() return indirection_str + string("boolean") - base_type::character() return indirection_str + string("character") - base_type::integer() return indirection_str + string("integer") - base_type::floating() return indirection_str + string("floating") - base_type::double_precision() return indirection_str + string("double_precision") + base_type::none() return all_string + string("none") + base_type::object() return all_string + type_def->type_def.name + base_type::template() return all_string + string("template") + base_type::template_type() return all_string + string("template_type") + base_type::void_return() return all_string + string("void_return") + base_type::boolean() return all_string + string("boolean") + base_type::character() return all_string + string("character") + base_type::integer() return all_string + string("integer") + base_type::floating() return all_string + string("floating") + base_type::double_precision() return all_string + string("double_precision") base_type::function() { - var temp = indirection_str + string("fun(") + var temp = all_string + string("fun(") parameter_types.for_each(fun(parameter_type: *type) temp += parameter_type->to_string() + ", ";) return temp + ")" + return_type->to_string() } diff --git a/tests/to_parse.krak b/tests/to_parse.krak index 4c18f68..4fdaa5f 100644 --- a/tests/to_parse.krak +++ b/tests/to_parse.krak @@ -1,5 +1,5 @@ import to_import: simple_print, simple_println, a, b, string_id - +/* fun something(param: int): Something { var to_ret.construct_with_param(param): Something return to_ret @@ -36,6 +36,8 @@ fun return_something_p_1(it: Something): Something { it.member += 11 return it } +*/ +fun id(in: T): T return in; /* fun some_function(): int return 0; fun some_other_function(in: bool): float { @@ -43,16 +45,21 @@ fun some_other_function(in: bool): float { } */ fun main(): int { - var test_methods = something(77) - var test_methods_param.construct_with_param(10090): Something - simple_println("Constructing an object and printint its member, copy-constructing it, and printing that out, then letting both be destructed") - simple_println(test_methods.member) - simple_println(test_methods_param.member) - var second_obj = test_methods - second_obj.member += 5 - simple_println(second_obj.member) - /*var some = return_something_p_1(second_obj)*/ - simple_println(return_something_p_1(second_obj).member) + var a = id(7) + simple_println(a) + + + + /*var test_methods = something(77)*/ + /*var test_methods_param.construct_with_param(10090): Something*/ + /*simple_println("Constructing an object and printint its member, copy-constructing it, and printing that out, then letting both be destructed")*/ + /*simple_println(test_methods.member)*/ + /*simple_println(test_methods_param.member)*/ + /*var second_obj = test_methods*/ + /*second_obj.member += 5*/ + /*simple_println(second_obj.member)*/ + /*[>var some = return_something_p_1(second_obj)<]*/ + /*simple_println(return_something_p_1(second_obj).member)*/ return 0 /* var a_declaration:int