From 23fbfe9db37cfec48ef4594d742fb1abe7743d7f Mon Sep 17 00:00:00 2001 From: Nathan Braswell Date: Tue, 6 Feb 2024 02:13:19 -0500 Subject: [PATCH] add JIT support for Define, Drop, and Const --- slj/src/lib.rs | 213 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 148 insertions(+), 65 deletions(-) diff --git a/slj/src/lib.rs b/slj/src/lib.rs index ff346cf..af14e40 100644 --- a/slj/src/lib.rs +++ b/slj/src/lib.rs @@ -31,8 +31,10 @@ extern "C" fn rust_cons(car: Form, cdr: Form) -> Form { extern "C" fn rust_eq(a: Form, b: Form) -> Form { Form::new_bool(a == b) } -extern "C" fn rust_drop_rc_form(ptr: *mut CrcInner
) { - let _ = Crc::from_ptr(ptr); +extern "C" fn rust_cvec_form_grow(ptr: &mut Cvec) { + ptr.grow() +} +extern "C" fn rust_drop_rc_form(f: Form) { } const TRACE_ID_OFFSET: usize = 32; @@ -59,6 +61,10 @@ impl JIT { let isa = isa_builder.finish(settings::Flags::new(flag_builder)).unwrap(); let mut jb = JITBuilder::with_isa(isa, default_libcall_names()); jb.symbol("rust_print_form", rust_print_form as *const u8); + jb.symbol("rust_cons", rust_cons as *const u8); + jb.symbol("rust_eq", rust_eq as *const u8); + jb.symbol("rust_cvec_form_grow", rust_cvec_form_grow as *const u8); + jb.symbol("rust_drop_rc_form", rust_drop_rc_form as *const u8); let mut module = JITModule::new(jb); let int = module.target_config().pointer_type(); let mut ctx = module.make_context(); @@ -114,6 +120,12 @@ impl JIT { let eq_func = self.module.declare_function("rust_eq", Linkage::Import, &eq_sig).unwrap(); let local_eq_func = self.module.declare_func_in_func(eq_func, bcx.func); + let mut grow_cvec_form_sig = self.module.make_signature(); + grow_cvec_form_sig.params.push(AbiParam::new(self.int)); + let grow_cvec_form_func = self.module.declare_function("rust_cvec_form_grow", Linkage::Import, &grow_cvec_form_sig).unwrap(); + let local_cvec_form_grow_func = self.module.declare_func_in_func(grow_cvec_form_func, bcx.func); + + let mut drop_sig = self.module.make_signature(); drop_sig.params.push(AbiParam::new(self.int)); let drop_rc_form_func = self.module.declare_function("rust_drop_rc_form", Linkage::Import, &drop_sig).unwrap(); @@ -143,6 +155,23 @@ impl JIT { type_assert(bcx, loaded, TAG_PRIM, local_print_func, int); loaded } + fn gen_stack_push(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_cvec_form_grow_func: FuncRef, local_print_func: FuncRef) { + let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32); + let new_len = bcx.ins().iadd_imm(len, 1); + let cap = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_CAP_OFFSET as i32); + + let grow = bcx.ins().icmp(IntCC::SignedGreaterThanOrEqual, new_len, cap); + + let grow_block = bcx.create_block(); + let merge_block = bcx.create_block(); + + bcx.ins().brif(grow, grow_block, &[], merge_block, &[]); + bcx.switch_to_block(grow_block); + bcx.ins().call(local_cvec_form_grow_func, &[tmp_stack_ptr]); + bcx.ins().jump(merge_block, &[]); + bcx.switch_to_block(merge_block); + gen_stack_push_nocap(bcx, int, tmp_stack_ptr, x, local_print_func); + } fn gen_stack_push_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) { let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32); let new_len = bcx.ins().iadd_imm(len, 1); @@ -161,26 +190,18 @@ impl JIT { let const_block = bcx.create_block(); - bcx.append_block_param(const_block, int); - bcx.append_block_param(const_block, int); let dyn_block = bcx.create_block(); - bcx.append_block_param(dyn_block, int); - bcx.append_block_param(dyn_block, int); let merge_block = bcx.create_block(); bcx.append_block_param(merge_block, int); - bcx.ins().brif(rc, const_block, &[a,b], dyn_block, &[a,b]); + bcx.ins().brif(rc, dyn_block, &[], const_block, &[]); bcx.switch_to_block(const_block); - let a = bcx.block_params(const_block)[0].clone(); - let b = bcx.block_params(const_block)[1].clone(); let e1 = bcx.ins().icmp(IntCC::Equal, a, b); let e = bcx.ins().sextend(int, e1); let r = bcx.ins().bor_imm(e, TAG_BOOL_FALSE as i64); //onst TAG_BOOL_FALSE: usize = 0b010; onst TAG_BOOL_TRUE: usize = 0b011; bcx.ins().jump(merge_block, &[r]); bcx.switch_to_block(dyn_block); - let a = bcx.block_params(dyn_block)[0].clone(); - let b = bcx.block_params(dyn_block)[1].clone(); let call = bcx.ins().call(local_eq_func, &[a, b]); let r = bcx.inst_results(call)[0].clone(); bcx.ins().jump(merge_block, &[r]); @@ -195,11 +216,9 @@ impl JIT { let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64); let incr_block = bcx.create_block(); - bcx.append_block_param(incr_block, int); let merge_block = bcx.create_block(); - bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]); + bcx.ins().brif(rc, incr_block, &[], merge_block, &[]); bcx.switch_to_block(incr_block); - let x = bcx.block_params(incr_block)[0].clone(); increment_unchecked(bcx, int, x, local_print_func); bcx.ins().jump(merge_block, &[]); @@ -209,18 +228,16 @@ impl JIT { let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64); let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); let new_count = bcx.ins().iadd_imm(count, 1); - bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32); + bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); } fn decrement(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) { let tt = bcx.ins().band_imm(x, TAG_RC as i64); let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64); - let incr_block = bcx.create_block(); - bcx.append_block_param(incr_block, int); + let decr_block = bcx.create_block(); let merge_block = bcx.create_block(); - bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]); - bcx.switch_to_block(incr_block); - let x = bcx.block_params(incr_block)[0].clone(); + bcx.ins().brif(rc, decr_block, &[], merge_block, &[]); + bcx.switch_to_block(decr_block); decrement_unchecked(bcx, int, x, local_drop_rc_form_func, local_print_func); bcx.ins().jump(merge_block, &[]); @@ -232,24 +249,19 @@ impl JIT { let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); let live_block = bcx.create_block(); - bcx.append_block_param(live_block, int); - bcx.append_block_param(live_block, int); let dead_block = bcx.create_block(); - bcx.append_block_param(dead_block, int); let merge_block = bcx.create_block(); let rc_1 = bcx.ins().icmp_imm(IntCC::Equal, count, 1 as i64); - bcx.ins().brif(rc_1, live_block, &[crc_inner_ptr,count], dead_block, &[crc_inner_ptr]); + bcx.ins().brif(rc_1, dead_block, &[], live_block, &[]); bcx.switch_to_block(live_block); - let crc_inner_ptr = bcx.block_params(live_block)[0].clone(); - let count = bcx.block_params(live_block)[1].clone(); let new_count = bcx.ins().iadd_imm(count, -1); - bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32); + bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_RC_OFFSET as i32); bcx.ins().jump(merge_block, &[]); bcx.switch_to_block(dead_block); - let call = bcx.ins().call(local_drop_rc_form_func, &[crc_inner_ptr]); + let call = bcx.ins().call(local_drop_rc_form_func, &[x]); bcx.ins().jump(merge_block, &[]); bcx.switch_to_block(merge_block); @@ -311,30 +323,6 @@ impl JIT { Op::Debug => { println!("Debug(op) {}", tmp_stack.last().unwrap()); } - Op::Define { sym } => { - let v = tmp_stack.pop().unwrap(); - println!("Define(op) {sym} = {}", v); - e = e.define(sym, v); - } - Op::Const ( con ) => { - println!("Const(op) {con}"); - tmp_stack.push(con.clone()); - } - Op::Drop => { - println!("Drop(op) {}", tmp_stack.last().unwrap()); - tmp_stack.pop().unwrap(); - } - Op::Lookup { sym } => { - println!("Lookup(op) {sym}"); - tmp_stack.push(e.lookup(sym)?.clone()); - } - Op::InlinePrim(prim) => { - println!("InlinePrim(op) {prim:?}"); - let b = tmp_stack.pop().unwrap(); - let a = if prim.two_params() { Some(tmp_stack.pop().unwrap()) } else { None }; - tmp_stack.pop().unwrap(); // pop the prim - tmp_stack.push(eval_prim(*prim, b, a)?); - } Op::Call { len, nc, nc_id, statik } => { println!("Call(op)"); if let Some(static_call_id) = statik { @@ -394,6 +382,9 @@ impl JIT { unimplemented!(); } Op::Return => { + // we should pop ret_stack, and if trace id isn't 0, we should try to jump + // to it from some sort of id/func table (also set e, and decrement crc when popping). + // If not, we need to return out to the interpreter. println!("Return(op)"); let (e, nc, resume_data) = ret_stack.pop().unwrap(); if let Some(resume_id) = resume_data { @@ -407,6 +398,10 @@ impl JIT { self.try_resume_trace(resume_data); return Ok(Some((tmp_stack.pop().unwrap(), e, (*nc).clone()))); } + Op::Lookup { sym } => { + println!("Lookup(op) {sym}"); + tmp_stack.push(e.lookup(sym)?.clone()); + } */ Op::InlinePrim(p) => { /* @@ -442,6 +437,52 @@ impl JIT { gen_stack_push_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap is fine offset += 1; } + Op::Define { sym } => { + let s = bcx.ins().iconst(self.int, unsafe { *((&mut Form::new_symbol(sym)) as *mut Form as *mut usize) } as i64); + let v = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func); + + let e = bcx.ins().load(self.int, MemFlags::trusted(), e_ptr, 0); + let kv = { + let call = bcx.ins().call(local_cons_func, &[s, v]); + bcx.inst_results(call)[0].clone() + }; + let new_e = { + let call = bcx.ins().call(local_cons_func, &[kv, e]); + bcx.inst_results(call)[0].clone() + }; + bcx.ins().store(MemFlags::trusted(), new_e, e_ptr, 0); + offset += 1; + + //pub fn define(&self, s: &str, v: Form) -> Form { + // Form::new_pair(Form::new_pair(Form::new_symbol(s), v), self.clone()) + //} + //let v = tmp_stack.pop().unwrap(); + //println!("Define(op) {sym} = {}", v); + //e = e.define(sym, v); + } + Op::Drop => { + let v = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func); + decrement(&mut bcx, self.int, v, local_drop_rc_form_func, local_print_func); + offset += 1; + //println!("Drop(op) {}", tmp_stack.last().unwrap()); + //tmp_stack.pop().unwrap(); + } + Op::Const ( con ) => { + println!("doing a const"); + con.increment(); + println!("incr"); + let con_rc = con.is_rc(); + let con = bcx.ins().iconst(self.int, unsafe { *(con as *const Form as *const usize) } as i64); + if con_rc { + increment_unchecked(&mut bcx, self.int, con, local_print_func); + } + gen_stack_push(&mut bcx, self.int, tmp_stack_ptr, con, local_cvec_form_grow_func, local_print_func); + offset += 1; + /* + println!("Const(op) {con}"); + tmp_stack.push(con.clone()); + */ + } _ => { // quit out of trace! println!("compiling back to interp because we can't compile {op:?}"); @@ -452,10 +493,31 @@ impl JIT { } } } + //[ + // Define { sym: "n" }, + // Define { sym: "faft_h" }, + // Drop, + // Const(Form(Eq)), + // Lookup { sym: "n" }, + // Const(Form(1)), + // InlinePrim(Eq), + // Guard { const_value: Form(false), side_val: Some(Form(('debug 1))), side_cont: Eval {..}, side_id: ID { id: 3 }, }, + // Drop, + // Const(Form(Add)), + // Lookup { sym: "n" }, + // Lookup { sym: "faft_h" }, + // Lookup { sym: "faft_h" }, + // Const(Form(Sub)), + // Lookup { sym: "n" }, + // Const(Form(1)), + // InlinePrim(Sub), + // Call { len: 3, statik: None, nc: Call { n: 3, to_go: Form(nil), c: Ret { id: ID { id: 1 } } }, nc_id: ID { id: 4 } } + //], bcx.seal_all_blocks(); bcx.finalize(); } + println!("{:?}", self.ctx.func); self.module.define_function(func_a, &mut self.ctx).unwrap(); self.module.clear_context(&mut self.ctx); @@ -845,6 +907,9 @@ impl Form { Self { data: (((ds.as_ptr() as usize) << SYM_PTR_OFFSET) | (ds.len() << SYM_LEN_OFFSET) | TAG_SYMBOL) as *const Form, phantom: PhantomData } } + pub fn is_rc(&self) -> bool { + self.data as usize & TAG_RC == TAG_RC + } pub fn int(&self) -> Result { if self.data as usize & TAG_MASK == TAG_INT { Ok(self.data as isize >> 3) @@ -957,16 +1022,32 @@ impl Form { e = ne; } } + fn increment(&self) { + match self.data as usize & TAG_MASK { + TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("increment simple {self}");*/ }, + TAG_PAIR => { + //println!("increment pair"); + unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner)).increment(); } + }, + TAG_CLOSURE => { + //println!("increment pair"); + unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner)).increment(); } + }, + _ => unreachable!(), + } + } } impl Drop for Form { fn drop(&mut self) { match self.data as usize & TAG_MASK { - TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("dropping simple {self}"); */ }, // doing nothing for symbol is fine + TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("dropping simple {self}");*/ }, // doing nothing for symbol is fine // since it's deduplicated TAG_PAIR => { + //println!("dropping pair"); let _ = Crc::::from_ptr( (self.data as usize & !TAG_MASK) as *mut CrcInner ); }, TAG_CLOSURE => { + //println!("dropping closure"); let _ = Crc::::from_ptr( (self.data as usize & !TAG_MASK) as *mut CrcInner ); }, _ => unreachable!(), @@ -975,18 +1056,8 @@ impl Drop for Form { } impl Clone for Form { fn clone(&self) -> Self { - match self.data as usize & TAG_MASK { - TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { Self { data: self.data, phantom: PhantomData } }, - TAG_PAIR => { - unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner)).increment(); } - Self { data: self.data, phantom: PhantomData } - }, - TAG_CLOSURE => { - unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner)).increment(); } - Self { data: self.data, phantom: PhantomData } - }, - _ => unreachable!(), - } + self.increment(); + Self { data: self.data, phantom: PhantomData } } } impl PartialEq for Form { @@ -1016,18 +1087,23 @@ impl fmt::Display for Form { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.data as usize & TAG_MASK { TAG_INT => { + //println!("printing int"); write!(f, "{}", self.data as isize >> 3)?; }, TAG_NIL => { + //println!("printing nil"); write!(f, "nil")?; }, TAG_BOOL_FALSE => { + //println!("printing false"); write!(f, "false")?; }, TAG_BOOL_TRUE => { + //println!("printing true"); write!(f, "true")?; }, TAG_PAIR => { + //println!("printinga pair"); write!(f, "({}", self.car().unwrap())?; let mut traverse = self.cdr().unwrap(); loop { @@ -1048,14 +1124,18 @@ impl fmt::Display for Form { } }, TAG_PRIM => { + //println!("printinga prim"); write!(f, "{:?}", self.prim().unwrap())?; }, TAG_SYMBOL => { + //println!("printinga symbol"); write!(f, "'{}", self.sym().unwrap())?; }, TAG_CLOSURE => { + //println!("printinga closure"); let Closure { params, e, body, id, } = self.closure().unwrap(); write!(f, "<{params} {e} {body} {id}>")?; + //println!("doneprinting closure"); }, _ => unreachable!(), } @@ -1452,10 +1532,13 @@ impl Ctx { if offset == 0 { if let Some(f) = self.compiled_traces.get(&id) { println!("Calling JIT function {id}, tmp_stack is {:?}!", tmp_stack); + println!("{}", unsafe { *((&mut e) as *mut Form as *mut usize) } as usize); let trace_ret = f(&mut e, tmp_stack, ret_stack); let new_id = ID { id: NonZeroI64::new((trace_ret >> TRACE_ID_OFFSET) as i64).unwrap() }; offset = trace_ret & TRACE_OFFSET_MASK; println!("\tresult of call is new_id {new_id} and offset {offset}, tmp_stack is {:?}!", tmp_stack); + println!("{}", unsafe { *((&mut e) as *mut Form as *mut usize) } as usize); + println!("\te is {e}!"); // if we've returned a new trace to start, do so, // otherwise fall through to trace interpretation immediately if new_id != id {