diff --git a/stdlib/ast_nodes.krak b/stdlib/ast_nodes.krak index 9994bd5..35ce1c0 100644 --- a/stdlib/ast_nodes.krak +++ b/stdlib/ast_nodes.krak @@ -321,6 +321,7 @@ obj function_template (Object) { var instantiated: vector<*ast_node> var template_types: vector var template_type_replacements: map + var instantiated_map: map, *ast_node> var scope: map> fun construct(name_in: string, syntax_node_in: *tree, template_types_in: vector, template_type_replacements_in: map): *function_template { name.copy_construct(&name_in) @@ -328,6 +329,7 @@ obj function_template (Object) { instantiated.construct() template_types.copy_construct(&template_types_in) template_type_replacements.copy_construct(&template_type_replacements_in) + instantiated_map.construct() scope.construct() return this } @@ -337,6 +339,7 @@ obj function_template (Object) { instantiated.copy_construct(&old->instantiated) template_types.copy_construct(&old->template_types) template_type_replacements.copy_construct(&old->template_type_replacements) + instantiated_map.copy_construct(&old->instantiated_map) scope.copy_construct(&old->scope) } fun destruct() { @@ -344,6 +347,7 @@ obj function_template (Object) { instantiated.destruct() template_types.destruct() template_type_replacements.destruct() + instantiated_map.destruct() scope.destruct() } fun operator=(other: ref function_template) { @@ -352,7 +356,8 @@ obj function_template (Object) { } fun operator==(other: ref function_template): bool { return name == name && syntax_node == other.syntax_node && instantiated == other.instantiated && - scope == other.scope && template_types == other.template_types && template_type_replacements == other.template_type_replacements + scope == other.scope && template_types == other.template_types && template_type_replacements == other.template_type_replacements && + instantiated_map == other.instantiated_map } } fun ast_code_block_ptr(): *ast_node { diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index c289165..20acda5 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -495,32 +495,43 @@ fun find_or_instantiate_function_template(identifier: *tree, template_in var name = concat_symbol_tree(identifier) var results = scope_lookup(name, scope) var real_types = get_nodes("type", template_inst).map(fun(t: *tree): *type return transform_type(t, scope, map());) + var real_types_deref = real_types.map(fun(t:*type):type return *t;) for (var i = 0; i < results.size; i++;) { if (is_function_template(results[i])) { var template_types = results[i]->function_template.template_types var template_type_replacements = results[i]->function_template.template_type_replacements if (template_types.size != real_types.size) continue - println("FOR FIND OR INSTATINTATE PREEEE") - template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) - println("MAP DONE") - for (var j = 0; j < template_types.size; j++;) { - template_type_replacements[template_types[j]] = real_types[j] - println("Just made") - println(template_types[j]) - println("equal to") - println(real_types[j]->to_string()) + // check if already instantiated + var inst_func = null() + if (results[i]->function_template.instantiated_map.contains_key(real_types_deref)) { + println("USING CACHED TEMPLATE FUNCITON") + inst_func = results[i]->function_template.instantiated_map[real_types_deref] + } else { + println("FOR FIND OR INSTATINTATE PREEEE") + template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) + println("MAP DONE") + for (var j = 0; j < template_types.size; j++;) { + template_type_replacements[template_types[j]] = real_types[j] + println("Just made") + println(template_types[j]) + println("equal to") + println(real_types[j]->to_string()) + } + + println("FOR FIND OR INSTATINTATE") + template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) + println("MAP DONE") + + inst_func = second_pass_function(results[i]->function_template.syntax_node, results[i], template_type_replacements, false) + // and fully instantiate it + inst_func->function.body_statement = transform_statement(get_node("statement", results[i]->function_template.syntax_node), inst_func) + // add to instantiated_map so we only instantiate with a paticular set of types once + results[i]->function_template.instantiated_map.set(real_types_deref, inst_func) } - - println("FOR FIND OR INSTATINTATE") - template_type_replacements.for_each(fun(key: string, value: *type) println(string("MAP: ") + key + " : " + value->to_string());) - println("MAP DONE") - - var part_instantiated = second_pass_function(results[i]->function_template.syntax_node, results[i], template_type_replacements, false) - // and fully instantiate it - part_instantiated->function.body_statement = transform_statement(get_node("statement", results[i]->function_template.syntax_node), part_instantiated) - if (function_satisfies_params(part_instantiated, param_types)) - return part_instantiated + + if (function_satisfies_params(inst_func, param_types)) + return inst_func } } println("FREAK OUT MACHINE") diff --git a/stdlib/type.krak b/stdlib/type.krak index f2a2d33..1cc04ff 100644 --- a/stdlib/type.krak +++ b/stdlib/type.krak @@ -109,9 +109,13 @@ obj type (Object) { } fun operator!=(other: ref type):bool return !(*this == other); fun operator==(other: ref type):bool { - if ( (return_type && other.return_type && *return_type != *other.return_type) || (return_type && !other.return_type) || (!return_type && other.return_type) ) + if (parameter_types.size != other.parameter_types.size) return false - return base == other.base && parameter_types == other.parameter_types && indirection == other.indirection && type_def == other.type_def && traits == other.traits + for (var i = 0; i < parameter_types.size; i++;) + if (!deref_equality(parameter_types[i], other.parameter_types[i])) + return false + return base == other.base && deref_equality(return_type, other.return_type) && + indirection == other.indirection && deref_equality(type_def, other.type_def) && traits == other.traits } fun to_string(): string { var all_string = string("traits:[") diff --git a/stdlib/util.krak b/stdlib/util.krak index 6451f71..7e4bfd5 100644 --- a/stdlib/util.krak +++ b/stdlib/util.krak @@ -5,6 +5,12 @@ import serialize // maybe my favorite function fun do_nothing() {} +fun deref_equality(a: *T, b: *T): bool { + if ( (a && b && !(*a == *b)) || (a && !b) || (!a && b) ) + return false + return true +} + fun max(a: T, b: T): T { if (a > b) return a; diff --git a/tests/to_parse.krak b/tests/to_parse.krak index 32e7472..afcfdad 100644 --- a/tests/to_parse.krak +++ b/tests/to_parse.krak @@ -48,6 +48,7 @@ fun some_other_function(in: bool): float { fun main(): int { var a = id(7) simple_println(a) + var b = id(8) /*var b = id<*char>("Double down time")*/ /*simple_println(b)*/ simple_println(id("Double down time"))