src/Tools/Compute_Oracle/am_sml.ML
author obua
Thu, 20 Sep 2007 12:10:23 +0200
changeset 24654 329f1b4d9d16
parent 24584 01e83ffa6c54
child 25217 3224db6415ae
permissions -rw-r--r--
improved computing

(*  Title:      Tools/Compute_Oracle/am_sml.ML
    ID:         $Id$
    Author:     Steven Obua

    ToDO: "parameterless rewrite cannot be used in pattern": In a lot of cases it CAN be used, and these cases should be handled properly; 
          right now, all cases throw an exception.
 
*)

signature AM_SML = 
sig
  include ABSTRACT_MACHINE
  val save_result : (string * term) -> unit
  val set_compiled_rewriter : (term -> term) -> unit				       
  val list_nth : 'a list * int -> 'a
end

structure AM_SML : AM_SML = struct

open AbstractMachine;

type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term)

val saved_result = ref (NONE:(string*term)option)

fun save_result r = (saved_result := SOME r)
fun clear_result () = (saved_result := NONE)

val list_nth = List.nth

(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*)

val compiled_rewriter = ref (NONE:(term -> term)Option.option)

fun set_compiled_rewriter r = (compiled_rewriter := SOME r)

fun importable (Var _) = false
  | importable (Const _) = true			   
  | importable (App (a, b)) = importable a andalso importable b
  | importable (Abs _) = false

(*Returns true iff at most 0 .. (free-1) occur unbound. therefore
  check_freevars 0 t iff t is closed*)
fun check_freevars free (Var x) = x < free
  | check_freevars free (Const c) = true
  | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v
  | check_freevars free (Abs m) = check_freevars (free+1) m

fun count_patternvars PVar = 1
  | count_patternvars (PConst (_, ps)) =
      List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps

fun update_arity arity code a = 
    (case Inttab.lookup arity code of
	 NONE => Inttab.update_new (code, a) arity
       | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)

(* We have to find out the maximal arity of each constant *)
fun collect_pattern_arity PVar arity = arity
  | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))

(* We also need to find out the maximal toplevel arity of each function constant *)
fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity"
  | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args)

local
fun collect applevel (Var _) arity = arity
  | collect applevel (Const c) arity = update_arity arity c applevel
  | collect applevel (Abs m) arity = collect 0 m arity
  | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
in
fun collect_term_arity t arity = collect 0 t arity
end

fun collect_guard_arity (Guard (a,b)) arity  = collect_term_arity b (collect_term_arity a arity)


fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x)

fun beta (Const c) = Const c
  | beta (Var i) = Var i
  | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b)))
  | beta (App (a, b)) = 
    (case beta a of
	 Abs m => beta (App (Abs m, b))
       | a => App (a, beta b))
  | beta (Abs m) = Abs (beta m)
and subst x (Const c) t = Const c
  | subst x (Var i) t = if i = x then t else Var i
  | subst x (App (a,b)) t = App (subst x a t, subst x b t)
  | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t))
and lift level (Const c) = Const c
  | lift level (App (a,b)) = App (lift level a, lift level b)
  | lift level (Var i) = if i < level then Var i else Var (i+1)
  | lift level (Abs m) = Abs (lift (level + 1) m)
and unlift level (Const c) = Const c
  | unlift level (App (a, b)) = App (unlift level a, unlift level b)
  | unlift level (Abs m) = Abs (unlift (level+1) m)
  | unlift level (Var i) = if i < level then Var i else Var (i-1)

fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
  | nlift level n (Const c) = Const c
  | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
  | nlift level n (Abs b) = Abs (nlift (level+1) n b)

fun subst_const (c, t) (Const c') = if c = c' then t else Const c'
  | subst_const _ (Var i) = Var i
  | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b)
  | subst_const ct (Abs m) = Abs (subst_const ct m)

