src/Tools/Compute_Oracle/am_ghc.ML
changeset 32960 69916a850301
parent 32952 aeb1e44fbc19
child 35010 d6e492cea6e4
--- a/src/Tools/Compute_Oracle/am_ghc.ML	Sat Oct 17 01:05:59 2009 +0200
+++ b/src/Tools/Compute_Oracle/am_ghc.ML	Sat Oct 17 14:43:18 2009 +0200
@@ -14,7 +14,7 @@
 
 fun update_arity arity code a = 
     (case Inttab.lookup arity code of
-	 NONE => Inttab.update_new (code, a) arity
+         NONE => Inttab.update_new (code, a) arity
        | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
 
 (* We have to find out the maximal arity of each constant *)
@@ -39,65 +39,65 @@
 
 fun adjust_rules rules =
     let
-	val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty
-	fun arity_of c = the (Inttab.lookup arity c)
-	fun adjust_pattern PVar = PVar
-	  | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
-	fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
-	  | adjust_rule (rule as (p as PConst (c, args),t)) = 
-	    let
-		val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
-		val args = map adjust_pattern args	        
-		val len = length args
-		val arity = arity_of c
-		fun lift level n (Var m) = if m < level then Var m else Var (m+n) 
-		  | lift level n (Const c) = Const c
-		  | lift level n (App (a,b)) = App (lift level n a, lift level n b)
-		  | lift level n (Abs b) = Abs (lift (level+1) n b)
-		val lift = lift 0
-		fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) 
-	    in
-		if len = arity then
-		    rule
-		else if arity >= len then  
-		    (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t))
-		else (raise Compile "internal error in adjust_rule")
-	    end
+        val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty
+        fun arity_of c = the (Inttab.lookup arity c)
+        fun adjust_pattern PVar = PVar
+          | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
+        fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable")
+          | adjust_rule (rule as (p as PConst (c, args),t)) = 
+            let
+                val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () 
+                val args = map adjust_pattern args              
+                val len = length args
+                val arity = arity_of c
+                fun lift level n (Var m) = if m < level then Var m else Var (m+n) 
+                  | lift level n (Const c) = Const c
+                  | lift level n (App (a,b)) = App (lift level n a, lift level n b)
+                  | lift level n (Abs b) = Abs (lift (level+1) n b)
+                val lift = lift 0
+                fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) 
+            in
+                if len = arity then
+                    rule
+                else if arity >= len then  
+                    (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t))
+                else (raise Compile "internal error in adjust_rule")
+            end
     in
-	(arity, map adjust_rule rules)
-    end		    
+        (arity, map adjust_rule rules)
+    end             
 
 fun print_term arity_of n =
 let
     fun str x = string_of_int x
     fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
-											  
+                                                                                          
     fun print_apps d f [] = f
       | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
     and print_call d (App (a, b)) args = print_call d a (b::args) 
       | print_call d (Const c) args = 
-	(case arity_of c of 
-	     NONE => print_apps d ("Const "^(str c)) args 
-	   | SOME a =>
-	     let
-		 val len = length args
-	     in
-		 if a <= len then 
-		     let
-			 val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a))))
-		     in
-			 print_apps d s (List.drop (args, a))
-		     end
-		 else 
-		     let
-			 fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1)))
-			 fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
-			 fun append_args [] t = t
-			   | append_args (c::cs) t = append_args cs (App (t, c))
-		     in
-			 print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
-		     end
-	     end)
+        (case arity_of c of 
+             NONE => print_apps d ("Const "^(str c)) args 
+           | SOME a =>
+             let
+                 val len = length args
+             in
+                 if a <= len then 
+                     let
+                         val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a))))
+                     in
+                         print_apps d s (List.drop (args, a))
+                     end
+                 else 
+                     let
+                         fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1)))
+                         fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
+                         fun append_args [] t = t
+                           | append_args (c::cs) t = append_args cs (App (t, c))
+                     in
+                         print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
+                     end
+             end)
       | print_call d t args = print_apps d (print_term d t) args
     and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1))
       | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")"
@@ -105,108 +105,108 @@
 in
     print_term 0 
 end
-			 			
+                                                
 fun print_rule arity_of (p, t) = 
