diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index fcade1e..ecd2f8e 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -506,12 +506,37 @@ obj ast_transformation (Object) { return declaration } fun transform_assignment_statement(node: *tree, scope: *ast_node, template_replacements: map): *ast_node { - var assign_to = transform(get_node("factor", node), scope, template_replacements) var to_assign = transform(get_node("boolean_expression", node), scope, template_replacements) - if (get_node("\"\\+=\"", node)) to_assign = make_operator_call("+", vector(assign_to, to_assign)) - else if (get_node("\"-=\"", node)) to_assign = make_operator_call("-", vector(assign_to, to_assign)) - else if (get_node("\"\\*=\"", node)) to_assign = make_operator_call("*", vector(assign_to, to_assign)) - else if (get_node("\"/=\"", node)) to_assign = make_operator_call("/", vector(assign_to, to_assign)) + // for []= overloading + if (get_node("\"=\"", node)) { + println("Regular Assignment!") + var factor_part = get_node("factor", node) + if (factor_part->children.size == 1) { + println("Factor has only one child!") + var inner_unarad = get_node("unarad", factor_part) + if (get_node("\"[\"", inner_unarad)) { + println("Inner Unarad has [!") + var assign_to = transform(get_node("unarad", inner_unarad), scope, template_replacements) + var assign_idx = transform(get_node("expression", inner_unarad), scope, template_replacements) + var possible_bracket_assign = find_and_make_any_operator_overload_call(string("[]="), vector(assign_to, assign_idx, to_assign), scope, template_replacements) + if (possible_bracket_assign) { + println("Computed and returning []=!") + return possible_bracket_assign + } else println("Could not Compute and return []=!") + } else println("Inner Unarad does not have [!") + } else println("Factor not 1 child") + } else println("Not regular assignment") + var assign_to = transform(get_node("factor", node), scope, template_replacements) + if (get_node("\"=\"", node)) { + var possible_assign = find_and_make_any_operator_overload_call(string("="), vector(assign_to, to_assign), scope, template_replacements) + if (possible_assign) { + println("Computed and returning operator=!") + return possible_assign + } + } else if (get_node("\"\\+=\"", node)) to_assign = make_operator_call("+", vector(assign_to, to_assign)) + else if (get_node("\"-=\"", node)) to_assign = make_operator_call("-", vector(assign_to, to_assign)) + else if (get_node("\"\\*=\"", node)) to_assign = make_operator_call("*", vector(assign_to, to_assign)) + else if (get_node("\"/=\"", node)) to_assign = make_operator_call("/", vector(assign_to, to_assign)) var assignment = ast_assignment_statement_ptr(assign_to, to_assign) return assignment } @@ -688,6 +713,8 @@ obj ast_transformation (Object) { } } else { func_name = concat_symbol_tree(node->children[1]) + if (func_name == "[") + func_name += "]" var first_param = transform(node->children[0], scope, template_replacements) var second_param = null() if (func_name == "." || func_name == "->") { @@ -718,6 +745,13 @@ obj ast_transformation (Object) { } var parameter_types = parameters.map(fun(param: *ast_node): *type return get_ast_type(param);) // check for operator overloading + var possible_overload_call = find_and_make_any_operator_overload_call(func_name, parameters, scope, template_replacements) + if (possible_overload_call) + return possible_overload_call + return ast_function_call_ptr(get_builtin_function(func_name, parameter_types), parameters) + } + fun find_and_make_any_operator_overload_call(func_name: string, parameters: vector<*ast_node>, scope: *ast_node, template_replacements: map): *ast_node { + var parameter_types = parameters.map(fun(param: *ast_node): *type return get_ast_type(param);) var possible_overload = null() if (parameter_types[0]->is_object() && parameter_types[0]->indirection == 0) { possible_overload = function_lookup(string("operator")+func_name, parameter_types.first()->type_def, parameter_types.slice(1,-1)) @@ -731,7 +765,7 @@ obj ast_transformation (Object) { possible_overload = find_or_instantiate_template_function(string("operator")+func_name, null>(), scope, parameter_types, template_replacements, map()) if (possible_overload) return ast_function_call_ptr(possible_overload, parameters) - return ast_function_call_ptr(get_builtin_function(func_name, parameter_types), parameters) + return null() } fun find_or_instantiate_template_function(name: string, template_inst: *tree, scope: *ast_node, param_types: vector<*type>, template_replacements: map, replacements_base: map): *ast_node { // replacments base is for templated methods starting off with the replacements of their parent (possibly templated) object @@ -864,7 +898,7 @@ fun get_builtin_function(name: string, param_types: vector<*type>): *ast_node { return ast_function_ptr(name, type_ptr(param_types, type_ptr(base_type::boolean())), vector<*ast_node>()) if (name == "." || name == "->") return ast_function_ptr(name, type_ptr(param_types, param_types[1]), vector<*ast_node>()) - if (name == "[") + if (name == "[]") return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_decreased_indirection()), vector<*ast_node>()) if (name == "&") return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_increased_indirection()), vector<*ast_node>()) diff --git a/stdlib/c_generator.krak b/stdlib/c_generator.krak index fab791b..650cace 100644 --- a/stdlib/c_generator.krak +++ b/stdlib/c_generator.krak @@ -522,7 +522,7 @@ obj c_generator (Object) { // don't propegate enclosing function down right of access if (func_name == "." || func_name == "->") return code_triple("(") + generate(parameters[0], enclosing_object, enclosing_func, null>>>()) + func_name + generate(parameters[1], null(), null(), null>>>()) + string(")") - if (func_name == "[") + if (func_name == "[]") return code_triple("(") + generate(parameters[0], enclosing_object, enclosing_func, null>>>()) + "[" + generate(parameters[1], null(), null(), null>>>()) + string("])") // the post ones need to be post-ed specifically, and take the p off if (func_name == "++p" || func_name == "--p") diff --git a/tests/test_bracket_assign.expected_results b/tests/test_bracket_assign.expected_results index 486056a..36b192e 100644 --- a/tests/test_bracket_assign.expected_results +++ b/tests/test_bracket_assign.expected_results @@ -1 +1,3 @@ bracket assign: index: 4, rhs: 9 +just bracket: index: 5 +just =: index: 6 diff --git a/tests/test_bracket_assign.krak b/tests/test_bracket_assign.krak index 0fc34e2..b7b7840 100644 --- a/tests/test_bracket_assign.krak +++ b/tests/test_bracket_assign.krak @@ -7,10 +7,20 @@ obj BracketAssign { print(", rhs: ") println(rhs) } + fun operator[](index:int) { + print("just bracket: index: ") + println(index) + } + fun operator=(index:int) { + print("just =: index: ") + println(index) + } } fun main():int { var test:BracketAssign test[4] = 9 + test[5] + test = 6 return 0 }