diff --git a/include/Type.h b/include/Type.h index 5ed2198..9c04217 100644 --- a/include/Type.h +++ b/include/Type.h @@ -41,6 +41,8 @@ class Type { void decreaseIndirection(); void modifyIndirection(int mod); Type withIncreasedIndirection(); + Type withReference(); + Type *withReferencePtr(); Type *withIncreasedIndirectionPtr(); Type withDecreasedIndirection(); diff --git a/src/ASTTransformation.cpp b/src/ASTTransformation.cpp index 5adebb3..e70812b 100644 --- a/src/ASTTransformation.cpp +++ b/src/ASTTransformation.cpp @@ -200,12 +200,12 @@ void ASTTransformation::secondPass(NodeTree* ast, NodeTree* par // Let's make an equality function prototype Type *thisADTType = adtDef->getDataRef()->valueType; - NodeTree* equalityFunc = new NodeTree("function", ASTData(function, Symbol("operator==", true), new Type(std::vector{thisADTType}, new Type(boolean)))); + NodeTree* equalityFunc = new NodeTree("function", ASTData(function, Symbol("operator==", true), new Type(std::vector{thisADTType->withReferencePtr()}, new Type(boolean)))); adtDef->addChild(equalityFunc); addToScope("operator==", equalityFunc, adtDef); addToScope("~enclosing_scope", adtDef, equalityFunc); - NodeTree* inequalityFunc = new NodeTree("function", ASTData(function, Symbol("operator!=", true), new Type(std::vector{thisADTType}, new Type(boolean)))); + NodeTree* inequalityFunc = new NodeTree("function", ASTData(function, Symbol("operator!=", true), new Type(std::vector{thisADTType->withReferencePtr()}, new Type(boolean)))); adtDef->addChild(inequalityFunc); addToScope("operator!=", inequalityFunc, adtDef); addToScope("~enclosing_scope", adtDef, inequalityFunc); @@ -215,6 +215,11 @@ void ASTTransformation::secondPass(NodeTree* ast, NodeTree* par addToScope("copy_construct", copy_constructFunc, adtDef); addToScope("~enclosing_scope", adtDef, copy_constructFunc); + NodeTree* assignmentFunc = new NodeTree("function", ASTData(function, Symbol("operator=", true), new Type(std::vector{thisADTType->withReferencePtr()}, new Type(void_type)))); + adtDef->addChild(assignmentFunc); + addToScope("operator=", assignmentFunc, adtDef); + addToScope("~enclosing_scope", adtDef, assignmentFunc); + NodeTree* destructFunc = new NodeTree("function", ASTData(function, Symbol("destruct", true), new Type(std::vector(), new Type(void_type)))); adtDef->addChild(destructFunc); addToScope("destruct", destructFunc, adtDef); diff --git a/src/CGenerator.cpp b/src/CGenerator.cpp index fad7d4f..f9ffa51 100644 --- a/src/CGenerator.cpp +++ b/src/CGenerator.cpp @@ -306,11 +306,9 @@ std::pair CGenerator::generateTranslationUnit(std::str nameDecoration += "_" + ValueTypeToCTypeDecoration(paramType); std::string fun_name = "fun_" + declarationData.symbol.getName() + "__" + CifyName(orig_fun_name + nameDecoration); std::string first_param; - if (orig_fun_name == "operator==" || orig_fun_name == "operator!=" || orig_fun_name == "copy_construct" || orig_fun_name == "destruct") { + if (orig_fun_name == "operator==" || orig_fun_name == "operator!=" || orig_fun_name == "copy_construct" || orig_fun_name == "operator=" + || orig_fun_name == "destruct") { first_param = ValueTypeToCType(declarationData.valueType->withIncreasedIndirectionPtr(), "this"); - //first_param = ValueTypeToCType(child->getDataRef()->valueType->parameterTypes[0]->withIncreasedIndirectionPtr(), "this"); - //if (orig_fun_name == "operator==" || orig_fun_name == "operator!=" || orig_fun_name == "copy_construct") - //first_param += ", "; } bool has_param = child->getDataRef()->valueType->parameterTypes.size(); std::string first_part = "\n" + ValueTypeToCType(child->getDataRef()->valueType->returnType, fun_name) + "(" + first_param + @@ -319,39 +317,107 @@ std::pair CGenerator::generateTranslationUnit(std::str functionDefinitions += first_part + "{ /*adt func*/\n"; if (orig_fun_name == "operator==") { functionDefinitions += " /* equality woop woop */\n"; - functionDefinitions += " if (this->flag != in.flag) return false;\n"; + functionDefinitions += " bool equal = true;\n"; + functionDefinitions += " if (this->flag != in->flag) equal = false;\n"; for (auto child : decChildren) { if (child->getName() != "function" && child->getDataRef()->valueType->typeDefinition != declaration) { std::string option_name = child->getDataRef()->symbol.getName(); - functionDefinitions += " if (this->flag == " + declarationData.symbol.getName() + "__" + option_name + ") {\n"; + functionDefinitions += " else if (this->flag == " + declarationData.symbol.getName() + "__" + option_name + ") {\n"; NodeTree* method = nullptr; if (method = getMethod(child->getDataRef()->valueType, "operator==", std::vector{*child->getDataRef()->valueType})) { - functionDefinitions += " return " + generateMethodIfExists(child->getDataRef()->valueType, "operator==", - "&this->" + option_name + ", " + (method->getDataRef()->valueType->parameterTypes[0]->is_reference ? "&" : "") + "in." + option_name, - std::vector{*child->getDataRef()->valueType}) + ";\n}\n"; + bool is_reference = method->getDataRef()->valueType->parameterTypes[0]->is_reference; + + auto itemTypeVector = std::vector{child->getDataRef()->valueType->withIncreasedIndirection()}; + bool need_temporary = !is_reference && getMethod(child->getDataRef()->valueType, "copy_construct", itemTypeVector); + if (need_temporary) { + functionDefinitions += " " + ValueTypeToCType(child->getDataRef()->valueType, "copy_constructTemporary") + ";\n"; + functionDefinitions += " " + generateMethodIfExists(child->getDataRef()->valueType, "copy_construct", + "©_constructTemporary, &in->" + option_name, itemTypeVector) + ";\n"; + } + + std::string otherValue = (is_reference ? "&" : "") + (need_temporary ? "copy_constructTemporary" : "in->" + option_name); + functionDefinitions += " equal = " + generateMethodIfExists(child->getDataRef()->valueType, "operator==", + "&this->" + option_name + ", " + otherValue, + std::vector{*child->getDataRef()->valueType}) + ";\n"; + // Remember, we don't destruct copy_constructTemporary because the function will do that + functionDefinitions += "}\n"; } else { - functionDefinitions += " return this->" + option_name + " == in." + option_name + ";\n}\n"; + functionDefinitions += " equal = this->" + option_name + " == in->" + option_name + ";\n}\n"; } } } - - functionDefinitions += " return true;\n"; + functionDefinitions += " return equal;\n"; } else if (orig_fun_name == "operator!=") { functionDefinitions += " /* inequality woop woop */\n"; std::string adtName = declarationData.symbol.getName(); - functionDefinitions += " return !fun_" + adtName + "__" + CifyName("operator==") + "_" + adtName+ "(this, in);\n"; + + functionDefinitions += " bool equal = !fun_" + adtName + "__" + CifyName("operator==") + "_" + adtName + "_space__div__star_ref_star__div__space__star_(this, in);\n"; + functionDefinitions += " return equal;\n"; + } else if (orig_fun_name == "operator=") { + functionDefinitions += "/* wopo assignment */\n"; + auto adtType = declaration->getDataRef()->valueType; + functionDefinitions += " " + generateMethodIfExists(adtType, "destruct", + "this", std::vector()) + ";\n"; + functionDefinitions += " " + generateMethodIfExists(adtType, "copy_construct", + "this, in", std::vector{adtType->withIncreasedIndirection()}) + ";\n"; } else if (orig_fun_name == "copy_construct") { functionDefinitions += " /* copy_construct woop woop */\n"; + functionDefinitions += " this->flag = in->flag;\n"; + std::string elsePrefix = ""; + for (auto child : decChildren) { + if (child->getName() != "function" && child->getDataRef()->valueType->typeDefinition != declaration) { + std::string option_name = child->getDataRef()->symbol.getName(); + functionDefinitions += " " + elsePrefix + " if (in->flag == " + declarationData.symbol.getName() + "__" + option_name + ") {\n"; + elsePrefix = "else"; + NodeTree* method = nullptr; + auto itemTypeVector = std::vector{child->getDataRef()->valueType->withIncreasedIndirection()}; + if (method = getMethod(child->getDataRef()->valueType, "copy_construct", itemTypeVector)) { + functionDefinitions += " " + generateMethodIfExists(child->getDataRef()->valueType, "copy_construct", + "&this->" + option_name + ", &in->" + option_name, itemTypeVector) + ";\n"; + } else { + functionDefinitions += "this->" + option_name + " = in->" + option_name + ";\n"; + } + functionDefinitions += " }\n"; + } + } } else if (orig_fun_name == "destruct") { functionDefinitions += " /* destruct woop woop */\n"; + std::string elsePrefix = ""; + for (auto child : decChildren) { + if (child->getName() != "function" && child->getDataRef()->valueType->typeDefinition != declaration) { + std::string option_name = child->getDataRef()->symbol.getName(); + functionDefinitions += " " + elsePrefix + " if (this->flag == " + declarationData.symbol.getName() + "__" + option_name + ") {\n"; + elsePrefix = "else"; + NodeTree* method = nullptr; + if (method = getMethod(child->getDataRef()->valueType, "destruct", std::vector())) { + functionDefinitions += " " + generateMethodIfExists(child->getDataRef()->valueType, "destruct", + "&this->" + option_name, std::vector()) + ";\n"; + } + functionDefinitions += " }\n"; + } + } } else { // ok, is a constructor function functionDefinitions += " /* constructor woop woop */\n"; functionDefinitions += " " + declarationData.symbol.getName() + " toRet;\n"; functionDefinitions += " toRet.flag = " + declarationData.symbol.getName() + "__" + orig_fun_name + ";\n"; - if (has_param) - functionDefinitions += " toRet." + orig_fun_name + " = in;\n"; + if (has_param) { + NodeTree* method = nullptr; + auto paramType = child->getDataRef()->valueType->parameterTypes[0]; + auto itemTypeVector = std::vector{paramType->withIncreasedIndirection()}; + functionDefinitions += "/*" + ValueTypeToCType(paramType, "") + "*/\n"; + if (method = getMethod(paramType, "copy_construct", itemTypeVector)) { + functionDefinitions += " " + generateMethodIfExists(paramType, "copy_construct", + "&toRet." + orig_fun_name + ", &in", itemTypeVector) + ";\n"; + } else { + functionDefinitions += " toRet." + orig_fun_name + " = in;\n"; + } + + if (method = getMethod(paramType, "destruct", std::vector())) { + functionDefinitions += " " + generateMethodIfExists(paramType, "destruct", + "&in", std::vector()) + ";\n"; + } } functionDefinitions += " return toRet;\n"; } functionDefinitions += "}\n"; @@ -836,7 +902,7 @@ CCodeTriple CGenerator::generate(NodeTree* from, NodeTree* enc //for (int i = 0; i < (functionDefChildren.size() > 0 ? functionDefChildren.size()-1 : 0); i++) //nameDecoration += "_" + ValueTypeToCTypeDecoration(functionDefChildren[i]->getData().valueType); // Note that we only add scoping to the object, as this specifies our member function too - /*HERE*/ return function_header + prefixIfNeeded(scopePrefix(unaliasedTypeDef), CifyName(unaliasedTypeDef->getDataRef()->symbol.getName())) +"__" + + return function_header + prefixIfNeeded(scopePrefix(unaliasedTypeDef), CifyName(unaliasedTypeDef->getDataRef()->symbol.getName())) +"__" + CifyName(functionName + nameDecoration) + "(" + (name == "." ? "&" : "") + generate(children[1], enclosingObject, true, enclosingFunction) + ","; //The comma lets the upper function call know we already started the param list //Note that we got here from a function call. We just pass up this special case and let them finish with the perentheses @@ -964,12 +1030,7 @@ CCodeTriple CGenerator::generate(NodeTree* from, NodeTree* enc case value: { // ok, we now check for it being a multiline string and escape all returns if it is (so that multiline strings work) - //if (data.symbol.getName()[0] == '"') { if (data.symbol.getName()[0] == '"' && strSlice(data.symbol.getName(), 0, 3) == "\"\"\"") { - //bool multiline_str = strSlice(data.symbol.getName(), 0, 3) == "\"\"\""; - //std::string innerString = multiline_str - //? strSlice(data.symbol.getName(), 3, -4) - //: strSlice(data.symbol.getName(), 1, -2); std::string innerString = strSlice(data.symbol.getName(), 3, -4); std::string newStr; for (auto character: innerString) @@ -1199,8 +1260,9 @@ std::string CGenerator::ValueTypeToCTypeThingHelper(Type *type, std::string decl return return_type; for (int i = 0; i < type->getIndirection(); i++) return_type += "*"; - if (type->is_reference) + if (type->is_reference) { return_type += " /*ref*/ *"; + } return return_type + declaration; } diff --git a/src/Type.cpp b/src/Type.cpp index a9d3029..e1750c1 100644 --- a/src/Type.cpp +++ b/src/Type.cpp @@ -240,6 +240,16 @@ Type Type::withIncreasedIndirection() { newOne->increaseIndirection(); return *newOne; } +Type Type::withReference() { + Type *newOne = clone(); + newOne->is_reference = true; + return *newOne; +} +Type *Type::withReferencePtr() { + Type *newOne = clone(); + newOne->is_reference = true; + return newOne; +} Type *Type::withIncreasedIndirectionPtr() { Type *newOne = clone(); newOne->increaseIndirection(); diff --git a/tests/test_adt.expected_results b/tests/test_adt.expected_results index 64e7fa5..f6fd947 100644 --- a/tests/test_adt.expected_results +++ b/tests/test_adt.expected_results @@ -1 +1,48 @@ option1 +no int +an int: 7 +equality true works! +equality false works! +matched an int:11 correctly! +matched no int correctly! +matched no_obj correctly +assignment to old variable +gonna make object in function 100 +constructed object 100 : 100 +copy constructed object 100 : 200 from 100 : 100 +destructed object 100 : 100 +copy constructed object 100 : 300 from 100 : 200 +copy constructed object 100 : 400 from 100 : 300 +destructed object 100 : 300 +copy constructed object 100 : 500 from 100 : 400 +destructed object 100 : 400 +destructed object 100 : 200 +done assignment to old variable +matched an_obj correctly 100 : 500 +int assignment to old var +destructed object 100 : 500 +done int assignment to old var +matched an_int correctly 1337 +test copy_construct for non ref equality +gonna make object in function 110 +constructed object 110 : 110 +copy constructed object 110 : 210 from 110 : 110 +destructed object 110 : 110 +copy constructed object 110 : 310 from 110 : 210 +copy constructed object 110 : 410 from 110 : 310 +destructed object 110 : 310 +gonna make object in function 110 +constructed object 110 : 110 +copy constructed object 110 : 210 from 110 : 110 +destructed object 110 : 110 +copy constructed object 110 : 310 from 110 : 210 +copy constructed object 110 : 410 from 110 : 310 +destructed object 110 : 310 +copy constructed object 110 : 510 from 110 : 410 +destructed object 110 : 510 +equality an_obj correctly +destructed object 110 : 410 +destructed object 110 : 210 +destructed object 110 : 410 +destructed object 110 : 210 +done test copy_construct for non ref equality diff --git a/tests/test_adt.krak b/tests/test_adt.krak index 0e93bc4..cb127af 100644 --- a/tests/test_adt.krak +++ b/tests/test_adt.krak @@ -11,30 +11,34 @@ adt maybe_int { } fun TestObj(num: int): TestObj { - print("gonna makke object in function ") + print("gonna make object in function ") println(num) var toRet.construct(num): TestObj return toRet } obj TestObj (Object) { var obj_num: int + var ref_num: int fun construct(num:int): *TestObj { obj_num = num + ref_num = num print("constructed object ") - println(obj_num) + print(obj_num);print(" : ");println(ref_num) } fun copy_construct(old: *TestObj) { - obj_num = old->obj_num + 100 + obj_num = old->obj_num + ref_num = old->ref_num + 100 print("copy constructed object ") - print(obj_num) + print(obj_num);print(" : ");print(ref_num) print(" from ") - println(old->obj_num) + print(old->obj_num);print(" : ");println(old->ref_num) } fun destruct() { print("destructed object ") - println(obj_num) + print(obj_num);print(" : ");println(ref_num) } - fun operator==(other: ref TestObj): bool { + /*fun operator==(other: ref TestObj): bool {*/ + fun operator==(other: TestObj): bool { return obj_num == other.obj_num; } } @@ -48,9 +52,7 @@ adt maybe_object { fun handle_possibility(it: maybe_int) { if (it == maybe_int::no_int()) { println("no int") - } - /*if (it == maybe_int::an_int) {*/ - else { + } else { print("an int: ") println(it.an_int) } @@ -67,7 +69,7 @@ fun can_pass(it: options): options { } fun main():int { - var it: options = can_pass(options::option0()) + var it: options = can_pass(options::option1()) if (it == options::option0()) { println("nope") } @@ -115,30 +117,46 @@ fun main():int { println(int_thiny) } } + println("assignment to old variable") obj_item = maybe_object::an_obj(TestObj(100)) + println("done assignment to old variable") match (obj_item) { maybe_object::no_obj() println("matched no_obj incorrectly") maybe_object::an_obj(obj_instance) { print("matched an_obj correctly ") - println(obj_instance.obj_num) + print(obj_instance.obj_num);print(" : ");println(obj_instance.ref_num) } maybe_object::an_int(int_thiny) { print("matched an_intj incorrectly ") println(int_thiny) } } + println("int assignment to old var") obj_item = maybe_object::an_int(1337) + println("done int assignment to old var") + /*println("test copying thing");*/ + /*var obj_item_new = maybe_object::an_obj(TestObj(1000))*/ + /*println("new object assingment")*/ + /*var obj_item_new_copy = obj_item_new;*/ + /*println("done new object assingment")*/ + /*println("done test copying thing");*/ match (obj_item) { maybe_object::no_obj() println("matched no_obj incorrectly") maybe_object::an_obj(obj_instance) { print("matched an_obj incorrectly ") - println(obj_instance.obj_num) + print(obj_instance.obj_num);print(" : ");println(obj_instance.ref_num) } maybe_object::an_int(int_thiny) { print("matched an_int correctly ") println(int_thiny) } } + println("test copy_construct for non ref equality"); + if (maybe_object::an_obj(TestObj(110)) == maybe_object::an_obj(TestObj(110))) + println("equality an_obj correctly ") + else + println("equality an_obj incorrectly ") + println("done test copy_construct for non ref equality"); return 0 } diff --git a/tests/test_grammer.krak b/tests/test_grammer.krak index 6192467..e1bdd81 100644 --- a/tests/test_grammer.krak +++ b/tests/test_grammer.krak @@ -119,8 +119,8 @@ fun main():int { var parse.construct(a): parser /*var result = parse.parse_input(string("a"), string("fun name"))*/ - var result = parse.parse_input(read_file(string("test_adt.krak")), string("fun name")) - /*var result = parse.parse_input(read_file(string("to_parse.krak")), string("fun name"))*/ + /*var result = parse.parse_input(read_file(string("test_adt.krak")), string("fun name"))*/ + var result = parse.parse_input(read_file(string("to_parse.krak")), string("fun name")) /*var result = parse.parse_input(string("inport a;"), string("fun name"))*/ /*var result = parse.parse_input(string("fun main():int { return 0; }"), string("fun name"))*/ /*var result = parse.parse_input(string("ad"), string("fun name"))*/