Files
kraken/stdlib/function_value_lower.krak
2018-05-22 19:43:54 -04:00

308 lines
18 KiB
Plaintext

import symbol:*
import tree:*
import vec:*
import map:*
import util:*
import type:*
import str:*
import mem:*
import io:*
import ast_nodes:*
import ast_transformation:*
import hash_set:*
import os:*
import pass_common:*
obj function_parent_block {
var function: *ast_node
var parent: *ast_node
var parent_block: *ast_node
var parent_function: *ast_node
}
fun make_function_parent_block(function: *ast_node, parent: *ast_node, parent_block: *ast_node, parent_function: *ast_node): function_parent_block {
var result: function_parent_block
result.function = function
result.parent = parent
result.parent_block = parent_block
result.parent_function = parent_function
return result
}
fun find_closed_variables(func: *ast_node, node: *ast_node): set<*ast_node> {
if (!node) return set<*ast_node>()
match (*node) {
ast_node::identifier(backing) {
if (!in_scope_chain(backing.enclosing_scope, func)) {
if (backing.name == "temporary_return_boomchaka" ||
backing.name == "temp_boom_return")
error("trying to close over temp return")
else
return set(node);
}
}
ast_node::code_block(backing) {
var to_ret = set<*ast_node>()
backing.children.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);)
return to_ret
}
ast_node::function_call(backing) {
if (is_function(backing.func) && (backing.func->function.name == "." || backing.func->function.name == "->"))
return find_closed_variables(func, backing.parameters.first())
var to_ret = find_closed_variables(func, backing.func)
backing.parameters.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);)
return to_ret
}
ast_node::function(backing) {
// if this is a lambda, we need to check all of the things it closes over
var to_ret = set<*ast_node>()
backing.closed_variables.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);)
return to_ret
}
ast_node::return_statement(backing) return find_closed_variables(func, backing.return_value)
ast_node::if_statement(backing) return find_closed_variables(func, backing.condition) + find_closed_variables(func, backing.then_part) + find_closed_variables(func, backing.else_part)
ast_node::match_statement(backing) {
var to_ret = set<*ast_node>()
backing.cases.for_each(fun(n: *ast_node) to_ret += find_closed_variables(func, n);)
return to_ret
}
ast_node::case_statement(backing) return find_closed_variables(func, backing.statement)
ast_node::while_loop(backing) return find_closed_variables(func, backing.condition) + find_closed_variables(func, backing.statement)
ast_node::for_loop(backing) {
return find_closed_variables(func, backing.init) + find_closed_variables(func, backing.condition) +
find_closed_variables(func, backing.update) + find_closed_variables(func, backing.body)
}
ast_node::return_statement(backing) return find_closed_variables(func, backing.return_value)
ast_node::defer_statement(backing) return find_closed_variables(func, backing.statement)
ast_node::assignment_statement(backing) return find_closed_variables(func, backing.to) + find_closed_variables(func, backing.from)
ast_node::declaration_statement(backing) return find_closed_variables(func, backing.expression) + find_closed_variables(func, backing.init_method_call)
ast_node::if_comp(backing) return find_closed_variables(func, backing.statement)
ast_node::cast(backing) return find_closed_variables(func, backing.value)
}
return set<*ast_node>()
}
fun in_scope_chain(node: *ast_node, high_scope: *ast_node): bool {
if (node == high_scope)
return true
if (get_ast_scope(node)->contains_key(str("~enclosing_scope")))
return in_scope_chain(get_ast_scope(node)->get(str("~enclosing_scope"))[0], high_scope)
return false
}
fun function_value_lower(name_ast_map: *map<str, pair<*tree<symbol>,*ast_node>>, ast_to_syntax: *map<*ast_node, *tree<symbol>>) {
var curr_time = get_time()
var visited = hash_set<*ast_node>()
var lambdas = set<*ast_node>()
name_ast_map->for_each(fun(name: str, syntax_ast_pair: pair<*tree<symbol>,*ast_node>) {
lambdas.add(syntax_ast_pair.second->translation_unit.lambdas)
// do in order so that inner lambdas are done before outer ones, so enclosed
// variables can propegate outwards
syntax_ast_pair.second->translation_unit.lambdas.for_each(fun(n: *ast_node) {
n->function.closed_variables = find_closed_variables(n, n->function.body_statement)
})
})
var all_types = hash_set<*type>()
var function_value_creation_points = vec<function_parent_block>()
var function_value_call_points = vec<function_parent_block>()
var closed_over_uses = vec<pair<*ast_node, pair<*ast_node, *ast_node>>>()
name_ast_map->for_each(fun(name: str, syntax_ast_pair: pair<*tree<symbol>,*ast_node>) {
var helper_before = fun(node: *ast_node, parent_chain: *stack<*ast_node>) {
var t = get_ast_type(node)
if (t) all_types.add(t)
match(*node) {
ast_node::identifier(backing) {
// see if this identifier use is a closed variable in a closure
var enclosing_func = parent_chain->item_from_top_satisfying_or(fun(n: *ast_node): bool return is_function(n);, null<ast_node>())
if (enclosing_func && enclosing_func->function.closed_variables.contains(node)) {
closed_over_uses.add(make_pair(node, make_pair(parent_chain->top(), enclosing_func)))
}
}
ast_node::function(backing) {
var parent = parent_chain->top()
// need to use function value if
// it isn't a regular function definition (or lambda top reference)
var need_done = !is_translation_unit(parent) && !backing.type->is_raw
if (need_done) {
function_value_creation_points.add(make_function_parent_block(node, parent_chain->top(),
parent_chain->item_from_top_satisfying(fun(i: *ast_node): bool return is_code_block(i);),
parent_chain->item_from_top_satisfying(fun(i: *ast_node): bool return is_function(i);)
))
}
}
ast_node::function_call(backing) {
if (!get_ast_type(backing.func)->is_raw)
function_value_call_points.add(make_function_parent_block(backing.func, node, null<ast_node>(), null<ast_node>()))
}
}
}
run_on_tree(helper_before, empty_pass_second_half(), syntax_ast_pair.second, &visited)
})
curr_time = split(curr_time, "\tclosed_over_uses + function_value_call_points")
var void_ptr = type_ptr(base_type::void_return(), 1)
var lambda_type_to_struct_type_and_call_func = map<type, pair<*type, *ast_node>>(); //freaking vexing parse moved
all_types.chaotic_closure(fun(t: *type): set<*type> {
if (t->is_function())
return from_vector(t->parameter_types + t->return_type)
return set<*type>()
})
var all_type_values = all_types.map(fun(t: *type): type return *t;)
curr_time = split(curr_time, "\tall types/all type values")
all_type_values.for_each(fun(t: type) {
if (t.is_function() && t.indirection == 0 && !t.is_raw && !lambda_type_to_struct_type_and_call_func.contains_key(t)) {
var cleaned = t.clone()
cleaned->is_raw = true
var new_type_def_name = t.to_string() + "_function_value_struct"
var new_type_def = ast_type_def_ptr(new_type_def_name)
var func_ident = ast_identifier_ptr("func", cleaned, new_type_def)
add_to_scope("func", func_ident, new_type_def)
var func_closure_type = cleaned->clone()
func_closure_type->parameter_types.add(0, type_ptr(base_type::void_return(), 1))
var func_closure_ident = ast_identifier_ptr("func_closure", func_closure_type, new_type_def)
add_to_scope("func_closure", func_closure_ident, new_type_def)
var data_ident = ast_identifier_ptr("data", void_ptr, new_type_def)
add_to_scope("data", data_ident, new_type_def)
new_type_def->type_def.variables.add(ast_declaration_statement_ptr(func_ident, null<ast_node>()))
new_type_def->type_def.variables.add(ast_declaration_statement_ptr(func_closure_ident, null<ast_node>()))
new_type_def->type_def.variables.add(ast_declaration_statement_ptr(data_ident, null<ast_node>()))
add_to_scope("~enclosing_scope", name_ast_map->values.first().second, new_type_def)
add_to_scope(new_type_def_name, new_type_def, name_ast_map->values.first().second)
name_ast_map->values.first().second->translation_unit.children.add(new_type_def)
var lambda_struct_type = type_ptr(new_type_def)
var lambda_call_type = type_ptr(vec(lambda_struct_type) + t.parameter_types, t.return_type, 0, false, false, true)
// create parameters
var lambda_call_func_param = ast_identifier_ptr("func_struct", lambda_struct_type, null<ast_node>())
var lambda_call_parameters = vec(lambda_call_func_param) + cleaned->parameter_types.map(fun(t:*type): *ast_node {
return ast_identifier_ptr("pass_through_param", t, null<ast_node>())
})
var lambda_call_function = ast_function_ptr(str("lambda_call"), lambda_call_type, lambda_call_parameters, false)
// create call body with if, etc
var if_statement = ast_if_statement_ptr(access_expression(lambda_call_func_param, "data"))
lambda_call_function->function.body_statement = ast_code_block_ptr(if_statement)
if_statement->if_statement.then_part = ast_code_block_ptr(ast_return_statement_ptr(ast_function_call_ptr(access_expression(lambda_call_func_param, "func_closure"),
vec(access_expression(lambda_call_func_param, "data")) + lambda_call_parameters.slice(1,-1))))
if_statement->if_statement.else_part = ast_code_block_ptr(ast_return_statement_ptr(ast_function_call_ptr(access_expression(lambda_call_func_param, "func"),
lambda_call_parameters.slice(1,-1))))
lambda_type_to_struct_type_and_call_func[t] = make_pair(lambda_struct_type, lambda_call_function)
// we have to add it for t and *cleaned since we might get either (we make the lambda's type raw later, so if used at creation point will be cleaned...)
// NOPE does this for other functions not lambdas super wrong
/*lambda_type_to_struct_type_and_call_func[*cleaned] = make_pair(lambda_struct_type, lambda_call_function)*/
name_ast_map->values.first().second->translation_unit.children.add(new_type_def)
name_ast_map->values.first().second->translation_unit.children.add(lambda_call_function)
}
})
curr_time = split(curr_time, "\tall type values forEach")
var lambda_creation_funcs = map<*ast_node, *ast_node>()
// create the closure type for each lambda
var closure_id = 0
lambdas.for_each(fun(l: *ast_node) {
var closure_struct_type: *type
if (l->function.closed_variables.size()) {
var new_type_def_name = str("closure_struct_") + closure_id++
var new_type_def = ast_type_def_ptr(new_type_def_name)
l->function.closed_variables.for_each(fun(v: *ast_node) {
var closed_var_type = v->identifier.type
if (lambda_type_to_struct_type_and_call_func.contains_key(*closed_var_type))
closed_var_type = lambda_type_to_struct_type_and_call_func[*closed_var_type].first
var closed_ident = ast_identifier_ptr(v->identifier.name, closed_var_type->clone_with_increased_indirection(), new_type_def)
new_type_def->type_def.variables.add(ast_declaration_statement_ptr(closed_ident, null<ast_node>()))
add_to_scope(v->identifier.name, closed_ident, new_type_def)
})
add_to_scope("~enclosing_scope", name_ast_map->values.first().second, new_type_def)
add_to_scope(new_type_def_name, new_type_def, name_ast_map->values.first().second)
name_ast_map->values.first().second->translation_unit.children.add(new_type_def)
closure_struct_type = type_ptr(new_type_def)->clone_with_increased_indirection()
}
var return_type = lambda_type_to_struct_type_and_call_func[*l->function.type].first
var creation_type = type_ptr(vec<*type>(), return_type, 0, false, false, true)
lambda_creation_funcs[l] = ast_function_ptr(l->function.name + "_creation", creation_type, vec<*ast_node>(), false);
var body = ast_code_block_ptr()
var ident = ast_identifier_ptr("to_ret", return_type, body)
body->code_block.children.add(ast_declaration_statement_ptr(ident, null<ast_node>()))
body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "func"), l))
body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "func_closure"), l))
if (l->function.closed_variables.size()) {
var closure_lambda_param = ast_identifier_ptr("closure_data_pass", closure_struct_type, l)
l->function.parameters.add(0, closure_lambda_param)
var closure_param = ast_identifier_ptr("closure", closure_struct_type, body)
lambda_creation_funcs[l]->function.parameters.add(closure_param)
body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "data"), closure_param))
l->function.closed_variables.for_each(fun(v: *ast_node) {
// have to make sure to clean here as well
var closed_param_type = v->identifier.type
if (lambda_type_to_struct_type_and_call_func.contains_key(*closed_param_type))
closed_param_type = lambda_type_to_struct_type_and_call_func[*closed_param_type].first
var closed_param = ast_identifier_ptr("closed_param", closed_param_type->clone_with_increased_indirection(), l)
lambda_creation_funcs[l]->function.parameters.add(closed_param)
body->code_block.children.add(ast_assignment_statement_ptr(access_expression(closure_param, v->identifier.name), closed_param))
})
} else {
body->code_block.children.add(ast_assignment_statement_ptr(access_expression(ident, "data"), ast_value_ptr(str("0"), type_ptr(base_type::void_return(), 1))))
}
body->code_block.children.add(ast_return_statement_ptr(ident))
lambda_creation_funcs[l]->function.body_statement = body
name_ast_map->values.first().second->translation_unit.children.add(lambda_creation_funcs[l])
})
curr_time = split(curr_time, "\tlambdas forEach")
function_value_call_points.for_each(fun(p: function_parent_block) {
// parent is the function call
var function_struct = p.function
p.parent->function_call.func = lambda_type_to_struct_type_and_call_func[*get_ast_type(p.function)].second
p.parent->function_call.parameters.add(0, function_struct)
})
curr_time = split(curr_time, "\tfunction_value_call_points.forEach")
function_value_creation_points.for_each(fun(p: function_parent_block) {
var lambda_creation_params = vec<*ast_node>()
// add the declaration of the closure struct to the enclosing code block
if (p.function->function.closed_variables.size()) {
// pull closure type off lambda creation func parameter
var closure_type = get_ast_type(lambda_creation_funcs[p.function]->function.parameters[0])->clone_with_decreased_indirection()
var closure_struct_ident = ast_identifier_ptr("closure_struct", closure_type, p.parent_block)
p.parent_block->code_block.children.add(0,ast_declaration_statement_ptr(closure_struct_ident, null<ast_node>()))
lambda_creation_params.add(make_operator_call("&", vec(closure_struct_ident)))
p.function->function.closed_variables.for_each(fun(v: *ast_node) {
var addr_of = make_operator_call("&", vec(v))
if (p.parent_function->function.closed_variables.contains(v)) {
closed_over_uses.add(make_pair(v, make_pair(addr_of, p.parent_function)))
}
lambda_creation_params.add(addr_of)
})
}
var func_call = ast_function_call_ptr(lambda_creation_funcs[p.function], lambda_creation_params)
replace_with_in(p.function, func_call, p.parent)
})
curr_time = split(curr_time, "\tfunction_value_creation_points.forEach")
lambdas.for_each(fun(l: *ast_node) l->function.type = l->function.type->clone();)
all_types.for_each(fun(t: *type) {
if (lambda_type_to_struct_type_and_call_func.contains_key(*t))
*t = *lambda_type_to_struct_type_and_call_func[*t].first
})
curr_time = split(curr_time, "\tlambdas.for_each")
closed_over_uses.for_each(fun(p: pair<*ast_node, pair<*ast_node, *ast_node>>) {
var variable = p.first
var parent = p.second.first
var lambda = p.second.second
var closure_param = lambda->function.parameters[0]
replace_with_in(variable, make_operator_call("*", vec(access_expression(closure_param, variable->identifier.name))), parent)
})
curr_time = split(curr_time, "\tclosed_over_uses")
// now we can make them raw
lambdas.for_each(fun(l: *ast_node) {
l->function.type->is_raw = true;
})
curr_time = split(curr_time, "\tlambdas is raw")
}