From a4500fed36e59931ce80494e71b174ac3e5a46bb Mon Sep 17 00:00:00 2001 From: Nathan Braswell Date: Mon, 5 Feb 2024 01:26:57 -0500 Subject: [PATCH] Implement (but have not tested) all primitive functions in JIT) --- slj/src/lib.rs | 221 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 194 insertions(+), 27 deletions(-) diff --git a/slj/src/lib.rs b/slj/src/lib.rs index 6c428d5..521f2f6 100644 --- a/slj/src/lib.rs +++ b/slj/src/lib.rs @@ -24,6 +24,15 @@ const JIT_LEVEL: i64 = 5; extern "C" fn rust_print_form(x: Form) { println!("from jit print: {x}"); } +extern "C" fn rust_cons(car: Form, cdr: Form) -> Form { + Form::new_pair(car, cdr) +} +extern "C" fn rust_eq(a: Form, b: Form) -> Form { + Form::new_bool(a == b) +} +extern "C" fn rust_drop_rc_form(ptr: *mut CrcInner
) { + let _ = Crc::from_ptr(ptr); +} const TRACE_ID_OFFSET: usize = 32; const TRACE_OFFSET_MASK: usize = 0xFF_FF_FF_FF; // could be bigger @@ -89,6 +98,26 @@ impl JIT { print_sig.params.push(AbiParam::new(self.int)); let print_func = self.module.declare_function("rust_print_form", Linkage::Import, &print_sig).unwrap(); let local_print_func = self.module.declare_func_in_func(print_func, bcx.func); + + let mut cons_sig = self.module.make_signature(); + cons_sig.params.push(AbiParam::new(self.int)); + cons_sig.params.push(AbiParam::new(self.int)); + cons_sig.returns.push(AbiParam::new(self.int)); + let cons_func = self.module.declare_function("rust_cons", Linkage::Import, &cons_sig).unwrap(); + let local_cons_func = self.module.declare_func_in_func(cons_func, bcx.func); + + let mut eq_sig = self.module.make_signature(); + eq_sig.params.push(AbiParam::new(self.int)); + eq_sig.params.push(AbiParam::new(self.int)); + eq_sig.returns.push(AbiParam::new(self.int)); + let eq_func = self.module.declare_function("rust_eq", Linkage::Import, &eq_sig).unwrap(); + let local_eq_func = self.module.declare_func_in_func(eq_func, bcx.func); + + let mut drop_sig = self.module.make_signature(); + drop_sig.params.push(AbiParam::new(self.int)); + let drop_rc_form_func = self.module.declare_function("rust_drop_rc_form", Linkage::Import, &drop_sig).unwrap(); + let local_drop_rc_form_func = self.module.declare_func_in_func(drop_rc_form_func, bcx.func); + //let call = bcx.ins().call(local_callee, &[add, add]); //bcx.inst_results(call)[0] @@ -97,7 +126,7 @@ impl JIT { let ok = bcx.ins().icmp_imm(IntCC::Equal, t, tag as i64); bcx.ins().trapz(ok, TrapCode::User(0)); } - fn stack_pop_norc(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value { + fn stack_pop(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value { let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32); let new_len = bcx.ins().iadd_imm(len, -1); bcx.ins().store(MemFlags::trusted(), new_len, tmp_stack_ptr, CVEC_LEN_OFFSET as i32); @@ -108,17 +137,12 @@ impl JIT { r } - fn stack_pop_int(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value { - let loaded = stack_pop_norc(bcx, int, tmp_stack_ptr, local_print_func); - type_assert(bcx, loaded, TAG_INT, local_print_func, int); - loaded - } fn stack_pop_prim(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, local_print_func: FuncRef) -> Value { - let loaded = stack_pop_norc(bcx, int, tmp_stack_ptr, local_print_func); + let loaded = stack_pop(bcx, int, tmp_stack_ptr, local_print_func); type_assert(bcx, loaded, TAG_PRIM, local_print_func, int); loaded } - fn gen_stack_push_norc_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) { + fn gen_stack_push_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) { let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32); let new_len = bcx.ins().iadd_imm(len, 1); bcx.ins().store(MemFlags::trusted(), new_len, tmp_stack_ptr, CVEC_LEN_OFFSET as i32); @@ -128,6 +152,132 @@ impl JIT { bcx.ins().store(MemFlags::trusted(), x, item_ptr, 0); } + fn gen_eq(bcx: &mut FunctionBuilder, int: Type, a: Value, b: Value, local_eq_func: FuncRef, local_print_func: FuncRef) -> Value { + let at = bcx.ins().band_imm(a, TAG_RC as i64); + let bt = bcx.ins().band_imm(b, TAG_RC as i64); + let tt = bcx.ins().band(at, bt); + let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64); + + + let const_block = bcx.create_block(); + bcx.append_block_param(const_block, int); + bcx.append_block_param(const_block, int); + let dyn_block = bcx.create_block(); + bcx.append_block_param(dyn_block, int); + bcx.append_block_param(dyn_block, int); + let merge_block = bcx.create_block(); + bcx.append_block_param(merge_block, int); + bcx.ins().brif(rc, const_block, &[a,b], dyn_block, &[a,b]); + + bcx.switch_to_block(const_block); + let a = bcx.block_params(const_block)[0].clone(); + let b = bcx.block_params(const_block)[1].clone(); + let e1 = bcx.ins().icmp(IntCC::Equal, a, b); + let e = bcx.ins().sextend(int, e1); + let r = bcx.ins().bor_imm(e, TAG_BOOL_FALSE as i64); //onst TAG_BOOL_FALSE: usize = 0b010; onst TAG_BOOL_TRUE: usize = 0b011; + bcx.ins().jump(merge_block, &[r]); + + bcx.switch_to_block(dyn_block); + let a = bcx.block_params(dyn_block)[0].clone(); + let b = bcx.block_params(dyn_block)[1].clone(); + let call = bcx.ins().call(local_eq_func, &[a, b]); + let r = bcx.inst_results(call)[0].clone(); + bcx.ins().jump(merge_block, &[r]); + + bcx.switch_to_block(merge_block); + let merge_result = bcx.block_params(merge_block)[0].clone(); + merge_result + } + + fn increment(bcx: &mut FunctionBuilder, int: Type, x: Value, local_print_func: FuncRef) { + let tt = bcx.ins().band_imm(x, TAG_RC as i64); + let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64); + + let incr_block = bcx.create_block(); + bcx.append_block_param(incr_block, int); + let merge_block = bcx.create_block(); + bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]); + bcx.switch_to_block(incr_block); + let x = bcx.block_params(incr_block)[0].clone(); + increment_unchecked(bcx, int, x, local_print_func); + bcx.ins().jump(merge_block, &[]); + + bcx.switch_to_block(merge_block); + } + fn increment_unchecked(bcx: &mut FunctionBuilder, int: Type, x: Value, local_print_func: FuncRef) { + let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64); + let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); + let new_count = bcx.ins().iadd_imm(count, 1); + bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32); + } + fn decrement(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) { + let tt = bcx.ins().band_imm(x, TAG_RC as i64); + let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64); + + let incr_block = bcx.create_block(); + bcx.append_block_param(incr_block, int); + let merge_block = bcx.create_block(); + bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]); + bcx.switch_to_block(incr_block); + let x = bcx.block_params(incr_block)[0].clone(); + decrement_unchecked(bcx, int, x, local_drop_rc_form_func, local_print_func); + bcx.ins().jump(merge_block, &[]); + + bcx.switch_to_block(merge_block); + } + fn decrement_unchecked(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) { + // if it's actually 0, we need to go back into rust to hit the free/drop + let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64); + let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); + + let live_block = bcx.create_block(); + bcx.append_block_param(live_block, int); + bcx.append_block_param(live_block, int); + let dead_block = bcx.create_block(); + bcx.append_block_param(dead_block, int); + let merge_block = bcx.create_block(); + + let rc_1 = bcx.ins().icmp_imm(IntCC::Equal, count, 1 as i64); + bcx.ins().brif(rc_1, live_block, &[crc_inner_ptr,count], dead_block, &[crc_inner_ptr]); + + bcx.switch_to_block(live_block); + let crc_inner_ptr = bcx.block_params(live_block)[0].clone(); + let count = bcx.block_params(live_block)[1].clone(); + let new_count = bcx.ins().iadd_imm(count, -1); + bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32); + bcx.ins().jump(merge_block, &[]); + + bcx.switch_to_block(dead_block); + let call = bcx.ins().call(local_drop_rc_form_func, &[crc_inner_ptr]); + bcx.ins().jump(merge_block, &[]); + + bcx.switch_to_block(merge_block); + } + + fn gen_cr(bcx: &mut FunctionBuilder, int: Type, x: Value, is_cdr: bool, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) -> Value { + type_assert(bcx, x, TAG_PAIR, local_print_func, int); + // need to RC increment whatever we get + // and RC decrement this pair + let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64); + let result = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, (CRC_INNER_DATA_OFFSET + if is_cdr { FORM_PAIR_CDR_OFFSET } else { FORM_PAIR_CAR_OFFSET }) as i32); + increment(bcx, int, result, local_print_func); + decrement_unchecked(bcx, int, x, local_print_func, local_print_func); + result + //const CRC_INNER_DATA_OFFSET: usize = 8*1; + //const FORM_PAIR_CAR_OFFSET: usize = 8*0; + //Ok(unsafe { &(*((self.data as usize & !TAG_MASK) as *mut CrcInner)).data.car }) + //Ok(unsafe { &(*((self.data as usize & !TAG_MASK) as *mut CrcInner)).data.cdr }) + //#[repr(C)] + //pub struct CrcInner { + // rc: Cell, + // data: T, + //} + //#[repr(C)] + //struct FormPair { + // car: Form, + // cdr: Form, + //} + } let mut offset = 0; for op in ops { match op { @@ -257,34 +407,43 @@ impl JIT { return Ok(Some((tmp_stack.pop().unwrap(), e, (*nc).clone()))); } */ - Op::InlinePrim(Prim::Add) => { + Op::InlinePrim(p) => { /* let b = tmp_stack.pop().unwrap(); let a = tmp_stack.pop().unwrap(); tmp_stack.pop().unwrap(); // pop the prim tmp_stack.push(Form::new_int(a.int()? + b.int()?)); - */ let cst = bcx.ins().iconst(self.int, 1337 << 3); let _ = bcx.ins().call(local_print_func, &[cst]); - - let b = stack_pop_int(&mut bcx, self.int, tmp_stack_ptr, local_print_func); - let _ = bcx.ins().call(local_print_func, &[cst]); - - let a = stack_pop_int(&mut bcx, self.int, tmp_stack_ptr, local_print_func); - let _ = bcx.ins().call(local_print_func, &[cst]); - + */ + let b = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func); + let a = if p.two_params() { Some(stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func)) } else { None }; + let result = match p { + Prim::Car => { gen_cr(&mut bcx, self.int, b, false, local_drop_rc_form_func, local_print_func) }, + Prim::Cdr => { gen_cr(&mut bcx, self.int, b, true, local_drop_rc_form_func, local_print_func) }, + Prim::Add | Prim::Sub | Prim::Mul | Prim::Div | Prim::Mod => { + let a = a.unwrap(); + type_assert(&mut bcx, a, TAG_INT, local_print_func, self.int); + type_assert(&mut bcx, b, TAG_INT, local_print_func, self.int); + match p { + Prim::Add => { bcx.ins().iadd(a, b) } + Prim::Sub => { bcx.ins().isub(a, b) } + Prim::Mul => { bcx.ins().imul(a, b) } + Prim::Div => { bcx.ins().sdiv(a, b) } + Prim::Mod => { bcx.ins().srem(a, b) } + _ => unreachable!(), + } + }, + Prim::Eq => { gen_eq(&mut bcx, self.int, a.unwrap(), b, local_eq_func, local_print_func) }, + Prim::Cons => { let call = bcx.ins().call(local_cons_func, &[a.unwrap(), b]); bcx.inst_results(call)[0].clone() }, + }; let _ = stack_pop_prim(&mut bcx, self.int, tmp_stack_ptr, local_print_func); - let _ = bcx.ins().call(local_print_func, &[cst]); - - let result = bcx.ins().iadd(a, b); - let _ = bcx.ins().call(local_print_func, &[result]); - gen_stack_push_norc_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap - // is fine, and ints dont - // rc + gen_stack_push_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap is fine offset += 1; } _ => { // quit out of trace! + println!("compiling back to interp because we can't compile {op:?}"); assert!(offset <= TRACE_OFFSET_MASK); let cst = bcx.ins().iconst(self.int, (id.id << TRACE_ID_OFFSET) | offset as i64); bcx.ins().return_(&[cst]); @@ -466,6 +625,8 @@ pub struct Crc { ptr: NonNull>, phantom: PhantomData> } +const CRC_INNER_RC_OFFSET: usize = 8*0; +const CRC_INNER_DATA_OFFSET: usize = 8*1; #[repr(C)] pub struct CrcInner { rc: Cell, @@ -564,6 +725,8 @@ pub struct Form { data: *const Form, phantom: PhantomData } +const FORM_PAIR_CAR_OFFSET: usize = 8*0; +const FORM_PAIR_CDR_OFFSET: usize = 8*1; #[repr(C)] struct FormPair { car: Form, @@ -629,9 +792,12 @@ const TAG_NIL: usize = 0b001; const TAG_BOOL_FALSE: usize = 0b010; const TAG_BOOL_TRUE: usize = 0b011; const TAG_SYMBOL: usize = 0b100; -const TAG_PAIR: usize = 0b101; +const TAG_PRIM: usize = 0b101; + const TAG_CLOSURE: usize = 0b110; -const TAG_PRIM: usize = 0b111; +const TAG_PAIR: usize = 0b111; + +const TAG_RC: usize = 0b110; static SYMBOLS: Lazy>> = Lazy::new(Mutex::default); @@ -1424,7 +1590,8 @@ impl Ctx { } Op::Return => { println!("Return(op)"); - let (e, nc, resume_data) = ret_stack.pop().unwrap(); + let (ne, nc, resume_data) = ret_stack.pop().unwrap(); + e = ne; if let Some(resume_id) = resume_data { if self.traces.contains_key(&resume_id) { id = resume_id;