Files
kraken/stdlib/adt_lower.krak
2018-03-21 00:00:06 -04:00

198 lines
15 KiB
Plaintext

import symbol:*
import tree:*
import vector:*
import map:*
import util:*
import string:*
import mem:*
import io:*
import ast_nodes:*
import ast_transformation:*
import hash_set:*
import pass_common:*
fun adt_lower(name_ast_map: *map<string, pair<*tree<symbol>,*ast_node>>, ast_to_syntax: *map<*ast_node, *tree<symbol>>) {
var type_def_option_map = map<*ast_node, vector<*ast_node>>()
var visited1 = hash_set<*ast_node>()
var visited2 = hash_set<*ast_node>()
name_ast_map->for_each(fun(name: string, syntax_ast_pair: pair<*tree<symbol>,*ast_node>) {
var helper_before = fun(node: *ast_node, parent_chain: *stack<*ast_node>) {
match(*node) {
ast_node::adt_def(backing) {
/*println(backing.name + ": transforming!")*/
type_def_option_map[node] = vector<*ast_node>()
var replacement = ast_type_def_ptr(backing.name, false);
// we're going to be replacing adt_def in the same ptr, so this works
replacement->type_def.self_type = node->adt_def.self_type
var option_union = ast_type_def_ptr(backing.name + "_union", true);
node->adt_def.options.for_each(fun(opt: *ast_node) {
if (!opt->identifier.type->is_empty_adt_option())
option_union->type_def.variables.add(ast_declaration_statement_ptr(opt, null<ast_node>()))
type_def_option_map[node].add(opt)
})
var option_union_type = type_ptr(option_union)
option_union->type_def.self_type = option_union_type
var option_union_ident = ast_identifier_ptr(string("data"), option_union_type, replacement)
replacement->type_def.variables.add(ast_declaration_statement_ptr(option_union_ident, null<ast_node>()))
add_to_scope("data", option_union_ident, replacement)
var flag = ast_identifier_ptr(string("flag"), type_ptr(base_type::integer()), replacement)
replacement->type_def.variables.add(ast_declaration_statement_ptr(flag, null<ast_node>()))
add_to_scope("flag", flag, replacement)
add_before_in(option_union, node, parent_chain)
var enclosing_scope = node->adt_def.scope[string("~enclosing_scope")][0]
var idx = 0
node->adt_def.option_funcs.for_each(fun(func: *ast_node) {
var adt_type = replacement->type_def.self_type
var block = ast_code_block_ptr()
add_to_scope("~enclosing_scope", func, block)
func->function.body_statement = block
var to_ret = ast_identifier_ptr(string("to_ret"), adt_type, block)
block->code_block.children.add(ast_declaration_statement_ptr(to_ret, null<ast_node>()))
var value = ast_value_ptr(to_string(idx), type_ptr(base_type::integer()))
block->code_block.children.add(ast_assignment_statement_ptr(make_operator_call(".", vector(to_ret, flag)), value))
var opt = type_def_option_map[node][idx]
var lvalue = make_operator_call(".", vector(make_operator_call(".", vector(to_ret, option_union_ident)), opt))
if (func->function.parameters.size) {
// do copy_construct if it should
block->code_block.children.add(assign_or_copy_construct_statement(lvalue, func->function.parameters[0]))
}
block->code_block.children.add(ast_return_statement_ptr(to_ret))
add_before_in(func, node, parent_chain)
add_to_scope(func->function.name, func, enclosing_scope)
add_to_scope("~enclosing_scope", enclosing_scope, func)
idx++
})
node->adt_def.regular_funcs.for_each(fun(func: *ast_node) {
var block = ast_code_block_ptr()
add_to_scope("~enclosing_scope", func, block)
func->function.body_statement = block
var func_this = func->function.this_param
if (func->function.name == "operator==") {
var other = func->function.parameters[0]
var if_stmt = ast_if_statement_ptr(make_operator_call("!=", vector(make_operator_call("->", vector(func_this, flag)), make_operator_call(".", vector(other, flag)))))
if_stmt->if_statement.then_part = ast_return_statement_ptr(ast_value_ptr(string("false"), type_ptr(base_type::boolean())))
block->code_block.children.add(if_stmt)
for (var i = 0; i < type_def_option_map[node].size; i++;) {
if (get_ast_type(type_def_option_map[node][i])->is_empty_adt_option())
continue
var if_stmt_inner = ast_if_statement_ptr(make_operator_call("==", vector(make_operator_call("->", vector(func_this, flag)), ast_value_ptr(to_string(i), type_ptr(base_type::integer())))))
var option = type_def_option_map[node][i]
var our_option = make_operator_call(".", vector(make_operator_call("->", vector(func_this, option_union_ident)), option))
var their_option = make_operator_call(".", vector(make_operator_call(".", vector(other, option_union_ident)), option))
if_stmt_inner->if_statement.then_part = ast_return_statement_ptr(possible_object_equality(our_option, their_option))
block->code_block.children.add(if_stmt_inner)
}
block->code_block.children.add(ast_return_statement_ptr(ast_value_ptr(string("true"), type_ptr(base_type::boolean()))))
} else if (func->function.name == "operator!=") {
var other = func->function.parameters[0]
block->code_block.children.add(ast_return_statement_ptr(make_operator_call("!", vector(make_method_call(func_this, "operator==", vector(other))))))
} else if (func->function.name == "construct") {
var value = ast_value_ptr(string("-1"), type_ptr(base_type::integer()))
block->code_block.children.add(ast_assignment_statement_ptr(make_operator_call("->", vector(func_this, flag)), value))
block->code_block.children.add(ast_return_statement_ptr(func_this))
} else if (func->function.name == "copy_construct") {
var other = func->function.parameters[0]
block->code_block.children.add(ast_assignment_statement_ptr(make_operator_call("->", vector(func_this, flag)), make_operator_call("->", vector(other, flag))))
for (var i = 0; i < type_def_option_map[node].size; i++;) {
if (get_ast_type(type_def_option_map[node][i])->is_empty_adt_option())
continue
var if_stmt_inner = ast_if_statement_ptr(make_operator_call("==", vector(make_operator_call("->", vector(func_this, flag)), ast_value_ptr(to_string(i), type_ptr(base_type::integer())))))
var option = type_def_option_map[node][i]
var our_option = make_operator_call(".", vector(make_operator_call("->", vector(func_this, option_union_ident)), option))
var their_option = make_operator_call(".", vector(make_operator_call("->", vector(other, option_union_ident)), option))
if_stmt_inner->if_statement.then_part = assign_or_copy_construct_statement(our_option, their_option)
block->code_block.children.add(if_stmt_inner)
}
} else if (func->function.name == "operator=") {
var other = func->function.parameters[0]
block->code_block.children.add(make_method_call(func_this, "destruct", vector<*ast_node>()))
block->code_block.children.add(make_method_call(func_this, "copy_construct", vector(make_operator_call("&", vector(other)))))
} else if (func->function.name == "destruct") {
for (var i = 0; i < type_def_option_map[node].size; i++;) {
var option = type_def_option_map[node][i]
var option_type = get_ast_type(option)
if (option_type->is_empty_adt_option())
continue
if (option_type->indirection == 0 && option_type->is_object() && has_method(option_type->type_def, "destruct", vector<*type>())) {
var if_stmt_inner = ast_if_statement_ptr(make_operator_call("==", vector(make_operator_call("->", vector(func_this, flag)), ast_value_ptr(to_string(i), type_ptr(base_type::integer())))))
var our_option = make_operator_call(".", vector(make_operator_call("->", vector(func_this, option_union_ident)), option))
if_stmt_inner->if_statement.then_part = make_method_call(our_option, "destruct", vector<*ast_node>())
block->code_block.children.add(if_stmt_inner)
}
}
} else error("impossible adt method")
replacement->type_def.methods.add(func)
add_to_scope(func->function.name, func, replacement)
add_to_scope("~enclosing_scope", replacement, func)
})
add_to_scope("~enclosing_scope", enclosing_scope, option_union)
add_to_scope("~enclosing_scope", enclosing_scope, replacement)
*node = *replacement
}
}
}
run_on_tree(helper_before, empty_pass_second_half(), syntax_ast_pair.second, &visited1)
})
name_ast_map->for_each(fun(name: string, syntax_ast_pair: pair<*tree<symbol>,*ast_node>) {
var second_helper = fun(node: *ast_node, parent_chain: *stack<*ast_node>) {
match(*node) {
ast_node::match_statement(backing) {
var block = ast_code_block_ptr()
add_to_scope("~enclosing_scope", parent_chain->item_from_top_satisfying(fun(i: *ast_node): bool return is_code_block(i) || is_function(i);), block)
var value = backing.value
var holder = ast_identifier_ptr(string("holder"), get_ast_type(value)->clone_with_increased_indirection(), block)
block->code_block.children.add(ast_declaration_statement_ptr(holder, null<ast_node>()))
block->code_block.children.add(ast_assignment_statement_ptr(holder, make_operator_call("&", vector(value))))
backing.cases.for_each(fun(case_stmt: *ast_node) {
var option = case_stmt->case_statement.option
var flag = get_from_scope(get_ast_type(value)->type_def, "flag")
var data = get_from_scope(get_ast_type(value)->type_def, "data")
var option_num = -7
if (!type_def_option_map.contains_key(get_ast_type(value)->type_def))
error("trying to match on non-adt")
for (var i = 0; i < type_def_option_map[get_ast_type(value)->type_def].size; i++;)
if (type_def_option_map[get_ast_type(value)->type_def][i] == option)
option_num = i;
var condition = make_operator_call("==", vector(make_operator_call("->", vector(holder, flag)), ast_value_ptr(to_string(option_num), type_ptr(base_type::integer()))))
var if_stmt = ast_if_statement_ptr(condition)
var inner_block = ast_code_block_ptr()
add_to_scope("~enclosing_scope", block, inner_block)
var unpack_ident = case_stmt->case_statement.unpack_ident
if (unpack_ident) {
var get_option = make_operator_call(".", vector(make_operator_call("->", vector(holder, data)), option))
get_option = make_operator_call("&", vector(get_option))
unpack_ident->identifier.type = unpack_ident->identifier.type->clone_with_ref()
inner_block->code_block.children.add(ast_declaration_statement_ptr(unpack_ident, get_option))
}
inner_block->code_block.children.add(case_stmt->case_statement.statement)
if_stmt->if_statement.then_part = inner_block
block->code_block.children.add(if_stmt)
})
*node = *block
}
ast_node::function_call(backing) {
if (is_function(backing.func) && (backing.func->function.name == "." || backing.func->function.name == "->")) {
var left_type = get_ast_type(backing.parameters[0])
if (left_type->is_object() && is_identifier(backing.parameters[1])) {
for (var i = 0; i < left_type->type_def->type_def.variables.size; i++;)
if (left_type->type_def->type_def.variables[i]->declaration_statement.identifier == backing.parameters[1])
return;
/*println(backing.parameters[1]->identifier.name + " getting, adt . or -> call!")*/
var object = node->function_call.parameters[0]
node->function_call.parameters[0] = make_operator_call(backing.func->function.name, vector(object, get_from_scope(left_type->type_def, "data")))
node->function_call.func = get_builtin_function(".", vector(get_ast_type(get_from_scope(left_type->type_def, "data")), get_ast_type(backing.parameters[1])))
}
}
}
}
}
run_on_tree(second_helper, empty_pass_second_half(), syntax_ast_pair.second, &visited2)
})
}