diff --git a/partial_eval.scm b/partial_eval.scm index 841d936..1fbfb83 100644 --- a/partial_eval.scm +++ b/partial_eval.scm @@ -141,7 +141,11 @@ ;;;;;;;;;;;;;;;;;; (empty_dict-list (array)) - (put-list (lambda (m k v) (cons (array k v) m))) + ;(put-list (lambda (m k v) (cons (array k v) m))) + (put-list (lambda (m k v) ((rec-lambda recurse (m k v len_m i) (cond ((= len_m i) (cons (array k v) m)) + ((= k (idx (idx m i) 0)) (concat (slice m 0 i) (array (array k v)) (slice m (+ i 1) len_m))) + (true (recurse m k v len_m (+ 1 i))))) + m k v (len m) 0))) (get-list (lambda (d k) ((rec-lambda recurse (k d len_d i) (cond ((= len_d i) false) ((= k (idx (idx d i) 0)) (idx d i)) (true (recurse k d len_d (+ 1 i))))) @@ -4594,9 +4598,13 @@ o)))) (is_prim_function_call (lambda (c s) (and (marked_array? c) (not (.marked_array_is_val c)) (<= 2 (len (.marked_array_values c))) (prim_comb? (idx (.marked_array_values c) 0)) (= s (.prim_comb_sym (idx (.marked_array_values c) 0)))))) - (is_markable (lambda (x) (and (marked_symbol? x) (not (.marked_symbol_is_val x))))) + (is_markable (lambda (x) (or (and (marked_symbol? x) (not (.marked_symbol_is_val x))) + (and (marked_array? x) (not (.marked_array_is_val x))) + ))) (is_markable_idx (lambda (c i) (and (marked_array? c) (< i (len (.marked_array_values c))) (is_markable (idx (.marked_array_values c) i))))) - (mark (lambda (x) (and (marked_symbol? x) (not (.marked_symbol_is_val x)) (.marked_symbol_value x)))) + (mark (lambda (x) (or (and (marked_symbol? x) (not (.marked_symbol_is_val x)) (.marked_symbol_value x)) + (and (marked_array? x) (not (.marked_array_is_val x)) (.hash x)) + ))) (mark_idx (lambda (c i) (and (marked_array? c) (< i (len (.marked_array_values c))) (mark (idx (.marked_array_values c) i))))) (combine-list (lambda (mf a b) (dlet ( (_ (true_print "going to combine " a " and " b)) @@ -4690,8 +4698,10 @@ ( _ (true_print " and params are " params)) ( (sub_implies sub_guarentees psub_data) (foldl (dlambda ((sub_implies sub_guarentees running_sub_data) i) (dlet ((psym (idx params (- i 1))) ((ttyp timpl assertions sub_sub_data) (infer_types (idx (.marked_array_values c) i) env_id implies guarentees)) - ) (array (combine-list (lambda (a b) (combine-list combine-type a b)) (put-list empty_dict-list psym timpl) sub_implies) - (combine-list combine-type (put-list empty_dict-list psym ttyp) sub_guarentees) + ) (array ;(combine-list (lambda (a b) (combine-list combine-type a b)) (put-list empty_dict-list psym timpl) sub_implies) + ;(combine-list combine-type (put-list empty_dict-list psym ttyp) sub_guarentees) + (put-list sub_implies psym timpl) + (put-list sub_guarentees psym ttyp) (concat running_sub_data (array (array ttyp timpl assertions sub_sub_data)))))) (array implies guarentees ;(array func_sub) @@ -4699,6 +4709,9 @@ ) (range 1 (len (.marked_array_values c))))) ( _ (true_print "based on inline (let) case " params " we have sub_implies " sub_implies " and sub_guarentees " sub_guarentees) ) ((ttyp timpl assertion inl_subdata) (infer_types (.comb_body func) (.comb_id func) sub_implies sub_guarentees)) + ; remove the implication if it's about something that only exists inside the inlined function (a parameter) + ; TODO: does this have to check for env_symbol? + (timpl (mif (and timpl (in_array (idx timpl 0) params)) false timpl)) ( _ (true_print "final result of inline " params " is type " ttyp " and impl " timpl)) ;(_ (true_print "exiting let")) ) (array ttyp timpl empty_dict-list (concat (array (array ttyp timpl assertion inl_subdata)) psub_data)))) @@ -4737,7 +4750,7 @@ ;(_ (true_print " doing infer-types for random call ")) (sub_results (map (lambda (x) (infer_types x env_id implies guarentees)) (.marked_array_values c))) ;(_ (true_print " done infer-types for random call ")) - ) (array false false empty_dict-list sub_results))) + ) (array (get-list-or guarentees (.hash c) false) false empty_dict-list sub_results))) ; fallthrough (true (array false false empty_dict-list type_data_nil)) @@ -4833,8 +4846,6 @@ (parameter_subs (map (lambda (i) (cached_infer_types_idx c (.marked_env_idx env) type_data i)) (range 1 (len func_param_values)))) (parameter_types (map just_type parameter_subs)) - (_ (mif (and (prim_comb? func_value) (= (.prim_comb_sym func_value) 'idx)) (true_print "ok, param of idx types are (" (true_str_strip (idx params 0)) ") " (idx parameter_types 0) " (" (true_str_strip (idx params 1)) ") " (idx parameter_types 1)))) - ;(_ (true_print "parameter types " parameter_types)) ;(_ (true_print "parameter subs " parameter_subs)) @@ -4993,7 +5004,24 @@ ) (_ (true_print "made eq_code")) ) (array nil eq_code nil ctx)))) - (dlet ((_ (true_print "missed better = " parameter_types))) (gen_cmp_impl false_val true_val false_val)))) + (dlet ((_ (true_print "missed better = " parameter_types))) (gen_cmp_impl false_val true_val false_val)))) + ; inline idx if we have the type+len of array and idx is a constant + ((and (prim_comb? func_value) (= (.prim_comb_sym func_value) 'idx) (= 2 num_params) + (idx parameter_types 0) (= 'arr (idx (idx parameter_types 0) 0)) (idx (idx parameter_types 0) 2) + (idx parameter_types 1) (= 'int (idx (idx parameter_types 1) 0))) + (val? (idx params 1)) (dlet ( + (_ (true_print "inlining idx IDX!!")) + ((param_codes err ctx _) (compile_params false ctx false)) + (array_len (idx (idx parameter_types 0) 2)) + (index (.val (idx params 1))) + (index (mif (< index 0) (+ index array_len) index)) + ((code err) (mif (and (>= index 0) (< index array_len)) + (array (concat (local.set '$prim_tmp_a (idx param_codes 0)) + (generate_dup (i64.load (* 8 index) (extract_ptr_code (local.get '$prim_tmp_a)))) + (generate_drop (local.get '$prim_tmp_a))) + nil) + (array nil (true_str "bad constant offset into typed array")))) + ) (array nil code err ctx)))