diff --git a/stdlib/ast_nodes.krak b/stdlib/ast_nodes.krak index 812cf61..35475bc 100644 --- a/stdlib/ast_nodes.krak +++ b/stdlib/ast_nodes.krak @@ -363,8 +363,8 @@ obj statement (Object) { return child == other.child } } -fun ast_if_statement_ptr(): *ast_node { - var to_ret.construct(): if_statement +fun ast_if_statement_ptr(condition: *ast_node): *ast_node { + var to_ret.construct(condition): if_statement var ptr = new() ptr->copy_construct(&ast_node::if_statement(to_ret)) return ptr @@ -376,12 +376,22 @@ fun is_if_statement(node: *ast_node): bool { return false } obj if_statement (Object) { + var condition: *ast_node + // these are not a part of the constructor because they have to be trnasformed with this as its scope + var then_part: *ast_node + var else_part: *ast_node var scope: map> - fun construct(): *if_statement { + fun construct(condition_in: *ast_node): *if_statement { + condition = condition_in + then_part = null() + else_part = null() scope.construct() return this } fun copy_construct(old: *if_statement) { + condition = old->condition + then_part = old->then_part + else_part = old->else_part scope.copy_construct(&old->scope) } fun destruct() { @@ -392,7 +402,7 @@ obj if_statement (Object) { copy_construct(&other) } fun operator==(other: ref if_statement): bool { - return true + return condition == other.condition && then_part == other.then_part && else_part == other.else_part } } fun ast_match_statement_ptr(): *ast_node { @@ -877,7 +887,7 @@ fun get_ast_children(node: *ast_node): vector<*ast_node> { ast_node::function(backing) return backing.parameters + backing.body_statement ast_node::code_block(backing) return backing.children ast_node::statement(backing) return vector<*ast_node>(backing.child) - ast_node::if_statement(backing) return vector<*ast_node>() + ast_node::if_statement(backing) return vector(backing.condition, backing.then_part, backing.else_part) ast_node::match_statement(backing) return vector<*ast_node>() ast_node::case_statement(backing) return vector<*ast_node>() ast_node::while_loop(backing) return vector<*ast_node>() diff --git a/stdlib/ast_transformation.krak b/stdlib/ast_transformation.krak index b6d06a1..76742df 100644 --- a/stdlib/ast_transformation.krak +++ b/stdlib/ast_transformation.krak @@ -188,6 +188,8 @@ obj ast_transformation (Object) { return transform_declaration_statement(node, scope) } else if (name == "assignment_statement") { return transform_assignment_statement(node, scope) + } else if (name == "if_statement") { + return transform_if_statement(node, scope) } else if (name == "return_statement") { return transform_return_statement(node, scope) } else if (name == "function_call") { @@ -287,6 +289,17 @@ obj ast_transformation (Object) { var assignment = ast_assignment_statement_ptr(transform(get_node("factor", node), scope), transform(get_node("boolean_expression", node), scope)) return assignment } + fun transform_if_statement(node: *tree, scope: *ast_node): *ast_node { + var if_statement = ast_if_statement_ptr(transform_expression(get_node("boolean_expression", node), scope)) + // one variable declarations might be in a code_block-less if statement + add_to_scope("~enclosing_scope", scope, if_statement) + var statements = transform_all(get_nodes("statement", node), if_statement) + if_statement->if_statement.then_part = statements[0] + // we have an else + if (statements.size == 2) + if_statement->if_statement.else_part = statements[1] + return if_statement + } fun transform_return_statement(node: *tree, scope: *ast_node): *ast_node { return ast_return_statement_ptr(transform(node->children[0], scope)) } diff --git a/stdlib/c_generator.krak b/stdlib/c_generator.krak index a11e8f3..a2bd941 100644 --- a/stdlib/c_generator.krak +++ b/stdlib/c_generator.krak @@ -93,6 +93,12 @@ obj c_generator (Object) { fun generate_assignment_statement(node: *ast_node): string { return generate(node->assignment_statement.to) + " = " + generate(node->assignment_statement.from) } + fun generate_if_statement(node: *ast_node): string { + var if_str = string("if (") + generate(node->if_statement.condition) + ") {\n" + generate(node->if_statement.then_part) + "}" + if (node->if_statement.else_part) + if_str += string(" else {\n") + generate(node->if_statement.else_part) + "}" + return if_str + "\n" + } fun generate_identifier(node: *ast_node): string { return node->identifier.name } @@ -133,6 +139,7 @@ obj c_generator (Object) { ast_node::statement(backing) return generate_statement(node) ast_node::declaration_statement(backing) return generate_declaration_statement(node) ast_node::assignment_statement(backing) return generate_assignment_statement(node) + ast_node::if_statement(backing) return generate_if_statement(node) ast_node::function(backing) return generate_function(node) ast_node::function_call(backing) return generate_function_call(node) ast_node::code_block(backing) return generate_code_block(node) diff --git a/tests/to_parse.krak b/tests/to_parse.krak index 766042f..33a96a4 100644 --- a/tests/to_parse.krak +++ b/tests/to_parse.krak @@ -24,6 +24,8 @@ fun main(): int { simple_print(yet_another_declaration) simple_print("Hello World!\n") simple_print(1337) + if (1 + 2 && false) simple_print("its true!") + else simple_print("its false!") return 0 }