src/Tools/Compute_Oracle/am_ghc.ML
author wenzelm
Tue Sep 29 16:24:36 2009 +0200 (2009-09-29)
changeset 32740 9dd0a2f83429
parent 30161 c26e515f1c29
child 32952 aeb1e44fbc19
permissions -rw-r--r--
explicit indication of Unsynchronized.ref;
haftmann@24584
     1
(*  Title:      Tools/Compute_Oracle/am_ghc.ML
obua@23663
     2
    Author:     Steven Obua
obua@23663
     3
*)
obua@23663
     4
obua@23663
     5
structure AM_GHC : ABSTRACT_MACHINE = struct
obua@23663
     6
obua@23663
     7
open AbstractMachine;
obua@23663
     8
obua@23663
     9
type program = string * string * (int Inttab.table)
obua@23663
    10
obua@23663
    11
fun count_patternvars PVar = 1
obua@23663
    12
  | count_patternvars (PConst (_, ps)) =
obua@23663
    13
      List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
obua@23663
    14
obua@23663
    15
fun update_arity arity code a = 
obua@23663
    16
    (case Inttab.lookup arity code of
obua@23663
    17
	 NONE => Inttab.update_new (code, a) arity
wenzelm@24134
    18
       | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
obua@23663
    19
obua@23663
    20
(* We have to find out the maximal arity of each constant *)
obua@23663
    21
fun collect_pattern_arity PVar arity = arity
obua@23663
    22
  | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
obua@23663
    23
 
obua@23663
    24
local
obua@23663
    25
fun collect applevel (Var _) arity = arity
obua@23663
    26
  | collect applevel (Const c) arity = update_arity arity c applevel
obua@23663
    27
  | collect applevel (Abs m) arity = collect 0 m arity
obua@23663
    28
  | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
obua@23663
    29
in
obua@23663
    30
fun collect_term_arity t arity = collect 0 t arity
obua@23663
    31
end
obua@23663
    32
obua@23663
    33
fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
obua@23663
    34
  | nlift level n (Const c) = Const c
obua@23663
    35
  | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
obua@23663
    36
  | nlift level n (Abs b) = Abs (nlift (level+1) n b)
obua@23663
    37
obua@23663
    38
fun rep n x = if n = 0 then [] else x::(rep (n-1) x)
obua@23663
    39
obua@23663
    40
fun adjust_rules rules =
obua@23663
    41
    let
obua@23663
    42
	val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty
obua@23663
    43
	fun arity_of c = the (Inttab.lookup arity c)
obua@23663
    44
	fun adjust_pattern PVar = PVar
obua@23663
    45
	  | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
obua@23663
    46
	fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
obua@23663
    47
	  | adjust_rule (rule as (p as PConst (c, args),t)) = 
obua@23663
    48
	    let
obua@23663
    49
		val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
obua@23663
    50
		val args = map adjust_pattern args	        
obua@23663
    51
		val len = length args
obua@23663
    52
		val arity = arity_of c
obua@23663
    53
		fun lift level n (Var m) = if m < level then Var m else Var (m+n) 
obua@23663
    54
		  | lift level n (Const c) = Const c
obua@23663
    55
		  | lift level n (App (a,b)) = App (lift level n a, lift level n b)
obua@23663
    56
		  | lift level n (Abs b) = Abs (lift (level+1) n b)
obua@23663
    57
		val lift = lift 0
obua@23663
    58
		fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) 
obua@23663
    59
	    in
obua@23663
    60
		if len = arity then
obua@23663
    61
		    rule
obua@23663
    62
		else if arity >= len then  
obua@23663
    63
		    (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t))
obua@23663
    64
		else (raise Compile "internal error in adjust_rule")
obua@23663
    65
	    end
obua@23663
    66
    in
obua@23663
    67
	(arity, map adjust_rule rules)
obua@23663
    68
    end		    
obua@23663
    69
obua@23663
    70
fun print_term arity_of n =
obua@23663
    71
let
obua@23663
    72
    fun str x = string_of_int x
obua@23663
    73
    fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
obua@23663
    74
											  
obua@23663
    75
    fun print_apps d f [] = f
obua@23663
    76
      | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
obua@23663
    77
    and print_call d (App (a, b)) args = print_call d a (b::args) 
obua@23663
    78
      | print_call d (Const c) args = 
