Files
kraken/stdlib/pass_common.krak
2018-03-21 00:00:06 -04:00

516 lines
26 KiB
Plaintext

import ast_nodes:*
import mem:*
import util:*
import vector:*
import stack:*
import string:*
import hash_set:*
fun get_first_terminal(source: *tree<symbol>): *tree<symbol> {
if (!source)
return null<tree<symbol>>()
if (source->data.terminal)
return source
if (source->children.size == 0)
return null<tree<symbol>>()
return get_first_terminal(source->children.first())
}
fun error(source: *tree<symbol>, message: *char) error(source, string(message));
fun error(source: *tree<symbol>, message: string) {
var first = get_first_terminal(source)
if (first)
error("***error |" + concat_symbol_tree(source) + "| *** " + first->data.source + ": " + first->data.position + " " + message)
error(message)
}
fun method_in_object(method: *ast_node, enclosing_object: *ast_node): bool {
var methods = enclosing_object->type_def.methods
for (var i = 0; i < methods.size; i++;) {
if (methods[i] == method || (is_template(methods[i]) && methods[i]->template.instantiated.contains(method))) {
return true
}
}
return false
}
fun is_dot_style_method_call(node: *ast_node): bool {
return is_function_call(node->function_call.func) &&
is_function(node->function_call.func->function_call.func) &&
(node->function_call.func->function_call.func->function.name == "->" || node->function_call.func->function_call.func->function.name == ".") &&
is_function(node->function_call.func->function_call.parameters[1]) &&
(is_type_def(get_ast_scope(node->function_call.func->function_call.parameters[1])->get(string("~enclosing_scope"))[0]) ||
// or if it's a templated method (yes, this has gotten uuuuugly)
is_type_def(get_ast_scope(get_ast_scope(node->function_call.func->function_call.parameters[1])->get(string("~enclosing_scope"))[0])->get(string("~enclosing_scope"))[0]) ||
// or it's in an adt
is_adt_def(get_ast_scope(node->function_call.func->function_call.parameters[1])->get(string("~enclosing_scope"))[0]))
// should get uglier when we have to figure out if it's just an inside lambda
}
fun function_satisfies_params(node: *ast_node, param_types: vector<*type>): bool {
var func_type = get_ast_type(node)
var func_param_types = func_type->parameter_types
if (!func_type->is_variadic && func_param_types.size != param_types.size) {
var param_string = string()
param_types.for_each(fun(t: *type) param_string += t->to_string() + ", ";)
/*println(string("type sizes don't match ") + param_types.size + " with needed " + param_string)*/
return false
} else if (param_types.size < func_param_types.size) {
return false
}
// note we iterate over the func_param_types which will stop short if function is variadic
// just like we want
for (var j = 0; j < func_param_types.size; j++;) {
// don't care about references
if (!func_param_types[j]->equality(param_types[j], false)) {
/*println(string("types don't match ") + func_param_types[j]->to_string() + " with needed " + param_types[j]->to_string())*/
if (func_param_types[j]->to_string() == param_types[j]->to_string())
error(string("types aren't equal, but their string rep is (and ref doesn't even matter): ") + func_param_types[j]->to_string() + " vs " + param_types[j]->to_string() )
return false
}
}
return true
}
fun access_expression(left: *ast_node, right: *char): *ast_node return access_expression(left, string(right))
fun access_expression(left: *ast_node, right: ref string): *ast_node {
var ltype = get_ast_type(left)
if (!ltype->is_object())
error("Trying to generate an access expression and the left side is not an object")
var ident = identifier_lookup(right, ltype->type_def)
if (!ident)
error("Trying to generate an access expression, can't find right: " + right)
if (ltype->indirection)
return make_operator_call("->", vector(left, ident))
return make_operator_call(".", vector(left, ident))
}
fun function_lookup(name: string, scope: *ast_node, param_types: vector<*type>): *ast_node {
var results = scope_lookup(name, scope)
for (var i = 0; i < results.size; i++;) {
if ((is_function(results[i]) || (is_identifier(results[i]) && get_ast_type(results[i])->is_function())) && function_satisfies_params(results[i], param_types)) {
return results[i]
}
}
return null<ast_node>()
}
fun identifier_lookup(name: ref string, scope: *ast_node): *ast_node {
/*println(string("doing identifier lookup for: ") + name)*/
var results = scope_lookup(name, scope)
if (!results.size) {
/*println(string("identifier lookup failed for ") + name)*/
return null<ast_node>()
}
return results[0]
}
fun scope_lookup(name: ref string, scope: *ast_node): vector<*ast_node> {
// println("*****Doing a name lookup for*****")
// println(name)
var results = vector(scope)
name.split("::").for_each(fun(i: string) {
// println(string("based on split, looking up: ") + i)
var next_results = vector<*ast_node>()
results.for_each(fun(s: *ast_node) {
// print("looking in scope: ")
// println(s)
scope_lookup_helper(i, s, set<*ast_node>()).for_each(fun (result: *ast_node) {
if (!next_results.contains(result))
next_results.add(result)
})
})
results = next_results
})
return results
}
fun scope_lookup_helper(name: ref string, scope: *ast_node, visited: set<*ast_node>): vector<*ast_node> {
// need to do properly scopded lookups
// print("scope is: ")
// get_ast_scope(scope)->for_each(fun(key: string, value: vector<*ast_node>) print(key + " ");)
// println()
var results = vector<*ast_node>()
// prevent re-checking the same one...
if (visited.contains(scope))
return results
visited.add(scope)
if (get_ast_scope(scope)->contains_key(name)) {
// println(name + " is in scope, adding to results")
results += get_ast_scope(scope)->get(name)
}
if (get_ast_scope(scope)->contains_key(string("~enclosing_scope")))
results += scope_lookup_helper(name, get_ast_scope(scope)->get(string("~enclosing_scope"))[0], visited)
if (is_translation_unit(scope)) {
scope->translation_unit.children.for_each(fun(child: *ast_node) {
if (is_import(child)) {
if (child->import.imported.contains(name)) {
// println(name + " is indeed imported")
results += scope_lookup_helper(name, child->import.translation_unit, visited)
} else if (child->import.starred) {
// println("import has an import *, checking along it")
results += scope_lookup_helper(name, child->import.translation_unit, visited)
} else {
// println(name + " is not imported (this time)")
// print("import imports")
// child->import.imported.for_each(fun(it: string) print(it + " ");)
}
}
})
}
return results
}
fun has_method(object: *ast_node, name: *char, parameter_types: vector<*type>): bool return has_method(object, string(name), parameter_types);
fun has_method(object: *ast_node, name: string, parameter_types: vector<*type>): bool {
var to_ret = function_lookup(name, object, parameter_types) || false
return to_ret
}
fun get_from_scope(node: *ast_node, member: *char): *ast_node
return get_from_scope(node, string(member))
fun get_from_scope(node: *ast_node, member: string): *ast_node {
/*if (get_ast_scope(node)->contains(member))*/
return get_ast_scope(node)->get(member).first()
/*return null<ast_node>()*/
}
fun make_method_call(object_ident: *ast_node, name: *char, parameters: vector<*ast_node>): *ast_node return make_method_call(object_ident, string(name), parameters);
fun make_method_call(object_ident: *ast_node, name: string, parameters: vector<*ast_node>): *ast_node {
// note that this type_def is the adt_def if this is an adt type
var method = function_lookup(name, get_ast_type(object_ident)->type_def, parameters.map(fun(param: *ast_node): *type return get_ast_type(param);))
return make_method_call(object_ident, method, parameters)
}
fun make_method_call(object_ident: *ast_node, method: *ast_node, parameters: vector<*ast_node>): *ast_node {
var access_op = "."
if (get_ast_type(object_ident)->indirection)
access_op = "->"
var method_access = ast_function_call_ptr(get_builtin_function(string(access_op), vector(get_ast_type(object_ident), get_ast_type(method))), vector(object_ident, method))
return ast_function_call_ptr(method_access, parameters)
}
fun make_operator_call(func: *char, params: vector<*ast_node>): *ast_node return make_operator_call(string(func), params);
fun make_operator_call(func: string, params: vector<*ast_node>): *ast_node {
return ast_function_call_ptr(get_builtin_function(func, params.map(fun(p:*ast_node): *type return get_ast_type(p);)), params)
}
fun get_builtin_function(name: *char, param_types: vector<*type>): *ast_node
return get_builtin_function(string(name), param_types, null<tree<symbol>>())
fun get_builtin_function(name: string, param_types: vector<*type>): *ast_node
return get_builtin_function(name, param_types, null<tree<symbol>>())
fun get_builtin_function(name: string, param_types: vector<*type>, syntax: *tree<symbol>): *ast_node {
// none of the builtin functions should take in references
param_types = param_types.map(fun(t: *type): *type return t->clone_without_ref();)
if (name == "==" || name == "!=" || name == ">" || name == "<" || name == "<=" || name == ">" || name == ">=" || name == "&&" || name == "||" || name == "!")
return ast_function_ptr(name, type_ptr(param_types, type_ptr(base_type::boolean()), 0, false, false, true), vector<*ast_node>(), false)
if (name == "." || name == "->") {
if (name == "->" && param_types[0]->indirection == 0)
error(syntax, string("drereferencing not a pointer: ") + name)
else if (name == "." && param_types[0]->indirection != 0)
error(syntax, string("dot operator on a pointer: ") + name)
else
return ast_function_ptr(name, type_ptr(param_types, param_types[1], 0, false, false, true), vector<*ast_node>(), false)
}
if (name == "[]") {
if (param_types[0]->indirection == 0)
error(syntax, string("drereferencing not a pointer: ") + name)
else
return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_decreased_indirection(), 0, false, false, true), vector<*ast_node>(), false)
}
if (name == "&" && param_types.size == 1)
return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_increased_indirection(), 0, false, false, true), vector<*ast_node>(), false)
if (name == "*" && param_types.size == 1) {
if (param_types[0]->indirection == 0)
error(syntax, string("drereferencing not a pointer: ") + name)
else
return ast_function_ptr(name, type_ptr(param_types, param_types[0]->clone_with_decreased_indirection(), 0, false, false, true), vector<*ast_node>(), false)
}
if (param_types.size > 1 && param_types[1]->rank() > param_types[0]->rank())
return ast_function_ptr(name, type_ptr(param_types, param_types[1], 0, false, false, true), vector<*ast_node>(), false)
return ast_function_ptr(name, type_ptr(param_types, param_types[0], 0, false, false, true), vector<*ast_node>(), false)
}
fun possible_object_equality(lvalue: *ast_node, rvalue: *ast_node): *ast_node {
var ltype = get_ast_type(lvalue)
var rtype = get_ast_type(rvalue)
if (ltype->indirection == 0 && (ltype->is_object() && has_method(ltype->type_def, "operator==", vector(rtype)))) {
return make_method_call(lvalue, "operator==", vector(rvalue))
} else if (ltype->is_object())
// return false if object but no operator== (right now don't try for templated)
return ast_value_ptr(string("false"), type_ptr(base_type::boolean()))
return make_operator_call("==", vector(lvalue, rvalue))
}
fun concat_symbol_tree(node: *tree<symbol>): string {
var str.construct(): string
if (node->data.data != "no_value")
str += node->data.data
node->children.for_each(fun(child: *tree<symbol>) str += concat_symbol_tree(child);)
return str
}
fun get_node(lookup: *char, parent: *tree<symbol>): *tree<symbol> {
return get_node(string(lookup), parent)
}
fun get_node(lookup: string, parent: *tree<symbol>): *tree<symbol> {
var results = get_nodes(lookup, parent)
if (results.size > 1)
error(parent, "get node too many results!")
if (results.size)
return results[0]
return null<tree<symbol>>()
}
fun get_nodes(lookup: *char, parent: *tree<symbol>): vector<*tree<symbol>> {
return get_nodes(string(lookup), parent)
}
fun get_nodes(lookup: string, parent: *tree<symbol>): vector<*tree<symbol>> {
return parent->children.filter(fun(node: *tree<symbol>):bool return node->data.name == lookup;)
}
fun add_to_scope(name: *char, to_add: *ast_node, add_to: *ast_node) {
add_to_scope(string(name), to_add, add_to)
}
fun add_to_scope(name: string, to_add: *ast_node, add_to: *ast_node) {
var add_to_map = get_ast_scope(add_to)
if (add_to_map->contains_key(name))
(*add_to_map)[name].add(to_add)
else
add_to_map->set(name, vector(to_add))
}
// for now, needs source to already be in a variable for copy_constructing
fun assign_or_copy_construct_statement(lvalue: *ast_node, rvalue: *ast_node): *ast_node {
var ltype = get_ast_type(lvalue)
if (ltype->indirection == 0 && (ltype->is_object() && has_method(ltype->type_def, "copy_construct", vector(ltype->clone_with_increased_indirection()))))
return make_method_call(lvalue, "copy_construct", vector(make_operator_call("&", vector(rvalue))))
return ast_assignment_statement_ptr(lvalue, rvalue)
}
fun get_children_pointer(node: *ast_node): *vector<*ast_node> {
var bc = null<vector<*ast_node>>()
match(*node) {
ast_node::translation_unit(backing) bc = &node->translation_unit.children
ast_node::code_block(backing) bc = &node->code_block.children
}
return bc
}
fun remove(orig: *ast_node, in: *stack<*ast_node>): *ast_node
return remove(orig, in->top())
fun remove(orig: *ast_node, in: *ast_node): *ast_node {
var bc = get_children_pointer(in)
if (bc) {
var i = bc->find(orig)
if (i >= 0) {
var temp = bc->at(i)
bc->remove(i)
return temp
}
}
error(string("cannot remove inside ") + get_ast_name(in))
}
fun replace_with_in(orig: *ast_node, new: *ast_node, in: *stack<*ast_node>)
replace_with_in(orig, new, in->top())
fun replace_with_in(orig: *ast_node, new: *ast_node, in: *ast_node) {
match (*in) {
ast_node::return_statement(backing) { backing.return_value = new; return; }
ast_node::cast(backing) { if (backing.value == orig) { backing.value = new; return; } }
ast_node::assignment_statement(backing) {
if (backing.to == orig) {
backing.to = new
return
}
if (backing.from == orig) {
backing.from = new
return
}
}
ast_node::declaration_statement(backing) {
if (backing.identifier == orig) {
backing.identifier = new
return
}
if (backing.expression == orig) {
backing.expression = new
return
}
}
ast_node::if_statement(backing) {
if (backing.condition == orig) {
backing.condition = new
return
}
if (backing.then_part == orig) {
backing.then_part = new
return
}
if (backing.else_part == orig) {
backing.else_part = new
return
}
}
ast_node::for_loop(backing) {
if (backing.init == orig) {
backing.init = new
return
}
if (backing.condition == orig) {
backing.condition = new
return
}
if (backing.update == orig) {
backing.update = new
return
}
if (backing.body == orig) {
backing.body = new
return
}
}
ast_node::while_loop(backing) {
if (backing.condition == orig) {
backing.condition = new
return
}
if (backing.statement == orig) {
backing.statement = new
return
}
}
ast_node::function(backing) {
if (backing.body_statement == orig)
backing.body_statement = new
return
}
ast_node::function_call(backing) {
if (backing.func == orig) {
backing.func = new
return
}
for (var i = 0; i < backing.parameters.size; i++;) {
if (backing.parameters[i] == orig) {
backing.parameters[i] = new
return
}
}
}
}
var bc = get_children_pointer(in)
if (bc) {
var i = bc->find(orig)
if (i >= 0) {
bc->set(i, new)
return
}
}
error(string("cannot replace_with_in inside ") + get_ast_name(in))
}
fun add_before_in(to_add: ref vector<*ast_node>, before: *ast_node, in: *stack<*ast_node>)
to_add.for_each(fun(n: *ast_node) add_before_in(n, before, in);)
fun add_before_in(to_add: vector<*ast_node>, before: *ast_node, in: *ast_node)
to_add.for_each(fun(n: *ast_node) add_before_in(n, before, in);)
fun add_before_in(to_add: *ast_node, before: *ast_node, in: *stack<*ast_node>)
add_before_in(to_add, before, in->top())
fun add_before_in(to_add: *ast_node, before: *ast_node, in: *ast_node) {
var bc = get_children_pointer(in)
if (bc) {
var i = bc->find(before)
if (i >= 0) {
bc->add(i, to_add)
return
}
}
error(string("cannot add_before_in to ") + get_ast_name(in))
}
fun add_after_in(to_add: *ast_node, before: *ast_node, in: *stack<*ast_node>)
add_after_in(to_add, before, in->top())
fun add_after_in(to_add: *ast_node, before: *ast_node, in: *ast_node) {
var bc = get_children_pointer(in)
if (bc) {
var i = bc->find(before)
if (i >= 0) {
bc->add(i+1, to_add)
return
}
}
error(string("cannot add_after_in to ") + get_ast_name(in))
}
fun empty_pass_first_half(): fun(*ast_node, *stack<*ast_node>, *hash_set<*ast_node>): bool {
return fun(node: *ast_node, parent_chain: *stack<*ast_node>, visited: *hash_set<*ast_node>): bool { return true; }
}
fun empty_pass_second_half(): fun(*ast_node, *stack<*ast_node>): void {
return fun(node: *ast_node, parent_chain: *stack<*ast_node>) {}
}
fun run_on_tree(func_before: fun(*ast_node,*stack<*ast_node>):void, func_after: fun(*ast_node,*stack<*ast_node>):void, tree: *ast_node, visited: *hash_set<*ast_node>)
run_on_tree(fun(n: *ast_node, s: *stack<*ast_node>, v: *hash_set<*ast_node>): bool {func_before(n, s);return true;}, func_after, tree, visited)
fun run_on_tree(func_before: fun(*ast_node,*stack<*ast_node>,*hash_set<*ast_node>):bool, func_after: fun(*ast_node,*stack<*ast_node>):void, tree: *ast_node, visited: *hash_set<*ast_node>) {
var parent_stack = stack<*ast_node>()
run_on_tree_helper(func_before, func_after, tree, &parent_stack, visited)
}
fun run_on_tree_helper(func_before: fun(*ast_node,*stack<*ast_node>,*hash_set<*ast_node>):bool,
func_after: fun(*ast_node,*stack<*ast_node>):void,
node: *ast_node, parent_chain: *stack<*ast_node>, visited: *hash_set<*ast_node>) {
// So some nodes should be done regardless of weather or not we've visited them - these are the places where a more reasonable AST might use bindings, i.e. variables and functions.
if (!node || (!is_function(node) && !is_identifier(node) && visited->contains(node))) return;
visited->add(node)
var do_children = func_before(node, parent_chain, visited)
parent_chain->push(node)
if (do_children) {
match(*node) {
ast_node::translation_unit(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::type_def(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::adt_def(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::function(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::template(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::code_block(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::if_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::match_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::case_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::while_loop(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::for_loop(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::return_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::defer_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::assignment_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::declaration_statement(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::if_comp(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::compiler_intrinsic(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
ast_node::function_call(backing) {
run_on_tree_helper(func_before, func_after, backing.func, parent_chain, visited)
node->function_call.parameters.for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
}
ast_node::cast(backing) get_ast_children(node).for_each(fun(n: *ast_node) run_on_tree_helper(func_before, func_after, n, parent_chain, visited);)
}
}
// function may have messed with the parent chain
if (parent_chain->data.contains(node))
while(parent_chain->pop() != node){}
func_after(node, parent_chain)
}
fun is_legal_parameter_node_type(n: *ast_node): bool {
match(*n) {
ast_node::translation_unit() return false;
ast_node::import() return false;
ast_node::identifier() return true;
ast_node::type_def() return false;
ast_node::adt_def() return false;
ast_node::function() return true;
ast_node::template() return false;
ast_node::code_block() return false;
ast_node::if_statement() return false;
ast_node::match_statement() return false;
ast_node::case_statement() return false;
ast_node::while_loop() return false;
ast_node::for_loop() return false;
ast_node::return_statement() return false;
ast_node::branching_statement() return false;
ast_node::defer_statement() return false;
ast_node::assignment_statement() return false;
ast_node::declaration_statement() return false;
ast_node::if_comp() return false;
ast_node::simple_passthrough() return false;
ast_node::function_call() return true;
ast_node::compiler_intrinsic() return true;
ast_node::cast() return true;
ast_node::value() return true;
}
error("is_legal_parameter_node_type with no type")
}
fun get_fully_scoped_name(n: *ast_node): string {
if (!n)
return string("NULL")
var above = string()
var scope_map = get_ast_scope(n);
if (scope_map && scope_map->contains_key(string("~enclosing_scope")))
above = get_fully_scoped_name(scope_map->get(string("~enclosing_scope"))[0])
return above + "::" + get_ast_name(n)
}