(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *)
fun inline_rules rules =
    let
	fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b
	  | term_contains_const c (Abs m) = term_contains_const c m
	  | term_contains_const c (Var i) = false
	  | term_contains_const c (Const c') = (c = c')
	fun find_rewrite [] = NONE
	  | find_rewrite ((prems, PConst (c, []), r) :: _) = 
	    if check_freevars 0 r then 
		if term_contains_const c r then 
		    raise Compile "parameterless rewrite is caught in cycle"
		else if not (null prems) then
		    raise Compile "parameterless rewrite may not be guarded"
		else
		    SOME (c, r) 
	    else raise Compile "unbound variable on right hand side or guards of rule"
	  | find_rewrite (_ :: rules) = find_rewrite rules
	fun remove_rewrite (c,r) [] = []
	  | remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = 
	    (if c = c' then 
		 if null args andalso r = r' andalso null (prems') then 
		     remove_rewrite cr rules 
		 else raise Compile "incompatible parameterless rewrites found"
	     else
		 rule :: (remove_rewrite cr rules))
	  | remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs)
	fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args)
	  | pattern_contains_const c (PVar) = false
	fun inline_rewrite (ct as (c, _)) (prems, p, r) = 
	    if pattern_contains_const c p then 
		raise Compile "parameterless rewrite cannot be used in pattern"
	    else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r)
	fun inline inlined rules =
	    (case find_rewrite rules of 
		 NONE => (Inttab.make inlined, rules)
	       | SOME ct => 
		 let
		     val rules = map (inline_rewrite ct) (remove_rewrite ct rules)
		     val inlined =  ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined)
		 in
		     inline inlined rules
		 end)		
    in
	inline [] rules		
    end


(*
   Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity.
   Also beta reduce the adjusted right hand side of a rule.   
*)
fun adjust_rules rules = 
    let
	val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty
	val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
	fun arity_of c = the (Inttab.lookup arity c)
	fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c)
	fun adjust_pattern PVar = PVar
	  | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
	fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
	  | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
	  | adjust_rule (rule as (prems, p as PConst (c, args),t)) = 
	    let
		val patternvars_counted = count_patternvars p
		fun check_fv t = check_freevars patternvars_counted t
		val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () 
		val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () 
		val args = map adjust_pattern args	        
		val len = length args
		val arity = arity_of c
		val lift = nlift 0
		fun adjust_tm n t = if n=0 then t else adjust_tm (n-1) (App (t, Var (n-1)))
		fun adjust_term n t = adjust_tm n (lift n t)
		fun adjust_guard n (Guard (a,b)) = Guard (adjust_term n a, adjust_term n b)
	    in
		if len = arity then
		    rule
		else if arity >= len then  
		    (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
		else (raise Compile "internal error in adjust_rule")
	    end
	fun beta_guard (Guard (a,b)) = Guard (beta a, beta b)
	fun beta_rule (prems, p, t) = ((map beta_guard prems, p, beta t) handle Match => raise Compile "beta_rule")
    in
	(arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
    end		    

fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count =
let
    fun str x = string_of_int x
    fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
    val module_prefix = (case module of NONE => "" | SOME s => s^".")											  
    fun print_apps d f [] = f
      | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
    and print_call d (App (a, b)) args = print_call d a (b::args) 
      | print_call d (Const c) args = 
	(case arity_of c of 
	     NONE => print_apps d (module_prefix^"Const "^(str c)) args 
	   | SOME 0 => module_prefix^"C"^(str c)
	   | SOME a =>
	     let
		 val len = length args
	     in
		 if a <= len then 
		     let
			 val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a)
			 val _ = if strict_a > a then raise Compile "strict" else ()
			 val s = module_prefix^"c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a))))
			 val s = s^(concat (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a))))
		     in
			 print_apps d s (List.drop (args, a))
		     end
		 else 
		     let
			 fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1)))
			 fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
			 fun append_args [] t = t
			   | append_args (c::cs) t = append_args cs (App (t, c))
		     in
			 print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
		     end
	     end)
      | print_call d t args = print_apps d (print_term d t) args
    and print_term d (Var x) = 
	if x < d then 
	    "b"^(str (d-x-1)) 
	else 
	    let
		val n = pattern_var_count - (x-d) - 1
		val x = "x"^(str n)
	    in
		if n < pattern_var_count - pattern_lazy_var_count then 
		    x
		else 
		    "("^x^" ())"
	    end								
      | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")"
      | print_term d t = print_call d t []
