diff --git a/src/ASTTransformation.cpp b/src/ASTTransformation.cpp index de1c2f3..1eb2781 100644 --- a/src/ASTTransformation.cpp +++ b/src/ASTTransformation.cpp @@ -1026,6 +1026,24 @@ void ASTTransformation::unifyType(NodeTree *syntaxType, Type type, std:: if (children.size() == 1) { (*templateTypeMap)[concatSymbolTree(children.back())] = type; } else { + if (type.typeDefinition) { + // ok, what happens here is that we get the origional type from our type. This is + // the same as the type we have now but it still has extra data from when it was instantiated + // like the templateTypeReplacement map, which we'll use. + // We get the etc part from the template we're matching against and unify it with the + // actual types the type we're unifying with used by passing it's through the templateTypeReplacement + // to get the type it was instantiated with. + auto origionalType = type.typeDefinition->getDataRef()->valueType; + auto typeTemplateDefinition = origionalType->templateDefinition; + if (typeTemplateDefinition && concatSymbolTree(getNode("scoped_identifier", children)) == concatSymbolTree(getNode("identifier", typeTemplateDefinition->getChildren()))) { + std::vector*> uninTypeInstTypes = getNodes("type", getNode("template_inst", children)); + std::vector*> typeInstTypes = getNodes("template_param", getNode("template_dec", typeTemplateDefinition->getChildren())); + for (int i = 0; i < uninTypeInstTypes.size(); i++) + unifyType(uninTypeInstTypes[i], *origionalType->templateTypeReplacement[concatSymbolTree(typeInstTypes[i])], templateTypeMap); + + return; + } + } throw "the inference just isn't good enough"; } } @@ -1333,10 +1351,6 @@ Type* ASTTransformation::typeFromTypeNode(NodeTree* typeNode, NodeTree* templateDefinition = templateClassLookup(scope, concatSymbolTree(typeNode->getChildren()[0]), templateParamInstantiationTypes); - if (templateDefinition == NULL) - std::cout << "Template definition is null!" << std::endl; - else - std::cout << "Template definition is not null!" << std::endl; std::string fullyInstantiatedName = templateDefinition->getDataRef()->symbol.getName() + "<" + instTypeString + ">"; diff --git a/tests/test_templateFuncInfr.krak b/tests/test_templateFuncInfr.krak index fc45967..138eab1 100644 --- a/tests/test_templateFuncInfr.krak +++ b/tests/test_templateFuncInfr.krak @@ -9,7 +9,6 @@ fun main():int { var fromTemplateFun = id(11) var aVec.construct(): vector aVec.addEnd(12) - //var fromTemplateFun = id(11); println(fromTemplateFun) println(idVec(aVec)) //println(idVec(aVec))