-    let	
-	fun str x = Int.toString x		    
-	fun print_pattern top n PVar = (n+1, "x"^(str n))
-	  | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c))
-	  | print_pattern top n (PConst (c, args)) = 
-	    let
-		val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args
-	    in
-		(n, if top then s else "("^s^")")
-	    end
-	and print_pattern_list r [] = r
-	  | print_pattern_list (n, p) (t::ts) = 
-	    let
-		val (n, t) = print_pattern false n t
-	    in
-		print_pattern_list (n, p^" "^t) ts
-	    end
-	val (n, pattern) = print_pattern true 0 p
+    let 
+        fun str x = Int.toString x                  
+        fun print_pattern top n PVar = (n+1, "x"^(str n))
+          | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c))
+          | print_pattern top n (PConst (c, args)) = 
+            let
+                val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args
+            in
+                (n, if top then s else "("^s^")")
+            end
+        and print_pattern_list r [] = r
+          | print_pattern_list (n, p) (t::ts) = 
+            let
+                val (n, t) = print_pattern false n t
+            in
+                print_pattern_list (n, p^" "^t) ts
+            end
+        val (n, pattern) = print_pattern true 0 p
     in
-	pattern^" = "^(print_term arity_of n t)	
+        pattern^" = "^(print_term arity_of n t) 
     end
 
 fun group_rules rules =
     let
-	fun add_rule (r as (PConst (c,_), _)) groups =
-	    let
-		val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
-	    in
-		Inttab.update (c, r::rs) groups
-	    end
-	  | add_rule _ _ = raise Compile "internal error group_rules"
+        fun add_rule (r as (PConst (c,_), _)) groups =
+            let
+                val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
+            in
+                Inttab.update (c, r::rs) groups
+            end
+          | add_rule _ _ = raise Compile "internal error group_rules"
     in
-	fold_rev add_rule rules Inttab.empty
+        fold_rev add_rule rules Inttab.empty
     end
 
 fun haskell_prog name rules = 
     let
-	val buffer = Unsynchronized.ref ""
-	fun write s = (buffer := (!buffer)^s)
-	fun writeln s = (write s; write "\n")
-	fun writelist [] = ()
-	  | writelist (s::ss) = (writeln s; writelist ss)
-	fun str i = Int.toString i
-	val (arity, rules) = adjust_rules rules
-	val rules = group_rules rules
-	val constants = Inttab.keys arity
-	fun arity_of c = Inttab.lookup arity c
-	fun rep_str s n = implode (rep n s)
-	fun indexed s n = s^(str n)
-	fun section n = if n = 0 then [] else (section (n-1))@[n-1]
-	fun make_show c = 
-	    let
-		val args = section (the (arity_of c))
-	    in
-		"  show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = "
-		^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args))
-	    end
-	fun default_case c = 
-	    let
-		val args = implode (map (indexed " x") (section (the (arity_of c))))
-	    in
-		(indexed "c" c)^args^" = "^(indexed "C" c)^args
-	    end
-	val _ = writelist [        
-		"module "^name^" where",
-		"",
-		"data Term = Const Integer | App Term Term | Abs (Term -> Term)",
-		"         "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)),
-		"",
-		"instance Show Term where"]
-	val _ = writelist (map make_show constants)
-	val _ = writelist [
-		"  show (Const c) = \"c\"++(show c)",
-		"  show (App a b) = \"A\"++(show a)++(show b)",
-		"  show (Abs _) = \"L\"",
-		""]
-	val _ = writelist [
-		"app (Abs a) b = a b",
-		"app a b = App a b",
-		"",
-		"calc s c = writeFile s (show c)",
-		""]
-	fun list_group c = (writelist (case Inttab.lookup rules c of 
-					   NONE => [default_case c, ""] 
-					 | SOME (rs as ((PConst (_, []), _)::rs')) => 
-					   if not (null rs') then raise Compile "multiple declaration of constant"
-					   else (map (print_rule arity_of) rs) @ [""]
-					 | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""]))
-	val _ = map list_group constants
+        val buffer = Unsynchronized.ref ""
+        fun write s = (buffer := (!buffer)^s)
+        fun writeln s = (write s; write "\n")
+        fun writelist [] = ()
+          | writelist (s::ss) = (writeln s; writelist ss)
+        fun str i = Int.toString i
+        val (arity, rules) = adjust_rules rules
+        val rules = group_rules rules
+        val constants = Inttab.keys arity
+        fun arity_of c = Inttab.lookup arity c
+        fun rep_str s n = implode (rep n s)
+        fun indexed s n = s^(str n)
+        fun section n = if n = 0 then [] else (section (n-1))@[n-1]
+        fun make_show c = 
+            let
+                val args = section (the (arity_of c))
+            in
+                "  show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = "
+                ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args))
+            end
+        fun default_case c = 
+            let
+                val args = implode (map (indexed " x") (section (the (arity_of c))))
+            in
+                (indexed "c" c)^args^" = "^(indexed "C" c)^args
+            end
+        val _ = writelist [        
+                "module "^name^" where",
+                "",
+                "data Term = Const Integer | App Term Term | Abs (Term -> Term)",
+                "         "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)),
+                "",
+                "instance Show Term where"]
+        val _ = writelist (map make_show constants)
+        val _ = writelist [
+                "  show (Const c) = \"c\"++(show c)",
+                "  show (App a b) = \"A\"++(show a)++(show b)",
+                "  show (Abs _) = \"L\"",
+                ""]
+        val _ = writelist [
+                "app (Abs a) b = a b",
+                "app a b = App a b",
+                "",
+                "calc s c = writeFile s (show c)",
+                ""]
+        fun list_group c = (writelist (case Inttab.lookup rules c of 
+                                           NONE => [default_case c, ""] 
+                                         | SOME (rs as ((PConst (_, []), _)::rs')) => 
+                                           if not (null rs') then raise Compile "multiple declaration of constant"
+                                           else (map (print_rule arity_of) rs) @ [""]
+                                         | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""]))
+        val _ = map list_group constants
     in
-	(arity, !buffer)
+        (arity, !buffer)
     end
 
 val guid_counter = Unsynchronized.ref 0
 fun get_guid () = 
     let
-	val c = !guid_counter
-	val _ = guid_counter := !guid_counter + 1
+        val c = !guid_counter
+        val _ = guid_counter := !guid_counter + 1
     in
-	(LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
+        (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c)
     end
 
 fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s])));