obua@23663
    79
	(case arity_of c of 
obua@23663
    80
	     NONE => print_apps d ("Const "^(str c)) args 
obua@23663
    81
	   | SOME a =>
obua@23663
    82
	     let
obua@23663
    83
		 val len = length args
obua@23663
    84
	     in
obua@23663
    85
		 if a <= len then 
obua@23663
    86
		     let
obua@23663
    87
			 val s = "c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a))))
obua@23663
    88
		     in
obua@23663
    89
			 print_apps d s (List.drop (args, a))
obua@23663
    90
		     end
obua@23663
    91
		 else 
obua@23663
    92
		     let
obua@23663
    93
			 fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1)))
obua@23663
    94
			 fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
obua@23663
    95
			 fun append_args [] t = t
obua@23663
    96
			   | append_args (c::cs) t = append_args cs (App (t, c))
obua@23663
    97
		     in
obua@23663
    98
			 print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
obua@23663
    99
		     end
obua@23663
   100
	     end)
obua@23663
   101
      | print_call d t args = print_apps d (print_term d t) args
obua@23663
   102
    and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1))
obua@23663
   103
      | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")"
obua@23663
   104
      | print_term d t = print_call d t []
obua@23663
   105
in
obua@23663
   106
    print_term 0 
obua@23663
   107
end
obua@23663
   108
			 			
obua@23663
   109
fun print_rule arity_of (p, t) = 
obua@23663
   110
    let	
obua@23663
   111
	fun str x = Int.toString x		    
obua@23663
   112
	fun print_pattern top n PVar = (n+1, "x"^(str n))
obua@23663
   113
	  | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c))
obua@23663
   114
	  | print_pattern top n (PConst (c, args)) = 
obua@23663
   115
	    let
obua@23663
   116
		val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args
obua@23663
   117
	    in
obua@23663
   118
		(n, if top then s else "("^s^")")
obua@23663
   119
	    end
obua@23663
   120
	and print_pattern_list r [] = r
obua@23663
   121
	  | print_pattern_list (n, p) (t::ts) = 
obua@23663
   122
	    let
obua@23663
   123
		val (n, t) = print_pattern false n t
obua@23663
   124
	    in
obua@23663
   125
		print_pattern_list (n, p^" "^t) ts
obua@23663
   126
	    end
obua@23663
   127
	val (n, pattern) = print_pattern true 0 p
obua@23663
   128
    in
obua@23663
   129
	pattern^" = "^(print_term arity_of n t)	
obua@23663
   130
    end
obua@23663
   131
obua@23663
   132
fun group_rules rules =
obua@23663
   133
    let
obua@23663
   134
	fun add_rule (r as (PConst (c,_), _)) groups =
obua@23663
   135
	    let
obua@23663
   136
		val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
obua@23663
   137
	    in
obua@23663
   138
		Inttab.update (c, r::rs) groups
obua@23663
   139
	    end
obua@23663
   140
	  | add_rule _ _ = raise Compile "internal error group_rules"
obua@23663
   141
    in
obua@23663
   142
	fold_rev add_rule rules Inttab.empty
obua@23663
   143
    end
obua@23663
   144
obua@23663
   145
fun haskell_prog name rules = 
obua@23663
   146
    let
wenzelm@32740
   147
	val buffer = Unsynchronized.ref ""
obua@23663
   148
	fun write s = (buffer := (!buffer)^s)
obua@23663
   149
	fun writeln s = (write s; write "\n")
obua@23663
   150
	fun writelist [] = ()
obua@23663
   151
	  | writelist (s::ss) = (writeln s; writelist ss)
obua@23663
   152
	fun str i = Int.toString i
obua@23663
   153
	val (arity, rules) = adjust_rules rules
obua@23663
   154
	val rules = group_rules rules
obua@23663
   155
	val constants = Inttab.keys arity
obua@23663
   156
	fun arity_of c = Inttab.lookup arity c
obua@23663
   157
	fun rep_str s n = concat (rep n s)
obua@23663
   158
	fun indexed s n = s^(str n)
obua@23663
   159
	fun section n = if n = 0 then [] else (section (n-1))@[n-1]
obua@23663
   160
	fun make_show c = 
obua@23663
   161
	    let
obua@23663
   162
		val args = section (the (arity_of c))
obua@23663
   163
	    in
obua@23663
   164
		"  show ("^(indexed "C" c)^(concat (map (indexed " a") args))^") = "
