add JIT support for Define, Drop, and Const

This commit is contained in:
2024-02-06 02:13:19 -05:00
parent 525b103f38
commit 23fbfe9db3

View File

@@ -31,8 +31,10 @@ extern "C" fn rust_cons(car: Form, cdr: Form) -> Form {
extern "C" fn rust_eq(a: Form, b: Form) -> Form {
Form::new_bool(a == b)
}
extern "C" fn rust_drop_rc_form(ptr: *mut CrcInner<Form>) {
let _ = Crc::from_ptr(ptr);
extern "C" fn rust_cvec_form_grow(ptr: &mut Cvec<Form>) {
ptr.grow()
}
extern "C" fn rust_drop_rc_form(f: Form) {
}
const TRACE_ID_OFFSET: usize = 32;
@@ -59,6 +61,10 @@ impl JIT {
let isa = isa_builder.finish(settings::Flags::new(flag_builder)).unwrap();
let mut jb = JITBuilder::with_isa(isa, default_libcall_names());
jb.symbol("rust_print_form", rust_print_form as *const u8);
jb.symbol("rust_cons", rust_cons as *const u8);
jb.symbol("rust_eq", rust_eq as *const u8);
jb.symbol("rust_cvec_form_grow", rust_cvec_form_grow as *const u8);
jb.symbol("rust_drop_rc_form", rust_drop_rc_form as *const u8);
let mut module = JITModule::new(jb);
let int = module.target_config().pointer_type();
let mut ctx = module.make_context();
@@ -114,6 +120,12 @@ impl JIT {
let eq_func = self.module.declare_function("rust_eq", Linkage::Import, &eq_sig).unwrap();
let local_eq_func = self.module.declare_func_in_func(eq_func, bcx.func);
let mut grow_cvec_form_sig = self.module.make_signature();
grow_cvec_form_sig.params.push(AbiParam::new(self.int));
let grow_cvec_form_func = self.module.declare_function("rust_cvec_form_grow", Linkage::Import, &grow_cvec_form_sig).unwrap();
let local_cvec_form_grow_func = self.module.declare_func_in_func(grow_cvec_form_func, bcx.func);
let mut drop_sig = self.module.make_signature();
drop_sig.params.push(AbiParam::new(self.int));
let drop_rc_form_func = self.module.declare_function("rust_drop_rc_form", Linkage::Import, &drop_sig).unwrap();
@@ -143,6 +155,23 @@ impl JIT {
type_assert(bcx, loaded, TAG_PRIM, local_print_func, int);
loaded
}
fn gen_stack_push(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_cvec_form_grow_func: FuncRef, local_print_func: FuncRef) {
let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
let new_len = bcx.ins().iadd_imm(len, 1);
let cap = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_CAP_OFFSET as i32);
let grow = bcx.ins().icmp(IntCC::SignedGreaterThanOrEqual, new_len, cap);
let grow_block = bcx.create_block();
let merge_block = bcx.create_block();
bcx.ins().brif(grow, grow_block, &[], merge_block, &[]);
bcx.switch_to_block(grow_block);
bcx.ins().call(local_cvec_form_grow_func, &[tmp_stack_ptr]);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(merge_block);
gen_stack_push_nocap(bcx, int, tmp_stack_ptr, x, local_print_func);
}
fn gen_stack_push_nocap(bcx: &mut FunctionBuilder, int: Type, tmp_stack_ptr: Value, x: Value, local_print_func: FuncRef) {
let len = bcx.ins().load(int, MemFlags::trusted(), tmp_stack_ptr, CVEC_LEN_OFFSET as i32);
let new_len = bcx.ins().iadd_imm(len, 1);
@@ -161,26 +190,18 @@ impl JIT {
let const_block = bcx.create_block();
bcx.append_block_param(const_block, int);
bcx.append_block_param(const_block, int);
let dyn_block = bcx.create_block();
bcx.append_block_param(dyn_block, int);
bcx.append_block_param(dyn_block, int);
let merge_block = bcx.create_block();
bcx.append_block_param(merge_block, int);
bcx.ins().brif(rc, const_block, &[a,b], dyn_block, &[a,b]);
bcx.ins().brif(rc, dyn_block, &[], const_block, &[]);
bcx.switch_to_block(const_block);
let a = bcx.block_params(const_block)[0].clone();
let b = bcx.block_params(const_block)[1].clone();
let e1 = bcx.ins().icmp(IntCC::Equal, a, b);
let e = bcx.ins().sextend(int, e1);
let r = bcx.ins().bor_imm(e, TAG_BOOL_FALSE as i64); //onst TAG_BOOL_FALSE: usize = 0b010; onst TAG_BOOL_TRUE: usize = 0b011;
bcx.ins().jump(merge_block, &[r]);
bcx.switch_to_block(dyn_block);
let a = bcx.block_params(dyn_block)[0].clone();
let b = bcx.block_params(dyn_block)[1].clone();
let call = bcx.ins().call(local_eq_func, &[a, b]);
let r = bcx.inst_results(call)[0].clone();
bcx.ins().jump(merge_block, &[r]);
@@ -195,11 +216,9 @@ impl JIT {
let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64);
let incr_block = bcx.create_block();
bcx.append_block_param(incr_block, int);
let merge_block = bcx.create_block();
bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]);
bcx.ins().brif(rc, incr_block, &[], merge_block, &[]);
bcx.switch_to_block(incr_block);
let x = bcx.block_params(incr_block)[0].clone();
increment_unchecked(bcx, int, x, local_print_func);
bcx.ins().jump(merge_block, &[]);
@@ -209,18 +228,16 @@ impl JIT {
let crc_inner_ptr = bcx.ins().band_imm(x, (!TAG_MASK) as i64);
let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
let new_count = bcx.ins().iadd_imm(count, 1);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
}
fn decrement(bcx: &mut FunctionBuilder, int: Type, x: Value, local_drop_rc_form_func: FuncRef, local_print_func: FuncRef) {
let tt = bcx.ins().band_imm(x, TAG_RC as i64);
let rc = bcx.ins().icmp_imm(IntCC::Equal, tt, TAG_RC as i64);
let incr_block = bcx.create_block();
bcx.append_block_param(incr_block, int);
let decr_block = bcx.create_block();
let merge_block = bcx.create_block();
bcx.ins().brif(rc, incr_block, &[x], merge_block, &[]);
bcx.switch_to_block(incr_block);
let x = bcx.block_params(incr_block)[0].clone();
bcx.ins().brif(rc, decr_block, &[], merge_block, &[]);
bcx.switch_to_block(decr_block);
decrement_unchecked(bcx, int, x, local_drop_rc_form_func, local_print_func);
bcx.ins().jump(merge_block, &[]);
@@ -232,24 +249,19 @@ impl JIT {
let count = bcx.ins().load(int, MemFlags::trusted(), crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
let live_block = bcx.create_block();
bcx.append_block_param(live_block, int);
bcx.append_block_param(live_block, int);
let dead_block = bcx.create_block();
bcx.append_block_param(dead_block, int);
let merge_block = bcx.create_block();
let rc_1 = bcx.ins().icmp_imm(IntCC::Equal, count, 1 as i64);
bcx.ins().brif(rc_1, live_block, &[crc_inner_ptr,count], dead_block, &[crc_inner_ptr]);
bcx.ins().brif(rc_1, dead_block, &[], live_block, &[]);
bcx.switch_to_block(live_block);
let crc_inner_ptr = bcx.block_params(live_block)[0].clone();
let count = bcx.block_params(live_block)[1].clone();
let new_count = bcx.ins().iadd_imm(count, -1);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_DATA_OFFSET as i32);
bcx.ins().store(MemFlags::trusted(), new_count, crc_inner_ptr, CRC_INNER_RC_OFFSET as i32);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(dead_block);
let call = bcx.ins().call(local_drop_rc_form_func, &[crc_inner_ptr]);
let call = bcx.ins().call(local_drop_rc_form_func, &[x]);
bcx.ins().jump(merge_block, &[]);
bcx.switch_to_block(merge_block);
@@ -311,30 +323,6 @@ impl JIT {
Op::Debug => {
println!("Debug(op) {}", tmp_stack.last().unwrap());
}
Op::Define { sym } => {
let v = tmp_stack.pop().unwrap();
println!("Define(op) {sym} = {}", v);
e = e.define(sym, v);
}
Op::Const ( con ) => {
println!("Const(op) {con}");
tmp_stack.push(con.clone());
}
Op::Drop => {
println!("Drop(op) {}", tmp_stack.last().unwrap());
tmp_stack.pop().unwrap();
}
Op::Lookup { sym } => {
println!("Lookup(op) {sym}");
tmp_stack.push(e.lookup(sym)?.clone());
}
Op::InlinePrim(prim) => {
println!("InlinePrim(op) {prim:?}");
let b = tmp_stack.pop().unwrap();
let a = if prim.two_params() { Some(tmp_stack.pop().unwrap()) } else { None };
tmp_stack.pop().unwrap(); // pop the prim
tmp_stack.push(eval_prim(*prim, b, a)?);
}
Op::Call { len, nc, nc_id, statik } => {
println!("Call(op)");
if let Some(static_call_id) = statik {
@@ -394,6 +382,9 @@ impl JIT {
unimplemented!();
}
Op::Return => {
// we should pop ret_stack, and if trace id isn't 0, we should try to jump
// to it from some sort of id/func table (also set e, and decrement crc<cont> when popping).
// If not, we need to return out to the interpreter.
println!("Return(op)");
let (e, nc, resume_data) = ret_stack.pop().unwrap();
if let Some(resume_id) = resume_data {
@@ -407,6 +398,10 @@ impl JIT {
self.try_resume_trace(resume_data);
return Ok(Some((tmp_stack.pop().unwrap(), e, (*nc).clone())));
}
Op::Lookup { sym } => {
println!("Lookup(op) {sym}");
tmp_stack.push(e.lookup(sym)?.clone());
}
*/
Op::InlinePrim(p) => {
/*
@@ -442,6 +437,52 @@ impl JIT {
gen_stack_push_nocap(&mut bcx, self.int, tmp_stack_ptr, result, local_print_func); // we just popped, so cap is fine
offset += 1;
}
Op::Define { sym } => {
let s = bcx.ins().iconst(self.int, unsafe { *((&mut Form::new_symbol(sym)) as *mut Form as *mut usize) } as i64);
let v = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
let e = bcx.ins().load(self.int, MemFlags::trusted(), e_ptr, 0);
let kv = {
let call = bcx.ins().call(local_cons_func, &[s, v]);
bcx.inst_results(call)[0].clone()
};
let new_e = {
let call = bcx.ins().call(local_cons_func, &[kv, e]);
bcx.inst_results(call)[0].clone()
};
bcx.ins().store(MemFlags::trusted(), new_e, e_ptr, 0);
offset += 1;
//pub fn define(&self, s: &str, v: Form) -> Form {
// Form::new_pair(Form::new_pair(Form::new_symbol(s), v), self.clone())
//}
//let v = tmp_stack.pop().unwrap();
//println!("Define(op) {sym} = {}", v);
//e = e.define(sym, v);
}
Op::Drop => {
let v = stack_pop(&mut bcx, self.int, tmp_stack_ptr, local_print_func);
decrement(&mut bcx, self.int, v, local_drop_rc_form_func, local_print_func);
offset += 1;
//println!("Drop(op) {}", tmp_stack.last().unwrap());
//tmp_stack.pop().unwrap();
}
Op::Const ( con ) => {
println!("doing a const");
con.increment();
println!("incr");
let con_rc = con.is_rc();
let con = bcx.ins().iconst(self.int, unsafe { *(con as *const Form as *const usize) } as i64);
if con_rc {
increment_unchecked(&mut bcx, self.int, con, local_print_func);
}
gen_stack_push(&mut bcx, self.int, tmp_stack_ptr, con, local_cvec_form_grow_func, local_print_func);
offset += 1;
/*
println!("Const(op) {con}");
tmp_stack.push(con.clone());
*/
}
_ => {
// quit out of trace!
println!("compiling back to interp because we can't compile {op:?}");
@@ -452,10 +493,31 @@ impl JIT {
}
}
}
//[
// Define { sym: "n" },
// Define { sym: "faft_h" },
// Drop,
// Const(Form(Eq)),
// Lookup { sym: "n" },
// Const(Form(1)),
// InlinePrim(Eq),
// Guard { const_value: Form(false), side_val: Some(Form(('debug 1))), side_cont: Eval {..}, side_id: ID { id: 3 }, },
// Drop,
// Const(Form(Add)),
// Lookup { sym: "n" },
// Lookup { sym: "faft_h" },
// Lookup { sym: "faft_h" },
// Const(Form(Sub)),
// Lookup { sym: "n" },
// Const(Form(1)),
// InlinePrim(Sub),
// Call { len: 3, statik: None, nc: Call { n: 3, to_go: Form(nil), c: Ret { id: ID { id: 1 } } }, nc_id: ID { id: 4 } }
//],
bcx.seal_all_blocks();
bcx.finalize();
}
println!("{:?}", self.ctx.func);
self.module.define_function(func_a, &mut self.ctx).unwrap();
self.module.clear_context(&mut self.ctx);
@@ -845,6 +907,9 @@ impl Form {
Self { data: (((ds.as_ptr() as usize) << SYM_PTR_OFFSET) | (ds.len() << SYM_LEN_OFFSET) | TAG_SYMBOL) as *const Form, phantom: PhantomData }
}
pub fn is_rc(&self) -> bool {
self.data as usize & TAG_RC == TAG_RC
}
pub fn int(&self) -> Result<isize> {
if self.data as usize & TAG_MASK == TAG_INT {
Ok(self.data as isize >> 3)
@@ -957,16 +1022,32 @@ impl Form {
e = ne;
}
}
fn increment(&self) {
match self.data as usize & TAG_MASK {
TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("increment simple {self}");*/ },
TAG_PAIR => {
//println!("increment pair");
unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner<FormPair>)).increment(); }
},
TAG_CLOSURE => {
//println!("increment pair");
unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner<Closure>)).increment(); }
},
_ => unreachable!(),
}
}
}
impl Drop for Form {
fn drop(&mut self) {
match self.data as usize & TAG_MASK {
TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("dropping simple {self}"); */ }, // doing nothing for symbol is fine
TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { /*println!("dropping simple {self}");*/ }, // doing nothing for symbol is fine
// since it's deduplicated
TAG_PAIR => {
//println!("dropping pair");
let _ = Crc::<FormPair>::from_ptr( (self.data as usize & !TAG_MASK) as *mut CrcInner<FormPair> );
},
TAG_CLOSURE => {
//println!("dropping closure");
let _ = Crc::<Closure>::from_ptr( (self.data as usize & !TAG_MASK) as *mut CrcInner<Closure> );
},
_ => unreachable!(),
@@ -975,18 +1056,8 @@ impl Drop for Form {
}
impl Clone for Form {
fn clone(&self) -> Self {
match self.data as usize & TAG_MASK {
TAG_INT | TAG_NIL | TAG_BOOL_FALSE | TAG_BOOL_TRUE | TAG_PRIM | TAG_SYMBOL => { Self { data: self.data, phantom: PhantomData } },
TAG_PAIR => {
unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner<FormPair>)).increment(); }
Self { data: self.data, phantom: PhantomData }
},
TAG_CLOSURE => {
unsafe { (*((self.data as usize & !TAG_MASK) as *mut CrcInner<Closure>)).increment(); }
Self { data: self.data, phantom: PhantomData }
},
_ => unreachable!(),
}
self.increment();
Self { data: self.data, phantom: PhantomData }
}
}
impl PartialEq for Form {
@@ -1016,18 +1087,23 @@ impl fmt::Display for Form {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.data as usize & TAG_MASK {
TAG_INT => {
//println!("printing int");
write!(f, "{}", self.data as isize >> 3)?;
},
TAG_NIL => {
//println!("printing nil");
write!(f, "nil")?;
},
TAG_BOOL_FALSE => {
//println!("printing false");
write!(f, "false")?;
},
TAG_BOOL_TRUE => {
//println!("printing true");
write!(f, "true")?;
},
TAG_PAIR => {
//println!("printinga pair");
write!(f, "({}", self.car().unwrap())?;
let mut traverse = self.cdr().unwrap();
loop {
@@ -1048,14 +1124,18 @@ impl fmt::Display for Form {
}
},
TAG_PRIM => {
//println!("printinga prim");
write!(f, "{:?}", self.prim().unwrap())?;
},
TAG_SYMBOL => {
//println!("printinga symbol");
write!(f, "'{}", self.sym().unwrap())?;
},
TAG_CLOSURE => {
//println!("printinga closure");
let Closure { params, e, body, id, } = self.closure().unwrap();
write!(f, "<{params} {e} {body} {id}>")?;
//println!("doneprinting closure");
},
_ => unreachable!(),
}
@@ -1452,10 +1532,13 @@ impl Ctx {
if offset == 0 {
if let Some(f) = self.compiled_traces.get(&id) {
println!("Calling JIT function {id}, tmp_stack is {:?}!", tmp_stack);
println!("{}", unsafe { *((&mut e) as *mut Form as *mut usize) } as usize);
let trace_ret = f(&mut e, tmp_stack, ret_stack);
let new_id = ID { id: NonZeroI64::new((trace_ret >> TRACE_ID_OFFSET) as i64).unwrap() };
offset = trace_ret & TRACE_OFFSET_MASK;
println!("\tresult of call is new_id {new_id} and offset {offset}, tmp_stack is {:?}!", tmp_stack);
println!("{}", unsafe { *((&mut e) as *mut Form as *mut usize) } as usize);
println!("\te is {e}!");
// if we've returned a new trace to start, do so,
// otherwise fall through to trace interpretation immediately
if new_id != id {