src/Pure/Tools/am_compiler.ML
changeset 16781 663235466562
child 16782 b214f21ae396
equal deleted inserted replaced
16780:aa284c1b72ad 16781:663235466562
       
     1 (*  Title:      Pure/Tools/am_compiler.ML
       
     2     ID:         $Id$
       
     3     Author:     Steven Obua
       
     4 *)
       
     5 
       
     6 signature COMPILING_AM = 
       
     7 sig
       
     8     include ABSTRACT_MACHINE
       
     9 
       
    10     datatype closure = CVar of int | CConst of int
       
    11 		     | CApp of closure * closure | CAbs of closure | Closure of (closure list) * closure
       
    12 
       
    13     val set_compiled_rewriter : (term -> closure) -> unit
       
    14 end
       
    15 
       
    16 structure AM_Compiler :> COMPILING_AM = struct
       
    17 
       
    18 datatype term = Var of int | Const of int | App of term * term | Abs of term
       
    19 
       
    20 datatype pattern = PVar | PConst of int * (pattern list)
       
    21 
       
    22 datatype closure = CVar of int | CConst of int
       
    23 		 | CApp of closure * closure | CAbs of closure | Closure of (closure list) * closure
       
    24 
       
    25 val compiled_rewriter = ref (NONE:(term -> closure)Option.option)
       
    26 
       
    27 fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
       
    28 
       
    29 type program = (term -> term)
       
    30 
       
    31 datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack
       
    32 
       
    33 exception Compile of string;
       
    34 exception Run of string;
       
    35 
       
    36 fun clos_of_term (Var x) = CVar x
       
    37   | clos_of_term (Const c) = CConst c
       
    38   | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v)
       
    39   | clos_of_term (Abs u) = CAbs (clos_of_term u)
       
    40 
       
    41 fun term_of_clos (CVar x) = Var x
       
    42   | term_of_clos (CConst c) = Const c
       
    43   | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v)
       
    44   | term_of_clos (CAbs u) = Abs (term_of_clos u)
       
    45   | term_of_clos (Closure (e, u)) = raise (Run "internal error: closure in normalized term found")
       
    46 
       
    47 fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a
       
    48   | strip_closure args x = (x, args)
       
    49 
       
    50 (* Returns true iff at most 0 .. (free-1) occur unbound. therefore check_freevars 0 t iff t is closed *)
       
    51 fun check_freevars free (Var x) = x < free
       
    52   | check_freevars free (Const c) = true
       
    53   | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v
       
    54   | check_freevars free (Abs m) = check_freevars (free+1) m
       
    55 
       
    56 fun count_patternvars PVar = 1
       
    57   | count_patternvars (PConst (_, ps)) = List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
       
    58 
       
    59 fun print_rule (p, t) = 
       
    60     let	
       
    61 	fun str x = Int.toString x		    
       
    62 	fun print_pattern n PVar = (n+1, "x"^(str n))
       
    63 	  | print_pattern n (PConst (c, [])) = (n, "c"^(str c))
       
    64 	  | print_pattern n (PConst (c, args)) = 
       
    65 	    let
       
    66 		val h = print_pattern n (PConst (c,[]))
       
    67 	    in
       
    68 		print_pattern_list h args
       
    69 	    end
       
    70 	and print_pattern_list r [] = r
       
    71 	  | print_pattern_list (n, p) (t::ts) = 
       
    72 	    let
       
    73 		val (n, t) = print_pattern n t
       
    74 	    in
       
    75 		print_pattern_list (n, "App ("^p^", "^t^")") ts
       
    76 	    end
       
    77 
       
    78 	val (n, pattern) = print_pattern 0 p
       
    79 	val pattern = if List.exists Char.isSpace (String.explode pattern) then "("^pattern^")" else pattern
       
    80 	
       
    81 	fun print_term d (Var x) = (*if x < d then "Var "^(str x) else "x"^(str (n-(x-d)-1))*) "Var "^(str x)
       
    82 	  | print_term d (Const c) = "c"^(str c)
       
    83 	  | print_term d (App (a,b)) = "App ("^(print_term d a)^", "^(print_term d b)^")"
       
    84 	  | print_term d (Abs c) = "Abs ("^(print_term (d+1) c)^")"
       
    85 
       
    86 	fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1))
       
    87 
       
    88 	val term = print_term 0 t
       
    89 	val term = if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")" else "Closure ([], "^term^")"
       
    90 			   
       
    91     in
       
    92 	"lookup stack "^pattern^" = weak stack ("^term^")"
       
    93     end
       
    94 
       
    95 fun remove_dups (c::cs) = 
       
    96     let val cs = remove_dups cs in
       
    97 	if (List.exists (fn x => x=c) cs) then cs else c::cs
       
    98     end
       
    99   | remove_dups [] = []
       
   100 
       
   101 fun constants_of PVar = []
       
   102   | constants_of (PConst (c, ps)) = c::(List.concat (map constants_of ps))
       
   103 
       
   104 fun constants_of_term (Var _) = []
       
   105   | constants_of_term (Abs m) = constants_of_term m
       
   106   | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b)
       
   107   | constants_of_term (Const c) = [c]
       
   108     
       
   109 fun load_rules sname name prog = 
       
   110     let
       
   111 	val buffer = ref ""
       
   112 	fun write s = (buffer := (!buffer)^s)
       
   113 	fun writeln s = (write s; write "\n")
       
   114 	fun writelist [] = ()
       
   115 	  | writelist (s::ss) = (writeln s; writelist ss)
       
   116 	fun str i = Int.toString i
       
   117 	val _ = writelist [
       
   118 		"structure "^name^" = struct",
       
   119 		"",
       
   120 		"datatype term = App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"]
       
   121 	val constants = remove_dups (List.concat (map (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog))
       
   122 	val _ = map (fn x => write (" | c"^(str x))) constants
       
   123 	val _ = writelist [
       
   124 		"",
       
   125 		"datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack",
       
   126 		""]
       
   127 	val _ = (case prog of
       
   128 		    r::rs => (writeln ("fun "^(print_rule r)); 
       
   129 			      map (fn r => writeln("  | "^(print_rule r))) rs; 
       
   130 			      writeln ("  | lookup stack clos = weak_last stack clos"); ())								
       
   131 		  | [] => (writeln "fun lookup stack clos = weak_last stack clos"))
       
   132 	val _ = writelist [
       
   133 		"and weak stack (Closure (e, App (a, b))) = weak (SAppL (Closure (e, b), stack)) (Closure (e, a))",
       
   134 		"  | weak (SAppL (b, stack)) (Closure (e, Abs m)) =  weak stack (Closure (b::e, m))",
       
   135 		"  | weak stack (clos as Closure (_, Abs _)) = weak_last stack clos",
       
   136 		"  | weak stack (Closure (e, Var n)) = weak stack (List.nth (e, n) handle Subscript => (Var (n-(length e))))",
       
   137 		"  | weak stack (Closure (e, c)) = weak stack c",
       
   138 		"  | weak stack clos = lookup stack clos",
       
   139 		"and weak_last (SAppR (a, stack)) b = weak stack (App(a, b))",
       
   140 		"  | weak_last (SAppL (b, stack)) a = weak (SAppR (a, stack)) b",
       
   141 		"  | weak_last stack c = (stack, c)",
       
   142 		"",
       
   143 		"fun lift n (v as Var m) = if m < n then v else Var (m+1)",
       
   144 		"  | lift n (Abs t) = Abs (lift (n+1) t)",
       
   145 		"  | lift n (App (a,b)) = App (lift n a, lift n b)",
       
   146 		"  | lift n (Closure (e, a)) = Closure (lift_env n e, lift (n+(length e)) a)",
       
   147 		"  | lift n c = c",
       
   148 		"and lift_env n e = map (lift n) e",
       
   149 		"",
       
   150 		"fun strong stack (Closure (e, Abs m)) = ",
       
   151 		"    let",
       
   152 		"      val (stack', wnf) = weak SEmpty (Closure ((Var 0)::(lift_env 0 e), m))",
       
   153 		"    in",
       
   154 		"      case stack' of",
       
   155 		"           SEmpty => strong (SAbs stack) wnf",
       
   156 		"         | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")",
       
   157 		"    end",
       
   158 		"  | strong stack (clos as (App (u, v))) = strong (SAppL (v, stack)) u",
       
   159 		"  | strong stack clos = strong_last stack clos",
       
   160 		"and strong_last (SAbs stack) m = strong stack (Abs m)",
       
   161 		"  | strong_last (SAppL (b, stack)) a = strong (SAppR (a, stack)) b",
       
   162 		"  | strong_last (SAppR (a, stack)) b = strong_last stack (App (a, b))",
       
   163 		"  | strong_last stack clos = (stack, clos)",
       
   164 		""]
       
   165 	
       
   166 	val ic = "(case c of "^(String.concat (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)"						  	
       
   167 	val _ = writelist [
       
   168 		"fun importTerm ("^sname^".Var x) = Var x",
       
   169 		"  | importTerm ("^sname^".Const c) =  "^ic,
       
   170 		"  | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)",
       
   171 		"  | importTerm ("^sname^".Abs m) = Abs (importTerm m)",
       
   172 		""]
       
   173 
       
   174 	fun ec c = "  | exportTerm c"^(str c)^" = "^sname^".CConst "^(str c)
       
   175 	val _ = writelist [
       
   176 		"fun exportTerm (Var x) = "^sname^".CVar x",
       
   177 		"  | exportTerm (Const c) = "^sname^".CConst c",
       
   178 		"  | exportTerm (App (a,b)) = "^sname^".CApp (exportTerm a, exportTerm b)",
       
   179 		"  | exportTerm (Abs m) = "^sname^".CAbs (exportTerm m)",
       
   180 		"  | exportTerm (Closure (closlist, clos)) = "^sname^".Closure (map exportTerm closlist, exportTerm clos)"]
       
   181 	val _ = writelist (map ec constants)
       
   182 		
       
   183 	val _ = writelist [
       
   184 		"",
       
   185 		"fun rewrite t = ",
       
   186 		"    let",
       
   187 		"      val (stack, wnf) = weak SEmpty (Closure ([], importTerm t))",
       
   188 		"    in",
       
   189 		"      case stack of ",
       
   190 		"           SEmpty => (case strong SEmpty wnf of",
       
   191 		"                          (SEmpty, snf) => exportTerm snf",
       
   192 		"                        | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))",
       
   193 		"         | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))",
       
   194 		"    end",
       
   195 		"",
       
   196 		"val _ = "^sname^".set_compiled_rewriter rewrite",
       
   197 		"",
       
   198 		"end;"]
       
   199 
       
   200 	val _ = 
       
   201 	    let
       
   202 		val fout = TextIO.openOut "gen_code.ML"
       
   203 		val _ = TextIO.output (fout, !buffer)
       
   204 		val _  = TextIO.closeOut fout
       
   205 	    in
       
   206 		()
       
   207 	    end
       
   208 	    
       
   209 	val dummy = (fn s => ())		    
       
   210     in
       
   211 	compiled_rewriter := NONE;	
       
   212 	use_text (dummy, dummy) false (!buffer);
       
   213 	case !compiled_rewriter of 
       
   214 	    NONE => raise (Compile "cannot communicate with compiled function")
       
   215 	  | SOME r => (compiled_rewriter := NONE; fn t => term_of_clos (r t))
       
   216     end	
       
   217 
       
   218 fun compile eqs = 
       
   219     let
       
   220 	val _ = map (fn (p, r) => 
       
   221                   (check_freevars (count_patternvars p) r; 
       
   222                    case p of PVar => raise (Compile "pattern reduces to a variable") | _ => ())) eqs
       
   223     in
       
   224 	load_rules "AM_Compiler" "AM_compiled_code" eqs
       
   225     end	
       
   226 
       
   227 fun run prog t = (prog t)
       
   228 			 	  
       
   229 end
       
   230 
       
   231 structure AbstractMachine = AM_Compiler