diff --git a/include/ASTTransformation.h b/include/ASTTransformation.h index e9b5d7a..dd87356 100644 --- a/include/ASTTransformation.h +++ b/include/ASTTransformation.h @@ -47,7 +47,7 @@ class ASTTransformation: public NodeTransformation { NodeTree* doFunction(NodeTree* scope, std::string lookup, std::vector*> nodes, std::map templateTypeReplacements); NodeTree* functionLookup(NodeTree* scope, std::string lookup, std::vector types); - NodeTree* templateFunctionLookup(NodeTree* scope, std::string lookup, std::vector templateInstantiationTypes, std::vector types); + NodeTree* templateFunctionLookup(NodeTree* scope, std::string lookup, std::vector* templateInstantiationTypes, std::vector types); std::vector*> scopeLookup(NodeTree* scope, std::string lookup, bool includeModules = false); std::vector*> scopeLookup(NodeTree* scope, std::string lookup, bool includeModules, std::vector*> visited); @@ -55,6 +55,8 @@ class ASTTransformation: public NodeTransformation { NodeTree* addToScope(std::string name, NodeTree* toAdd, NodeTree* addTo); Type* typeFromTypeNode(NodeTree* typeNode, NodeTree* scope, std::map templateTypeReplacements); NodeTree* templateClassLookup(NodeTree* scope, std::string name, std::vector templateInstantiationTypes); + void unifyType(NodeTree *syntaxType, Type type, std::map* templateTypeMap); + void unifyTemplateFunction(NodeTree* templateFunction, std::vector types, std::vector* templateInstantiationTypes); NodeTree* findOrInstantiateFunctionTemplate(std::vector*> children, NodeTree* scope, std::vector types, std::map templateTypeReplacements); std::map makeTemplateFunctionTypeMap(NodeTree* templateNode, std::vector types); std::vector>> makeTemplateNameTraitPairs(NodeTree* templateNode); diff --git a/src/ASTTransformation.cpp b/src/ASTTransformation.cpp index d63cba5..de1c2f3 100644 --- a/src/ASTTransformation.cpp +++ b/src/ASTTransformation.cpp @@ -375,8 +375,10 @@ NodeTree* ASTTransformation::transform(NodeTree* from, NodeTree if (types.size()) { newNode = functionLookup(scope, lookupName, types); if (newNode == NULL) { - std::cerr << "scope lookup error! Could not find " << lookupName << " in identifier (functionLookup)" << std::endl; - throw "LOOKUP ERROR: " + lookupName; + std::cout << "scope lookup failed! Could not find " << lookupName << " in identifier (functionLookup)" << std::endl; + std::cout << "(maybe this is supposted to happen because the function is a template and we're infrencing)" << std::endl; + //throw "LOOKUP ERROR: " + lookupName; + return nullptr; } } else { auto possibleMatches = scopeLookup(scope, lookupName); @@ -520,8 +522,7 @@ NodeTree* ASTTransformation::transform(NodeTree* from, NodeTree if (name == "access_operation") { std::cout << "lhs is: " << lhs->getDataRef()->toString() << std::endl; rhs = transform(children[2], lhs->getDataRef()->valueType->typeDefinition, types, templateTypeReplacements); //If an access operation, then the right side will be in the lhs's type's scope - } - else + } else rhs = transform(children[2], scope, types, templateTypeReplacements); std::string functionCallName = concatSymbolTree(children[1]); @@ -536,12 +537,17 @@ NodeTree* ASTTransformation::transform(NodeTree* from, NodeTree } return newNode; //skipChildren.insert(1); - } else if (children.size() == 2) { - //Is template instantiation - return findOrInstantiateFunctionTemplate(children, scope, types, templateTypeReplacements); - } else { - return transform(children[0], scope, types, templateTypeReplacements); //Just a promoted child, so do it instead } + if (children.size() == 1) { + newNode = transform(children[0], scope, types, templateTypeReplacements); //Just a promoted child, so do it instead + if (newNode) + return newNode; + } + // So if children.size() != 1, or that returned null because the function lookup failed, + // we try to do a template instatiation. If it had 2 children, it's an instantion, if it has 1 + // maybe it's a template instantiation we're supposed to infer the types for. Either way, we let + // findorinstantiatefunctiontemplate take care of it. + return findOrInstantiateFunctionTemplate(children, scope, types, templateTypeReplacements); } else if (name == "factor") { //Do factor here, as it has all the weird unary operators //If this is an actual part of an expression, not just a premoted child //NO SUPPORT FOR CASTING YET @@ -917,11 +923,10 @@ NodeTree* ASTTransformation::functionLookup(NodeTree* scope, s for (int j = 0; j < types.size(); j++) { Type* tmpType = children[j]->getDataRef()->valueType; //Don't worry if types don't match if it's a template type - // std::cout << "Checking for segfaults, we have" << std::endl; - // std::cout << types[j].toString() << std::endl; - // std::cout << tmpType->toString() << std::endl; - // std::cout << "Done!" << std::endl; - if (types[j] != *tmpType && tmpType->baseType != template_type_type) { + //if (types[j] != *tmpType && tmpType->baseType != template_type_type) { + // WE DO WORRY NOW B/C template type infrence is ugly and we need this to fail + // for regular function lookups so that we know to retry with a template + if (types[j] != *tmpType) { typesMatch = false; std::cout << "Types do not match between two " << lookup << " " << types[j].toString(); std::cout << " vs " << children[j]->getDataRef()->valueType->toString() << std::endl; @@ -1005,8 +1010,41 @@ NodeTree* ASTTransformation::templateClassLookup(NodeTree* sco return *mostFittingTemplates.begin(); } +void ASTTransformation::unifyType(NodeTree *syntaxType, Type type, std::map* templateTypeMap) { + // Ok, 3 options for syntaxType here. + // 1) This a basic type. (int, or object, etc) + // then check to see if it's the same as our type + // 2) This is a template type type (i.e. T) + // match! set up templateTypeMap[T] -> type + // 3) This some sort of instantiated template + // a) instantiated with some other type (i.e. vector) + // this will be a bit of a pain + // b) instantiated with a template type type (i.e. vector) + // this will be a bit of a pain too + + auto children = syntaxType->getChildren(); + if (children.size() == 1) { + (*templateTypeMap)[concatSymbolTree(children.back())] = type; + } else { + throw "the inference just isn't good enough"; + } +} + +void ASTTransformation::unifyTemplateFunction(NodeTree* templateFunction, std::vector types, std::vector* templateInstantiationTypes) { + NodeTree* templateSyntaxTree = templateFunction->getDataRef()->valueType->templateDefinition; + std::vector*> templateParameters = getNodes("typed_parameter", templateSyntaxTree); + if (templateParameters.size() != types.size()) + return; + std::map templateTypeMap; + for (int i = 0; i < types.size(); i++) + unifyType(getNode("type", templateParameters[i]), types[i], &templateTypeMap); + for (auto instantiationParam : getNodes("template_param", getNode("template_dec", templateSyntaxTree))) + templateInstantiationTypes->push_back(templateTypeMap[concatSymbolTree(instantiationParam)].clone()); +} + //Lookup function for template functions. It has some extra concerns compared to function lookup, namely traits -NodeTree* ASTTransformation::templateFunctionLookup(NodeTree* scope, std::string lookup, std::vector templateInstantiationTypes, std::vector types) { +NodeTree* ASTTransformation::templateFunctionLookup(NodeTree* scope, std::string lookup, std::vector* templateInstantiationTypes, std::vector types) { + std::map*, std::vector> templateInstantiationTypesPerFunction; std::set*> mostFittingTemplates; int bestNumTraitsSatisfied = -1; auto possibleMatches = scopeLookup(scope, lookup); @@ -1019,10 +1057,14 @@ NodeTree* ASTTransformation::templateFunctionLookup(NodeTree* std::cout << "Not a template, skipping" << std::endl; continue; } - + // If template instantiation was explicit, use those types. Otherwise, unify to find them + if (templateInstantiationTypes->size()) + templateInstantiationTypesPerFunction[i] = *templateInstantiationTypes; + else + unifyTemplateFunction(i, types, &templateInstantiationTypesPerFunction[i]); auto nameTraitsPairs = makeTemplateNameTraitPairs(templateSyntaxTree->getChildren()[1]); //Check if sizes match between the placeholder and actual template types - if (nameTraitsPairs.size() != templateInstantiationTypes.size()) + if (nameTraitsPairs.size() != templateInstantiationTypesPerFunction[i].size()) continue; std::map typeMap; @@ -1030,23 +1072,23 @@ NodeTree* ASTTransformation::templateFunctionLookup(NodeTree* int typeIndex = 0; int currentTraitsSatisfied = 0; for (auto j : nameTraitsPairs) { - if (!subset(j.second, templateInstantiationTypes[typeIndex]->traits)) { + if (!subset(j.second, templateInstantiationTypesPerFunction[i][typeIndex]->traits)) { traitsEqual = false; - std::cout << "Traits not a subset for " << j.first << " and " << templateInstantiationTypes[typeIndex]->toString() << ": "; + std::cout << "Traits not a subset for " << j.first << " and " << templateInstantiationTypesPerFunction[i][typeIndex]->toString() << ": "; std::copy(j.second.begin(), j.second.end(), std::ostream_iterator(std::cout, " ")); std::cout << " vs "; - std::copy(templateInstantiationTypes[typeIndex]->traits.begin(), templateInstantiationTypes[typeIndex]->traits.end(), std::ostream_iterator(std::cout, " ")); + std::copy(templateInstantiationTypesPerFunction[i][typeIndex]->traits.begin(), templateInstantiationTypesPerFunction[i][typeIndex]->traits.end(), std::ostream_iterator(std::cout, " ")); std::cout << std::endl; break; } else { - std::cout << "Traits ARE a subset for " << j.first << " and " << templateInstantiationTypes[typeIndex]->toString() << ": "; + std::cout << "Traits ARE a subset for " << j.first << " and " << templateInstantiationTypesPerFunction[i][typeIndex]->toString() << ": "; std::copy(j.second.begin(), j.second.end(), std::ostream_iterator(std::cout, " ")); std::cout << " vs "; - std::copy(templateInstantiationTypes[typeIndex]->traits.begin(), templateInstantiationTypes[typeIndex]->traits.end(), std::ostream_iterator(std::cout, " ")); + std::copy(templateInstantiationTypesPerFunction[i][typeIndex]->traits.begin(), templateInstantiationTypesPerFunction[i][typeIndex]->traits.end(), std::ostream_iterator(std::cout, " ")); std::cout << std::endl; } //As we go, build up the typeMap for when we transform the parameters for parameter checking - typeMap[j.first] = templateInstantiationTypes[typeIndex]; + typeMap[j.first] = templateInstantiationTypesPerFunction[i][typeIndex]; currentTraitsSatisfied += j.second.size(); typeIndex++; } @@ -1087,6 +1129,11 @@ NodeTree* ASTTransformation::templateFunctionLookup(NodeTree* std::cerr << "Multiple template functions fit with equal number of traits satisfied for " << lookup << "!" << std::endl; throw "Multiple matching template functions"; } + // Assign our most fitting instantiation types to what we were passed in + // if it was empty + if (templateInstantiationTypes->size() == 0) + *templateInstantiationTypes = templateInstantiationTypesPerFunction[*mostFittingTemplates.begin()]; + std::cout << *mostFittingTemplates.begin() << std::endl; return *mostFittingTemplates.begin(); } @@ -1263,7 +1310,7 @@ Type* ASTTransformation::typeFromTypeNode(NodeTree* typeNode, NodeTree |T| fun(|vec| a) { return a.at(0); } + * fun example(a:vec):T { return a.at(0); } * etc *******************************************************************************/ if (instType->baseType == template_type_type) @@ -1351,24 +1398,34 @@ NodeTree* ASTTransformation::findOrInstantiateFunctionTemplate(std::vec //First look to see if we can find this already instantiated std::cout << "\n\nFinding or instantiating templated function\n\n" << std::endl; std::string functionName = concatSymbolTree(children[0]); - - auto unsliced = children[1]->getChildren(); - std::vector*> templateParamInstantiationNodes = slice(unsliced, 1 , -2, 2);//skip <, >, and commas - std::string instTypeString = ""; + std::string fullyInstantiatedName; std::vector templateActualTypes; - for (int i = 0; i < templateParamInstantiationNodes.size(); i++) { - Type* instType = typeFromTypeNode(templateParamInstantiationNodes[i],scope, templateTypeReplacements); - instTypeString += (instTypeString == "" ? instType->toString() : "," + instType->toString()); - templateActualTypes.push_back(instType); - } - std::cout << "Size: " << templateParamInstantiationNodes.size() << std::endl; - std::string fullyInstantiatedName = functionName + "<" + instTypeString + ">"; - std::cout << "Looking for " << fullyInstantiatedName << std::endl; + NodeTree* templateDefinition = NULL; + // Are we supposed to infer our instantiation, or not? If we have only one child we're inferring as we don't + // have the actual instantiation part. If do have the instantiation part, then we'll use that. + // Note that as a part o finferring the instantiation we already find the template, so we make that + // condtitional too (templateDefinition) + if (children.size() == 1) { + // templateFunctionLookup adds the actual types to templateActualTypes if it's currently empty + templateDefinition = templateFunctionLookup(scope, functionName, &templateActualTypes, types); + } else { + auto unsliced = children[1]->getChildren(); + std::vector*> templateParamInstantiationNodes = slice(unsliced, 1 , -2, 2);//skip <, >, and commas + std::string instTypeString = ""; + for (int i = 0; i < templateParamInstantiationNodes.size(); i++) { + Type* instType = typeFromTypeNode(templateParamInstantiationNodes[i],scope, templateTypeReplacements); + instTypeString += (instTypeString == "" ? instType->toString() : "," + instType->toString()); + templateActualTypes.push_back(instType); + } + std::cout << "Size: " << templateParamInstantiationNodes.size() << std::endl; + fullyInstantiatedName = functionName + "<" + instTypeString + ">"; + std::cout << "Looking for " << fullyInstantiatedName << std::endl; + } std::cout << "Types are : "; - for (auto i : types) - std::cout << " " << i.toString(); - std::cout << std::endl; + for (auto i : types) + std::cout << " " << i.toString(); + std::cout << std::endl; NodeTree* instantiatedFunction = functionLookup(scope, fullyInstantiatedName, types); //If it already exists, return it @@ -1386,7 +1443,11 @@ NodeTree* ASTTransformation::findOrInstantiateFunctionTemplate(std::vec //Otherwise, we're going to instantiate it //Find the template definitions - NodeTree* templateDefinition = templateFunctionLookup(scope, functionName, templateActualTypes, types); + // templateFunctionLookup adds the actual types to templateActualTypes if it's currently empty + // by here, it's not as either we had the instantiation already or we figured out out before + // and are not actually doing this call + if (!templateDefinition) + templateDefinition = templateFunctionLookup(scope, functionName, &templateActualTypes, types); if (templateDefinition == NULL) { std::cout << functionName << " search turned up null, returing null" << std::endl; return NULL; @@ -1402,17 +1463,12 @@ NodeTree* ASTTransformation::findOrInstantiateFunctionTemplate(std::vec std::cout << std::endl; instantiatedFunction = new NodeTree("function", ASTData(function, Symbol(fullyInstantiatedName, true), typeFromTypeNode(templateChildren[templateChildren.size()-2], scope, newTemplateTypeReplacement))); - //scope->getDataRef()->scope[fullyInstantiatedName].push_back(instantiatedFunction); - //instantiatedFunction->getDataRef()->scope["~enclosing_scope"].push_back(templateDefinition->getDataRef()->scope["~enclosing_scope"][0]); //Instantiated Template Function's scope is it's template's definition's scope addToScope("~enclosing_scope", templateDefinition->getDataRef()->scope["~enclosing_scope"][0], instantiatedFunction); // Arrrrrgh this has a hard time working because the functions will need to see their parameter once they are emitted as C. // HAHAHAHAHA DOESN'T MATTER ALL ONE C FILE NOW, swap back to old way auto templateTopScope = getUpperTranslationUnit(templateDefinition); - //templateTopScope->getDataRef()->scope[fullyInstantiatedName].push_back(instantiatedFunction); addToScope(fullyInstantiatedName, instantiatedFunction, templateTopScope); templateTopScope->addChild(instantiatedFunction); // Add this object the the highest scope's - //topScope->getDataRef()->scope[fullyInstantiatedName].push_back(instantiatedFunction); - //topScope->addChild(instantiatedFunction); //Add this object the the highest scope's std::set skipChildren; skipChildren.insert(0); diff --git a/tests/test_templateFuncInfr.expected_results b/tests/test_templateFuncInfr.expected_results new file mode 100644 index 0000000..c1f3abd --- /dev/null +++ b/tests/test_templateFuncInfr.expected_results @@ -0,0 +1,2 @@ +11 +12 diff --git a/tests/test_templateFuncInfr.krak b/tests/test_templateFuncInfr.krak new file mode 100644 index 0000000..fc45967 --- /dev/null +++ b/tests/test_templateFuncInfr.krak @@ -0,0 +1,17 @@ +import io:* +import mem:* +import vector:* + +fun id(in: T): T { return in; } +fun idVec(in: vector): T { return in.get(0); } + +fun main():int { + var fromTemplateFun = id(11) + var aVec.construct(): vector + aVec.addEnd(12) + //var fromTemplateFun = id(11); + println(fromTemplateFun) + println(idVec(aVec)) + //println(idVec(aVec)) + return 0 +} diff --git a/tests/test_typeInfr.krak b/tests/test_typeInfr.krak index aae8edc..e0be334 100644 --- a/tests/test_typeInfr.krak +++ b/tests/test_typeInfr.krak @@ -28,14 +28,14 @@ fun main():int { var add = 9 + 3 var flt = 9.3 var msg = retMessage() - var fromTemplateFun = id(11); + var fromTemplateFun = id(11); var vec = new>()->construct() vec->addEnd(avar) var src: CustomObj src.data = 80 var dst = src - var throughTemp = inFun("complicated"); + var throughTemp = inFun("complicated"); println(str) println(avar)