@@ -220,106 +220,106 @@
 
 fun compile cache_patterns const_arity eqs = 
     let
-	val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
-	val eqs = map (fn (a,b,c) => (b,c)) eqs
-	val guid = get_guid ()
-	val module = "AMGHC_Prog_"^guid
-	val (arity, source) = haskell_prog module eqs
-	val module_file = tmp_file (module^".hs")
-	val object_file = tmp_file (module^".o")
-	val _ = writeTextFile module_file source
-	val _ = system ((!ghc)^" -c "^module_file)
-	val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else ()
+        val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
+        val eqs = map (fn (a,b,c) => (b,c)) eqs
+        val guid = get_guid ()
+        val module = "AMGHC_Prog_"^guid
+        val (arity, source) = haskell_prog module eqs
+        val module_file = tmp_file (module^".hs")
+        val object_file = tmp_file (module^".o")
+        val _ = writeTextFile module_file source
+        val _ = system ((!ghc)^" -c "^module_file)
+        val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else ()
     in
-	(guid, module_file, arity)	
+        (guid, module_file, arity)      
     end
 
 fun readResultFile name = File.read (Path.explode name) 
-			  			  			      			    
+                                                                                                    
 fun parse_result arity_of result =
     let
-	val result = String.explode result
-	fun shift NONE x = SOME x
-	  | shift (SOME y) x = SOME (y*10 + x)
-	fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest
-	  | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest
-	  | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest
-	  | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest
-	  | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest
-	  | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest
-	  | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest
-	  | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest
-	  | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest
-	  | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest
-	  | parse_int' x rest = (x, rest)
-	fun parse_int rest = parse_int' NONE rest
+        val result = String.explode result
+        fun shift NONE x = SOME x
+          | shift (SOME y) x = SOME (y*10 + x)
+        fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest
+          | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest
+          | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest
+          | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest
+          | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest
+          | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest
+          | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest
+          | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest
+          | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest
+          | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest
+          | parse_int' x rest = (x, rest)
+        fun parse_int rest = parse_int' NONE rest
 
