diff --git a/include/ASTTransformation.h b/include/ASTTransformation.h index 40bcd31..e9b5d7a 100644 --- a/include/ASTTransformation.h +++ b/include/ASTTransformation.h @@ -19,6 +19,11 @@ class ASTTransformation: public NodeTransformation { ASTTransformation(Importer* importerIn); ~ASTTransformation(); + NodeTree* getNode(std::string lookup, std::vector*> nodes); + NodeTree* getNode(std::string lookup, NodeTree* parent); + std::vector*> getNodes(std::string lookup, std::vector*> nodes); + std::vector*> getNodes(std::string lookup, NodeTree* parent); + //First pass defines all type_defs (objects and ailises) NodeTree* firstPass(std::string fileName, NodeTree* parseTree); std::set parseTraits(NodeTree* traitsNode); diff --git a/krakenGrammer.kgm b/krakenGrammer.kgm index e22add0..28fcc5a 100644 --- a/krakenGrammer.kgm +++ b/krakenGrammer.kgm @@ -104,7 +104,8 @@ number = integer | floating_literal ; access_operation = unarad "." identifier | unarad "->" identifier ; assignment_statement = factor WS "=" WS boolean_expression | factor WS "\+=" WS boolean_expression | factor WS "-=" WS boolean_expression | factor WS "\*=" WS boolean_expression | factor WS "/=" WS boolean_expression ; -declaration_statement = "var" WS identifier WS dec_type WS "=" WS boolean_expression | "var" WS identifier WS dec_type | "var" WS identifier WS "." WS identifier WS "\(" WS opt_parameter_list WS "\)" WS dec_type ; +# if it's being assigned to, we allow type inferencing +declaration_statement = "var" WS identifier WS "=" WS boolean_expression | "var" WS identifier WS dec_type WS "=" WS boolean_expression | "var" WS identifier WS dec_type | "var" WS identifier WS "." WS identifier WS "\(" WS opt_parameter_list WS "\)" WS dec_type ; hexadecimal = "0x(1|2|3|4|5|6|7|8|9|a|b|c|d|e|f)+" ; integer = numeric | hexadecimal ; floating_literal = numeric "." numeric ; diff --git a/src/ASTTransformation.cpp b/src/ASTTransformation.cpp index 30f8346..d63cba5 100644 --- a/src/ASTTransformation.cpp +++ b/src/ASTTransformation.cpp @@ -39,6 +39,33 @@ ASTTransformation::ASTTransformation(Importer *importerIn) { ASTTransformation::~ASTTransformation() { } +NodeTree* ASTTransformation::getNode(std::string lookup, NodeTree* parent) { + auto results = getNodes(lookup, parent); + if (results.size() > 1) + throw "too many results"; + if (results.size()) + return results[0]; + return nullptr; +} +NodeTree* ASTTransformation::getNode(std::string lookup, std::vector*> nodes) { + auto results = getNodes(lookup, nodes); + if (results.size() > 1) + throw "too many results"; + if (results.size()) + return results[0]; + return nullptr; +} +std::vector*> ASTTransformation::getNodes(std::string lookup, NodeTree* parent) { + return getNodes(lookup, parent->getChildren()); +} +std::vector*> ASTTransformation::getNodes(std::string lookup, std::vector*> nodes) { + std::vector*> results; + for (auto i : nodes) + if (i->getDataRef()->getName() == lookup) + results.push_back(i); + return results; +} + //First pass defines all type_defs (objects and ailises), and if_comp/simple_passthrough NodeTree* ASTTransformation::firstPass(std::string fileName, NodeTree* parseTree) { NodeTree* translationUnit = new NodeTree("translation_unit", ASTData(translation_unit, Symbol(fileName, false))); @@ -580,20 +607,24 @@ NodeTree* ASTTransformation::transform(NodeTree* from, NodeTree // NodeTree* newIdentifier = transform(children[1], scope); //Transform the identifier // newIdentifier->getDataRef()->valueType = Type(concatSymbolTree(children[0]));//set the type of the identifier std::string newIdentifierStr = concatSymbolTree(children[0]); - Type* identifierType; - if (children.size() > 1 && concatSymbolTree(children[1]) == ".") - identifierType = typeFromTypeNode(children.back(), scope, templateTypeReplacements); - else - identifierType = typeFromTypeNode(children[2], scope, templateTypeReplacements); + NodeTree* typeSyntaxNode = getNode("type", children); + Type* identifierType = typeSyntaxNode ? typeFromTypeNode(typeSyntaxNode, scope, templateTypeReplacements) : nullptr; + //if (children.size() > 1 && concatSymbolTree(children[1]) == ".") + //identifierType = typeFromTypeNode(children.back(), scope, templateTypeReplacements); + //else + //identifierType = typeFromTypeNode(children[2], scope, templateTypeReplacements); - std::cout << "Declaring an identifier " << newIdentifierStr << " to be of type " << identifierType->toString() << std::endl; - NodeTree* newIdentifier = new NodeTree("identifier", ASTData(identifier, Symbol(newIdentifierStr, true), identifierType)); - addToScope(newIdentifierStr, newIdentifier, scope); - addToScope("~enclosing_scope", scope, newNode); - addToScope("~enclosing_scope", newNode, newIdentifier); - newNode->addChild(newIdentifier); + if (identifierType) + std::cout << "Declaring an identifier " << newIdentifierStr << " to be of type " << identifierType->toString() << std::endl; + else + std::cout << "Declaring an identifier " << newIdentifierStr << " with type to be type-inferenced " << std::endl; if (children.size() > 1 && concatSymbolTree(children[1]) == ".") { + NodeTree* newIdentifier = new NodeTree("identifier", ASTData(identifier, Symbol(newIdentifierStr, true), identifierType)); + addToScope(newIdentifierStr, newIdentifier, scope); + addToScope("~enclosing_scope", scope, newNode); + addToScope("~enclosing_scope", newNode, newIdentifier); + newNode->addChild(newIdentifier); //A bit of a special case for declarations - if there's anything after just the normal 1 node declaration, it's either //an expression that is assigned to the declaration (int a = 4;) or a member call (Object a.constructAThing()) //This code is a simplified version of the code in function_call with respect to access_operation. @@ -615,9 +646,28 @@ NodeTree* ASTTransformation::transform(NodeTree* from, NodeTree return newNode; } - skipChildren.insert(0); //These, the type and the identifier, have been taken care of. - skipChildren.insert(2); - newNode->addChildren(transformChildren(children, skipChildren, scope, types, templateTypeReplacements)); + //skipChildren.insert(0); //These, the type and the identifier, have been taken care of. + //skipChildren.insert(2); + //auto transChildren = transformChildren(children, skipChildren, scope, types, templateTypeReplacements); + auto boolExp = getNode("boolean_expression", children); + NodeTree* toAssign = boolExp ? transform(boolExp, scope, types, templateTypeReplacements) : nullptr; + // for type inferencing + if (!identifierType) { + if (toAssign) + identifierType = toAssign->getDataRef()->valueType; + else + throw "have to inference but no expression"; + } + + NodeTree* newIdentifier = new NodeTree("identifier", ASTData(identifier, Symbol(newIdentifierStr, true), identifierType)); + addToScope(newIdentifierStr, newIdentifier, scope); + addToScope("~enclosing_scope", scope, newNode); + addToScope("~enclosing_scope", newNode, newIdentifier); + + newNode->addChild(newIdentifier); + if (toAssign) + newNode->addChild(toAssign); + //newNode->addChildren(transChildren); return newNode; } else if (name == "if_comp") { newNode = new NodeTree(name, ASTData(if_comp)); diff --git a/tests/test_typeInfr.expected_results b/tests/test_typeInfr.expected_results new file mode 100644 index 0000000..259c019 --- /dev/null +++ b/tests/test_typeInfr.expected_results @@ -0,0 +1,10 @@ +hello +9 +27 +12 +9.300000 +I do like type inference +11 +9 +80 +complicated diff --git a/tests/test_typeInfr.krak b/tests/test_typeInfr.krak new file mode 100644 index 0000000..aae8edc --- /dev/null +++ b/tests/test_typeInfr.krak @@ -0,0 +1,51 @@ +import io:* +import mem:* +import vector:* + +fun retMessage(): char* { + return "I do like type inference" +} +fun id(in: T): T { return in; } + +typedef CustomObj { + var data: int; +} + +typedef CustomObjTmplt { + var data: T; +} +fun inFun(in: T):T { + var src: CustomObjTmplt + src.data = in + var dst = src + return dst.data +} + +fun main():int { + var str = "hello" + var avar = 9 + var mul = 9 * 3 + var add = 9 + 3 + var flt = 9.3 + var msg = retMessage() + var fromTemplateFun = id(11); + var vec = new>()->construct() + vec->addEnd(avar) + + var src: CustomObj + src.data = 80 + var dst = src + var throughTemp = inFun("complicated"); + + println(str) + println(avar) + println(mul) + println(add) + println(flt) + println(msg) + println(fromTemplateFun) + println(vec->at(0)) + println(dst.data) + println(throughTemp) + return 0 +}