src/Tools/Compute_Oracle/am_compiler.ML
changeset 23174 3913451b0418
child 23663 84b5c89b8b49
equal deleted inserted replaced
23173:51179ca0c429 23174:3913451b0418
       
     1 (*  Title:      Tools/Compute_Oracle/am_compiler.ML
       
     2     ID:         $Id$
       
     3     Author:     Steven Obua
       
     4 *)
       
     5 
       
     6 signature COMPILING_AM = 
       
     7 sig
       
     8   include ABSTRACT_MACHINE
       
     9 
       
    10   datatype closure = CVar of int | CConst of int
       
    11     | CApp of closure * closure | CAbs of closure | Closure of (closure list) * closure
       
    12 
       
    13   val set_compiled_rewriter : (term -> closure) -> unit
       
    14   val list_nth : 'a list * int -> 'a
       
    15   val list_map : ('a -> 'b) -> 'a list -> 'b list
       
    16 end
       
    17 
       
    18 structure AM_Compiler : COMPILING_AM = struct
       
    19 
       
    20 val list_nth = List.nth;
       
    21 val list_map = map;
       
    22 
       
    23 datatype term = Var of int | Const of int | App of term * term | Abs of term
       
    24 
       
    25 datatype pattern = PVar | PConst of int * (pattern list)
       
    26 
       
    27 datatype closure = CVar of int | CConst of int
       
    28 	         | CApp of closure * closure | CAbs of closure
       
    29                  | Closure of (closure list) * closure
       
    30 
       
    31 val compiled_rewriter = ref (NONE:(term -> closure)Option.option)
       
    32 
       
    33 fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
       
    34 
       
    35 type program = (term -> term)
       
    36 
       
    37 datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack
       
    38 
       
    39 exception Compile of string;
       
    40 exception Run of string;
       
    41 
       
    42 fun clos_of_term (Var x) = CVar x
       
    43   | clos_of_term (Const c) = CConst c
       
    44   | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v)
       
    45   | clos_of_term (Abs u) = CAbs (clos_of_term u)
       
    46 
       
    47 fun term_of_clos (CVar x) = Var x
       
    48   | term_of_clos (CConst c) = Const c
       
    49   | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v)
       
    50   | term_of_clos (CAbs u) = Abs (term_of_clos u)
       
    51   | term_of_clos (Closure (e, u)) =
       
    52       raise (Run "internal error: closure in normalized term found")
       
    53 
       
    54 fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a
       
    55   | strip_closure args x = (x, args)
       
    56 
       
    57 (*Returns true iff at most 0 .. (free-1) occur unbound. therefore
       
    58   check_freevars 0 t iff t is closed*)
       
    59 fun check_freevars free (Var x) = x < free
       
    60   | check_freevars free (Const c) = true
       
    61   | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v
       
    62   | check_freevars free (Abs m) = check_freevars (free+1) m
       
    63 
       
    64 fun count_patternvars PVar = 1
       
    65   | count_patternvars (PConst (_, ps)) =
       
    66       List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
       
    67 
       
    68 fun print_rule (p, t) = 
       
    69     let	
       
    70 	fun str x = Int.toString x		    
       
    71 	fun print_pattern n PVar = (n+1, "x"^(str n))
       
    72 	  | print_pattern n (PConst (c, [])) = (n, "c"^(str c))
       
    73 	  | print_pattern n (PConst (c, args)) = 
       
    74 	    let
       
    75 		val h = print_pattern n (PConst (c,[]))
       
    76 	    in
       
    77 		print_pattern_list h args
       
    78 	    end
       
    79 	and print_pattern_list r [] = r
       
    80 	  | print_pattern_list (n, p) (t::ts) = 
       
    81 	    let
       
    82 		val (n, t) = print_pattern n t
       
    83 	    in
       
    84 		print_pattern_list (n, "App ("^p^", "^t^")") ts
       
    85 	    end
       
    86 
       
    87 	val (n, pattern) = print_pattern 0 p
       
    88 	val pattern =
       
    89             if exists_string Symbol.is_ascii_blank pattern then "(" ^ pattern ^")"
       
    90             else pattern
       
    91 	
       
    92 	fun print_term d (Var x) = (*if x < d then "Var "^(str x) else "x"^(str (n-(x-d)-1))*)
       
    93               "Var " ^ str x
       
    94 	  | print_term d (Const c) = "c" ^ str c
       
    95 	  | print_term d (App (a,b)) = "App (" ^ print_term d a ^ ", " ^ print_term d b ^ ")"
       
    96 	  | print_term d (Abs c) = "Abs (" ^ print_term (d + 1) c ^ ")"
       
    97 
       
    98 	fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1))
       
    99 
       
   100 	val term = print_term 0 t
       
   101 	val term =
       
   102             if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")"
       
   103             else "Closure ([], "^term^")"
       
   104 			   
       
   105     in
       
   106 	"lookup stack "^pattern^" = weak stack ("^term^")"
       
   107     end
       
   108 
       
   109 fun constants_of PVar = []
       
   110   | constants_of (PConst (c, ps)) = c :: maps constants_of ps
       
   111 
       
   112 fun constants_of_term (Var _) = []
       
   113   | constants_of_term (Abs m) = constants_of_term m
       
   114   | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b)
       
   115   | constants_of_term (Const c) = [c]
       
   116     
       
   117 fun load_rules sname name prog = 
       
   118     let
       
   119         (* FIXME consider using more readable/efficient Buffer.empty |> fold Buffer.add etc. *)
       
   120 	val buffer = ref ""
       
   121 	fun write s = (buffer := (!buffer)^s)
       
   122 	fun writeln s = (write s; write "\n")
       
   123 	fun writelist [] = ()
       
   124 	  | writelist (s::ss) = (writeln s; writelist ss)
       
   125 	fun str i = Int.toString i
       
   126 	val _ = writelist [
       
   127 		"structure "^name^" = struct",
       
   128 		"",
       
   129 		"datatype term = App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"]
       
   130 	val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog)
       
   131 	val _ = map (fn x => write (" | c"^(str x))) constants
       
   132 	val _ = writelist [
       
   133 		"",
       
   134 		"datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack",
       
   135 		""]
       
   136 	val _ = (case prog of
       
   137 		    r::rs => (writeln ("fun "^(print_rule r)); 
       
   138 			      map (fn r => writeln("  | "^(print_rule r))) rs; 
       
   139 			      writeln ("  | lookup stack clos = weak_last stack clos"); ())								
       
   140 		  | [] => (writeln "fun lookup stack clos = weak_last stack clos"))
       
   141 	val _ = writelist [
       
   142 		"and weak stack (Closure (e, App (a, b))) = weak (SAppL (Closure (e, b), stack)) (Closure (e, a))",
       
   143 		"  | weak (SAppL (b, stack)) (Closure (e, Abs m)) =  weak stack (Closure (b::e, m))",
       
   144 		"  | weak stack (clos as Closure (_, Abs _)) = weak_last stack clos",
       
   145 		"  | weak stack (Closure (e, Var n)) = weak stack ("^sname^".list_nth (e, n) handle _ => (Var (n-(length e))))",
       
   146 		"  | weak stack (Closure (e, c)) = weak stack c",
       
   147 		"  | weak stack clos = lookup stack clos",
       
   148 		"and weak_last (SAppR (a, stack)) b = weak stack (App(a, b))",
       
   149 		"  | weak_last (SAppL (b, stack)) a = weak (SAppR (a, stack)) b",
       
   150 		"  | weak_last stack c = (stack, c)",
       
   151 		"",
       
   152 		"fun lift n (v as Var m) = if m < n then v else Var (m+1)",
       
   153 		"  | lift n (Abs t) = Abs (lift (n+1) t)",
       
   154 		"  | lift n (App (a,b)) = App (lift n a, lift n b)",
       
   155 		"  | lift n (Closure (e, a)) = Closure (lift_env n e, lift (n+(length e)) a)",
       
   156 		"  | lift n c = c",
       
   157 		"and lift_env n e = map (lift n) e",
       
   158 		"",
       
   159 		"fun strong stack (Closure (e, Abs m)) = ",
       
   160 		"    let",
       
   161 		"      val (stack', wnf) = weak SEmpty (Closure ((Var 0)::(lift_env 0 e), m))",
       
   162 		"    in",
       
   163 		"      case stack' of",
       
   164 		"           SEmpty => strong (SAbs stack) wnf",
       
   165 		"         | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")",
       
   166 		"    end",
       
   167 		"  | strong stack (clos as (App (u, v))) = strong (SAppL (v, stack)) u",
       
   168 		"  | strong stack clos = strong_last stack clos",
       
   169 		"and strong_last (SAbs stack) m = strong stack (Abs m)",
       
   170 		"  | strong_last (SAppL (b, stack)) a = strong (SAppR (a, stack)) b",
       
   171 		"  | strong_last (SAppR (a, stack)) b = strong_last stack (App (a, b))",
       
   172 		"  | strong_last stack clos = (stack, clos)",
       
   173 		""]
       
   174 	
       
   175 	val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)"						  	
       
   176 	val _ = writelist [
       
   177 		"fun importTerm ("^sname^".Var x) = Var x",
       
   178 		"  | importTerm ("^sname^".Const c) =  "^ic,
       
   179 		"  | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)",
       
   180 		"  | importTerm ("^sname^".Abs m) = Abs (importTerm m)",
       
   181 		""]
       
   182 
       
   183 	fun ec c = "  | exportTerm c"^(str c)^" = "^sname^".CConst "^(str c)
       
   184 	val _ = writelist [
       
   185 		"fun exportTerm (Var x) = "^sname^".CVar x",
       
   186 		"  | exportTerm (Const c) = "^sname^".CConst c",
       
   187 		"  | exportTerm (App (a,b)) = "^sname^".CApp (exportTerm a, exportTerm b)",
       
   188 		"  | exportTerm (Abs m) = "^sname^".CAbs (exportTerm m)",
       
   189 		"  | exportTerm (Closure (closlist, clos)) = "^sname^".Closure ("^sname^".list_map exportTerm closlist, exportTerm clos)"]
       
   190 	val _ = writelist (map ec constants)
       
   191 		
       
   192 	val _ = writelist [
       
   193 		"",
       
   194 		"fun rewrite t = ",
       
   195 		"    let",
       
   196 		"      val (stack, wnf) = weak SEmpty (Closure ([], importTerm t))",
       
   197 		"    in",
       
   198 		"      case stack of ",
       
   199 		"           SEmpty => (case strong SEmpty wnf of",
       
   200 		"                          (SEmpty, snf) => exportTerm snf",
       
   201 		"                        | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))",
       
   202 		"         | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))",
       
   203 		"    end",
       
   204 		"",
       
   205 		"val _ = "^sname^".set_compiled_rewriter rewrite",
       
   206 		"",
       
   207 		"end;"]
       
   208 
       
   209 	val _ = 
       
   210 	    let
       
   211 		(*val fout = TextIO.openOut "gen_code.ML"
       
   212 		val _ = TextIO.output (fout, !buffer)
       
   213 		val _  = TextIO.closeOut fout*)
       
   214 	    in
       
   215 		()
       
   216 	    end
       
   217     in
       
   218 	compiled_rewriter := NONE;	
       
   219 	use_text "" Output.ml_output false (!buffer);
       
   220 	case !compiled_rewriter of 
       
   221 	    NONE => raise (Compile "cannot communicate with compiled function")
       
   222 	  | SOME r => (compiled_rewriter := NONE; fn t => term_of_clos (r t))
       
   223     end	
       
   224 
       
   225 fun compile eqs = 
       
   226     let
       
   227 	val _ = map (fn (p, r) => 
       
   228                   (check_freevars (count_patternvars p) r; 
       
   229                    case p of PVar => raise (Compile "pattern reduces to a variable") | _ => ())) eqs
       
   230     in
       
   231 	load_rules "AM_Compiler" "AM_compiled_code" eqs
       
   232     end	
       
   233 
       
   234 fun run prog t = (prog t)
       
   235 			 	  
       
   236 end
       
   237 
       
   238 structure AbstractMachine = AM_Compiler