-	fun parse (#"C"::rest) = 
-	    (case parse_int rest of 
-		 (SOME c, rest) => 
-		 let
-		     val (args, rest) = parse_list (the (arity_of c)) rest
-		     fun app_args [] t = t
-		       | app_args (x::xs) t = app_args xs (App (t, x))
-		 in
-		     (app_args args (Const c), rest)
-		 end		     
-	       | (NONE, rest) => raise Run "parse C")
-	  | parse (#"c"::rest) = 
-	    (case parse_int rest of
-		 (SOME c, rest) => (Const c, rest)
-	       | _ => raise Run "parse c")
-	  | parse (#"A"::rest) = 
-	    let
-		val (a, rest) = parse rest
-		val (b, rest) = parse rest
-	    in
-		(App (a,b), rest)
-	    end
-	  | parse (#"L"::rest) = raise Run "there may be no abstraction in the result"
-	  | parse _ = raise Run "invalid result"
-	and parse_list n rest = 
-	    if n = 0 then 
-		([], rest) 
-	    else 
-		let 
-		    val (x, rest) = parse rest
-		    val (xs, rest) = parse_list (n-1) rest
-		in
-		    (x::xs, rest)
-		end
-	val (parsed, rest) = parse result
-	fun is_blank (#" "::rest) = is_blank rest
-	  | is_blank (#"\n"::rest) = is_blank rest
-	  | is_blank [] = true
-	  | is_blank _ = false
+        fun parse (#"C"::rest) = 
+            (case parse_int rest of 
+                 (SOME c, rest) => 
+                 let
+                     val (args, rest) = parse_list (the (arity_of c)) rest
+                     fun app_args [] t = t
+                       | app_args (x::xs) t = app_args xs (App (t, x))
+                 in
+                     (app_args args (Const c), rest)
+                 end                 
+               | (NONE, rest) => raise Run "parse C")
+          | parse (#"c"::rest) = 
+            (case parse_int rest of
+                 (SOME c, rest) => (Const c, rest)
+               | _ => raise Run "parse c")
+          | parse (#"A"::rest) = 
+            let
+                val (a, rest) = parse rest
+                val (b, rest) = parse rest
+            in
+                (App (a,b), rest)
+            end
+          | parse (#"L"::rest) = raise Run "there may be no abstraction in the result"
+          | parse _ = raise Run "invalid result"
+        and parse_list n rest = 
+            if n = 0 then 
+                ([], rest) 
+            else 
+                let 
+                    val (x, rest) = parse rest
+                    val (xs, rest) = parse_list (n-1) rest
+                in
+                    (x::xs, rest)
+                end
+        val (parsed, rest) = parse result
+        fun is_blank (#" "::rest) = is_blank rest
+          | is_blank (#"\n"::rest) = is_blank rest
+          | is_blank [] = true
+          | is_blank _ = false
     in
-	if is_blank rest then parsed else raise Run "non-blank suffix in result file"	
+        if is_blank rest then parsed else raise Run "non-blank suffix in result file"   
     end
 
 fun run (guid, module_file, arity) t = 
     let
-	val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
-	fun arity_of c = Inttab.lookup arity c			 
-	val callguid = get_guid()
-	val module = "AMGHC_Prog_"^guid
-	val call = module^"_Call_"^callguid
-	val result_file = tmp_file (module^"_Result_"^callguid^".txt")
-	val call_file = tmp_file (call^".hs")
-	val term = print_term arity_of 0 t
-	val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
-	val _ = writeTextFile call_file call_source
-	val _ = system ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file)
-	val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')")
-	val t' = parse_result arity_of result
-	val _ = OS.FileSys.remove call_file
-	val _ = OS.FileSys.remove result_file
+        val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
+        fun arity_of c = Inttab.lookup arity c                   
+        val callguid = get_guid()
+        val module = "AMGHC_Prog_"^guid
+        val call = module^"_Call_"^callguid
+        val result_file = tmp_file (module^"_Result_"^callguid^".txt")
+        val call_file = tmp_file (call^".hs")
+        val term = print_term arity_of 0 t
+        val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")"
+        val _ = writeTextFile call_file call_source
+        val _ = system ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file)
+        val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')")
+        val t' = parse_result arity_of result
+        val _ = OS.FileSys.remove call_file
+        val _ = OS.FileSys.remove result_file
     in
-	t'
+        t'
     end
 
-	
+        
 fun discard _ = ()
-		 	  
+                          
 end