src/HOL/Matrix/Compute_Oracle/am_ghc.ML
author wenzelm
Thu Jun 30 13:21:41 2011 +0200 (2011-06-30)
changeset 43602 8c89a1fb30f2
parent 41959 b460124855b8
child 43850 7f2cbc713344
permissions -rw-r--r--
standardized use of Path operations;
wenzelm@41959
     1
(*  Title:      HOL/Matrix/Compute_Oracle/am_ghc.ML
obua@23663
     2
    Author:     Steven Obua
obua@23663
     3
*)
obua@23663
     4
wenzelm@41952
     5
structure AM_GHC : ABSTRACT_MACHINE =
wenzelm@41952
     6
struct
obua@23663
     7
obua@23663
     8
open AbstractMachine;
obua@23663
     9
obua@23663
    10
type program = string * string * (int Inttab.table)
obua@23663
    11
obua@23663
    12
fun count_patternvars PVar = 1
obua@23663
    13
  | count_patternvars (PConst (_, ps)) =
obua@23663
    14
      List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
obua@23663
    15
obua@23663
    16
fun update_arity arity code a = 
obua@23663
    17
    (case Inttab.lookup arity code of
wenzelm@32960
    18
         NONE => Inttab.update_new (code, a) arity
wenzelm@24134
    19
       | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
obua@23663
    20
obua@23663
    21
(* We have to find out the maximal arity of each constant *)
obua@23663
    22
fun collect_pattern_arity PVar arity = arity
obua@23663
    23
  | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
obua@23663
    24
 
obua@23663
    25
local
obua@23663
    26
fun collect applevel (Var _) arity = arity
obua@23663
    27
  | collect applevel (Const c) arity = update_arity arity c applevel
obua@23663
    28
  | collect applevel (Abs m) arity = collect 0 m arity
obua@23663
    29
  | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
obua@23663
    30
in
obua@23663
    31
fun collect_term_arity t arity = collect 0 t arity
obua@23663
    32
end
obua@23663
    33
obua@23663
    34
fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
obua@23663
    35
  | nlift level n (Const c) = Const c
obua@23663
    36
  | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
obua@23663
    37
  | nlift level n (Abs b) = Abs (nlift (level+1) n b)
obua@23663
    38
obua@23663
    39
fun rep n x = if n = 0 then [] else x::(rep (n-1) x)
obua@23663
    40
obua@23663
    41
fun adjust_rules rules =
obua@23663
    42
    let
wenzelm@32960
    43
        val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty
wenzelm@32960
    44
        fun arity_of c = the (Inttab.lookup arity c)
wenzelm@32960
    45
        fun adjust_pattern PVar = PVar
wenzelm@32960
    46
          | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
wenzelm@32960
    47
        fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
wenzelm@32960
    48
          | adjust_rule (rule as (p as PConst (c, args),t)) = 
wenzelm@32960
    49
            let
wenzelm@32960
    50
                val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
wenzelm@32960
    51
                val args = map adjust_pattern args              
wenzelm@32960
    52
                val len = length args
wenzelm@32960
    53
                val arity = arity_of c
wenzelm@32960
    54
                fun lift level n (Var m) = if m < level then Var m else Var (m+n) 
wenzelm@32960
    55
                  | lift level n (Const c) = Const c
wenzelm@32960
    56
                  | lift level n (App (a,b)) = App (lift level n a, lift level n b)
wenzelm@32960
    57
                  | lift level n (Abs b) = Abs (lift (level+1) n b)
wenzelm@32960
    58
                val lift = lift 0
wenzelm@32960
    59
                fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) 
wenzelm@32960
    60
            in
wenzelm@32960
    61
                if len = arity then
wenzelm@32960
    62
                    rule
wenzelm@32960
    63
                else if arity >= len then  
wenzelm@32960
    64
                    (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t))
wenzelm@32960
    65
                else (raise Compile "internal error in adjust_rule")
wenzelm@32960
    66
            end
obua@23663
    67
    in
wenzelm@32960
    68
        (arity, map adjust_rule rules)
wenzelm@32960
    69
    end             
obua@23663
    70
obua@23663
    71
fun print_term arity_of n =
obua@23663
    72
let
obua@23663
    73
    fun str x = string_of_int x
obua@23663
    74
    fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
wenzelm@32960
    75
                                                                                          
obua@23663
    76
    fun print_apps d f [] = f
obua@23663
    77
      | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
obua@23663
    78
    and print_call d (App (a, b)) args = print_call d a (b::args) 
obua@23663
    79
      | print_call d (Const c) args = 
wenzelm@32960
    80
        (case arity_of c of 
wenzelm@32960
    81
             NONE => print_apps d ("Const "^(str c)) args 
wenzelm@32960
    82
           | SOME a =>
wenzelm@32960
    83
             let
wenzelm@32960
    84
                 val len = length args
wenzelm@32960
    85
             in
wenzelm@32960
    86
                 if a <= len then 
wenzelm@32960
    87
                     let
wenzelm@32960
    88
                         val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a))))
