Implement (but have not tested) all primitive functions in JIT)

This commit is contained in:
2024-02-05 01:26:57 -05:00
parent 99d4fa5021
commit a4500fed36

View File

@@ -24,6 +24,15 @@ const JIT_LEVEL: i64 = 5;
extern "C" fn rust_print_form(x: Form) {
println!("from jit print: {x}");
}
extern "C" fn rust_cons(car: Form, cdr: Form) -> Form {
Form::new_pair(car, cdr)
}
extern "C" fn rust_eq(a: Form, b: Form) -> Form {
Form::new_bool(a == b)
}
extern "C" fn rust_drop_rc_form(ptr: *mut CrcInner<Form>) {
let _ = Crc::from_ptr(ptr);
}
const TRACE_ID_OFFSET: usize = 32;
const TRACE_OFFSET_MASK: usize = 0xFF_FF_FF_FF; // could be bigger
@@ -89,6 +98,26 @@ impl JIT {
print_sig.params.push(AbiParam::new(self.int));
let print_func = self.module.declare_function("rust_print_form", Linkage::Import, &print_sig).unwrap();
let local_print_func = self.module.declare_func_in_func(print_func, bcx.func);
let mut cons_sig = self.module.make_signature();
cons_sig.params.push(AbiParam::new(self.int));
cons_sig.params.push(AbiParam::new(self.int));
cons_sig.returns.push(AbiParam::new(self.int));
let cons_func = self.module.declare_function("rust_cons", Linkage::Import, &cons_sig).unwrap();
let local_cons_func = self.module.declare_func_in_func(cons_func, bcx.func);
let mut eq_sig = self.module.make_signature();
eq_sig.params.push(AbiParam::new(self.int));
eq_sig.params.push(AbiParam::new(self.int));
eq_sig.returns.push(AbiParam::new(self.int));
let eq_func = self.module.declare_function("rust_eq", Linkage::Import, &eq_sig).unwrap();
let local_eq_func = self.module.declare_func_in_func(eq_func, bcx.func);
let mut drop_sig = self.module.make_signature();
drop_sig.params.push(AbiParam::new(self.int));
let drop_rc_form_func = self.module.declare_function("rust_drop_rc_form", Linkage::Import, &drop_sig).unwrap();
let local_drop_rc_form_func = self.module.declare_func_in_func(drop_rc_form_func, bcx.func);
//let call = bcx.ins().call(local_callee, &[add, add]);
//bcx.inst_results(call)[0]
@@ -97,7 +126,7 @@ impl JIT {
let ok = bcx.ins().icmp_imm(IntCC::Equal, t, tag as i64);
bcx.ins().trapz(ok, TrapCode::User(0));
}
fn stack_pop_norc(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value {
fn stack_pop(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value {
let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
let new_len = bcx.ins().iadd_imm(len, -1);
bcx.ins().store(MemFlags::trusted(), new_len, tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
@@ -108,17 +137,12 @@ impl JIT {
r
}
fn stack_pop_int(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value {
let loaded = stack_pop_norc(bcx, int, tmp_stack_ptr, local_print_func);
type_assert(bcx, loaded, TAG_INT, local_print_func, int);
loaded
}
fn stack_pop_prim(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value {
let loaded = stack_pop_norc(bcx, int, tmp_stack_ptr, local_print_func);
let loaded = stack_pop(bcx, int, tmp_stack_ptr, local_print_func);
type_assert(bcx, loaded, TAG_PRIM, local_print_func, int);
loaded
}
fn gen_stack_push_norc_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) {
fn gen_stack_push_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) {
let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
let new_len = bcx.ins().iadd_imm(len, 1);
bcx.ins().store(MemFlags::trusted(), new_len, tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
@@ -128,6 +152,132 @@ impl JIT {
bcx.ins().store(MemFlags::trusted(), x, item_ptr, 0);
}
fn gen_eq(bcx: &mut FunctionBuilder, int: Type, a: Value, b: Value, local_eq_func: FuncRef, local_print_func: FuncRef) -> Value {
let at = bcx.ins().band_imm(a, TAG_RC as i64);
let bt = bcx.ins().band_imm(b, TAG_RC as i64);
let tt = bcx.ins().band(at, bt);
let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64);
let const_block = bcx.create_block();
bcx.append_block_param(const_block, int);
bcx.append_block_param(const_block, int);
let dyn_block = bcx.create_block();
bcx.append_block_param(dyn_block, int);
bcx.append_block_param(dyn_block, int);
let merge_block = bcx.create_block();
bcx.append_block_param(merge_block, int);
bcx.ins().brif(rc, const_block, &[a,b], dyn_block, &[a,b]);
bcx.switch_to_block(const_block);
let a = bcx.block_params(const_block)[0].clone();
let b = bcx.block_params(const_block)[1].clone();
let e1 = bcx.ins().icmp(IntCC::Equal, a, b);
let e = bcx.ins().sextend(int, e1);
let r = bcx.ins().bor_imm(e, TAG_BOOL_FALSE as i64); //onst TAG_BOOL_FALSE: usize = 0b010; onst TAG_BOOL_TRUE: usize = 0b011;
bcx.ins().jump(merge_block, &[r]);
bcx.switch_to_block(dyn_block);
let a = bcx.block_params(dyn_block)[0].clone();
let b = bcx.block_params(dyn_block)[1].clone();
let call = bcx.ins().call(local_eq_func, &[a, b]);
let r = bcx.inst_results(call)[0].clone();
bcx.ins().jump(merge_block, &[r]);
bcx.switch_to_block(merge_block);
let merge_result = bcx.block_params(merge_block)[0].clone();
merge_result
}
fn increment(bcx: &mut FunctionBuilder, int: Type, x: Value, local_print_func: FuncRef) {
let tt = bcx.ins().band_imm(x, TAG_RC as i64);
let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64);
let incr_block = bcx.create_block();
bcx.append_block_param(incr_block, int);
let merge_block = bcx.create_block();
bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]);
bcx.switch_to_block(incr_block);
let x = bcx.block_params(incr_block)[0].clone();
increment_unchecked(bcx, int, x, local_print_func);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(merge_block);
}
fn increment_unchecked(bcx: &mut FunctionBuilder, int: Type, x: Value, local_print_func: FuncRef) {
let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64);
let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
let new_count = bcx.ins().iadd_imm(count, 1);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32);
}
fn decrement(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) {
let tt = bcx.ins().band_imm(x, TAG_RC as i64);
let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64);
let incr_block = bcx.create_block();
bcx.append_block_param(incr_block, int);
let merge_block = bcx.create_block();
bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]);
bcx.switch_to_block(incr_block);
let x = bcx.block_params(incr_block)[0].clone();
decrement_unchecked(bcx, int, x, local_drop_rc_form_func, local_print_func);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(merge_block);
}
fn decrement_unchecked(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) {
// if it's actually 0, we need to go back into rust to hit the free/drop
let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64);
let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
let live_block = bcx.create_block();
bcx.append_block_param(live_block, int);
bcx.append_block_param(live_block, int);
let dead_block = bcx.create_block();
bcx.append_block_param(dead_block, int);
let merge_block = bcx.create_block();
let rc_1 = bcx.ins().icmp_imm(IntCC::Equal, count, 1 as i64);
bcx.ins().brif(rc_1, live_block, &[crc_inner_ptr,count], dead_block, &[crc_inner_ptr]);
bcx.switch_to_block(live_block);
let crc_inner_ptr = bcx.block_params(live_block)[0].clone();
let count = bcx.block_params(live_block)[1].clone();
let new_count = bcx.ins().iadd_imm(count, -1);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(dead_block);
let call = bcx.ins().call(local_drop_rc_form_func, &[crc_inner_ptr]);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(merge_block);
}
fn gen_cr(bcx: &mut FunctionBuilder, int: Type, x: Value, is_cdr: bool, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) -> Value {
type_assert(bcx, x, TAG_PAIR, local_print_func, int);
// need to RC increment whatever we get
// and RC decrement this pair
let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64);
let result = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, (CRC_INNER_DATA_OFFSET + if is_cdr { FORM_PAIR_CDR_OFFSET } else { FORM_PAIR_CAR_OFFSET }) as i32);
increment(bcx, int, result, local_print_func);
decrement_unchecked(bcx, int, x, local_print_func, local_print_func);
result
//const CRC_INNER_DATA_OFFSET: usize = 8*1;
//const FORM_PAIR_CAR_OFFSET: usize = 8*0;
//Ok(unsafe { &(*((self.data as usize & !TAG_MASK) as *mut CrcInner<FormPair>)).data.car })
//Ok(unsafe { &(*((self.data as usize & !TAG_MASK) as *mut CrcInner<FormPair>)).data.cdr })
//#[repr(C)]
//pub struct CrcInner<T> {
// rc: Cell<usize>,
// data: T,
//}
//#[repr(C)]
//struct FormPair {
// car: Form,
// cdr: Form,
//}
}
let mut offset = 0;
for op in ops {
match op {
@@ -257,34 +407,43 @@ impl JIT {
return Ok(Some((tmp_stack.pop().unwrap(), e, (*nc).clone())));
}
*/
Op::InlinePrim(Prim::Add) => {
Op::InlinePrim(p) => {
/*
let b = tmp_stack.pop().unwrap();
let a = tmp_stack.pop().unwrap();
tmp_stack.pop().unwrap(); // pop the prim
tmp_stack.push(Form::new_int(a.int()? + b.int()?));
*/
let cst = bcx.ins().iconst(self.int, 1337 << 3);
let _ = bcx.ins().call(local_print_func, &[cst]);
let b = stack_pop_int(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
let _ = bcx.ins().call(local_print_func, &[cst]);
let a = stack_pop_int(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
let _ = bcx.ins().call(local_print_func, &[cst]);
*/
let b = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
let a = if p.two_params() { Some(stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func)) } else { None };
let result = match p {
Prim::Car => { gen_cr(&mut bcx, self.int, b, false, local_drop_rc_form_func, local_print_func) },
Prim::Cdr => { gen_cr(&mut bcx, self.int, b, true, local_drop_rc_form_func, local_print_func) },
Prim::Add | Prim::Sub | Prim::Mul | Prim::Div | Prim::Mod => {
let a = a.unwrap();
type_assert(&mut bcx, a, TAG_INT, local_print_func, self.int);
type_assert(&mut bcx, b, TAG_INT, local_print_func, self.int);
match p {
Prim::Add => { bcx.ins().iadd(a, b) }
Prim::Sub => { bcx.ins().isub(a, b) }
Prim::Mul => { bcx.ins().imul(a, b) }
Prim::Div => { bcx.ins().sdiv(a, b) }
Prim::Mod => { bcx.ins().srem(a, b) }
_ => unreachable!(),
}
},
Prim::Eq => { gen_eq(&mut bcx, self.int, a.unwrap(), b, local_eq_func, local_print_func) },
Prim::Cons => { let call = bcx.ins().call(local_cons_func, &[a.unwrap(), b]); bcx.inst_results(call)[0].clone() },
};
let _ = stack_pop_prim(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
let _ = bcx.ins().call(local_print_func, &[cst]);
let result = bcx.ins().iadd(a, b);
let _ = bcx.ins().call(local_print_func, &[result]);
gen_stack_push_norc_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap
// is fine, and ints dont
// rc
gen_stack_push_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap is fine
offset += 1;
}
_ => {
// quit out of trace!
println!("compiling back to interp because we can't compile {op:?}");
assert!(offset <= TRACE_OFFSET_MASK);
let cst = bcx.ins().iconst(self.int, (id.id << TRACE_ID_OFFSET) | offset as i64);
bcx.ins().return_(&[cst]);
@@ -466,6 +625,8 @@ pub struct Crc<T> {
ptr: NonNull<CrcInner<T>>,
phantom: PhantomData<CrcInner<T>>
}
const CRC_INNER_RC_OFFSET: usize = 8*0;
const CRC_INNER_DATA_OFFSET: usize = 8*1;
#[repr(C)]
pub struct CrcInner<T> {
rc: Cell<usize>,
@@ -564,6 +725,8 @@ pub struct Form {
data: *const Form,
phantom: PhantomData<Form>
}
const FORM_PAIR_CAR_OFFSET: usize = 8*0;
const FORM_PAIR_CDR_OFFSET: usize = 8*1;
#[repr(C)]
struct FormPair {
car: Form,
@@ -629,9 +792,12 @@ const TAG_NIL: usize = 0b001;
const TAG_BOOL_FALSE: usize = 0b010;
const TAG_BOOL_TRUE: usize = 0b011;
const TAG_SYMBOL: usize = 0b100;
const TAG_PAIR: usize = 0b101;
const TAG_PRIM: usize = 0b101;
const TAG_CLOSURE: usize = 0b110;
const TAG_PRIM: usize = 0b111;
const TAG_PAIR: usize = 0b111;
const TAG_RC: usize = 0b110;
static SYMBOLS: Lazy<Mutex<BTreeMap<String,&'static str>>> = Lazy::new(Mutex::default);
@@ -1424,7 +1590,8 @@ impl Ctx {
}
Op::Return => {
println!("Return(op)");
let (e, nc, resume_data) = ret_stack.pop().unwrap();
let (ne, nc, resume_data) = ret_stack.pop().unwrap();
e = ne;
if let Some(resume_id) = resume_data {
if self.traces.contains_key(&resume_id) {
id = resume_id;