src/Tools/Compute_Oracle/am_compiler.ML
changeset 37872 d83659570337
parent 37871 c7ce7685e087
child 37873 66d90b2b87bc
equal deleted inserted replaced
37871:c7ce7685e087 37872:d83659570337
     1 (*  Title:      Tools/Compute_Oracle/am_compiler.ML
       
     2     Author:     Steven Obua
       
     3 *)
       
     4 
       
     5 signature COMPILING_AM = 
       
     6 sig
       
     7   include ABSTRACT_MACHINE
       
     8 
       
     9   val set_compiled_rewriter : (term -> term) -> unit
       
    10   val list_nth : 'a list * int -> 'a
       
    11   val list_map : ('a -> 'b) -> 'a list -> 'b list
       
    12 end
       
    13 
       
    14 structure AM_Compiler : COMPILING_AM = struct
       
    15 
       
    16 val list_nth = List.nth;
       
    17 val list_map = map;
       
    18 
       
    19 open AbstractMachine;
       
    20 
       
    21 val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option)
       
    22 
       
    23 fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
       
    24 
       
    25 type program = (term -> term)
       
    26 
       
    27 fun count_patternvars PVar = 1
       
    28   | count_patternvars (PConst (_, ps)) =
       
    29       List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
       
    30 
       
    31 fun print_rule (p, t) = 
       
    32     let
       
    33         fun str x = Int.toString x
       
    34         fun print_pattern n PVar = (n+1, "x"^(str n))
       
    35           | print_pattern n (PConst (c, [])) = (n, "c"^(str c))
       
    36           | print_pattern n (PConst (c, args)) = 
       
    37             let
       
    38                 val h = print_pattern n (PConst (c,[]))
       
    39             in
       
    40                 print_pattern_list h args
       
    41             end
       
    42         and print_pattern_list r [] = r
       
    43           | print_pattern_list (n, p) (t::ts) = 
       
    44             let
       
    45                 val (n, t) = print_pattern n t
       
    46             in
       
    47                 print_pattern_list (n, "App ("^p^", "^t^")") ts
       
    48             end
       
    49 
       
    50         val (n, pattern) = print_pattern 0 p
       
    51         val pattern =
       
    52             if exists_string Symbol.is_ascii_blank pattern then "(" ^ pattern ^")"
       
    53             else pattern
       
    54         
       
    55         fun print_term d (Var x) = (*if x < d then "Var "^(str x) else "x"^(str (n-(x-d)-1))*)
       
    56               "Var " ^ str x
       
    57           | print_term d (Const c) = "c" ^ str c
       
    58           | print_term d (App (a,b)) = "App (" ^ print_term d a ^ ", " ^ print_term d b ^ ")"
       
    59           | print_term d (Abs c) = "Abs (" ^ print_term (d + 1) c ^ ")"
       
    60           | print_term d (Computed c) = print_term d c
       
    61 
       
    62         fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1))
       
    63 
       
    64         val term = print_term 0 t
       
    65         val term =
       
    66             if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")"
       
    67             else "Closure ([], "^term^")"
       
    68                            
       
    69     in
       
    70         "  | weak_reduce (false, stack, "^pattern^") = Continue (false, stack, "^term^")"
       
    71     end
       
    72 
       
    73 fun constants_of PVar = []
       
    74   | constants_of (PConst (c, ps)) = c :: maps constants_of ps
       
    75 
       
    76 fun constants_of_term (Var _) = []
       
    77   | constants_of_term (Abs m) = constants_of_term m
       
    78   | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b)
       
    79   | constants_of_term (Const c) = [c]
       
    80   | constants_of_term (Computed c) = constants_of_term c
       
    81     
       
    82 fun load_rules sname name prog = 
       
    83     let
       
    84         val buffer = Unsynchronized.ref ""
       
    85         fun write s = (buffer := (!buffer)^s)
       
    86         fun writeln s = (write s; write "\n")
       
    87         fun writelist [] = ()
       
    88           | writelist (s::ss) = (writeln s; writelist ss)
       
    89         fun str i = Int.toString i
       
    90         val _ = writelist [
       
    91                 "structure "^name^" = struct",
       
    92                 "",
       
    93                 "datatype term = Dummy | App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"]
       
    94         val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog)
       
    95         val _ = map (fn x => write (" | c"^(str x))) constants
       
    96         val _ = writelist [
       
    97                 "",
       
    98                 "datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack",
       
    99                 "",
       
   100                 "type state = bool * stack * term",
       
   101                 "",
       
   102                 "datatype loopstate = Continue of state | Stop of stack * term",
       
   103                 "",
       
   104                 "fun proj_C (Continue s) = s",
       
   105                 "  | proj_C _ = raise Match",
       
   106                 "",
       
   107                 "fun proj_S (Stop s) = s",
       
   108                 "  | proj_S _ = raise Match",
       
   109                 "",
       
   110                 "fun cont (Continue _) = true",
       
   111                 "  | cont _ = false",
       
   112                 "",
       
   113                 "fun do_reduction reduce p =",
       
   114                 "    let",
       
   115                 "       val s = Unsynchronized.ref (Continue p)",
       
   116                 "       val _ = while cont (!s) do (s := reduce (proj_C (!s)))",
       
   117                 "   in",
       
   118                 "       proj_S (!s)",
       
   119                 "   end",
       
   120                 ""]
       
   121 
       
   122         val _ = writelist [
       
   123                 "fun weak_reduce (false, stack, Closure (e, App (a, b))) = Continue (false, SAppL (Closure (e, b), stack), Closure (e, a))",
       
   124                 "  | weak_reduce (false, SAppL (b, stack), Closure (e, Abs m)) = Continue (false, stack, Closure (b::e, m))",
       
   125                 "  | weak_reduce (false, stack, c as Closure (e, Abs m)) = Continue (true, stack, c)",
       
   126                 "  | weak_reduce (false, stack, Closure (e, Var n)) = Continue (false, stack, case "^sname^".list_nth (e, n) of Dummy => Var n | r => r)",
       
   127                 "  | weak_reduce (false, stack, Closure (e, c)) = Continue (false, stack, c)"]
       
   128         val _ = writelist (map print_rule prog)
       
   129         val _ = writelist [
       
   130                 "  | weak_reduce (false, stack, clos) = Continue (true, stack, clos)",
       
   131                 "  | weak_reduce (true, SAppR (a, stack), b) = Continue (false, stack, App (a,b))",
       
   132                 "  | weak_reduce (true, s as (SAppL (b, stack)), a) = Continue (false, SAppR (a, stack), b)",
       
   133                 "  | weak_reduce (true, stack, c) = Stop (stack, c)",
       
   134                 "",
       
   135                 "fun strong_reduce (false, stack, Closure (e, Abs m)) =",
       
   136                 "    let",
       
   137                 "        val (stack', wnf) = do_reduction weak_reduce (false, SEmpty, Closure (Dummy::e, m))",
       
   138                 "    in",
       
   139                 "        case stack' of",
       
   140                 "            SEmpty => Continue (false, SAbs stack, wnf)",
       
   141                 "          | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")",
       
   142                 "    end",              
       
   143                 "  | strong_reduce (false, stack, clos as (App (u, v))) = Continue (false, SAppL (v, stack), u)",
       
   144                 "  | strong_reduce (false, stack, clos) = Continue (true, stack, clos)",
       
   145                 "  | strong_reduce (true, SAbs stack, m) = Continue (false, stack, Abs m)",
       
   146                 "  | strong_reduce (true, SAppL (b, stack), a) = Continue (false, SAppR (a, stack), b)",
       
   147                 "  | strong_reduce (true, SAppR (a, stack), b) = Continue (true, stack, App (a, b))",
       
   148                 "  | strong_reduce (true, stack, clos) = Stop (stack, clos)",
       
   149                 ""]
       
   150         
       
   151         val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)"                                                       
       
   152         val _ = writelist [
       
   153                 "fun importTerm ("^sname^".Var x) = Var x",
       
   154                 "  | importTerm ("^sname^".Const c) =  "^ic,
       
   155                 "  | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)",
       
   156                 "  | importTerm ("^sname^".Abs m) = Abs (importTerm m)",
       
   157                 ""]
       
   158 
       
   159         fun ec c = "  | exportTerm c"^(str c)^" = "^sname^".Const "^(str c)
       
   160         val _ = writelist [
       
   161                 "fun exportTerm (Var x) = "^sname^".Var x",
       
   162                 "  | exportTerm (Const c) = "^sname^".Const c",
       
   163                 "  | exportTerm (App (a,b)) = "^sname^".App (exportTerm a, exportTerm b)",
       
   164                 "  | exportTerm (Abs m) = "^sname^".Abs (exportTerm m)",
       
   165                 "  | exportTerm (Closure (closlist, clos)) = raise ("^sname^".Run \"internal error, cannot export Closure\")",
       
   166                 "  | exportTerm Dummy = raise ("^sname^".Run \"internal error, cannot export Dummy\")"]
       
   167         val _ = writelist (map ec constants)
       
   168                 
       
   169         val _ = writelist [
       
   170                 "",
       
   171                 "fun rewrite t = ",
       
   172                 "    let",
       
   173                 "      val (stack, wnf) = do_reduction weak_reduce (false, SEmpty, Closure ([], importTerm t))",
       
   174                 "    in",
       
   175                 "      case stack of ",
       
   176                 "           SEmpty => (case do_reduction strong_reduce (false, SEmpty, wnf) of",
       
   177                 "                          (SEmpty, snf) => exportTerm snf",
       
   178                 "                        | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))",
       
   179                 "         | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))",
       
   180                 "    end",
       
   181                 "",
       
   182                 "val _ = "^sname^".set_compiled_rewriter rewrite",
       
   183                 "",
       
   184                 "end;"]
       
   185 
       
   186     in
       
   187         compiled_rewriter := NONE;      
       
   188         use_text ML_Env.local_context (1, "") false (!buffer);
       
   189         case !compiled_rewriter of 
       
   190             NONE => raise (Compile "cannot communicate with compiled function")
       
   191           | SOME r => (compiled_rewriter := NONE; r)
       
   192     end 
       
   193 
       
   194 fun compile cache_patterns const_arity eqs = 
       
   195     let
       
   196         val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
       
   197         val eqs = map (fn (a,b,c) => (b,c)) eqs
       
   198         fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") 
       
   199         val _ = map (fn (p, r) => 
       
   200                   (check (p, r); 
       
   201                    case p of PVar => raise (Compile "pattern is just a variable") | _ => ())) eqs
       
   202     in
       
   203         load_rules "AM_Compiler" "AM_compiled_code" eqs
       
   204     end 
       
   205 
       
   206 fun run prog t = (prog t)
       
   207 
       
   208 fun discard p = ()
       
   209                                   
       
   210 end
       
   211