wenzelm@32960
    89
                     in
wenzelm@32960
    90
                         print_apps d s (List.drop (args, a))
wenzelm@32960
    91
                     end
wenzelm@32960
    92
                 else 
wenzelm@32960
    93
                     let
wenzelm@32960
    94
                         fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1)))
wenzelm@32960
    95
                         fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
wenzelm@32960
    96
                         fun append_args [] t = t
wenzelm@32960
    97
                           | append_args (c::cs) t = append_args cs (App (t, c))
wenzelm@32960
    98
                     in
wenzelm@32960
    99
                         print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
wenzelm@32960
   100
                     end
wenzelm@32960
   101
             end)
obua@23663
   102
      | print_call d t args = print_apps d (print_term d t) args
obua@23663
   103
    and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1))
obua@23663
   104
      | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")"
obua@23663
   105
      | print_term d t = print_call d t []
obua@23663
   106
in
obua@23663
   107
    print_term 0 
obua@23663
   108
end
wenzelm@32960
   109
                                                
obua@23663
   110
fun print_rule arity_of (p, t) = 
wenzelm@32960
   111
    let 
wenzelm@41491
   112
        fun str x = string_of_int x                  
wenzelm@32960
   113
        fun print_pattern top n PVar = (n+1, "x"^(str n))
wenzelm@32960
   114
          | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c))
wenzelm@32960
   115
          | print_pattern top n (PConst (c, args)) = 
wenzelm@32960
   116
            let
wenzelm@32960
   117
                val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args
wenzelm@32960
   118
            in
wenzelm@32960
   119
                (n, if top then s else "("^s^")")
wenzelm@32960
   120
            end
wenzelm@32960
   121
        and print_pattern_list r [] = r
wenzelm@32960
   122
          | print_pattern_list (n, p) (t::ts) = 
wenzelm@32960
   123
            let
wenzelm@32960
   124
                val (n, t) = print_pattern false n t
wenzelm@32960
   125
            in
wenzelm@32960
   126
                print_pattern_list (n, p^" "^t) ts
wenzelm@32960
   127
            end
wenzelm@32960
   128
        val (n, pattern) = print_pattern true 0 p
obua@23663
   129
    in
wenzelm@32960
   130
        pattern^" = "^(print_term arity_of n t) 
obua@23663
   131
    end
obua@23663
   132
obua@23663
   133
fun group_rules rules =
obua@23663
   134
    let
wenzelm@32960
   135
        fun add_rule (r as (PConst (c,_), _)) groups =
wenzelm@32960
   136
            let
wenzelm@32960
   137
                val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
wenzelm@32960
   138
            in
wenzelm@32960
   139
                Inttab.update (c, r::rs) groups
wenzelm@32960
   140
            end
wenzelm@32960
   141
          | add_rule _ _ = raise Compile "internal error group_rules"
obua@23663
   142
    in
wenzelm@32960
   143
        fold_rev add_rule rules Inttab.empty
obua@23663
   144
    end
obua@23663
   145
obua@23663
   146
fun haskell_prog name rules = 
obua@23663
   147
    let
wenzelm@32960
   148
        val buffer = Unsynchronized.ref ""
wenzelm@32960
   149
        fun write s = (buffer := (!buffer)^s)
wenzelm@32960
   150
        fun writeln s = (write s; write "\n")
wenzelm@32960
   151
        fun writelist [] = ()
wenzelm@32960
   152
          | writelist (s::ss) = (writeln s; writelist ss)
wenzelm@41491
   153
        fun str i = string_of_int i
wenzelm@32960
   154
        val (arity, rules) = adjust_rules rules
wenzelm@32960
   155
        val rules = group_rules rules
wenzelm@32960
   156
        val constants = Inttab.keys arity
wenzelm@32960
   157
        fun arity_of c = Inttab.lookup arity c
wenzelm@32960
   158
        fun rep_str s n = implode (rep n s)
wenzelm@32960
   159
        fun indexed s n = s^(str n)
wenzelm@32960
   160
        fun section n = if n = 0 then [] else (section (n-1))@[n-1]
wenzelm@32960
   161
        fun make_show c = 
wenzelm@32960
   162
            let
wenzelm@32960
   163
                val args = section (the (arity_of c))
wenzelm@32960
   164
            in
wenzelm@32960
   165
                "  show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = "
wenzelm@32960
   166
                ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args))
wenzelm@32960
   167
            end
wenzelm@32960
   168
        fun default_case c = 
wenzelm@32960
   169
            let
wenzelm@32960
   170
                val args = implode (map (indexed " x") (section (the (arity_of c))))
