src/Tools/Compute_Oracle/am_ghc.ML
changeset 37872 d83659570337
parent 37871 c7ce7685e087
child 37873 66d90b2b87bc
equal deleted inserted replaced
37871:c7ce7685e087 37872:d83659570337
     1 (*  Title:      Tools/Compute_Oracle/am_ghc.ML
       
     2     Author:     Steven Obua
       
     3 *)
       
     4 
       
     5 structure AM_GHC : ABSTRACT_MACHINE = struct
       
     6 
       
     7 open AbstractMachine;
       
     8 
       
     9 type program = string * string * (int Inttab.table)
       
    10 
       
    11 fun count_patternvars PVar = 1
       
    12   | count_patternvars (PConst (_, ps)) =
       
    13       List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
       
    14 
       
    15 fun update_arity arity code a = 
       
    16     (case Inttab.lookup arity code of
       
    17          NONE => Inttab.update_new (code, a) arity
       
    18        | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
       
    19 
       
    20 (* We have to find out the maximal arity of each constant *)
       
    21 fun collect_pattern_arity PVar arity = arity
       
    22   | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
       
    23  
       
    24 local
       
    25 fun collect applevel (Var _) arity = arity
       
    26   | collect applevel (Const c) arity = update_arity arity c applevel
       
    27   | collect applevel (Abs m) arity = collect 0 m arity
       
    28   | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
       
    29 in
       
    30 fun collect_term_arity t arity = collect 0 t arity
       
    31 end
       
    32 
       
    33 fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
       
    34   | nlift level n (Const c) = Const c
       
    35   | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
       
    36   | nlift level n (Abs b) = Abs (nlift (level+1) n b)
       
    37 
       
    38 fun rep n x = if n = 0 then [] else x::(rep (n-1) x)
       
    39 
       
    40 fun adjust_rules rules =
       
    41     let
       
    42         val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty
       
    43         fun arity_of c = the (Inttab.lookup arity c)
       
    44         fun adjust_pattern PVar = PVar
       
    45           | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
       
    46         fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
       
    47           | adjust_rule (rule as (p as PConst (c, args),t)) = 
       
    48             let
       
    49                 val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
       
    50                 val args = map adjust_pattern args              
       
    51                 val len = length args
       
    52                 val arity = arity_of c
       
    53                 fun lift level n (Var m) = if m < level then Var m else Var (m+n) 
       
    54                   | lift level n (Const c) = Const c
       
    55                   | lift level n (App (a,b)) = App (lift level n a, lift level n b)
       
    56                   | lift level n (Abs b) = Abs (lift (level+1) n b)
       
    57                 val lift = lift 0
       
    58                 fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) 
       
    59             in
       
    60                 if len = arity then
       
    61                     rule
       
    62                 else if arity >= len then  
       
    63                     (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t))
       
    64                 else (raise Compile "internal error in adjust_rule")
       
    65             end
       
    66     in
       
    67         (arity, map adjust_rule rules)
       
    68     end             
       
    69 
       
    70 fun print_term arity_of n =
       
    71 let
       
    72     fun str x = string_of_int x
       
    73     fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
       
    74                                                                                           
       
    75     fun print_apps d f [] = f
       
    76       | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
       
    77     and print_call d (App (a, b)) args = print_call d a (b::args) 
       
    78       | print_call d (Const c) args = 
       
    79         (case arity_of c of 
       
    80              NONE => print_apps d ("Const "^(str c)) args 
       
    81            | SOME a =>
       
    82              let
       
    83                  val len = length args
       
    84              in
       
    85                  if a <= len then 
       
    86                      let
       
    87                          val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a))))
       
    88                      in
       
    89                          print_apps d s (List.drop (args, a))
       
    90                      end
       
    91                  else 
       
    92                      let
       
    93                          fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1)))
       
    94                          fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
       
    95                          fun append_args [] t = t
       
    96                            | append_args (c::cs) t = append_args cs (App (t, c))
       
    97                      in
       
    98                          print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
       
    99                      end
       
   100              end)
       
   101       | print_call d t args = print_apps d (print_term d t) args
       
   102     and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1))
       
   103       | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")"
       
   104       | print_term d t = print_call d t []
       
   105 in
       
   106     print_term 0 
       
   107 end
       
   108                                                 
       
   109 fun print_rule arity_of (p, t) = 
       
   110     let 
       
   111         fun str x = Int.toString x                  
       
   112         fun print_pattern top n PVar = (n+1, "x"^(str n))
       
   113           | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c))
       
   114           | print_pattern top n (PConst (c, args)) = 
       
   115             let
       
   116                 val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args
       
   117             in
       
   118                 (n, if top then s else "("^s^")")
       
   119             end
       
   120         and print_pattern_list r [] = r
       
   121           | print_pattern_list (n, p) (t::ts) = 
       
   122             let
       
   123                 val (n, t) = print_pattern false n t
       
   124             in
       
   125                 print_pattern_list (n, p^" "^t) ts
       
   126             end
       
   127         val (n, pattern) = print_pattern true 0 p
       
   128     in
       
   129         pattern^" = "^(print_term arity_of n t) 
       
   130     end
       
   131 
       
   132 fun group_rules rules =
       
   133     let
       
   134         fun add_rule (r as (PConst (c,_), _)) groups =
       
   135             let
       
   136                 val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
       
   137             in
       
   138                 Inttab.update (c, r::rs) groups
       
   139             end
       
   140           | add_rule _ _ = raise Compile "internal error group_rules"
       
   141     in
       
   142         fold_rev add_rule rules Inttab.empty
       
   143     end
       
   144 
       
   145 fun haskell_prog name rules = 
       
   146     let
       
   147         val buffer = Unsynchronized.ref ""
       
   148         fun write s = (buffer := (!buffer)^s)
       
   149         fun writeln s = (write s; write "\n")
       
   150         fun writelist [] = ()
       
   151           | writelist (s::ss) = (writeln s; writelist ss)
       
   152         fun str i = Int.toString i
       
   153         val (arity, rules) = adjust_rules rules
       
   154         val rules = group_rules rules
       
   155         val constants = Inttab.keys arity
       
   156         fun arity_of c = Inttab.lookup arity c
       
   157         fun rep_str s n = implode (rep n s)
       
   158         fun indexed s n = s^(str n)
       
   159         fun section n = if n = 0 then [] else (section (n-1))@[n-1]
       
   160         fun make_show c = 
       
   161             let
       
   162                 val args = section (the (arity_of c))
       
   163             in
       
   164                 "  show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = "
       
   165                 ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args))
       
   166             end
       
   167         fun default_case c = 
       
   168             let
       
   169                 val args = implode (map (indexed " x") (section (the (arity_of c))))
       
   170             in
       
   171                 (indexed "c" c)^args^" = "^(indexed "C" c)^args
       
   172             end
       
   173         val _ = writelist [        
       
   174                 "module "^name^" where",
       
   175                 "",
       
   176                 "data Term = Const Integer | App Term Term | Abs (Term -> Term)",
       
   177                 "         "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)),
       
   178                 "",
       
   179                 "instance Show Term where"]
       
   180         val _ = writelist (map make_show constants)
       
   181         val _ = writelist [
       
   182                 "  show (Const c) = \"c\"++(show c)",
       
   183                 "  show (App a b) = \"A\"++(show a)++(show b)",
       
   184                 "  show (Abs _) = \"L\"",
       
   185                 ""]
       
   186         val _ = writelist [
       
   187                 "app (Abs a) b = a b",
       
   188                 "app a b = App a b",
       
   189                 "",
       
   190                 "calc s c = writeFile s (show c)",
       
   191                 ""]
       
   192         fun list_group c = (writelist (case Inttab.lookup rules c of 
       
   193                                            NONE => [default_case c, ""] 
       
   194                                          | SOME (rs as ((PConst (_, []), _)::rs')) => 
       
   195                                            if not (null rs') then raise Compile "multiple declaration of constant"
       
   196                                            else (map (print_rule arity_of) rs) @ [""]
       
   197                                          | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""]))
       
   198         val _ = map list_group constants
       
   199     in
       
   200         (arity, !buffer)
       
   201     end
       
   202 
       
   203 val guid_counter = Unsynchronized.ref 0
       
   204 fun get_guid () = 
       
   205     let
       
   206         val c = !guid_counter
       
   207         val _ = guid_counter := !guid_counter + 1
       
   208     in
       
   209         (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
       
   210     end
       
   211 
       
   212 fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s])));
       
   213 fun wrap s = "\""^s^"\""
       
   214 
       
   215 fun writeTextFile name s = File.write (Path.explode name) s
       
   216     
       
   217 val ghc = Unsynchronized.ref (case getenv "GHC_PATH" of "" => "ghc" | s => s)
       
   218 
       
   219 fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false)
       
   220 
       
   221 fun compile cache_patterns const_arity eqs = 
       
   222     let
       
   223         val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
       
   224         val eqs = map (fn (a,b,c) => (b,c)) eqs
       
   225         val guid = get_guid ()
       
   226         val module = "AMGHC_Prog_"^guid
       
   227         val (arity, source) = haskell_prog module eqs
       
   228         val module_file = tmp_file (module^".hs")
       
   229         val object_file = tmp_file (module^".o")
       
   230         val _ = writeTextFile module_file source
       
   231         val _ = bash ((!ghc)^" -c "^module_file)
       
   232         val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else ()
       
   233     in
       
   234         (guid, module_file, arity)      
       
   235     end
       
   236 
       
   237 fun readResultFile name = File.read (Path.explode name) 
       
   238                                                                                                     
       
   239 fun parse_result arity_of result =
       
   240     let
       
   241         val result = String.explode result
       
   242         fun shift NONE x = SOME x
       
   243           | shift (SOME y) x = SOME (y*10 + x)
       
   244         fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest
       
   245           | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest
       
   246           | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest
       
   247           | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest
       
   248           | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest
       
   249           | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest
       
   250           | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest
       
   251           | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest
       
   252           | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest
       
   253           | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest
       
   254           | parse_int' x rest = (x, rest)
       
   255         fun parse_int rest = parse_int' NONE rest
       
   256 
       
   257         fun parse (#"C"::rest) = 
       
   258             (case parse_int rest of 
       
   259                  (SOME c, rest) => 
       
   260                  let
       
   261                      val (args, rest) = parse_list (the (arity_of c)) rest
       
   262                      fun app_args [] t = t
       
   263                        | app_args (x::xs) t = app_args xs (App (t, x))
       
   264                  in
       
   265                      (app_args args (Const c), rest)
       
   266                  end                 
       
   267                | (NONE, rest) => raise Run "parse C")
       
   268           | parse (#"c"::rest) = 
       
   269             (case parse_int rest of
       
   270                  (SOME c, rest) => (Const c, rest)
       
   271                | _ => raise Run "parse c")
       
   272           | parse (#"A"::rest) = 
       
   273             let
       
   274                 val (a, rest) = parse rest
       
   275                 val (b, rest) = parse rest
       
   276             in
       
   277                 (App (a,b), rest)
       
   278             end
       
   279           | parse (#"L"::rest) = raise Run "there may be no abstraction in the result"
       
   280           | parse _ = raise Run "invalid result"
       
   281         and parse_list n rest = 
       
   282             if n = 0 then 
       
   283                 ([], rest) 
       
   284             else 
       
   285                 let 
       
   286                     val (x, rest) = parse rest
       
   287                     val (xs, rest) = parse_list (n-1) rest
       
   288                 in
       
   289                     (x::xs, rest)
       
   290                 end
       
   291         val (parsed, rest) = parse result
       
   292         fun is_blank (#" "::rest) = is_blank rest
       
   293           | is_blank (#"\n"::rest) = is_blank rest
       
   294           | is_blank [] = true
       
   295           | is_blank _ = false
       
   296     in
       
   297         if is_blank rest then parsed else raise Run "non-blank suffix in result file"   
       
   298     end
       
   299 
       
   300 fun run (guid, module_file, arity) t = 
       
   301     let
       
   302         val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
       
   303         fun arity_of c = Inttab.lookup arity c                   
       
   304         val callguid = get_guid()
       
   305         val module = "AMGHC_Prog_"^guid
       
   306         val call = module^"_Call_"^callguid
       
   307         val result_file = tmp_file (module^"_Result_"^callguid^".txt")
       
   308         val call_file = tmp_file (call^".hs")
       
   309         val term = print_term arity_of 0 t
       
   310         val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
       
   311         val _ = writeTextFile call_file call_source
       
   312         val _ = bash ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file)
       
   313         val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')")
       
   314         val t' = parse_result arity_of result
       
   315         val _ = OS.FileSys.remove call_file
       
   316         val _ = OS.FileSys.remove result_file
       
   317     in
       
   318         t'
       
   319     end
       
   320 
       
   321         
       
   322 fun discard _ = ()
       
   323                           
       
   324 end
       
   325