in
    print_term 0 
end

fun section n = if n = 0 then [] else (section (n-1))@[n-1]
			 			
fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = 
    let	
	fun str x = Int.toString x		    
	fun print_pattern top n PVar = (n+1, "x"^(str n))
	  | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else ""))
	  | print_pattern top n (PConst (c, args)) = 
	    let
		val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")
		val (n, s) = print_pattern_list 0 top (n, f) args
	    in
		(n, s)
	    end
	and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")")
	  | print_pattern_list' counter top (n, p) (t::ts) = 
	    let
		val (n, t) = print_pattern false n t
	    in
		print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts
	    end	
	and print_pattern_list counter top (n, p) (t::ts) = 
	    let
		val (n, t) = print_pattern false n t
	    in
		print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts
	    end
	val c = (case p of PConst (c, _) => c | _ => raise Match)
	val (n, pattern) = print_pattern true 0 p
	val lazy_vars = the (arity_of c) - the (toplevel_arity_of c)
	fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm
        fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")"
	val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(concat (map (fn i => " a"^(str i)) (section (the (arity_of c)))))
	fun print_guards t [] = print_tm t
	  | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(concat (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch
    in
	(if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards))
    end

fun group_rules rules =
    let
	fun add_rule (r as (_, PConst (c,_), _)) groups =
	    let
		val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
	    in
		Inttab.update (c, r::rs) groups
	    end
	  | add_rule _ _ = raise Compile "internal error group_rules"
    in
	fold_rev add_rule rules Inttab.empty
    end

fun sml_prog name code rules = 
    let
	val buffer = ref ""
	fun write s = (buffer := (!buffer)^s)
	fun writeln s = (write s; write "\n")
	fun writelist [] = ()
	  | writelist (s::ss) = (writeln s; writelist ss)
	fun str i = Int.toString i
	val (inlinetab, rules) = inline_rules rules
	val (arity, toplevel_arity, rules) = adjust_rules rules
	val rules = group_rules rules
	val constants = Inttab.keys arity
	fun arity_of c = Inttab.lookup arity c
	fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
	fun rep_str s n = concat (rep n s)
	fun indexed s n = s^(str n)
        fun string_of_tuple [] = ""
	  | string_of_tuple (x::xs) = "("^x^(concat (map (fn s => ", "^s) xs))^")"
        fun string_of_args [] = ""
	  | string_of_args (x::xs) = x^(concat (map (fn s => " "^s) xs))
	fun default_case gnum c = 
	    let
		val leftargs = concat (map (indexed " x") (section (the (arity_of c))))
		val rightargs = section (the (arity_of c))
		val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa)
		val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs
		val right = (indexed "C" c)^" "^(string_of_tuple xs)
		val message = "(\"unresolved lazy call: "^(string_of_int c)^", \"^(makestring x"^(string_of_int (strict_args - 1))^"))"		
		val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right		
	    in
		(indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right
	    end

	fun eval_rules c = 
	    let
		val arity = the (arity_of c)
		val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa)
		fun eval_rule n = 
		    let
			val sc = string_of_int c
			val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc)
                        fun arg i = 
			    let
				val x = indexed "x" i
				val x = if i < n then "(eval bounds "^x^")" else x
				val x = if i < strict_arity then x else "(fn () => "^x^")"
			    in
				x
			    end
			val right = "c"^sc^" "^(string_of_args (map arg (section arity)))
			val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right		
			val right = if arity > 0 then right else "C"^sc
		    in
			"  | eval bounds ("^left^") = "^right
		    end
	    in
		map eval_rule (rev (section (arity + 1)))
	    end
        
	fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else ""
	val _ = writelist [                   
		"structure "^name^" = struct",
		"",
		"datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)",
		"         "^(concat (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)),
		""]
	fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")"
	fun make_term_eq c = "  | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^
                             (case the (arity_of c) of 
				  0 => "true"
				| n => 
				  let 
				      val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n)
				      val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs))
				  in
				      eq^(concat eqs)
				  end)
	val _ = writelist [
		"fun term_eq (Const c1) (Const c2) = (c1 = c2)",
		"  | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"]
	val _ = writelist (map make_term_eq constants)		
	val _ = writelist [
		"  | term_eq _ _ = false",
                "" 
		] 
	val _ = writelist [
		"fun app (Abs a) b = a b",
		"  | app a b = App (a, b)",
		""]	
	fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else [])
	fun writefundecl [] = () 
	  | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => "  | "^s) xs)))
	fun list_group c = (case Inttab.lookup rules c of 
				NONE => [defcase 0 c]
			      | SOME rs => 
				let
				    val rs = 
					fold
					    (fn r => 
					     fn rs =>
						let 
						    val (gnum, l, rs) = 
							(case rs of 
							     [] => (0, [], []) 
							   | (gnum, l)::rs => (gnum, l, rs))
						    val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r 
						in 
						    if gnum' = gnum then 
							(gnum, r::l)::rs
						    else
							let
							    val args = concat (map (fn i => " a"^(str i)) (section (the (arity_of c))))
							    fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args
							    val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') 
							in
							    (gnum', [])::(gnum, s::r::l)::rs
							end
						end)
					rs []
				    val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs)
				in
				    rev (map (fn z => rev (snd z)) rs)
				end)
	val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants)
	val _ = writelist [
		"fun convert (Const i) = AM_SML.Const i",
		"  | convert (App (a, b)) = AM_SML.App (convert a, convert b)",
                "  | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""]	
	fun make_convert c = 
	    let
		val args = map (indexed "a") (section (the (arity_of c)))
		val leftargs = 
		    case args of
			[] => ""
		      | (x::xs) => "("^x^(concat (map (fn s => ", "^s) xs))^")"
		val args = map (indexed "convert a") (section (the (arity_of c)))
		val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c))
	    in
		"  | convert (C"^(str c)^" "^leftargs^") = "^right
	    end 		
	val _ = writelist (map make_convert constants)
	val _ = writelist [
		"",
		"fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)",
		"  | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"]
	val _ = map (writelist o eval_rules) constants
	val _ = writelist [
                "  | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)",
                "  | eval bounds (AbstractMachine.Const c) = Const c"]                
	val _ = writelist [		
		"",
		"fun export term = AM_SML.save_result (\""^code^"\", convert term)",
		"",
                "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))",
                "",
		"end"]
    in
	(arity, toplevel_arity, inlinetab, !buffer)
    end