wenzelm@32960
   171
            in
wenzelm@32960
   172
                (indexed "c" c)^args^" = "^(indexed "C" c)^args
wenzelm@32960
   173
            end
wenzelm@32960
   174
        val _ = writelist [        
wenzelm@32960
   175
                "module "^name^" where",
wenzelm@32960
   176
                "",
wenzelm@32960
   177
                "data Term = Const Integer | App Term Term | Abs (Term -> Term)",
wenzelm@32960
   178
                "         "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)),
wenzelm@32960
   179
                "",
wenzelm@32960
   180
                "instance Show Term where"]
wenzelm@32960
   181
        val _ = writelist (map make_show constants)
wenzelm@32960
   182
        val _ = writelist [
wenzelm@32960
   183
                "  show (Const c) = \"c\"++(show c)",
wenzelm@32960
   184
                "  show (App a b) = \"A\"++(show a)++(show b)",
wenzelm@32960
   185
                "  show (Abs _) = \"L\"",
wenzelm@32960
   186
                ""]
wenzelm@32960
   187
        val _ = writelist [
wenzelm@32960
   188
                "app (Abs a) b = a b",
wenzelm@32960
   189
                "app a b = App a b",
wenzelm@32960
   190
                "",
wenzelm@32960
   191
                "calc s c = writeFile s (show c)",
wenzelm@32960
   192
                ""]
wenzelm@32960
   193
        fun list_group c = (writelist (case Inttab.lookup rules c of 
wenzelm@32960
   194
                                           NONE => [default_case c, ""] 
wenzelm@32960
   195
                                         | SOME (rs as ((PConst (_, []), _)::rs')) => 
wenzelm@32960
   196
                                           if not (null rs') then raise Compile "multiple declaration of constant"
wenzelm@32960
   197
                                           else (map (print_rule arity_of) rs) @ [""]
wenzelm@32960
   198
                                         | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""]))
wenzelm@32960
   199
        val _ = map list_group constants
obua@23663
   200
    in
wenzelm@32960
   201
        (arity, !buffer)
obua@23663
   202
    end
obua@23663
   203
wenzelm@32740
   204
val guid_counter = Unsynchronized.ref 0
obua@23663
   205
fun get_guid () = 
obua@23663
   206
    let
wenzelm@32960
   207
        val c = !guid_counter
wenzelm@32960
   208
        val _ = guid_counter := !guid_counter + 1
obua@23663
   209
    in
wenzelm@41491
   210
        string_of_int (Time.toMicroseconds (Time.now ())) ^ string_of_int c
obua@23663
   211
    end
obua@23663
   212
wenzelm@43602
   213
fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.basic s)));
obua@23663
   214
fun wrap s = "\""^s^"\""
obua@23663
   215
obua@23663
   216
fun writeTextFile name s = File.write (Path.explode name) s
obua@23663
   217
    
obua@23663
   218
fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false)
obua@23663
   219
obua@25520
   220
fun compile cache_patterns const_arity eqs = 
obua@23663
   221
    let
wenzelm@32960
   222
        val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
wenzelm@32960
   223
        val eqs = map (fn (a,b,c) => (b,c)) eqs
wenzelm@32960
   224
        val guid = get_guid ()
wenzelm@32960
   225
        val module = "AMGHC_Prog_"^guid
wenzelm@32960
   226
        val (arity, source) = haskell_prog module eqs
wenzelm@32960
   227
        val module_file = tmp_file (module^".hs")
wenzelm@32960
   228
        val object_file = tmp_file (module^".o")
wenzelm@32960
   229
        val _ = writeTextFile module_file source
wenzelm@41953
   230
        val _ = bash ("exec \"$ISABELLE_GHC\" -c " ^ module_file)
wenzelm@41952
   231
        val _ =
wenzelm@41952
   232
          if not (fileExists object_file) then
wenzelm@41952
   233
            raise Compile ("Failure compiling haskell code (ISABELLE_GHC='" ^ getenv "ISABELLE_GHC" ^ "')")
wenzelm@41952
   234
          else ()
obua@23663
   235
    in
wenzelm@32960
   236
        (guid, module_file, arity)      
obua@23663
   237
    end
obua@23663
   238
obua@23663
   239
fun readResultFile name = File.read (Path.explode name) 
wenzelm@32960
   240
                                                                                                    
obua@23663
   241
fun parse_result arity_of result =
obua@23663
   242
    let
wenzelm@32960
   243
        val result = String.explode result
wenzelm@32960
   244
        fun shift NONE x = SOME x
wenzelm@32960
   245
          | shift (SOME y) x = SOME (y*10 + x)
wenzelm@32960
   246
        fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest
wenzelm@32960
   247
          | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest
wenzelm@32960
   248
          | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest
wenzelm@32960
   249
          | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest
wenzelm@32960
   250
          | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest
