integrated JIT compiler with simple generated do-nothing JIT traces. Modified JIT trace interface so that it can return a trace-id and offset to jump to, allowing the trace interpreter to handle hard cases, etc. (or in this case, all cases, as it just immediatly returns the current id and offset 0).

This commit is contained in:
2024-02-02 01:46:13 -05:00
parent 9cd46a31eb
commit 776fc7c921
2 changed files with 123 additions and 33 deletions

View File

@@ -18,6 +18,8 @@ use cranelift_module::{default_libcall_names, Linkage, Module, FuncId};
use once_cell::sync::Lazy;
const TRACE_LEVEL: i64 = 1;
const JIT_LEVEL: i64 = 5;
extern "C" fn rust_add1(x: Form, y: Form) -> Form {
println!("Add 1");
@@ -28,6 +30,8 @@ extern "C" fn rust_add2(x: isize, y: isize) -> isize {
x + y
}
const TRACE_ID_OFFSET: usize = 32;
const TRACE_OFFSET_MASK: usize = 0xFF_FF_FF_FF; // could be bigger
// https://github.com/bytecodealliance/wasmtime/blob/main/cranelift/jit/examples/jit-minimal.rs
pub struct JIT {
module: JITModule,
@@ -57,6 +61,70 @@ impl JIT {
Self { module, ctx, func_ctx, int }
}
fn comple_trace(&mut self, id: ID, ops: &Vec<Op>) -> (FuncId, extern "C" fn(&mut Form, &mut Cvec<Form>, &mut Cvec<(Form, Crc<Cont>, Option<ID>)>) -> usize) {
let mut sig_a = self.module.make_signature();
sig_a.call_conv = isa::CallConv::Tail;
sig_a.params.push(AbiParam::new(self.int));
sig_a.params.push(AbiParam::new(self.int));
sig_a.params.push(AbiParam::new(self.int));
sig_a.returns.push(AbiParam::new(self.int));
let func_a = self.module.declare_function(&format!("{id}_inner"), Linkage::Local, &sig_a).unwrap();
let mut sig_b = self.module.make_signature();
sig_b.params.push(AbiParam::new(self.int));
sig_b.params.push(AbiParam::new(self.int));
sig_b.params.push(AbiParam::new(self.int));
sig_b.returns.push(AbiParam::new(self.int));
let func_b = self.module.declare_function(&format!("{id}_outer"), Linkage::Local, &sig_b).unwrap();
self.ctx.func.signature = sig_a;
self.ctx.func.name = UserFuncName::user(0, func_a.as_u32());
{
let mut bcx: FunctionBuilder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx);
let block = bcx.create_block();
bcx.switch_to_block(block);
bcx.append_block_params_for_function_params(block);
let param = bcx.block_params(block)[0];
let cst = bcx.ins().iconst(self.int, id.id << TRACE_ID_OFFSET);
bcx.ins().return_(&[cst]);
bcx.seal_all_blocks();
bcx.finalize();
}
self.module.define_function(func_a, &mut self.ctx).unwrap();
self.module.clear_context(&mut self.ctx);
self.ctx.func.signature = sig_b;
self.ctx.func.name = UserFuncName::user(0, func_b.as_u32());
{
let mut bcx: FunctionBuilder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx);
let block = bcx.create_block();
bcx.switch_to_block(block);
bcx.append_block_params_for_function_params(block);
let local_func = self.module.declare_func_in_func(func_a, &mut bcx.func);
let params = bcx.block_params(block).iter().cloned().collect::<Vec<_>>();
let call = bcx.ins().call(local_func, &params);
let value = {
let results = bcx.inst_results(call);
assert_eq!(results.len(), 1);
results[0].clone()
};
bcx.ins().return_(&[value]);
bcx.seal_all_blocks();
bcx.finalize();
}
self.module.define_function(func_b, &mut self.ctx).unwrap();
self.module.clear_context(&mut self.ctx);
// perform linking
self.module.finalize_definitions().unwrap();
let code_b = self.module.get_finalized_function(func_b);
let ptr_b = unsafe { mem::transmute::<_, _>(code_b) };
(func_a, ptr_b)
}
// returns the id for the inner and the pointer for the outer
pub fn compile_with_wrapper(&mut self) -> (FuncId, extern "C" fn (Form) -> Form) {
let mut sig_a = self.module.make_signature();
@@ -857,14 +925,14 @@ impl fmt::Display for Trace {
Ok(())
}
}
#[derive(Debug)]
struct Ctx {
id_counter: i64,
cont_count: BTreeMap<ID, i64>,
tracing: Option<Trace>,
traces: BTreeMap<ID, Vec<Op>>,
compiled_traces: BTreeMap<ID, extern "C" fn(&mut Form, &mut Cvec<Form>, &mut Cvec<(Form, Crc<Cont>, Option<ID>)>) -> *mut CrcInner<Cont>>,
compiled_traces: BTreeMap<ID, extern "C" fn(&mut Form, &mut Cvec<Form>, &mut Cvec<(Form, Crc<Cont>, Option<ID>)>) -> usize>,
trace_resume_data: BTreeMap<ID, TraceBookkeeping>,
jit: JIT,
}
impl fmt::Display for Ctx {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -880,6 +948,7 @@ impl Ctx {
traces: BTreeMap::new(),
compiled_traces: BTreeMap::new(),
trace_resume_data: BTreeMap::new(),
jit: JIT::new(),
}
}
fn alloc_id(&mut self) -> ID {
@@ -982,7 +1051,7 @@ impl Ctx {
let entry = self.cont_count.entry(id).or_insert(0);
println!("tracing call start for {id}, has been called {} times so far", *entry);
*entry += 1;
if *entry > 1 && self.tracing.is_none() && self.traces.get(&id).is_none() {
if *entry > TRACE_LEVEL && self.tracing.is_none() && self.traces.get(&id).is_none() {
self.tracing = Some(Trace::new(id, id));
}
@@ -1077,8 +1146,8 @@ impl Ctx {
}
// returns f, e, c for interp
fn execute_trace_if_exists(&mut self,
id: ID,
e: &Form,
mut id: ID,
mut e: Form,
tmp_stack: &mut Cvec<Form>,
ret_stack: &mut Cvec<(Form, Crc<Cont>, Option<ID>)>) -> Result<Option<(Form, Form, Cont)>> {
if self.trace_running() {
@@ -1087,30 +1156,48 @@ impl Ctx {
// in the future it should just tack on the opcodes while jugging the proper
// bookkeeping stacks
}
if let Some(f) = self.compiled_traces.get(&id) {
let mut e = e.clone();
let trace_ret = f(&mut e, tmp_stack, ret_stack);
if trace_ret != std::ptr::null_mut() {
return Ok(Some((tmp_stack.pop().unwrap(), e, (*Crc::from_ptr_clone(trace_ret)).clone())));
} else {
bail!("some sort of error in compiled trace. Got to figure out how to report this better");
let mut offset = 0;
loop {
if offset == 0 {
if let Some(f) = self.compiled_traces.get(&id) {
println!("Calling JIT function {id}!");
let trace_ret = f(&mut e, tmp_stack, ret_stack);
let new_id = ID { id: (trace_ret >> TRACE_ID_OFFSET) as i64 };
offset = trace_ret & TRACE_OFFSET_MASK;
println!("\tresult of call is new_id {new_id} and offset {offset}");
// if we've returned a new trace to start, do so,
// otherwise fall through to trace interpretation immediately
if new_id != id {
id = new_id;
continue;
}
}
}
} else if let Some(mut trace) = self.traces.get(&id) {
println!("Starting trace playback");
let mut e = e.clone();
loop {
if let Some(trace) = self.traces.get(&id) {
let entry = self.cont_count.entry(id).or_insert(0);
*entry += 1;
if *entry > JIT_LEVEL && !self.compiled_traces.contains_key(&id) {
println!("Compiling trace for {id}!");
let (inner_id, outer_ptr) = self.jit.comple_trace(id, trace);
self.compiled_traces.insert(id, outer_ptr);
continue;
}
println!("Starting trace playback");
println!("Running trace {trace:?}, \n\ttmp_stack:{tmp_stack:?}");
for b in trace.iter() {
for b in trace.iter().skip(offset) {
match b {
Op::Guard { const_value, side_val, side_cont, side_id, tbk } => {
println!("Guard(op) {const_value}");
if const_value != tmp_stack.last().unwrap() {
if let Some(new_trace) = self.traces.get(side_id) {
if self.traces.contains_key(side_id) {
if side_val.is_some() {
tmp_stack.pop().unwrap();
}
println!("\tchaining trace to side trace");
trace = new_trace;
id = *side_id;
offset = 0;
break; // break out of this trace and let infinate loop spin
} else {
println!("\tending playback b/c failed guard");
@@ -1156,10 +1243,11 @@ impl Ctx {
Op::Call { len, nc, nc_id, statik } => {
println!("Call(op)");
if let Some(static_call_id) = statik {
if let Some(new_trace) = self.traces.get(static_call_id) {
if self.traces.contains_key(static_call_id) {
ret_stack.push((e.clone(), (*nc).clone(), Some(*nc_id)));
println!("\tchaining to call trace b/c Call with statik");
trace = new_trace;
id = *static_call_id;
offset = 0;
break; // break out of this trace and let infinate loop spin
}
}
@@ -1170,10 +1258,11 @@ impl Ctx {
let b = tmp_stack.pop().unwrap();
let a = if *len == 2 { None } else { assert!(*len == 3); Some(tmp_stack.pop().unwrap()) };
let result = eval_prim(p, b, a)?;
if let Some(new_trace) = self.traces.get(nc_id) {
if self.traces.contains_key(nc_id) {
*tmp_stack.last_mut().unwrap() = result; // for the prim itself
println!("\tchaining to ret trace b/c Call with dyamic but primitive and next traced");
trace = new_trace;
id = *nc_id;
offset = 0;
break; // break out of this trace and let infinate loop spin
} else {
println!("\tstopping playback to ret b/c Call with dyamic but primitive and next not-traced");
@@ -1187,10 +1276,11 @@ impl Ctx {
bail!("arguments length doesn't match");
}
ret_stack.push((e.clone(), (*nc).clone(), Some(*nc_id)));
if let Some(new_trace) = self.traces.get(call_id) {
if self.traces.contains_key(call_id) {
println!("\tchaining to call trace b/c Call with dyamic but traced");
e = ie.clone();
trace = new_trace;
id = *call_id;
offset = 0;
break; // break out of this trace and let infinate loop spin
} else {
return Ok(Some((b.clone(), ie.clone(), Cont::Frame { syms: ps.clone(), id: *call_id, c: Crc::new(Cont::Eval { c: Crc::new(Cont::Ret { id: *call_id }) }) })));
@@ -1212,9 +1302,9 @@ impl Ctx {
println!("Return(op)");
let (e, nc, resume_data) = ret_stack.pop().unwrap();
if let Some(resume_id) = resume_data {
if let Some(new_trace) = self.traces.get(&resume_id) {
println!("\tchaining to return trace b/c Return {resume_id} - {new_trace:?}");
trace = new_trace;
if self.traces.contains_key(&resume_id) {
id = resume_id;
offset = 0;
break; // break out of this trace and let infinate loop spin
}
}
@@ -1224,9 +1314,9 @@ impl Ctx {
}
}
}
} else {
return Ok(None);
}
} else {
Ok(None)
}
}
}
@@ -1318,7 +1408,7 @@ pub fn eval(f: Form) -> Result<Form> {
e = ne;
if let Some(nc_id) = resume_data {
tmp_stack.push(f); // ugly dance pt 1
if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(nc_id, &e, &mut tmp_stack, &mut ret_stack)? {
if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(nc_id, e.clone(), &mut tmp_stack, &mut ret_stack)? {
f = fp;
e = ep;
c = cp;
@@ -1350,7 +1440,7 @@ pub fn eval(f: Form) -> Result<Form> {
bail!("arguments length doesn't match");
}
ret_stack.push((e.clone(), nc, resume_data));
if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(*id, ie, &mut tmp_stack, &mut ret_stack)? {
if let Some((fp, ep, cp)) = ctx.execute_trace_if_exists(*id, ie.clone(), &mut tmp_stack, &mut ret_stack)? {
f = fp;
e = ep;
c = cp;

View File

@@ -26,7 +26,7 @@ fn main() -> Result<()> {
//let res = ptr_c(Form::new_int(1337));
//println!("sucessful 3 run with result {res}");
return Ok(());
//return Ok(());
fn alias(a: Crc<u64>, b: Crc<u64>) {
println!("a: {}, b: {}", *a, *b);