diff --git a/slj/src/lib.rs b/slj/src/lib.rs index 0b5ffaa..baf5cde 100644 --- a/slj/src/lib.rs +++ b/slj/src/lib.rs @@ -18,6 +18,8 @@ use cranelift_module::{default_libcall_names, Linkage, Module, FuncId}; use once_cell::sync::Lazy; +const TRACE_LEVEL: i64 = 1; +const JIT_LEVEL: i64 = 5; extern "C" fn rust_add1(x: Form, y: Form) -> Form { println!("Add 1"); @@ -28,6 +30,8 @@ extern "C" fn rust_add2(x: isize, y: isize) -> isize { x + y } +const TRACE_ID_OFFSET: usize = 32; +const TRACE_OFFSET_MASK: usize = 0xFF_FF_FF_FF; // could be bigger // https://github.com/bytecodealliance/wasmtime/blob/main/cranelift/jit/examples/jit-minimal.rs pub struct JIT { module: JITModule, @@ -57,6 +61,70 @@ impl JIT { Self { module, ctx, func_ctx, int } } + fn comple_trace(&mut self, id: ID, ops: &Vec) -> (FuncId, extern "C" fn(&mut Form, &mut Cvec
, &mut Cvec<(Form, Crc, Option)>) -> usize) { + let mut sig_a = self.module.make_signature(); + sig_a.call_conv = isa::CallConv::Tail; + sig_a.params.push(AbiParam::new(self.int)); + sig_a.params.push(AbiParam::new(self.int)); + sig_a.params.push(AbiParam::new(self.int)); + sig_a.returns.push(AbiParam::new(self.int)); + let func_a = self.module.declare_function(&format!("{id}_inner"), Linkage::Local, &sig_a).unwrap(); + + let mut sig_b = self.module.make_signature(); + sig_b.params.push(AbiParam::new(self.int)); + sig_b.params.push(AbiParam::new(self.int)); + sig_b.params.push(AbiParam::new(self.int)); + sig_b.returns.push(AbiParam::new(self.int)); + let func_b = self.module.declare_function(&format!("{id}_outer"), Linkage::Local, &sig_b).unwrap(); + + self.ctx.func.signature = sig_a; + self.ctx.func.name = UserFuncName::user(0, func_a.as_u32()); + { + let mut bcx: FunctionBuilder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); + let block = bcx.create_block(); + bcx.switch_to_block(block); + bcx.append_block_params_for_function_params(block); + let param = bcx.block_params(block)[0]; + + let cst = bcx.ins().iconst(self.int, id.id << TRACE_ID_OFFSET); + bcx.ins().return_(&[cst]); + + bcx.seal_all_blocks(); + bcx.finalize(); + } + self.module.define_function(func_a, &mut self.ctx).unwrap(); + self.module.clear_context(&mut self.ctx); + + self.ctx.func.signature = sig_b; + self.ctx.func.name = UserFuncName::user(0, func_b.as_u32()); + { + let mut bcx: FunctionBuilder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); + let block = bcx.create_block(); + bcx.switch_to_block(block); + bcx.append_block_params_for_function_params(block); + let local_func = self.module.declare_func_in_func(func_a, &mut bcx.func); + let params = bcx.block_params(block).iter().cloned().collect::>(); + let call = bcx.ins().call(local_func, ¶ms); + let value = { + let results = bcx.inst_results(call); + assert_eq!(results.len(), 1); + results[0].clone() + }; + bcx.ins().return_(&[value]); + bcx.seal_all_blocks(); + bcx.finalize(); + } + + self.module.define_function(func_b, &mut self.ctx).unwrap(); + self.module.clear_context(&mut self.ctx); + + // perform linking + self.module.finalize_definitions().unwrap(); + + let code_b = self.module.get_finalized_function(func_b); + let ptr_b = unsafe { mem::transmute::<_, _>(code_b) }; + (func_a, ptr_b) + } // returns the id for the inner and the pointer for the outer pub fn compile_with_wrapper(&mut self) -> (FuncId, extern "C" fn (Form) -> Form) { let mut sig_a = self.module.make_signature(); @@ -857,14 +925,14 @@ impl fmt::Display for Trace { Ok(()) } } -#[derive(Debug)] struct Ctx { id_counter: i64, cont_count: BTreeMap, tracing: Option, traces: BTreeMap>, - compiled_traces: BTreeMap, &mut Cvec<(Form, Crc, Option)>) -> *mut CrcInner>, + compiled_traces: BTreeMap, &mut Cvec<(Form, Crc, Option)>) -> usize>, trace_resume_data: BTreeMap, + jit: JIT, } impl fmt::Display for Ctx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -880,6 +948,7 @@ impl Ctx { traces: BTreeMap::new(), compiled_traces: BTreeMap::new(), trace_resume_data: BTreeMap::new(), + jit: JIT::new(), } } fn alloc_id(&mut self) -> ID { @@ -982,7 +1051,7 @@ impl Ctx { let entry = self.cont_count.entry(id).or_insert(0); println!("tracing call start for {id}, has been called {} times so far", *entry); *entry += 1; - if *entry > 1 && self.tracing.is_none() && self.traces.get(&id).is_none() { + if *entry > TRACE_LEVEL && self.tracing.is_none() && self.traces.get(&id).is_none() { self.tracing = Some(Trace::new(id, id)); } @@ -1077,8 +1146,8 @@ impl Ctx { } // returns f, e, c for interp fn execute_trace_if_exists(&mut self, - id: ID, - e: &Form, + mut id: ID, + mut e: Form, tmp_stack: &mut Cvec, ret_stack: &mut Cvec<(Form, Crc, Option)>) -> Result> { if self.trace_running() { @@ -1087,30 +1156,48 @@ impl Ctx { // in the future it should just tack on the opcodes while jugging the proper // bookkeeping stacks } - if let Some(f) = self.compiled_traces.get(&id) { - let mut e = e.clone(); - let trace_ret = f(&mut e, tmp_stack, ret_stack); - if trace_ret != std::ptr::null_mut() { - return Ok(Some((tmp_stack.pop().unwrap(), e, (*Crc::from_ptr_clone(trace_ret)).clone()))); - } else { - bail!("some sort of error in compiled trace. Got to figure out how to report this better"); + let mut offset = 0; + loop { + if offset == 0 { + if let Some(f) = self.compiled_traces.get(&id) { + println!("Calling JIT function {id}!"); + let trace_ret = f(&mut e, tmp_stack, ret_stack); + let new_id = ID { id: (trace_ret >> TRACE_ID_OFFSET) as i64 }; + offset = trace_ret & TRACE_OFFSET_MASK; + println!("\tresult of call is new_id {new_id} and offset {offset}"); + // if we've returned a new trace to start, do so, + // otherwise fall through to trace interpretation immediately + if new_id != id { + id = new_id; + continue; + } + } } - } else if let Some(mut trace) = self.traces.get(&id) { - println!("Starting trace playback"); - let mut e = e.clone(); - loop { + if let Some(trace) = self.traces.get(&id) { + + let entry = self.cont_count.entry(id).or_insert(0); + *entry += 1; + if *entry > JIT_LEVEL && !self.compiled_traces.contains_key(&id) { + println!("Compiling trace for {id}!"); + let (inner_id, outer_ptr) = self.jit.comple_trace(id, trace); + self.compiled_traces.insert(id, outer_ptr); + continue; + } + + println!("Starting trace playback"); println!("Running trace {trace:?}, \n\ttmp_stack:{tmp_stack:?}"); - for b in trace.iter() { + for b in trace.iter().skip(offset) { match b { Op::Guard { const_value, side_val, side_cont, side_id, tbk } => { println!("Guard(op) {const_value}"); if const_value != tmp_stack.last().unwrap() { - if let Some(new_trace) = self.traces.get(side_id) { + if self.traces.contains_key(side_id) { if side_val.is_some() { tmp_stack.pop().unwrap(); } println!("\tchaining trace to side trace"); - trace = new_trace; + id = *side_id; + offset = 0; break; // break out of this trace and let infinate loop spin } else { println!("\tending playback b/c failed guard"); @@ -1156,10 +1243,11 @@ impl Ctx { Op::Call { len, nc, nc_id, statik } => { println!("Call(op)"); if let Some(static_call_id) = statik { - if let Some(new_trace) = self.traces.get(static_call_id) { + if self.traces.contains_key(static_call_id) { ret_stack.push((e.clone(), (*nc).clone(), Some(*nc_id))); println!("\tchaining to call trace b/c Call with statik"); - trace = new_trace; + id = *static_call_id; + offset = 0; break; // break out of this trace and let infinate loop spin } } @@ -1170,10 +1258,11 @@ impl Ctx { let b = tmp_stack.pop().unwrap(); let a = if *len == 2 { None } else { assert!(*len == 3); Some(tmp_stack.pop().unwrap()) }; let result = eval_prim(p, b, a)?; - if let Some(new_trace) = self.traces.get(nc_id) { + if self.traces.contains_key(nc_id) { *tmp_stack.last_mut().unwrap() = result; // for the prim itself println!("\tchaining to ret trace b/c Call with dyamic but primitive and next traced"); - trace = new_trace; + id = *nc_id; + offset = 0; break; // break out of this trace and let infinate loop spin } else { println!("\tstopping playback to ret b/c Call with dyamic but primitive and next not-traced"); @@ -1187,10 +1276,11 @@ impl Ctx { bail!("arguments length doesn't match"); } ret_stack.push((e.clone(), (*nc).clone(), Some(*nc_id))); - if let Some(new_trace) = self.traces.get(call_id) { + if self.traces.contains_key(call_id) { println!("\tchaining to call trace b/c Call with dyamic but traced"); e = ie.clone(); - trace = new_trace; + id = *call_id; + offset = 0; break; // break out of this trace and let infinate loop spin } else { return Ok(Some((b.clone(), ie.clone(), Cont::Frame { syms: ps.clone(), id: *call_id, c: Crc::new(Cont::Eval { c: Crc::new(Cont::Ret { id: *call_id }) }) }))); @@ -1212,9 +1302,9 @@ impl Ctx { println!("Return(op)"); let (e, nc, resume_data) = ret_stack.pop().unwrap(); if let Some(resume_id) = resume_data { - if let Some(new_trace) = self.traces.get(&resume_id) { - println!("\tchaining to return trace b/c Return {resume_id} - {new_trace:?}"); - trace = new_trace; + if self.traces.contains_key(&resume_id) { + id = resume_id; + offset = 0; break; // break out of this trace and let infinate loop spin } } @@ -1224,9 +1314,9 @@ impl Ctx { } } } + } else { + return Ok(None); } - } else { - Ok(None) } } } @@ -1318,7 +1408,7 @@ pub fn eval(f: Form) -> Result { e = ne; if let Some(nc_id) = resume_data { tmp_stack.push(f); // ugly dance pt 1 - if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(nc_id, &e, &mut tmp_stack, &mut ret_stack)? { + if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(nc_id, e.clone(), &mut tmp_stack, &mut ret_stack)? { f = fp; e = ep; c = cp; @@ -1350,7 +1440,7 @@ pub fn eval(f: Form) -> Result { bail!("arguments length doesn't match"); } ret_stack.push((e.clone(), nc, resume_data)); - if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(*id, ie, &mut tmp_stack, &mut ret_stack)? { + if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(*id, ie.clone(), &mut tmp_stack, &mut ret_stack)? { f = fp; e = ep; c = cp; diff --git a/slj/src/main.rs b/slj/src/main.rs index 6499a3c..eb6e2c0 100644 --- a/slj/src/main.rs +++ b/slj/src/main.rs @@ -26,7 +26,7 @@ fn main() -> Result<()> { //let res = ptr_c(Form::new_int(1337)); //println!("sucessful 3 run with result {res}"); - return Ok(()); + //return Ok(()); fn alias(a: Crc, b: Crc) { println!("a: {}, b: {}", *a, *b);