wenzelm@32960
   251
          | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest
wenzelm@32960
   252
          | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest
wenzelm@32960
   253
          | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest
wenzelm@32960
   254
          | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest
wenzelm@32960
   255
          | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest
wenzelm@32960
   256
          | parse_int' x rest = (x, rest)
wenzelm@32960
   257
        fun parse_int rest = parse_int' NONE rest
obua@23663
   258
wenzelm@32960
   259
        fun parse (#"C"::rest) = 
wenzelm@32960
   260
            (case parse_int rest of 
wenzelm@32960
   261
                 (SOME c, rest) => 
wenzelm@32960
   262
                 let
wenzelm@32960
   263
                     val (args, rest) = parse_list (the (arity_of c)) rest
wenzelm@32960
   264
                     fun app_args [] t = t
wenzelm@32960
   265
                       | app_args (x::xs) t = app_args xs (App (t, x))
wenzelm@32960
   266
                 in
wenzelm@32960
   267
                     (app_args args (Const c), rest)
wenzelm@32960
   268
                 end                 
wenzelm@32960
   269
               | (NONE, rest) => raise Run "parse C")
wenzelm@32960
   270
          | parse (#"c"::rest) = 
wenzelm@32960
   271
            (case parse_int rest of
wenzelm@32960
   272
                 (SOME c, rest) => (Const c, rest)
wenzelm@32960
   273
               | _ => raise Run "parse c")
wenzelm@32960
   274
          | parse (#"A"::rest) = 
wenzelm@32960
   275
            let
wenzelm@32960
   276
                val (a, rest) = parse rest
wenzelm@32960
   277
                val (b, rest) = parse rest
wenzelm@32960
   278
            in
wenzelm@32960
   279
                (App (a,b), rest)
wenzelm@32960
   280
            end
wenzelm@32960
   281
          | parse (#"L"::rest) = raise Run "there may be no abstraction in the result"
wenzelm@32960
   282
          | parse _ = raise Run "invalid result"
wenzelm@32960
   283
        and parse_list n rest = 
wenzelm@32960
   284
            if n = 0 then 
wenzelm@32960
   285
                ([], rest) 
wenzelm@32960
   286
            else 
wenzelm@32960
   287
                let 
wenzelm@32960
   288
                    val (x, rest) = parse rest
wenzelm@32960
   289
                    val (xs, rest) = parse_list (n-1) rest
wenzelm@32960
   290
                in
wenzelm@32960
   291
                    (x::xs, rest)
wenzelm@32960
   292
                end
wenzelm@32960
   293
        val (parsed, rest) = parse result
wenzelm@32960
   294
        fun is_blank (#" "::rest) = is_blank rest
wenzelm@32960
   295
          | is_blank (#"\n"::rest) = is_blank rest
wenzelm@32960
   296
          | is_blank [] = true
wenzelm@32960
   297
          | is_blank _ = false
obua@23663
   298
    in
wenzelm@32960
   299
        if is_blank rest then parsed else raise Run "non-blank suffix in result file"   
obua@23663
   300
    end
obua@23663
   301
obua@23663
   302
fun run (guid, module_file, arity) t = 
obua@23663
   303
    let
wenzelm@32960
   304
        val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
wenzelm@32960
   305
        fun arity_of c = Inttab.lookup arity c                   
wenzelm@32960
   306
        val callguid = get_guid()
wenzelm@32960
   307
        val module = "AMGHC_Prog_"^guid
wenzelm@32960
   308
        val call = module^"_Call_"^callguid
wenzelm@32960
   309
        val result_file = tmp_file (module^"_Result_"^callguid^".txt")
wenzelm@32960
   310
        val call_file = tmp_file (call^".hs")
wenzelm@32960
   311
        val term = print_term arity_of 0 t
wenzelm@32960
   312
        val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
wenzelm@32960
   313
        val _ = writeTextFile call_file call_source
wenzelm@41953
   314
        val _ = bash ("exec \"$ISABELLE_GHC\" -e \""^call^".call\" "^module_file^" "^call_file)
wenzelm@41952
   315
        val result = readResultFile result_file handle IO.Io _ =>
wenzelm@41952
   316
          raise Run ("Failure running haskell compiler (ISABELLE_GHC='" ^ getenv "ISABELLE_GHC" ^ "')")
wenzelm@32960
   317
        val t' = parse_result arity_of result
wenzelm@32960
   318
        val _ = OS.FileSys.remove call_file
wenzelm@32960
   319
        val _ = OS.FileSys.remove result_file
obua@23663
   320
    in
wenzelm@32960
   321
        t'
obua@23663
   322
    end
obua@23663
   323
wenzelm@32960
   324
        
obua@23663
   325
fun discard _ = ()
wenzelm@32960
   326
                          
obua@23663
   327
end
obua@23663
   328