obua@23663
   165
		^"\""^(indexed "C" c)^"\""^(concat (map (fn a => "++(show "^(indexed "a" a)^")") args))
obua@23663
   166
	    end
obua@23663
   167
	fun default_case c = 
obua@23663
   168
	    let
obua@23663
   169
		val args = concat (map (indexed " x") (section (the (arity_of c))))
obua@23663
   170
	    in
obua@23663
   171
		(indexed "c" c)^args^" = "^(indexed "C" c)^args
obua@23663
   172
	    end
obua@23663
   173
	val _ = writelist [        
obua@23663
   174
		"module "^name^" where",
obua@23663
   175
		"",
obua@23663
   176
		"data Term = Const Integer | App Term Term | Abs (Term -> Term)",
obua@23663
   177
		"         "^(concat (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)),
obua@23663
   178
		"",
obua@23663
   179
		"instance Show Term where"]
obua@23663
   180
	val _ = writelist (map make_show constants)
obua@23663
   181
	val _ = writelist [
obua@23663
   182
		"  show (Const c) = \"c\"++(show c)",
obua@23663
   183
		"  show (App a b) = \"A\"++(show a)++(show b)",
obua@23663
   184
		"  show (Abs _) = \"L\"",
obua@23663
   185
		""]
obua@23663
   186
	val _ = writelist [
obua@23663
   187
		"app (Abs a) b = a b",
obua@23663
   188
		"app a b = App a b",
obua@23663
   189
		"",
obua@23663
   190
		"calc s c = writeFile s (show c)",
obua@23663
   191
		""]
obua@23663
   192
	fun list_group c = (writelist (case Inttab.lookup rules c of 
obua@23663
   193
					   NONE => [default_case c, ""] 
obua@23663
   194
					 | SOME (rs as ((PConst (_, []), _)::rs')) => 
obua@23663
   195
					   if not (null rs') then raise Compile "multiple declaration of constant"
obua@23663
   196
					   else (map (print_rule arity_of) rs) @ [""]
obua@23663
   197
					 | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""]))
obua@23663
   198
	val _ = map list_group constants
obua@23663
   199
    in
obua@23663
   200
	(arity, !buffer)
obua@23663
   201
    end
obua@23663
   202
wenzelm@32740
   203
val guid_counter = Unsynchronized.ref 0
obua@23663
   204
fun get_guid () = 
obua@23663
   205
    let
obua@23663
   206
	val c = !guid_counter
obua@23663
   207
	val _ = guid_counter := !guid_counter + 1
obua@23663
   208
    in
obua@23663
   209
	(LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
obua@23663
   210
    end
obua@23663
   211
obua@23663
   212
fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s])));
obua@23663
   213
fun wrap s = "\""^s^"\""
obua@23663
   214
obua@23663
   215
fun writeTextFile name s = File.write (Path.explode name) s
obua@23663
   216
    
wenzelm@32740
   217
val ghc = Unsynchronized.ref (case getenv "GHC_PATH" of "" => "ghc" | s => s)
obua@23663
   218
obua@23663
   219
fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false)
obua@23663
   220
obua@25520
   221
fun compile cache_patterns const_arity eqs = 
obua@23663
   222
    let
obua@23663
   223
	val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
obua@23663
   224
	val eqs = map (fn (a,b,c) => (b,c)) eqs
obua@23663
   225
	val guid = get_guid ()
obua@23663
   226
	val module = "AMGHC_Prog_"^guid
obua@23663
   227
	val (arity, source) = haskell_prog module eqs
obua@23663
   228
	val module_file = tmp_file (module^".hs")
obua@23663
   229
	val object_file = tmp_file (module^".o")
obua@23663
   230
	val _ = writeTextFile module_file source
obua@23663
   231
	val _ = system ((!ghc)^" -c "^module_file)
obua@23663
   232
	val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else ()
obua@23663
   233
    in
obua@23663
   234
	(guid, module_file, arity)	
obua@23663
   235
    end
obua@23663
   236
obua@23663
   237
fun readResultFile name = File.read (Path.explode name) 
obua@23663
   238
			  			  			      			    
obua@23663
   239
fun parse_result arity_of result =
obua@23663
   240
    let
obua@23663
   241
	val result = String.explode result
obua@23663
   242
	fun shift NONE x = SOME x
