From 1e76bf2772d0733071be1ca49442ea99d51e1576 Mon Sep 17 00:00:00 2001 From: Nathan Braswell Date: Fri, 26 Jun 2015 13:29:37 -0400 Subject: [PATCH] Closures work\! --- include/ASTTransformation.h | 5 ++- include/CGenerator.h | 5 +++ src/ASTTransformation.cpp | 65 +++++++++++++++--------------- src/CGenerator.cpp | 32 +++++++++++++-- src/util.cpp | 2 +- stdlib/regex.krak | 28 ++++--------- tests/test_lambda.expected_results | 5 +++ tests/test_lambda.krak | 33 +++++++++------ 8 files changed, 102 insertions(+), 73 deletions(-) diff --git a/include/ASTTransformation.h b/include/ASTTransformation.h index 6288db1..41e0c0c 100644 --- a/include/ASTTransformation.h +++ b/include/ASTTransformation.h @@ -42,8 +42,6 @@ class ASTTransformation: public NodeTransformation { virtual NodeTree* transform(NodeTree* from); NodeTree* transform(NodeTree* from, NodeTree* scope, std::vector types, bool limitToFunction, std::map templateTypeReplacements); std::vector*> transformChildren(std::vector*> children, std::set skipChildren, NodeTree* scope, std::vector types, bool limitToFunction, std::map templateTypeReplacements); - std::vector mapNodesToTypes(std::vector*> nodes); - std::vector mapNodesToTypePointers(std::vector*> nodes); std::string concatSymbolTree(NodeTree* root); NodeTree* doFunction(NodeTree* scope, std::string lookup, std::vector*> nodes, std::map templateTypeReplacements); @@ -73,4 +71,7 @@ class ASTTransformation: public NodeTransformation { int lambdaID = 0; }; +std::vector mapNodesToTypes(std::vector*> nodes); +std::vector mapNodesToTypePointers(std::vector*> nodes); + #endif diff --git a/include/CGenerator.h b/include/CGenerator.h index 1705187..c18b57b 100644 --- a/include/CGenerator.h +++ b/include/CGenerator.h @@ -12,6 +12,8 @@ #include "NodeTree.h" #include "ASTData.h" #include "Type.h" +// for mapNodesToTypes +#include "ASTTransformation.h" #include "util.h" #include "Poset.h" @@ -31,6 +33,8 @@ class CGenerator { std::pair generateTranslationUnit(std::string name, std::map*> ASTs); CCodeTriple generate(NodeTree* from, NodeTree* enclosingObject = NULL, bool justFuncName = false, NodeTree* enclosingFunction = NULL); std::string generateAliasChains(std::map*> ASTs, NodeTree* definition); + + std::string closureStructType(std::set*> closedVariables); std::string ValueTypeToCType(Type *type, std::string, ClosureTypeSpecialType closureSpecial = ClosureTypeRegularNone); std::string ValueTypeToCTypeDecoration(Type *type, ClosureTypeSpecialType closureSpecial = ClosureTypeRegularNone); std::string ValueTypeToCTypeThingHelper(Type *type, std::string ptrStr, ClosureTypeSpecialType closureSpecial); @@ -52,6 +56,7 @@ class CGenerator { std::string linkerString; std::string functionTypedefString; std::map> functionTypedefMap; + std::map*>, std::string> closureStructMap; std::vector*>> distructDoubleStack; std::stack loopDistructStackDepth; std::vector*>> deferDoubleStack; diff --git a/src/ASTTransformation.cpp b/src/ASTTransformation.cpp index cd00a3b..fa99973 100644 --- a/src/ASTTransformation.cpp +++ b/src/ASTTransformation.cpp @@ -794,26 +794,6 @@ std::vector*> ASTTransformation::transformChildren(std::vector return transformedChildren; } -//Extract types from already transformed nodes -std::vector ASTTransformation::mapNodesToTypePointers(std::vector*> nodes) { - std::vector types; - for (auto i : nodes) { - std::cout << i->getDataRef()->toString() << std::endl; - types.push_back((i->getDataRef()->valueType)); - } - return types; -} - -//Extract types from already transformed nodes -std::vector ASTTransformation::mapNodesToTypes(std::vector*> nodes) { - std::vector types; - for (auto i : nodes) { - std::cout << i->getDataRef()->toString() << std::endl; - types.push_back(*(i->getDataRef()->valueType)); - } - return types; -} - //Simple way to extract strings from syntax trees. Used often for identifiers, strings, types std::string ASTTransformation::concatSymbolTree(NodeTree* root) { std::string concatString; @@ -929,24 +909,24 @@ bool ASTTransformation::inScopeChain(NodeTree* node, NodeTree* // used to calculate the closedvariables for closures std::set*> ASTTransformation::findVariablesToClose(NodeTree* func, NodeTree* stat) { std::set*> closed; - for (auto child: stat->getChildren()) { //enum ASTType {undef, translation_unit, interpreter_directive, import, identifier, type_def, //function, code_block, typed_parameter, expression, boolean_expression, statement, //if_statement, while_loop, for_loop, return_statement, break_statement, continue_statement, defer_statement, //assignment_statement, declaration_statement, if_comp, simple_passthrough, passthrough_params, //in_passthrough_params, out_passthrough_params, opt_string, param_assign, function_call, value}; - if (child->getDataRef()->type == function || child->getDataRef()->type == translation_unit - || child->getDataRef()->type == type_def || child->getDataRef()->type == value - ) - continue; - if (child->getDataRef()->type == function_call && (child->getDataRef()->symbol.getName() == "." || child->getDataRef()->symbol.getName() == "->")) { - // only search on the left side of access operators like . and -> - auto recClosed = findVariablesToClose(func, child->getChildren().front()); - closed.insert(recClosed.begin(), recClosed.end()); - continue; - } - if (child->getDataRef()->type == identifier && !inScopeChain(child, func)) - closed.insert(child); + if (stat->getDataRef()->type == function || stat->getDataRef()->type == translation_unit + || stat->getDataRef()->type == type_def || stat->getDataRef()->type == value + ) + return closed; + if (stat->getDataRef()->type == function_call && (stat->getDataRef()->symbol.getName() == "." || stat->getDataRef()->symbol.getName() == "->")) { + // only search on the left side of access operators like . and -> + auto recClosed = findVariablesToClose(func, stat->getChildren()[1]); + closed.insert(recClosed.begin(), recClosed.end()); + return closed; + } + if (stat->getDataRef()->type == identifier && !inScopeChain(stat, func)) + closed.insert(stat); + for (auto child: stat->getChildren()) { auto recClosed = findVariablesToClose(func, child); closed.insert(recClosed.begin(), recClosed.end()); } @@ -1671,3 +1651,22 @@ NodeTree* ASTTransformation::addToScope(std::string name, NodeTree mapNodesToTypePointers(std::vector*> nodes) { + std::vector types; + for (auto i : nodes) { + std::cout << i->getDataRef()->toString() << std::endl; + types.push_back((i->getDataRef()->valueType)); + } + return types; +} + +//Extract types from already transformed nodes +std::vector mapNodesToTypes(std::vector*> nodes) { + std::vector types; + for (auto i : nodes) { + std::cout << i->getDataRef()->toString() << std::endl; + types.push_back(*(i->getDataRef()->valueType)); + } + return types; +} diff --git a/src/CGenerator.cpp b/src/CGenerator.cpp index 05acefd..bf4bf12 100644 --- a/src/CGenerator.cpp +++ b/src/CGenerator.cpp @@ -200,7 +200,7 @@ std::pair CGenerator::generateTranslationUnit(std::str else { std::string nameDecoration, parameters; if (declarationData.closedVariables.size()) - parameters += "struct closed *closed_varibles"; + parameters += closureStructType(declarationData.closedVariables) + "*"; for (int j = 0; j < decChildren.size()-1; j++) { if (j > 0 || declarationData.closedVariables.size() ) parameters += ", "; @@ -319,7 +319,7 @@ CCodeTriple CGenerator::generate(NodeTree* from, NodeTree* enc std::string nameDecoration, parameters; if (data.closedVariables.size()) - parameters += "struct closed *closed_varibles"; + parameters += closureStructType(data.closedVariables) + " *closed_varibles"; for (int j = 0; j < children.size()-1; j++) { if (j > 0 || data.closedVariables.size()) parameters += ", "; @@ -336,8 +336,16 @@ CCodeTriple CGenerator::generate(NodeTree* from, NodeTree* enc funcName += CifyName(data.symbol.getName() + nameDecoration); if (from->getDataRef()->closedVariables.size()) { std::string tmpStruct = "closureStruct" + getID(); - output.preValue += "struct specialClosure " + tmpStruct + ";\n"; - output += "("+ ValueTypeToCType(data.valueType, "") +"){" + funcName + ", &" + tmpStruct + "}"; + output.preValue += closureStructType(data.closedVariables) + " " + tmpStruct + " = {"; + bool notFirst = false; + for (auto var : data.closedVariables) { + if (notFirst) + output.preValue += ", "; + notFirst = true; + output.preValue += "." + scopePrefix(var) + var->getDataRef()->symbol.getName() + " = &" + scopePrefix(var) + var->getDataRef()->symbol.getName(); + } + output.preValue += "};\n"; + output += "("+ ValueTypeToCType(data.valueType, "") +"){(void*)" + funcName + ", &" + tmpStruct + "}"; } else { output += "("+ ValueTypeToCType(data.valueType, "") +"){" + funcName + ", NULL}"; } @@ -802,6 +810,22 @@ std::string CGenerator::emitDestructors(std::vector*> identifi return destructorString; } +std::string CGenerator::closureStructType(std::set*> closedVariables) { + auto it = closureStructMap.find(closedVariables); + if (it != closureStructMap.end()) + return it->second; + std::string typedefString = "typedef struct { "; + // note the increased indirection b/c we're using references to what we closed over + for (auto var : closedVariables) { + auto tmp = var->getDataRef()->valueType->withIncreasedIndirection(); + typedefString += ValueTypeToCType(&tmp, scopePrefix(var) + var->getDataRef()->symbol.getName()) + ";"; + } + std::string structName = "closureStructType" + getID(); + typedefString += " } " + structName + ";\n"; + functionTypedefString += typedefString; + closureStructMap[closedVariables] = structName; + return structName; +} std::string CGenerator::ValueTypeToCType(Type *type, std::string declaration, ClosureTypeSpecialType closureSpecial) { return ValueTypeToCTypeThingHelper(type, " " + declaration, closureSpecial); } std::string CGenerator::ValueTypeToCTypeDecoration(Type *type, ClosureTypeSpecialType closureSpecial) { return CifyName(ValueTypeToCTypeThingHelper(type, "", closureSpecial)); } diff --git a/src/util.cpp b/src/util.cpp index 689af72..6f81d2d 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -17,7 +17,7 @@ std::string replaceExEscape(std::string first, std::string search, std::string r if (pos > 0) { int numBackslashes = 0; int countBack = 1; - while (pos-countBack >= 0 && first[pos-countBack] == '\\') { + while ((int)pos-countBack >= 0 && first[pos-countBack] == '\\') { numBackslashes++; countBack++; } diff --git a/stdlib/regex.krak b/stdlib/regex.krak index ac5a27a..8159ce2 100644 --- a/stdlib/regex.krak +++ b/stdlib/regex.krak @@ -32,18 +32,7 @@ obj regexState(Object) { next_states.destruct() } fun match(input: char): vector::vector { - return next_states.filter(fun(it:regexState*, input:char):bool { return it->character == input; }, input) - - io::print("in match for: "); io::println(character) - io::println("pre") - for (var i = 0; i < next_states.size; i++;) - io::println(next_states[i]->character) - var nx = next_states.filter(fun(it:regexState*, input:char):bool { return it->character == input; }, input) - io::println("next") - for (var i = 0; i < nx.size; i++;) - io::println(nx[i]->character) - //return next_states.filter(fun(it:regexState*, input:char):bool { return it->character == input; }, input) - return nx + return next_states.filter(fun(it:regexState*):bool { return it->character == input; }) } fun is_end():bool { return next_states.any_true(fun(state: regexState*):bool { return state->character == 1; }) @@ -60,7 +49,8 @@ obj regex(Object) { var beginningAndEnd = compile(regexStringIn) // init our begin, and the end state as the next state of each end begin = beginningAndEnd.first - beginningAndEnd.second.do(fun(it: regexState*, end: regexState*): void { it->next_states.add(end); }, mem::new()->construct(conversions::to_char(1))) + var end = mem::new()->construct(conversions::to_char(1)) + beginningAndEnd.second.do(fun(it: regexState*): void { it->next_states.add(end); }) return this } @@ -118,11 +108,11 @@ obj regex(Object) { i = perenEnd-1 if (alternating) { - previous_end.do(fun(it: regexState*, innerBegin: vector::vector):void { it->next_states.add_all(innerBegin); }, innerBeginEnd.first->next_states) + previous_end.do(fun(it: regexState*):void { it->next_states.add_all(innerBeginEnd.first->next_states); } ) current_begin.add_all(innerBeginEnd.first->next_states) current_end.add_all(innerBeginEnd.second) } else { - current_end.do(fun(it: regexState*, innerBegin: vector::vector):void { it->next_states.add_all(innerBegin); }, innerBeginEnd.first->next_states) + current_end.do(fun(it: regexState*):void { it->next_states.add_all(innerBeginEnd.first->next_states); } ) previous_begin = current_begin previous_end = current_end current_begin = innerBeginEnd.first->next_states @@ -136,11 +126,11 @@ obj regex(Object) { } else { var next = mem::new()->construct(regex_string[i]) if (alternating) { - previous_end.do(fun(it: regexState*, next: regexState*):void { it->next_states.add(next); }, next) + previous_end.do(fun(it: regexState*):void { it->next_states.add(next); }) current_begin.add(next) current_end.add(next) } else { - current_end.do(fun(it: regexState*, next: regexState*):void { it->next_states.add(next); }, next) + current_end.do(fun(it: regexState*):void { it->next_states.add(next); }) previous_begin = current_begin previous_end = current_end current_begin = vector::vector(next) @@ -154,7 +144,6 @@ obj regex(Object) { return beginAndEnd } - fun long_match(to_match: char*): int { return long_match(string::string(to_match)); } fun long_match(to_match: string::string): int { var next = vector::vector(begin) @@ -165,7 +154,7 @@ obj regex(Object) { if (next.any_true(fun(state: regexState*):bool { return state->is_end(); })) longest = i //next = next.flatten_map(fun(state: regexState*): vector::vector { return state->match(to_match[i]); }) - next = next.flatten_map(fun(state: regexState*, c:char): vector::vector { return state->match(c); }, to_match[i]) + next = next.flatten_map(fun(state: regexState*): vector::vector { return state->match(to_match[i]); }) } if (next.any_true(fun(state: regexState*):bool { return state->is_end(); })) return to_match.length() @@ -173,4 +162,3 @@ obj regex(Object) { } } - diff --git a/tests/test_lambda.expected_results b/tests/test_lambda.expected_results index d8d6539..98281df 100644 --- a/tests/test_lambda.expected_results +++ b/tests/test_lambda.expected_results @@ -11,3 +11,8 @@ 7 8 9 +4 +closures now +1337 +13371010 +80 diff --git a/tests/test_lambda.krak b/tests/test_lambda.krak index 18bb01a..5abdfd5 100644 --- a/tests/test_lambda.krak +++ b/tests/test_lambda.krak @@ -1,4 +1,5 @@ import io:* +import vector:* fun runLambda(func: fun():int):void { println(func()) @@ -6,27 +7,33 @@ fun runLambda(func: fun():int):void { fun somethingElse():int { return 4; } -//fun callLambda(func: fun(int):void):void { - //func(10) -//} +fun callLambda(func: fun(int):void):void { + func(10) +} -//fun itr(it: T, func: fun(T):T):T { - //println(it) - //return func(it); -//} +fun itr(it: T, func: fun(T):T):T { + println(it) + return func(it); +} fun main():int { - //var func = fun():void { println("8"); } - //func() - //runLambda(fun():int { return 9;}) - //callLambda(fun(a:int):void { println(a);}) - //var j = 0 - //while (j < 10) j = itr(j, fun(a:int):int { return a+1; }) + var func = fun():void { println("8"); } + func() + runLambda(fun():int { return 9;}) + callLambda(fun(a:int):void { println(a);}) + var j = 0 + while (j < 10) j = itr(j, fun(a:int):int { return a+1; }) runLambda(somethingElse) println("closures now") var a = 1337 runLambda(fun():int { return a;}) + runLambda(fun():int { print(a); print(j); return j;}) + + var v = vector(80) + var idx = 0 + runLambda(fun():int { return v.get(idx);}) + return 0 }