val guid_counter = ref 0
fun get_guid () = 
    let
	val c = !guid_counter
	val _ = guid_counter := !guid_counter + 1
    in
	(LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
    end


fun writeTextFile name s = File.write (Path.explode name) s

fun use_source src = use_text "" Output.ml_output false src
    
fun compile eqs = 
    let
	val guid = get_guid ()
	val code = Real.toString (random ())
	val module = "AMSML_"^guid
	val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs
(*	val _ = writeTextFile "Gencode.ML" source*)
	val _ = compiled_rewriter := NONE
	val _ = use_source source
    in
	case !compiled_rewriter of 
	    NONE => raise Compile "broken link to compiled function"
	  | SOME f => (module, code, arity, toplevel_arity, inlinetab, f)
    end


fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = 
    let	
	val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
	fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
	  | inline (Var i) = Var i
	  | inline (App (a, b)) = App (inline a, inline b)
	  | inline (Abs m) = Abs (inline m)
	val t = beta (inline t)
	fun arity_of c = Inttab.lookup arity c		 	 
	fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
	val term = print_term NONE arity_of toplevel_arity_of 0 0 t 
        val source = "local open "^module^" in val _ = export ("^term^") end"
	val _ = writeTextFile "Gencode_call.ML" source
	val _ = clear_result ()
	val _ = use_source source
    in
	case !saved_result of 
	    NONE => raise Run "broken link to compiled code"
	  | SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked")
    end		

fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = 
    let	
	val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
	fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
	  | inline (Var i) = Var i
	  | inline (App (a, b)) = App (inline a, inline b)
	  | inline (Abs m) = Abs (inline m)
    in
	compiled_fun (beta (inline t))
    end	

fun discard p = ()
			 	  
end