obua@23663
   243
	  | shift (SOME y) x = SOME (y*10 + x)
obua@23663
   244
	fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest
obua@23663
   245
	  | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest
obua@23663
   246
	  | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest
obua@23663
   247
	  | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest
obua@23663
   248
	  | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest
obua@23663
   249
	  | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest
obua@23663
   250
	  | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest
obua@23663
   251
	  | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest
obua@23663
   252
	  | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest
obua@23663
   253
	  | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest
obua@23663
   254
	  | parse_int' x rest = (x, rest)
obua@23663
   255
	fun parse_int rest = parse_int' NONE rest
obua@23663
   256
obua@23663
   257
	fun parse (#"C"::rest) = 
obua@23663
   258
	    (case parse_int rest of 
obua@23663
   259
		 (SOME c, rest) => 
obua@23663
   260
		 let
obua@23663
   261
		     val (args, rest) = parse_list (the (arity_of c)) rest
obua@23663
   262
		     fun app_args [] t = t
obua@23663
   263
		       | app_args (x::xs) t = app_args xs (App (t, x))
obua@23663
   264
		 in
obua@23663
   265
		     (app_args args (Const c), rest)
obua@23663
   266
		 end		     
obua@23663
   267
	       | (NONE, rest) => raise Run "parse C")
obua@23663
   268
	  | parse (#"c"::rest) = 
obua@23663
   269
	    (case parse_int rest of
obua@23663
   270
		 (SOME c, rest) => (Const c, rest)
obua@23663
   271
	       | _ => raise Run "parse c")
obua@23663
   272
	  | parse (#"A"::rest) = 
obua@23663
   273
	    let
obua@23663
   274
		val (a, rest) = parse rest
obua@23663
   275
		val (b, rest) = parse rest
obua@23663
   276
	    in
obua@23663
   277
		(App (a,b), rest)
obua@23663
   278
	    end
obua@23663
   279
	  | parse (#"L"::rest) = raise Run "there may be no abstraction in the result"
obua@23663
   280
	  | parse _ = raise Run "invalid result"
obua@23663
   281
	and parse_list n rest = 
obua@23663
   282
	    if n = 0 then 
obua@23663
   283
		([], rest) 
obua@23663
   284
	    else 
obua@23663
   285
		let 
obua@23663
   286
		    val (x, rest) = parse rest
obua@23663
   287
		    val (xs, rest) = parse_list (n-1) rest
obua@23663
   288
		in
obua@23663
   289
		    (x::xs, rest)
obua@23663
   290
		end
obua@23663
   291
	val (parsed, rest) = parse result
obua@23663
   292
	fun is_blank (#" "::rest) = is_blank rest
obua@23663
   293
	  | is_blank (#"\n"::rest) = is_blank rest
obua@23663
   294
	  | is_blank [] = true
obua@23663
   295
	  | is_blank _ = false
obua@23663
   296
    in
obua@23663
   297
	if is_blank rest then parsed else raise Run "non-blank suffix in result file"	
obua@23663
   298
    end
obua@23663
   299
obua@23663
   300
fun run (guid, module_file, arity) t = 
obua@23663
   301
    let
obua@23663
   302
	val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
obua@23663
   303
	fun arity_of c = Inttab.lookup arity c			 
obua@23663
   304
	val callguid = get_guid()
obua@23663
   305
	val module = "AMGHC_Prog_"^guid
obua@23663
   306
	val call = module^"_Call_"^callguid
obua@23663
   307
	val result_file = tmp_file (module^"_Result_"^callguid^".txt")
obua@23663
   308
	val call_file = tmp_file (call^".hs")
obua@23663
   309
	val term = print_term arity_of 0 t
obua@23663
   310
	val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
obua@23663
   311
	val _ = writeTextFile call_file call_source
obua@23663
   312
	val _ = system ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file)
obua@23663
   313
	val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')")
obua@23663
   314
	val t' = parse_result arity_of result
obua@23663
   315
	val _ = OS.FileSys.remove call_file
obua@23663
   316
	val _ = OS.FileSys.remove result_file
obua@23663
   317
    in
obua@23663
   318
	t'
obua@23663
   319
    end
obua@23663
   320
obua@23663
   321
	
obua@23663
   322
fun discard _ = ()
obua@23663
   323
		 	  
obua@23663
   324
end
obua@23663
   325