# HG changeset patch # User obua # Date 1183995385 -7200 # Node ID 84b5c89b8b49f1ec31079f5d605f82beee330240 # Parent 91d06b04951f0d01a288aa9da066f456ab4f3404 new version of computing oracle diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/Compute_Oracle.thy --- a/src/Tools/Compute_Oracle/Compute_Oracle.thy Mon Jul 09 11:44:23 2007 +0200 +++ b/src/Tools/Compute_Oracle/Compute_Oracle.thy Mon Jul 09 17:36:25 2007 +0200 @@ -5,28 +5,10 @@ Steven Obua's evaluator. *) -theory Compute_Oracle -imports CPure -uses - "am_interpreter.ML" - "am_compiler.ML" - "am_util.ML" - "compute.ML" +theory Compute_Oracle imports CPure +uses "am.ML" "am_compiler.ML" "am_interpreter.ML" "am_ghc.ML" "am_sml.ML" "compute.ML" "linker.ML" begin -oracle compute_oracle ("Compute.computer * (int -> string) * cterm") = - {* Compute.oracle_fn *} - -ML {* -structure Compute = -struct - open Compute +setup {* Compute.setup; *} - fun rewrite_param r n ct = - compute_oracle (Thm.theory_of_cterm ct) (r, n, ct) - - fun rewrite r ct = rewrite_param r default_naming ct -end -*} - -end +end \ No newline at end of file diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Compute_Oracle/am.ML Mon Jul 09 17:36:25 2007 +0200 @@ -0,0 +1,46 @@ +signature ABSTRACT_MACHINE = +sig + +datatype term = Var of int | Const of int | App of term * term | Abs of term + +datatype pattern = PVar | PConst of int * (pattern list) + +datatype guard = Guard of term * term + +type program + +exception Compile of string; + +(* The de-Bruijn index 0 occurring on the right hand side refers to the LAST pattern variable, when traversing the pattern from left to right, + 1 to the second last, and so on. *) +val compile : (guard list * pattern * term) list -> program + +val discard : program -> unit + +exception Run of string; +val run : program -> term -> term + +end + +structure AbstractMachine : ABSTRACT_MACHINE = +struct + +datatype term = Var of int | Const of int | App of term * term | Abs of term + +datatype pattern = PVar | PConst of int * (pattern list) + +datatype guard = Guard of term * term + +type program = unit + +exception Compile of string; + +fun compile _ = raise Compile "abstract machine stub" + +fun discard _ = raise Compile "abstract machine stub" + +exception Run of string; + +fun run p t = raise Run "abstract machine stub" + +end diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am_compiler.ML --- a/src/Tools/Compute_Oracle/am_compiler.ML Mon Jul 09 11:44:23 2007 +0200 +++ b/src/Tools/Compute_Oracle/am_compiler.ML Mon Jul 09 17:36:25 2007 +0200 @@ -1,4 +1,4 @@ -(* Title: Tools/Compute_Oracle/am_compiler.ML +(* Title: Pure/Tools/am_compiler.ML ID: $Id$ Author: Steven Obua *) @@ -7,10 +7,7 @@ sig include ABSTRACT_MACHINE - datatype closure = CVar of int | CConst of int - | CApp of closure * closure | CAbs of closure | Closure of (closure list) * closure - - val set_compiled_rewriter : (term -> closure) -> unit + val set_compiled_rewriter : (term -> term) -> unit val list_nth : 'a list * int -> 'a val list_map : ('a -> 'b) -> 'a list -> 'b list end @@ -20,39 +17,14 @@ val list_nth = List.nth; val list_map = map; -datatype term = Var of int | Const of int | App of term * term | Abs of term - -datatype pattern = PVar | PConst of int * (pattern list) +open AbstractMachine; -datatype closure = CVar of int | CConst of int - | CApp of closure * closure | CAbs of closure - | Closure of (closure list) * closure - -val compiled_rewriter = ref (NONE:(term -> closure)Option.option) +val compiled_rewriter = ref (NONE:(term -> term)Option.option) fun set_compiled_rewriter r = (compiled_rewriter := SOME r) type program = (term -> term) -datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack - -exception Compile of string; -exception Run of string; - -fun clos_of_term (Var x) = CVar x - | clos_of_term (Const c) = CConst c - | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v) - | clos_of_term (Abs u) = CAbs (clos_of_term u) - -fun term_of_clos (CVar x) = Var x - | term_of_clos (CConst c) = Const c - | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v) - | term_of_clos (CAbs u) = Abs (term_of_clos u) - | term_of_clos (Closure (e, u)) = - raise (Run "internal error: closure in normalized term found") - -fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a - | strip_closure args x = (x, args) (*Returns true iff at most 0 .. (free-1) occur unbound. therefore check_freevars 0 t iff t is closed*) @@ -103,7 +75,7 @@ else "Closure ([], "^term^")" in - "lookup stack "^pattern^" = weak stack ("^term^")" + " | weak_reduce (false, stack, "^pattern^") = Continue (false, stack, "^term^")" end fun constants_of PVar = [] @@ -116,7 +88,6 @@ fun load_rules sname name prog = let - (* FIXME consider using more readable/efficient Buffer.empty |> fold Buffer.add etc. *) val buffer = ref "" fun write s = (buffer := (!buffer)^s) fun writeln s = (write s; write "\n") @@ -126,50 +97,62 @@ val _ = writelist [ "structure "^name^" = struct", "", - "datatype term = App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"] + "datatype term = Dummy | App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"] val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog) val _ = map (fn x => write (" | c"^(str x))) constants val _ = writelist [ "", "datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack", - ""] - val _ = (case prog of - r::rs => (writeln ("fun "^(print_rule r)); - map (fn r => writeln(" | "^(print_rule r))) rs; - writeln (" | lookup stack clos = weak_last stack clos"); ()) - | [] => (writeln "fun lookup stack clos = weak_last stack clos")) - val _ = writelist [ - "and weak stack (Closure (e, App (a, b))) = weak (SAppL (Closure (e, b), stack)) (Closure (e, a))", - " | weak (SAppL (b, stack)) (Closure (e, Abs m)) = weak stack (Closure (b::e, m))", - " | weak stack (clos as Closure (_, Abs _)) = weak_last stack clos", - " | weak stack (Closure (e, Var n)) = weak stack ("^sname^".list_nth (e, n) handle _ => (Var (n-(length e))))", - " | weak stack (Closure (e, c)) = weak stack c", - " | weak stack clos = lookup stack clos", - "and weak_last (SAppR (a, stack)) b = weak stack (App(a, b))", - " | weak_last (SAppL (b, stack)) a = weak (SAppR (a, stack)) b", - " | weak_last stack c = (stack, c)", + "", + "type state = bool * stack * term", + "", + "datatype loopstate = Continue of state | Stop of stack * term", + "", + "fun proj_C (Continue s) = s", + " | proj_C _ = raise Match", + "", + "fun proj_S (Stop s) = s", + " | proj_S _ = raise Match", + "", + "fun cont (Continue _) = true", + " | cont _ = false", "", - "fun lift n (v as Var m) = if m < n then v else Var (m+1)", - " | lift n (Abs t) = Abs (lift (n+1) t)", - " | lift n (App (a,b)) = App (lift n a, lift n b)", - " | lift n (Closure (e, a)) = Closure (lift_env n e, lift (n+(length e)) a)", - " | lift n c = c", - "and lift_env n e = map (lift n) e", + "fun do_reduction reduce p =", + " let", + " val s = ref (Continue p)", + " val _ = while cont (!s) do (s := reduce (proj_C (!s)))", + " in", + " proj_S (!s)", + " end", + ""] + + val _ = writelist [ + "fun weak_reduce (false, stack, Closure (e, App (a, b))) = Continue (false, SAppL (Closure (e, b), stack), Closure (e, a))", + " | weak_reduce (false, SAppL (b, stack), Closure (e, Abs m)) = Continue (false, stack, Closure (b::e, m))", + " | weak_reduce (false, stack, c as Closure (e, Abs m)) = Continue (true, stack, c)", + " | weak_reduce (false, stack, Closure (e, Var n)) = Continue (false, stack, case "^sname^".list_nth (e, n) of Dummy => Var n | r => r)", + " | weak_reduce (false, stack, Closure (e, c)) = Continue (false, stack, c)"] + val _ = writelist (map print_rule prog) + val _ = writelist [ + " | weak_reduce (false, stack, clos) = Continue (true, stack, clos)", + " | weak_reduce (true, SAppR (a, stack), b) = Continue (false, stack, App (a,b))", + " | weak_reduce (true, s as (SAppL (b, stack)), a) = Continue (false, SAppR (a, stack), b)", + " | weak_reduce (true, stack, c) = Stop (stack, c)", "", - "fun strong stack (Closure (e, Abs m)) = ", + "fun strong_reduce (false, stack, Closure (e, Abs m)) =", " let", - " val (stack', wnf) = weak SEmpty (Closure ((Var 0)::(lift_env 0 e), m))", + " val (stack', wnf) = do_reduction weak_reduce (false, SEmpty, Closure (Dummy::e, m))", " in", - " case stack' of", - " SEmpty => strong (SAbs stack) wnf", - " | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")", - " end", - " | strong stack (clos as (App (u, v))) = strong (SAppL (v, stack)) u", - " | strong stack clos = strong_last stack clos", - "and strong_last (SAbs stack) m = strong stack (Abs m)", - " | strong_last (SAppL (b, stack)) a = strong (SAppR (a, stack)) b", - " | strong_last (SAppR (a, stack)) b = strong_last stack (App (a, b))", - " | strong_last stack clos = (stack, clos)", + " case stack' of", + " SEmpty => Continue (false, SAbs stack, wnf)", + " | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")", + " end", + " | strong_reduce (false, stack, clos as (App (u, v))) = Continue (false, SAppL (v, stack), u)", + " | strong_reduce (false, stack, clos) = Continue (true, stack, clos)", + " | strong_reduce (true, SAbs stack, m) = Continue (false, stack, Abs m)", + " | strong_reduce (true, SAppL (b, stack), a) = Continue (false, SAppR (a, stack), b)", + " | strong_reduce (true, SAppR (a, stack), b) = Continue (true, stack, App (a, b))", + " | strong_reduce (true, stack, clos) = Stop (stack, clos)", ""] val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)" @@ -180,23 +163,24 @@ " | importTerm ("^sname^".Abs m) = Abs (importTerm m)", ""] - fun ec c = " | exportTerm c"^(str c)^" = "^sname^".CConst "^(str c) + fun ec c = " | exportTerm c"^(str c)^" = "^sname^".Const "^(str c) val _ = writelist [ - "fun exportTerm (Var x) = "^sname^".CVar x", - " | exportTerm (Const c) = "^sname^".CConst c", - " | exportTerm (App (a,b)) = "^sname^".CApp (exportTerm a, exportTerm b)", - " | exportTerm (Abs m) = "^sname^".CAbs (exportTerm m)", - " | exportTerm (Closure (closlist, clos)) = "^sname^".Closure ("^sname^".list_map exportTerm closlist, exportTerm clos)"] + "fun exportTerm (Var x) = "^sname^".Var x", + " | exportTerm (Const c) = "^sname^".Const c", + " | exportTerm (App (a,b)) = "^sname^".App (exportTerm a, exportTerm b)", + " | exportTerm (Abs m) = "^sname^".Abs (exportTerm m)", + " | exportTerm (Closure (closlist, clos)) = raise ("^sname^".Run \"internal error, cannot export Closure\")", + " | exportTerm Dummy = raise ("^sname^".Run \"internal error, cannot export Dummy\")"] val _ = writelist (map ec constants) val _ = writelist [ "", "fun rewrite t = ", " let", - " val (stack, wnf) = weak SEmpty (Closure ([], importTerm t))", + " val (stack, wnf) = do_reduction weak_reduce (false, SEmpty, Closure ([], importTerm t))", " in", " case stack of ", - " SEmpty => (case strong SEmpty wnf of", + " SEmpty => (case do_reduction strong_reduce (false, SEmpty, wnf) of", " (SEmpty, snf) => exportTerm snf", " | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))", " | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))", @@ -206,33 +190,29 @@ "", "end;"] - val _ = - let - (*val fout = TextIO.openOut "gen_code.ML" - val _ = TextIO.output (fout, !buffer) - val _ = TextIO.closeOut fout*) - in - () - end in compiled_rewriter := NONE; use_text "" Output.ml_output false (!buffer); case !compiled_rewriter of NONE => raise (Compile "cannot communicate with compiled function") - | SOME r => (compiled_rewriter := NONE; fn t => term_of_clos (r t)) + | SOME r => (compiled_rewriter := NONE; r) end fun compile eqs = let + val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () + val eqs = map (fn (a,b,c) => (b,c)) eqs + fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") val _ = map (fn (p, r) => - (check_freevars (count_patternvars p) r; - case p of PVar => raise (Compile "pattern reduces to a variable") | _ => ())) eqs + (check (p, r); + case p of PVar => raise (Compile "pattern is just a variable") | _ => ())) eqs in load_rules "AM_Compiler" "AM_compiled_code" eqs end fun run prog t = (prog t) + +fun discard p = () end -structure AbstractMachine = AM_Compiler diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am_ghc.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Compute_Oracle/am_ghc.ML Mon Jul 09 17:36:25 2007 +0200 @@ -0,0 +1,334 @@ +(* Title: Pure/Tools/am_ghc.ML + ID: $Id$ + Author: Steven Obua +*) + +structure AM_GHC : ABSTRACT_MACHINE = struct + +open AbstractMachine; + +type program = string * string * (int Inttab.table) + + +(*Returns true iff at most 0 .. (free-1) occur unbound. therefore + check_freevars 0 t iff t is closed*) +fun check_freevars free (Var x) = x < free + | check_freevars free (Const c) = true + | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v + | check_freevars free (Abs m) = check_freevars (free+1) m + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun update_arity arity code a = + (case Inttab.lookup arity code of + NONE => Inttab.update_new (code, a) arity + | SOME a' => if a > a' then Inttab.update (code, a) arity else arity) + +(* We have to find out the maximal arity of each constant *) +fun collect_pattern_arity PVar arity = arity + | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) + +local +fun collect applevel (Var _) arity = arity + | collect applevel (Const c) arity = update_arity arity c applevel + | collect applevel (Abs m) arity = collect 0 m arity + | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) +in +fun collect_term_arity t arity = collect 0 t arity +end + +fun nlift level n (Var m) = if m < level then Var m else Var (m+n) + | nlift level n (Const c) = Const c + | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) + | nlift level n (Abs b) = Abs (nlift (level+1) n b) + +fun rep n x = if n = 0 then [] else x::(rep (n-1) x) + +fun adjust_rules rules = + let + val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty + fun arity_of c = the (Inttab.lookup arity c) + fun adjust_pattern PVar = PVar + | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C + fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable") + | adjust_rule (rule as (p as PConst (c, args),t)) = + let + val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () + val args = map adjust_pattern args + val len = length args + val arity = arity_of c + fun lift level n (Var m) = if m < level then Var m else Var (m+n) + | lift level n (Const c) = Const c + | lift level n (App (a,b)) = App (lift level n a, lift level n b) + | lift level n (Abs b) = Abs (lift (level+1) n b) + val lift = lift 0 + fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) + in + if len = arity then + rule + else if arity >= len then + (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t)) + else (raise Compile "internal error in adjust_rule") + end + in + (arity, map adjust_rule rules) + end + +fun print_term arity_of n = +let + fun str x = string_of_int x + fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s + + fun print_apps d f [] = f + | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args + and print_call d (App (a, b)) args = print_call d a (b::args) + | print_call d (Const c) args = + (case arity_of c of + NONE => print_apps d ("Const "^(str c)) args + | SOME a => + let + val len = length args + in + if a <= len then + let + val s = "c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a)))) + in + print_apps d s (List.drop (args, a)) + end + else + let + fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1))) + fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) + fun append_args [] t = t + | append_args (c::cs) t = append_args cs (App (t, c)) + in + print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) + end + end) + | print_call d t args = print_apps d (print_term d t) args + and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1)) + | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")" + | print_term d t = print_call d t [] +in + print_term 0 +end + +fun print_rule arity_of (p, t) = + let + fun str x = Int.toString x + fun print_pattern top n PVar = (n+1, "x"^(str n)) + | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)) + | print_pattern top n (PConst (c, args)) = + let + val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args + in + (n, if top then s else "("^s^")") + end + and print_pattern_list r [] = r + | print_pattern_list (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list (n, p^" "^t) ts + end + val (n, pattern) = print_pattern true 0 p + in + pattern^" = "^(print_term arity_of n t) + end + +fun group_rules rules = + let + fun add_rule (r as (PConst (c,_), _)) groups = + let + val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) + in + Inttab.update (c, r::rs) groups + end + | add_rule _ _ = raise Compile "internal error group_rules" + in + fold_rev add_rule rules Inttab.empty + end + +fun haskell_prog name rules = + let + val buffer = ref "" + fun write s = (buffer := (!buffer)^s) + fun writeln s = (write s; write "\n") + fun writelist [] = () + | writelist (s::ss) = (writeln s; writelist ss) + fun str i = Int.toString i + val (arity, rules) = adjust_rules rules + val rules = group_rules rules + val constants = Inttab.keys arity + fun arity_of c = Inttab.lookup arity c + fun rep_str s n = concat (rep n s) + fun indexed s n = s^(str n) + fun section n = if n = 0 then [] else (section (n-1))@[n-1] + fun make_show c = + let + val args = section (the (arity_of c)) + in + " show ("^(indexed "C" c)^(concat (map (indexed " a") args))^") = " + ^"\""^(indexed "C" c)^"\""^(concat (map (fn a => "++(show "^(indexed "a" a)^")") args)) + end + fun default_case c = + let + val args = concat (map (indexed " x") (section (the (arity_of c)))) + in + (indexed "c" c)^args^" = "^(indexed "C" c)^args + end + val _ = writelist [ + "module "^name^" where", + "", + "data Term = Const Integer | App Term Term | Abs (Term -> Term)", + " "^(concat (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)), + "", + "instance Show Term where"] + val _ = writelist (map make_show constants) + val _ = writelist [ + " show (Const c) = \"c\"++(show c)", + " show (App a b) = \"A\"++(show a)++(show b)", + " show (Abs _) = \"L\"", + ""] + val _ = writelist [ + "app (Abs a) b = a b", + "app a b = App a b", + "", + "calc s c = writeFile s (show c)", + ""] + fun list_group c = (writelist (case Inttab.lookup rules c of + NONE => [default_case c, ""] + | SOME (rs as ((PConst (_, []), _)::rs')) => + if not (null rs') then raise Compile "multiple declaration of constant" + else (map (print_rule arity_of) rs) @ [""] + | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""])) + val _ = map list_group constants + in + (arity, !buffer) + end + +val guid_counter = ref 0 +fun get_guid () = + let + val c = !guid_counter + val _ = guid_counter := !guid_counter + 1 + in + (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) + end + +fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s]))); +fun wrap s = "\""^s^"\"" + +fun writeTextFile name s = File.write (Path.explode name) s + +val ghc = ref (case getenv "GHC_PATH" of "" => "ghc" | s => s) + +fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) + +fun compile eqs = + let + val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () + val eqs = map (fn (a,b,c) => (b,c)) eqs + val guid = get_guid () + val module = "AMGHC_Prog_"^guid + val (arity, source) = haskell_prog module eqs + val module_file = tmp_file (module^".hs") + val object_file = tmp_file (module^".o") + val _ = writeTextFile module_file source + val _ = system ((!ghc)^" -c "^module_file) + val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else () + in + (guid, module_file, arity) + end + +fun readResultFile name = File.read (Path.explode name) + +fun parse_result arity_of result = + let + val result = String.explode result + fun shift NONE x = SOME x + | shift (SOME y) x = SOME (y*10 + x) + fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest + | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest + | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest + | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest + | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest + | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest + | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest + | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest + | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest + | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest + | parse_int' x rest = (x, rest) + fun parse_int rest = parse_int' NONE rest + + fun parse (#"C"::rest) = + (case parse_int rest of + (SOME c, rest) => + let + val (args, rest) = parse_list (the (arity_of c)) rest + fun app_args [] t = t + | app_args (x::xs) t = app_args xs (App (t, x)) + in + (app_args args (Const c), rest) + end + | (NONE, rest) => raise Run "parse C") + | parse (#"c"::rest) = + (case parse_int rest of + (SOME c, rest) => (Const c, rest) + | _ => raise Run "parse c") + | parse (#"A"::rest) = + let + val (a, rest) = parse rest + val (b, rest) = parse rest + in + (App (a,b), rest) + end + | parse (#"L"::rest) = raise Run "there may be no abstraction in the result" + | parse _ = raise Run "invalid result" + and parse_list n rest = + if n = 0 then + ([], rest) + else + let + val (x, rest) = parse rest + val (xs, rest) = parse_list (n-1) rest + in + (x::xs, rest) + end + val (parsed, rest) = parse result + fun is_blank (#" "::rest) = is_blank rest + | is_blank (#"\n"::rest) = is_blank rest + | is_blank [] = true + | is_blank _ = false + in + if is_blank rest then parsed else raise Run "non-blank suffix in result file" + end + +fun run (guid, module_file, arity) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun arity_of c = Inttab.lookup arity c + val callguid = get_guid() + val module = "AMGHC_Prog_"^guid + val call = module^"_Call_"^callguid + val result_file = tmp_file (module^"_Result_"^callguid^".txt") + val call_file = tmp_file (call^".hs") + val term = print_term arity_of 0 t + val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")" + val _ = writeTextFile call_file call_source + val _ = system ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file) + val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')") + val t' = parse_result arity_of result + val _ = OS.FileSys.remove call_file + val _ = OS.FileSys.remove result_file + in + t' + end + + +fun discard _ = () + +end + diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am_interpreter.ML --- a/src/Tools/Compute_Oracle/am_interpreter.ML Mon Jul 09 11:44:23 2007 +0200 +++ b/src/Tools/Compute_Oracle/am_interpreter.ML Mon Jul 09 17:36:25 2007 +0200 @@ -1,32 +1,13 @@ -(* Title: Tools/Compute_Oracle/am_interpreter.ML +(* Title: Pure/Tools/am_interpreter.ML ID: $Id$ Author: Steven Obua *) -signature ABSTRACT_MACHINE = -sig - -datatype term = Var of int | Const of int | App of term * term | Abs of term - -datatype pattern = PVar | PConst of int * (pattern list) - -type program - -exception Compile of string; -val compile : (pattern * term) list -> program - -exception Run of string; -val run : program -> term -> term - -end - structure AM_Interpreter : ABSTRACT_MACHINE = struct -datatype term = Var of int | Const of int | App of term * term | Abs of term +open AbstractMachine; -datatype pattern = PVar | PConst of int * (pattern list) - -datatype closure = CVar of int | CConst of int +datatype closure = CDummy | CVar of int | CConst of int | CApp of closure * closure | CAbs of closure | Closure of (closure list) * closure @@ -36,9 +17,6 @@ datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack -exception Compile of string; -exception Run of string; - fun clos_of_term (Var x) = CVar x | clos_of_term (Const c) = CConst c | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v) @@ -49,6 +27,7 @@ | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v) | term_of_clos (CAbs u) = Abs (term_of_clos u) | term_of_clos (Closure (e, u)) = raise (Run "internal error: closure in normalized term found") + | term_of_clos CDummy = raise (Run "internal error: dummy in normalized term found") fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a | strip_closure args x = (x, args) @@ -78,24 +57,29 @@ | SOME args => pattern_match_list args ps cs) | pattern_match_list _ _ _ = NONE -(* Returns true iff at most 0 .. (free-1) occur unbound. therefore check_freevars 0 t iff t is closed *) -fun check_freevars free (Var x) = x < free - | check_freevars free (Const c) = true - | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v - | check_freevars free (Abs m) = check_freevars (free+1) m - fun count_patternvars PVar = 1 | count_patternvars (PConst (_, ps)) = List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps fun pattern_key (PConst (c, ps)) = (c, length ps) | pattern_key _ = raise (Compile "pattern reduces to variable") +(*Returns true iff at most 0 .. (free-1) occur unbound. therefore + check_freevars 0 t iff t is closed*) +fun check_freevars free (Var x) = x < free + | check_freevars free (Const c) = true + | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v + | check_freevars free (Abs m) = check_freevars (free+1) m + fun compile eqs = let - val eqs = map (fn (p, r) => (check_freevars (count_patternvars p) r; - (pattern_key p, (p, clos_of_term r)))) eqs + val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () + val eqs = map (fn (a,b,c) => (b,c)) eqs + fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") + val eqs = map (fn (p, r) => (check (p,r); (pattern_key p, (p, clos_of_term r)))) eqs + fun merge (k, a) table = prog_struct.update (k, case prog_struct.lookup table k of NONE => [a] | SOME l => a::l) table + val p = fold merge eqs prog_struct.empty in - Program (prog_struct.make (map (fn (k, a) => (k, [a])) eqs)) + Program p end fun match_rules n [] clos = NONE @@ -112,51 +96,66 @@ | SOME rules => match_rules 0 rules clos) | _ => NONE -fun lift n (c as (CConst _)) = c - | lift n (v as CVar m) = if m < n then v else CVar (m+1) - | lift n (CAbs t) = CAbs (lift (n+1) t) - | lift n (CApp (a,b)) = CApp (lift n a, lift n b) - | lift n (Closure (e, a)) = Closure (lift_env n e, lift (n+(length e)) a) -and lift_env n e = map (lift n) e + +type state = bool * program * stack * closure + +datatype loopstate = Continue of state | Stop of stack * closure + +fun proj_C (Continue s) = s + | proj_C _ = raise Match + +fun proj_S (Stop s) = s + | proj_S _ = raise Match + +fun cont (Continue _) = true + | cont _ = false -fun weak prog stack (Closure (e, CApp (a, b))) = weak prog (SAppL (Closure (e, b), stack)) (Closure (e, a)) - | weak prog (SAppL (b, stack)) (Closure (e, CAbs m)) = weak prog stack (Closure (b::e, m)) - | weak prog stack (Closure (e, CVar n)) = weak prog stack (List.nth (e, n) handle Subscript => (CVar (n-(length e)))) - | weak prog stack (Closure (e, c as CConst _)) = weak prog stack c - | weak prog stack clos = - case match_closure prog clos of - NONE => weak_last prog stack clos - | SOME r => weak prog stack r -and weak_last prog (SAppR (a, stack)) b = weak prog stack (CApp (a,b)) - | weak_last prog (s as (SAppL (b, stack))) a = weak prog (SAppR (a, stack)) b - | weak_last prog stack c = (stack, c) +fun do_reduction reduce p = + let + val s = ref (Continue p) + val _ = while cont (!s) do (s := reduce (proj_C (!s))) + in + proj_S (!s) + end -fun strong prog stack (Closure (e, CAbs m)) = +fun weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a)) + | weak_reduce (false, prog, SAppL (b, stack), Closure (e, CAbs m)) = Continue (false, prog, stack, Closure (b::e, m)) + | weak_reduce (false, prog, stack, Closure (e, CVar n)) = Continue (false, prog, stack, case List.nth (e, n) of CDummy => CVar n | r => r) + | weak_reduce (false, prog, stack, Closure (e, c as CConst _)) = Continue (false, prog, stack, c) + | weak_reduce (false, prog, stack, clos) = + (case match_closure prog clos of + NONE => Continue (true, prog, stack, clos) + | SOME r => Continue (false, prog, stack, r)) + | weak_reduce (true, prog, SAppR (a, stack), b) = Continue (false, prog, stack, CApp (a,b)) + | weak_reduce (true, prog, s as (SAppL (b, stack)), a) = Continue (false, prog, SAppR (a, stack), b) + | weak_reduce (true, prog, stack, c) = Stop (stack, c) + +fun strong_reduce (false, prog, stack, Closure (e, CAbs m)) = let - val (stack', wnf) = weak prog SEmpty (Closure ((CVar 0)::(lift_env 0 e), m)) + val (stack', wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure (CDummy::e, m)) in case stack' of - SEmpty => strong prog (SAbs stack) wnf + SEmpty => Continue (false, prog, SAbs stack, wnf) | _ => raise (Run "internal error in strong: weak failed") end - | strong prog stack (clos as (CApp (u, v))) = strong prog (SAppL (v, stack)) u - | strong prog stack clos = strong_last prog stack clos -and strong_last prog (SAbs stack) m = strong prog stack (CAbs m) - | strong_last prog (SAppL (b, stack)) a = strong prog (SAppR (a, stack)) b - | strong_last prog (SAppR (a, stack)) b = strong_last prog stack (CApp (a, b)) - | strong_last prog stack clos = (stack, clos) + | strong_reduce (false, prog, stack, clos as (CApp (u, v))) = Continue (false, prog, SAppL (v, stack), u) + | strong_reduce (false, prog, stack, clos) = Continue (true, prog, stack, clos) + | strong_reduce (true, prog, SAbs stack, m) = Continue (false, prog, stack, CAbs m) + | strong_reduce (true, prog, SAppL (b, stack), a) = Continue (false, prog, SAppR (a, stack), b) + | strong_reduce (true, prog, SAppR (a, stack), b) = Continue (true, prog, stack, CApp (a, b)) + | strong_reduce (true, prog, stack, clos) = Stop (stack, clos) fun run prog t = let - val (stack, wnf) = weak prog SEmpty (Closure ([], clos_of_term t)) + val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure ([], clos_of_term t)) in case stack of - SEmpty => (case strong prog SEmpty wnf of + SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of (SEmpty, snf) => term_of_clos snf | _ => raise (Run "internal error in run: strong failed")) | _ => raise (Run "internal error in run: weak failed") end -end +fun discard p = () -structure AbstractMachine = AM_Interpreter +end diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am_sml.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Compute_Oracle/am_sml.ML Mon Jul 09 17:36:25 2007 +0200 @@ -0,0 +1,530 @@ +(* Title: Pure/Tools/am_sml.ML + ID: $Id$ + Author: Steven Obua + + ToDO: "parameterless rewrite cannot be used in pattern": In a lot of cases it CAN be used, and these cases should be handled properly; + right now, all cases throw an exception. + +*) + +signature AM_SML = +sig + include ABSTRACT_MACHINE + val save_result : (string * term) -> unit + val set_compiled_rewriter : (term -> term) -> unit + val list_nth : 'a list * int -> 'a +end + +structure AM_SML : AM_SML = struct + +open AbstractMachine; + +type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) + +val saved_result = ref (NONE:(string*term)option) + +fun save_result r = (saved_result := SOME r) +fun clear_result () = (saved_result := NONE) + +val list_nth = List.nth + +(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*) + +val compiled_rewriter = ref (NONE:(term -> term)Option.option) + +fun set_compiled_rewriter r = (compiled_rewriter := SOME r) + +fun importable (Var _) = false + | importable (Const _) = true + | importable (App (a, b)) = importable a andalso importable b + | importable (Abs _) = false + +(*Returns true iff at most 0 .. (free-1) occur unbound. therefore + check_freevars 0 t iff t is closed*) +fun check_freevars free (Var x) = x < free + | check_freevars free (Const c) = true + | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v + | check_freevars free (Abs m) = check_freevars (free+1) m + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun update_arity arity code a = + (case Inttab.lookup arity code of + NONE => Inttab.update_new (code, a) arity + | SOME a' => if a > a' then Inttab.update (code, a) arity else arity) + +(* We have to find out the maximal arity of each constant *) +fun collect_pattern_arity PVar arity = arity + | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) + +(* We also need to find out the maximal toplevel arity of each function constant *) +fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" + | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) + +local +fun collect applevel (Var _) arity = arity + | collect applevel (Const c) arity = update_arity arity c applevel + | collect applevel (Abs m) arity = collect 0 m arity + | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) +in +fun collect_term_arity t arity = collect 0 t arity +end + +fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) + + +fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) + +fun beta (Const c) = Const c + | beta (Var i) = Var i + | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) + | beta (App (a, b)) = + (case beta a of + Abs m => beta (App (Abs m, b)) + | a => App (a, beta b)) + | beta (Abs m) = Abs (beta m) +and subst x (Const c) t = Const c + | subst x (Var i) t = if i = x then t else Var i + | subst x (App (a,b)) t = App (subst x a t, subst x b t) + | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) +and lift level (Const c) = Const c + | lift level (App (a,b)) = App (lift level a, lift level b) + | lift level (Var i) = if i < level then Var i else Var (i+1) + | lift level (Abs m) = Abs (lift (level + 1) m) +and unlift level (Const c) = Const c + | unlift level (App (a, b)) = App (unlift level a, unlift level b) + | unlift level (Abs m) = Abs (unlift (level+1) m) + | unlift level (Var i) = if i < level then Var i else Var (i-1) + +fun nlift level n (Var m) = if m < level then Var m else Var (m+n) + | nlift level n (Const c) = Const c + | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) + | nlift level n (Abs b) = Abs (nlift (level+1) n b) + +fun subst_const (c, t) (Const c') = if c = c' then t else Const c' + | subst_const _ (Var i) = Var i + | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) + | subst_const ct (Abs m) = Abs (subst_const ct m) + +(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) +fun inline_rules rules = + let + fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b + | term_contains_const c (Abs m) = term_contains_const c m + | term_contains_const c (Var i) = false + | term_contains_const c (Const c') = (c = c') + fun find_rewrite [] = NONE + | find_rewrite ((prems, PConst (c, []), r) :: _) = + if check_freevars 0 r then + if term_contains_const c r then + raise Compile "parameterless rewrite is caught in cycle" + else if not (null prems) then + raise Compile "parameterless rewrite may not be guarded" + else + SOME (c, r) + else raise Compile "unbound variable on right hand side or guards of rule" + | find_rewrite (_ :: rules) = find_rewrite rules + fun remove_rewrite (c,r) [] = [] + | remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = + (if c = c' then + if null args andalso r = r' andalso null (prems') then + remove_rewrite cr rules + else raise Compile "incompatible parameterless rewrites found" + else + rule :: (remove_rewrite cr rules)) + | remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs) + fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args) + | pattern_contains_const c (PVar) = false + fun inline_rewrite (ct as (c, _)) (prems, p, r) = + if pattern_contains_const c p then + raise Compile "parameterless rewrite cannot be used in pattern" + else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) + fun inline inlined rules = + (case find_rewrite rules of + NONE => (Inttab.make inlined, rules) + | SOME ct => + let + val rules = map (inline_rewrite ct) (remove_rewrite ct rules) + val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined) + in + inline inlined rules + end) + in + inline [] rules + end + + +(* + Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. + Also beta reduce the adjusted right hand side of a rule. +*) +fun adjust_rules rules = + let + val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty + val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty + fun arity_of c = the (Inttab.lookup arity c) + fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) + fun adjust_pattern PVar = PVar + | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C + fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") + | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") + | adjust_rule (rule as (prems, p as PConst (c, args),t)) = + let + val patternvars_counted = count_patternvars p + fun check_fv t = check_freevars patternvars_counted t + val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () + val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () + val args = map adjust_pattern args + val len = length args + val arity = arity_of c + val lift = nlift 0 + fun adjust_tm n t = if n=0 then t else adjust_tm (n-1) (App (t, Var (n-1))) + fun adjust_term n t = adjust_tm n (lift n t) + fun adjust_guard n (Guard (a,b)) = Guard (adjust_term n a, adjust_term n b) + in + if len = arity then + rule + else if arity >= len then + (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) + else (raise Compile "internal error in adjust_rule") + end + fun beta_guard (Guard (a,b)) = Guard (beta a, beta b) + fun beta_rule (prems, p, t) = ((map beta_guard prems, p, beta t) handle Match => raise Compile "beta_rule") + in + (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) + end + +fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = +let + fun str x = string_of_int x + fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s + val module_prefix = (case module of NONE => "" | SOME s => s^".") + fun print_apps d f [] = f + | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args + and print_call d (App (a, b)) args = print_call d a (b::args) + | print_call d (Const c) args = + (case arity_of c of + NONE => print_apps d (module_prefix^"Const "^(str c)) args + | SOME 0 => module_prefix^"C"^(str c) + | SOME a => + let + val len = length args + in + if a <= len then + let + val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) + val _ = if strict_a > a then raise Compile "strict" else () + val s = module_prefix^"c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) + val s = s^(concat (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) + in + print_apps d s (List.drop (args, a)) + end + else + let + fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) + fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) + fun append_args [] t = t + | append_args (c::cs) t = append_args cs (App (t, c)) + in + print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) + end + end) + | print_call d t args = print_apps d (print_term d t) args + and print_term d (Var x) = + if x < d then + "b"^(str (d-x-1)) + else + let + val n = pattern_var_count - (x-d) - 1 + val x = "x"^(str n) + in + if n < pattern_var_count - pattern_lazy_var_count then + x + else + "("^x^" ())" + end + | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" + | print_term d t = print_call d t [] +in + print_term 0 +end + +fun section n = if n = 0 then [] else (section (n-1))@[n-1] + +fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = + let + fun str x = Int.toString x + fun print_pattern top n PVar = (n+1, "x"^(str n)) + | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) + | print_pattern top n (PConst (c, args)) = + let + val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") + val (n, s) = print_pattern_list 0 top (n, f) args + in + (n, s) + end + and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") + | print_pattern_list' counter top (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts + end + and print_pattern_list counter top (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts + end + val c = (case p of PConst (c, _) => c | _ => raise Match) + val (n, pattern) = print_pattern true 0 p + val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) + fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm + fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" + val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(concat (map (fn i => " a"^(str i)) (section (the (arity_of c))))) + fun print_guards t [] = print_tm t + | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(concat (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch + in + (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) + end + +fun group_rules rules = + let + fun add_rule (r as (_, PConst (c,_), _)) groups = + let + val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) + in + Inttab.update (c, r::rs) groups + end + | add_rule _ _ = raise Compile "internal error group_rules" + in + fold_rev add_rule rules Inttab.empty + end + +fun sml_prog name code rules = + let + val buffer = ref "" + fun write s = (buffer := (!buffer)^s) + fun writeln s = (write s; write "\n") + fun writelist [] = () + | writelist (s::ss) = (writeln s; writelist ss) + fun str i = Int.toString i + val (inlinetab, rules) = inline_rules rules + val (arity, toplevel_arity, rules) = adjust_rules rules + val rules = group_rules rules + val constants = Inttab.keys arity + fun arity_of c = Inttab.lookup arity c + fun toplevel_arity_of c = Inttab.lookup toplevel_arity c + fun rep_str s n = concat (rep n s) + fun indexed s n = s^(str n) + fun string_of_tuple [] = "" + | string_of_tuple (x::xs) = "("^x^(concat (map (fn s => ", "^s) xs))^")" + fun string_of_args [] = "" + | string_of_args (x::xs) = x^(concat (map (fn s => " "^s) xs)) + fun default_case gnum c = + let + val leftargs = concat (map (indexed " x") (section (the (arity_of c)))) + val rightargs = section (the (arity_of c)) + val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) + val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs + val right = (indexed "C" c)^" "^(string_of_tuple xs) + val debug_lazy = "(print x"^(string_of_int (strict_args - 1))^";" + val right = if strict_args < the (arity_of c) then debug_lazy^"raise AM_SML.Run \"unresolved lazy call: "^(string_of_int c)^"\")" else right + in + (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right + end + + fun eval_rules c = + let + val arity = the (arity_of c) + val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) + fun eval_rule n = + let + val sc = string_of_int c + val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) + fun arg i = + let + val x = indexed "x" i + val x = if i < n then "(eval bounds "^x^")" else x + val x = if i < strict_arity then x else "(fn () => "^x^")" + in + x + end + val right = "c"^sc^" "^(string_of_args (map arg (section arity))) + val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right + val right = if arity > 0 then right else "C"^sc + in + " | eval bounds ("^left^") = "^right + end + in + map eval_rule (rev (section (arity + 1))) + end + + fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" + val _ = writelist [ + "structure "^name^" = struct", + "", + "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", + " "^(concat (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), + ""] + fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" + fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ + (case the (arity_of c) of + 0 => "true" + | n => + let + val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) + val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) + in + eq^(concat eqs) + end) + val _ = writelist [ + "fun term_eq (Const c1) (Const c2) = (c1 = c2)", + " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] + val _ = writelist (map make_term_eq constants) + val _ = writelist [ + " | term_eq _ _ = false", + "" + ] + val _ = writelist [ + "fun app (Abs a) b = a b", + " | app a b = App (a, b)", + ""] + fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) + fun writefundecl [] = () + | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) + fun list_group c = (case Inttab.lookup rules c of + NONE => [defcase 0 c] + | SOME rs => + let + val rs = + fold + (fn r => + fn rs => + let + val (gnum, l, rs) = + (case rs of + [] => (0, [], []) + | (gnum, l)::rs => (gnum, l, rs)) + val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r + in + if gnum' = gnum then + (gnum, r::l)::rs + else + let + val args = concat (map (fn i => " a"^(str i)) (section (the (arity_of c)))) + fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args + val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') + in + (gnum', [])::(gnum, s::r::l)::rs + end + end) + rs [] + val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) + in + rev (map (fn z => rev (snd z)) rs) + end) + val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) + val _ = writelist [ + "fun convert (Const i) = AM_SML.Const i", + " | convert (App (a, b)) = AM_SML.App (convert a, convert b)", + " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] + fun make_convert c = + let + val args = map (indexed "a") (section (the (arity_of c))) + val leftargs = + case args of + [] => "" + | (x::xs) => "("^x^(concat (map (fn s => ", "^s) xs))^")" + val args = map (indexed "convert a") (section (the (arity_of c))) + val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) + in + " | convert (C"^(str c)^" "^leftargs^") = "^right + end + val _ = writelist (map make_convert constants) + val _ = writelist [ + "", + "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", + " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] + val _ = map (writelist o eval_rules) constants + val _ = writelist [ + " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", + " | eval bounds (AbstractMachine.Const c) = Const c"] + val _ = writelist [ + "", + "fun export term = AM_SML.save_result (\""^code^"\", convert term)", + "", + "val _ = AM_SML.set_compiled_rewriter (fn t => convert (eval [] t))", + "", + "end"] + in + (arity, toplevel_arity, inlinetab, !buffer) + end + +val guid_counter = ref 0 +fun get_guid () = + let + val c = !guid_counter + val _ = guid_counter := !guid_counter + 1 + in + (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) + end + + +fun writeTextFile name s = File.write (Path.explode name) s + +fun use_source src = use_text "" Output.ml_output false src + +fun compile eqs = + let + val guid = get_guid () + val code = Real.toString (random ()) + val module = "AMSML_"^guid + val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs +(* val _ = writeTextFile "Gencode.ML" source*) + val _ = compiled_rewriter := NONE + val _ = use_source source + in + case !compiled_rewriter of + NONE => raise Compile "broken link to compiled function" + | SOME f => (module, code, arity, toplevel_arity, inlinetab, f) + end + + +fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) + | inline (Var i) = Var i + | inline (App (a, b)) = App (inline a, inline b) + | inline (Abs m) = Abs (inline m) + val t = beta (inline t) + fun arity_of c = Inttab.lookup arity c + fun toplevel_arity_of c = Inttab.lookup toplevel_arity c + val term = print_term NONE arity_of toplevel_arity_of 0 0 t + val source = "local open "^module^" in val _ = export ("^term^") end" + val _ = writeTextFile "Gencode_call.ML" source + val _ = clear_result () + val _ = use_source source + in + case !saved_result of + NONE => raise Run "broken link to compiled code" + | SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked") + end + +fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) + | inline (Var i) = Var i + | inline (App (a, b)) = App (inline a, inline b) + | inline (Abs m) = Abs (inline m) + in + compiled_fun (beta (inline t)) + end + +fun discard p = () + +end diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/am_util.ML --- a/src/Tools/Compute_Oracle/am_util.ML Mon Jul 09 11:44:23 2007 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,172 +0,0 @@ -(* Title: Tools/Compute_Oracle/am_util.ML - ID: $Id$ - Author: Steven Obua -*) - -signature AM_UTIL = sig - - type naming (* = string -> int *) - - exception Parse of string - exception Tokenize - - (* takes a naming for the constants *) - val read_rule : naming -> string -> AbstractMachine.pattern * AbstractMachine.term - - (* takes a naming for the constants and one for the free variables *) - val read_term : naming -> naming -> string -> AbstractMachine.term - - val term_ord : AbstractMachine.term * AbstractMachine.term -> order - -end - -structure AM_Util : AM_UTIL = -struct - -fun term_ord (AbstractMachine.Var x, AbstractMachine.Var y) = int_ord (x,y) - | term_ord (AbstractMachine.Const c1, AbstractMachine.Const c2) = int_ord (c1, c2) - | term_ord (AbstractMachine.App a1, AbstractMachine.App a2) = - prod_ord term_ord term_ord (a1, a2) - | term_ord (AbstractMachine.Abs m1, AbstractMachine.Abs m2) = term_ord (m1, m2) - | term_ord (AbstractMachine.Const _, _) = LESS - | term_ord (AbstractMachine.Var _, AbstractMachine.Const _ ) = GREATER - | term_ord (AbstractMachine.Var _, _) = LESS - | term_ord (AbstractMachine.App _, AbstractMachine.Abs _) = LESS - | term_ord (AbstractMachine.App _, _) = GREATER - | term_ord (AbstractMachine.Abs _, _) = LESS - -type naming = string -> int - -datatype token = - TokenConst of string | TokenLeft | TokenRight | TokenVar of string | - TokenLambda | TokenDot | TokenNone | TokenEq - -exception Tokenize; - -fun tokenize s = - let - fun is_lower c = "a" <= c andalso c <= "z"; - val is_alphanum = Symbol.is_ascii_letter orf Symbol.is_ascii_digit; - fun tz TokenNone [] = [] - | tz x [] = [x] - | tz TokenNone (c::cs) = - if Symbol.is_ascii_blank c then tz TokenNone cs - else if is_lower c then (tz (TokenVar c) cs) - else if is_alphanum c then (tz (TokenConst c) cs) - else if c = "%" then (TokenLambda :: (tz TokenNone cs)) - else if c = "." then (TokenDot :: (tz TokenNone cs)) - else if c = "(" then (TokenLeft :: (tz TokenNone cs)) - else if c = ")" then (TokenRight :: (tz TokenNone cs)) - else if c = "=" then (TokenEq :: (tz TokenNone cs)) - else raise Tokenize - | tz (TokenConst s) (c::cs) = - if is_alphanum c then (tz (TokenConst (s ^ c)) cs) - else (TokenConst s)::(tz TokenNone (c::cs)) - | tz (TokenVar s) (c::cs) = - if is_alphanum c then (tz (TokenVar (s ^ c)) cs) - else (TokenVar s)::(tz TokenNone (c::cs)) - | tz _ _ = raise Tokenize - in tz TokenNone (explode s) end - -exception Parse of string; - -fun cons x xs = - if List.exists (fn y => x = y) xs then raise (Parse ("variable occurs twice: "^x)) - else (x::xs) - -fun parse_pattern f pvars ((TokenConst c)::ts) = - let - val (pvars, ts, plist) = parse_pattern_list f pvars ts - in - (pvars, ts, AbstractMachine.PConst (f c, plist)) - end - | parse_pattern _ _ _ = raise (Parse "parse_pattern: constant expected") -and parse_pattern_single f pvars ((TokenVar x)::ts) = (cons x pvars, ts, AbstractMachine.PVar) - | parse_pattern_single f pvars ((TokenConst c)::ts) = (pvars, ts, AbstractMachine.PConst (f c, [])) - | parse_pattern_single f pvars (TokenLeft::ts) = - let - val (pvars, ts, p) = parse_pattern f pvars ts - in - case ts of - TokenRight::ts => (pvars, ts, p) - | _ => raise (Parse "parse_pattern_single: closing bracket expected") - end - | parse_pattern_single _ _ _ = raise (Parse "parse_pattern_single: got stuck") -and parse_pattern_list f pvars (TokenEq::ts) = (pvars, TokenEq::ts, []) - | parse_pattern_list f pvars (TokenRight::ts) = (pvars, TokenRight::ts, []) - | parse_pattern_list f pvars ts = - let - val (pvars, ts, p) = parse_pattern_single f pvars ts - val (pvars, ts, ps) = parse_pattern_list f pvars ts - in - (pvars, ts, p::ps) - end - -fun app_terms x (t::ts) = app_terms (AbstractMachine.App (x, t)) ts - | app_terms x [] = x - -fun parse_term_single f vars ((TokenConst c)::ts) = (ts, AbstractMachine.Const (f c)) - | parse_term_single f vars ((TokenVar v)::ts) = (ts, AbstractMachine.Var (vars v)) - | parse_term_single f vars (TokenLeft::ts) = - let - val (ts, term) = parse_term f vars ts - in - case ts of - TokenRight::ts => (ts, term) - | _ => raise Parse ("parse_term_single: closing bracket expected") - end - | parse_term_single f vars (TokenLambda::(TokenVar x)::TokenDot::ts) = - let - val (ts, term) = parse_term f (fn s => if s=x then 0 else (vars s)+1) ts - in - (ts, AbstractMachine.Abs term) - end - | parse_term_single _ _ _ = raise Parse ("parse_term_single: got stuck") -and parse_term_list f vars [] = ([], []) - | parse_term_list f vars (TokenRight::ts) = (TokenRight::ts, []) - | parse_term_list f vars ts = - let - val (ts, term) = parse_term_single f vars ts - val (ts, terms) = parse_term_list f vars ts - in - (ts, term::terms) - end -and parse_term f vars ts = - let - val (ts, terms) = parse_term_list f vars ts - in - case terms of - [] => raise (Parse "parse_term: no term found") - | (t::terms) => (ts, app_terms t terms) - end - -fun read_rule f s = - let - val t = tokenize s - val (v, ts, pattern) = parse_pattern f [] t - fun vars [] (x:string) = raise (Parse "read_rule.vars: variable not found") - | vars (v::vs) x = if v = x then 0 else (vars vs x)+1 - in - case ts of - TokenEq::ts => - let - val (ts, term) = parse_term f (vars v) ts - in - case ts of - [] => (pattern, term) - | _ => raise (Parse "read_rule: still tokens left, end expected") - end - | _ => raise (Parse ("read_rule: = expected")) - end - -fun read_term f g s = - let - val t = tokenize s - val (ts, term) = parse_term f g t - in - case ts of - [] => term - | _ => raise (Parse ("read_term: still tokens left, end expected")) - end - -end diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/compute.ML --- a/src/Tools/Compute_Oracle/compute.ML Mon Jul 09 11:44:23 2007 +0200 +++ b/src/Tools/Compute_Oracle/compute.ML Mon Jul 09 17:36:25 2007 +0200 @@ -1,4 +1,4 @@ -(* Title: Tools/Compute_Oracle/compute.ML +(* Title: Pure/Tools/compute.ML ID: $Id$ Author: Steven Obua *) @@ -7,272 +7,356 @@ type computer - exception Make of string + datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML - val basic_make : theory -> thm list -> computer - val make : theory -> thm list -> computer + exception Make of string + val make : machine -> theory -> thm list -> computer + exception Compute of string val compute : computer -> (int -> string) -> cterm -> term val theory_of : computer -> theory + val hyps_of : computer -> term list + val shyps_of : computer -> sort list - val default_naming: int -> string - val oracle_fn: theory -> computer * (int -> string) * cterm -> term + val rewrite_param : computer -> (int -> string) -> cterm -> thm + val rewrite : computer -> cterm -> thm + + val discard : computer -> unit + + val setup : theory -> theory + end -structure Compute: COMPUTE = struct +structure Compute :> COMPUTE = struct + +datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML + +(* Terms are mapped to integer codes *) +structure Encode :> +sig + type encoding + val empty : encoding + val insert : term -> encoding -> int * encoding + val lookup_code : term -> encoding -> int option + val lookup_term : int -> encoding -> term option + val remove_code : int -> encoding -> encoding + val remove_term : term -> encoding -> encoding + val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a +end += +struct + +type encoding = int * (int Termtab.table) * (term Inttab.table) + +val empty = (0, Termtab.empty, Inttab.empty) + +fun insert t (e as (count, term2int, int2term)) = + (case Termtab.lookup term2int t of + NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term)) + | SOME code => (code, e)) + +fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t + +fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c + +fun remove_code c (e as (count, term2int, int2term)) = + (case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun remove_term t (e as (count, term2int, int2term)) = + (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun fold f (_, term2int, _) = Termtab.fold f term2int + +end + exception Make of string; - -fun is_mono_typ (Type (_, list)) = forall is_mono_typ list - | is_mono_typ _ = false - -fun is_mono_term (Const (_, t)) = is_mono_typ t - | is_mono_term (Var (_, t)) = is_mono_typ t - | is_mono_term (Free (_, t)) = is_mono_typ t - | is_mono_term (Bound _) = true - | is_mono_term (a $ b) = is_mono_term a andalso is_mono_term b - | is_mono_term (Abs (_, ty, t)) = is_mono_typ ty andalso is_mono_term t - -structure AMTermTab = TableFun (type key = AbstractMachine.term val ord = AM_Util.term_ord) - -fun add x y = x + y : int; -fun inc x = x + 1; - -exception Mono of term; +exception Compute of string; -val remove_types = - let - fun remove_types_var table invtable ccount vcount ldepth t = - (case Termtab.lookup table t of - NONE => - let - val a = AbstractMachine.Var vcount - in - (Termtab.update (t, a) table, - AMTermTab.update (a, t) invtable, - ccount, - inc vcount, - AbstractMachine.Var (add vcount ldepth)) - end - | SOME (AbstractMachine.Var v) => - (table, invtable, ccount, vcount, AbstractMachine.Var (add v ldepth)) - | SOME _ => sys_error "remove_types_var: lookup should be a var") - - fun remove_types_const table invtable ccount vcount ldepth t = - (case Termtab.lookup table t of - NONE => - let - val a = AbstractMachine.Const ccount - in - (Termtab.update (t, a) table, - AMTermTab.update (a, t) invtable, - inc ccount, - vcount, - a) - end - | SOME (c as AbstractMachine.Const _) => - (table, invtable, ccount, vcount, c) - | SOME _ => sys_error "remove_types_const: lookup should be a const") +local + fun make_constant t ty encoding = + let + val (code, encoding) = Encode.insert t encoding + in + (encoding, AbstractMachine.Const code) + end +in - fun remove_types table invtable ccount vcount ldepth t = - case t of - Var (_, ty) => - if is_mono_typ ty then remove_types_var table invtable ccount vcount ldepth t - else raise (Mono t) - | Free (_, ty) => - if is_mono_typ ty then remove_types_var table invtable ccount vcount ldepth t - else raise (Mono t) - | Const (_, ty) => - if is_mono_typ ty then remove_types_const table invtable ccount vcount ldepth t - else raise (Mono t) - | Abs (_, ty, t') => - if is_mono_typ ty then - let - val (table, invtable, ccount, vcount, t') = - remove_types table invtable ccount vcount (inc ldepth) t' - in - (table, invtable, ccount, vcount, AbstractMachine.Abs t') - end - else - raise (Mono t) - | a $ b => - let - val (table, invtable, ccount, vcount, a) = - remove_types table invtable ccount vcount ldepth a - val (table, invtable, ccount, vcount, b) = - remove_types table invtable ccount vcount ldepth b - in - (table, invtable, ccount, vcount, AbstractMachine.App (a,b)) - end - | Bound b => (table, invtable, ccount, vcount, AbstractMachine.Var b) - in - fn (table, invtable, ccount, vcount) => remove_types table invtable ccount vcount 0 - end - -fun infer_types naming = +fun remove_types encoding t = + case t of + Var (_, ty) => make_constant t ty encoding + | Free (_, ty) => make_constant t ty encoding + | Const (_, ty) => make_constant t ty encoding + | Abs (_, ty, t') => + let val (encoding, t'') = remove_types encoding t' in + (encoding, AbstractMachine.Abs t'') + end + | a $ b => + let + val (encoding, a) = remove_types encoding a + val (encoding, b) = remove_types encoding b + in + (encoding, AbstractMachine.App (a,b)) + end + | Bound b => (encoding, AbstractMachine.Var b) +end + +local + fun type_of (Free (_, ty)) = ty + | type_of (Const (_, ty)) = ty + | type_of (Var (_, ty)) = ty + | type_of _ = sys_error "infer_types: type_of error" +in +fun infer_types naming encoding = let - fun infer_types invtable ldepth bounds ty (AbstractMachine.Var v) = - if v < ldepth then (Bound v, List.nth (bounds, v)) else - (case AMTermTab.lookup invtable (AbstractMachine.Var (v-ldepth)) of - SOME (t as Var (_, ty)) => (t, ty) - | SOME (t as Free (_, ty)) => (t, ty) - | _ => sys_error "infer_types: lookup should deliver Var or Free") - | infer_types invtable ldepth _ ty (c as AbstractMachine.Const _) = - (case AMTermTab.lookup invtable c of - SOME (c as Const (_, ty)) => (c, ty) - | _ => sys_error "infer_types: lookup should deliver Const") - | infer_types invtable ldepth bounds (n,ty) (AbstractMachine.App (a, b)) = - let - val (a, aty) = infer_types invtable ldepth bounds (n+1, ty) a - val (adom, arange) = + fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v)) + | infer_types _ bounds _ (AbstractMachine.Const code) = + let + val c = the (Encode.lookup_term code encoding) + in + (c, type_of c) + end + | infer_types level bounds _ (AbstractMachine.App (a, b)) = + let + val (a, aty) = infer_types level bounds NONE a + val (adom, arange) = case aty of Type ("fun", [dom, range]) => (dom, range) | _ => sys_error "infer_types: function type expected" - val (b, bty) = infer_types invtable ldepth bounds (0, adom) b - in - (a $ b, arange) - end - | infer_types invtable ldepth bounds (0, ty as Type ("fun", [dom, range])) - (AbstractMachine.Abs m) = + val (b, bty) = infer_types level bounds (SOME adom) b + in + (a $ b, arange) + end + | infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) = let - val (m, _) = infer_types invtable (ldepth+1) (dom::bounds) (0, range) m + val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m in - (Abs (naming ldepth, dom, m), ty) + (Abs (naming level, dom, m), ty) end - | infer_types invtable ldepth bounds ty (AbstractMachine.Abs m) = - sys_error "infer_types: cannot infer type of abstraction" + | infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction" - fun infer invtable ty term = + fun infer ty term = let - val (term', _) = infer_types invtable 0 [] (0, ty) term + val (term', _) = infer_types 0 [] (SOME ty) term in term' end in infer end +end -datatype computer = - Computer of theory_ref * - (AbstractMachine.term Termtab.table * term AMTermTab.table * int * AbstractMachine.program) +datatype prog = + ProgBarras of AM_Interpreter.program + | ProgBarrasC of AM_Compiler.program + | ProgHaskell of AM_GHC.program + | ProgSML of AM_SML.program -fun basic_make thy raw_ths = +structure Sorttab = TableFun(type key = sort val ord = Term.sort_ord) + +datatype computer = Computer of theory_ref * Encode.encoding * term list * unit Sorttab.table * prog + +datatype cthm = ComputeThm of term list * sort list * term + +fun thm2cthm th = let - val ths = map (Thm.transfer thy) raw_ths; + val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th + val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else () + in + ComputeThm (hyps, shyps, prop) + end - fun thm2rule table invtable ccount th = - let - val prop = Thm.plain_prop_of th - handle THM _ => raise (Make "theorems must be plain propositions") - val (a, b) = Logic.dest_equals prop - handle TERM _ => raise (Make "theorems must be meta-level equations") +fun make machine thy raw_ths = + let + fun transfer (x:thm) = Thm.transfer thy x + val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths - val (table, invtable, ccount, vcount, prop) = - remove_types (table, invtable, ccount, 0) (a$b) - handle Mono _ => raise (Make "no type variables allowed") - val (left, right) = - (case prop of AbstractMachine.App x => x | _ => - sys_error "make: remove_types should deliver application") + fun thm2rule (encoding, hyptable, shyptable) th = + let + val (ComputeThm (hyps, shyps, prop)) = th + val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable + val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable + val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop) + val (a, b) = Logic.dest_equals prop + handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)") + val a = Envir.eta_contract a + val b = Envir.eta_contract b + val prems = map Envir.eta_contract prems - fun make_pattern table invtable n vars (var as AbstractMachine.Var v) = + val (encoding, left) = remove_types encoding a + val (encoding, right) = remove_types encoding b + fun remove_types_of_guard encoding g = + (let + val (t1, t2) = Logic.dest_equals g + val (encoding, t1) = remove_types encoding t1 + val (encoding, t2) = remove_types encoding t2 + in + (encoding, AbstractMachine.Guard (t1, t2)) + end handle TERM _ => raise (Make "guards must be meta-level equations")) + val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, []) + + fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = + raise (Make "no lambda abstractions allowed in pattern") + | make_pattern encoding n vars (var as AbstractMachine.Var _) = + raise (Make "no bound variables allowed in pattern") + | make_pattern encoding n vars (AbstractMachine.Const code) = + (case the (Encode.lookup_term code encoding) of + Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) + handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) + | _ => (n, vars, AbstractMachine.PConst (code, []))) + | make_pattern encoding n vars (AbstractMachine.App (a, b)) = let - val var' = the (AMTermTab.lookup invtable var) - val table = Termtab.delete var' table - val invtable = AMTermTab.delete var invtable - val vars = Inttab.update_new (v, n) vars handle Inttab.DUP _ => - raise (Make "no duplicate variable in pattern allowed") - in - (table, invtable, n+1, vars, AbstractMachine.PVar) - end - | make_pattern table invtable n vars (AbstractMachine.Abs _) = - raise (Make "no lambda abstractions allowed in pattern") - | make_pattern table invtable n vars (AbstractMachine.Const c) = - (table, invtable, n, vars, AbstractMachine.PConst (c, [])) - | make_pattern table invtable n vars (AbstractMachine.App (a, b)) = - let - val (table, invtable, n, vars, pa) = - make_pattern table invtable n vars a - val (table, invtable, n, vars, pb) = - make_pattern table invtable n vars b + val (n, vars, pa) = make_pattern encoding n vars a + val (n, vars, pb) = make_pattern encoding n vars b in case pa of AbstractMachine.PVar => raise (Make "patterns may not start with a variable") | AbstractMachine.PConst (c, args) => - (table, invtable, n, vars, AbstractMachine.PConst (c, args@[pb])) + (n, vars, AbstractMachine.PConst (c, args@[pb])) end - val (table, invtable, vcount, vars, pattern) = - make_pattern table invtable 0 Inttab.empty left + (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule. + As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore + this check can be left out. *) + + val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left val _ = (case pattern of - AbstractMachine.PVar => + AbstractMachine.PVar => raise (Make "patterns may not start with a variable") - | _ => ()) - - (* at this point, there shouldn't be any variables - left in table or invtable, only constants *) + (* | AbstractMachine.PConst (_, []) => + (print th; raise (Make "no parameter rewrite found"))*) + | _ => ()) (* finally, provide a function for renaming the - pattern bound variables on the right hand side *) + pattern bound variables on the right hand side *) - fun rename ldepth vars (var as AbstractMachine.Var v) = - if v < ldepth then var - else (case Inttab.lookup vars (v - ldepth) of - NONE => raise (Make "new variable on right hand side") - | SOME n => AbstractMachine.Var ((vcount-n-1)+ldepth)) - | rename ldepth vars (c as AbstractMachine.Const _) = c - | rename ldepth vars (AbstractMachine.App (a, b)) = - AbstractMachine.App (rename ldepth vars a, rename ldepth vars b) - | rename ldepth vars (AbstractMachine.Abs m) = - AbstractMachine.Abs (rename (ldepth+1) vars m) - + fun rename level vars (var as AbstractMachine.Var _) = var + | rename level vars (c as AbstractMachine.Const code) = + (case Inttab.lookup vars code of + NONE => c + | SOME n => AbstractMachine.Var (vcount-n-1+level)) + | rename level vars (AbstractMachine.App (a, b)) = + AbstractMachine.App (rename level vars a, rename level vars b) + | rename level vars (AbstractMachine.Abs m) = + AbstractMachine.Abs (rename (level+1) vars m) + + fun rename_guard (AbstractMachine.Guard (a,b)) = + AbstractMachine.Guard (rename 0 vars a, rename 0 vars b) in - (table, invtable, ccount, (pattern, rename 0 vars right)) + ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right)) end - val (table, invtable, ccount, rules) = - fold_rev (fn th => fn (table, invtable, ccount, rules) => + val ((encoding, hyptable, shyptable), rules) = + fold_rev (fn th => fn (encoding_hyptable, rules) => let - val (table, invtable, ccount, rule) = - thm2rule table invtable ccount th - in (table, invtable, ccount, rule::rules) end) - ths (Termtab.empty, AMTermTab.empty, 0, []) + val (encoding_hyptable, rule) = thm2rule encoding_hyptable th + in (encoding_hyptable, rule::rules) end) + ths ((Encode.empty, Termtab.empty, Sorttab.empty), []) - val prog = AbstractMachine.compile rules + val prog = + case machine of + BARRAS => ProgBarras (AM_Interpreter.compile rules) + | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules) + | HASKELL => ProgHaskell (AM_GHC.compile rules) + | SML => ProgSML (AM_SML.compile rules) - in Computer (Theory.self_ref thy, (table, invtable, ccount, prog)) end +(* val _ = print (Encode.fold (fn x => fn s => x::s) encoding [])*) + + fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) -fun make thy ths = + val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable + + in Computer (Theory.self_ref thy, encoding, Termtab.keys hyptable, shyptable, prog) end + +(*fun timeit f = let - val (_, {mk_rews={mk=mk,mk_eq_True=emk, ...},...}) = rep_ss (simpset_of thy) - fun mk_eq_True th = (case emk th of NONE => [th] | SOME th' => [th, th']) + val t1 = Time.toMicroseconds (Time.now ()) + val x = f () + val t2 = Time.toMicroseconds (Time.now ()) + val _ = writeln ("### time = "^(Real.toString ((Real.fromLargeInt t2 - Real.fromLargeInt t1)/(1000000.0)))^"s") in - basic_make thy (maps mk (maps mk_eq_True ths)) - end + x + end*) + +fun report s f = f () (*writeln s; timeit f*) -fun compute (Computer r) naming ct = +fun compute (Computer (rthy, encoding, hyps, shyptable, prog)) naming ct = let + fun run (ProgBarras p) = AM_Interpreter.run p + | run (ProgBarrasC p) = AM_Compiler.run p + | run (ProgHaskell p) = AM_GHC.run p + | run (ProgSML p) = AM_SML.run p val {t=t, T=ty, thy=ctthy, ...} = rep_cterm ct - val (rthy, (table, invtable, ccount, prog)) = r val thy = Theory.merge (Theory.deref rthy, ctthy) - val (table, invtable, ccount, vcount, t) = remove_types (table, invtable, ccount, 0) t - val t = AbstractMachine.run prog t - val t = infer_types naming invtable ty t + val (encoding, t) = report "remove_types" (fn () => remove_types encoding t) + val t = report "run" (fn () => run prog t) + val t = report "infer_types" (fn () => infer_types naming encoding ty t) in t end -fun theory_of (Computer (rthy, _)) = Theory.deref rthy +fun discard (Computer (rthy, encoding, hyps, shyptable, prog)) = + (case prog of + ProgBarras p => AM_Interpreter.discard p + | ProgBarrasC p => AM_Compiler.discard p + | ProgHaskell p => AM_GHC.discard p + | ProgSML p => AM_SML.discard p) + +fun theory_of (Computer (rthy, _, _,_,_)) = Theory.deref rthy +fun hyps_of (Computer (_, _, hyps, _, _)) = hyps +fun shyps_of (Computer (_, _, _, shyptable, _)) = Sorttab.keys (shyptable) +fun shyptab_of (Computer (_, _, _, shyptable, _)) = shyptable fun default_naming i = "v_" ^ Int.toString i +exception Param of computer * (int -> string) * cterm; -fun oracle_fn thy (r, naming, ct) = +fun rewrite_param r n ct = + let + val thy = theory_of_cterm ct + val th = timeit (fn () => invoke_oracle_i thy "Compute_Oracle.compute" (thy, Param (r, n, ct))) + val hyps = map (fn h => assume (cterm_of thy h)) (hyps_of r) + in + fold (fn h => fn p => implies_elim p h) hyps th + end + +(*fun rewrite_param r n ct = + let + val hyps = hyps_of r + val shyps = shyps_of r + val thy = theory_of_cterm ct + val _ = Theory.assert_super (theory_of r) thy + val t' = timeit (fn () => compute r n ct) + val eq = Logic.mk_equals (term_of ct, t') + in + Thm.unchecked_oracle thy "Compute.compute" (eq, hyps, shyps) + end*) + +fun rewrite r ct = rewrite_param r default_naming ct + +(* theory setup *) + +fun compute_oracle (thy, Param (r, naming, ct)) = let val _ = Theory.assert_super (theory_of r) thy val t' = compute r naming ct + val eq = Logic.mk_equals (term_of ct, t') + val hyps = hyps_of r + val shyptab = shyptab_of r + fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab + fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab + val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (eq::hyps) shyptab) + val _ = if not (null shyps) then raise Compute ("dangling sort hypotheses: "^(makestring shyps)) else () in - Logic.mk_equals (term_of ct, t') + fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps eq end + | compute_oracle _ = raise Match + + +val setup = (fn thy => (writeln "install oracle"; Theory.add_oracle ("compute", compute_oracle) thy)) + +(*val _ = Context.add_setup (Theory.add_oracle ("compute", compute_oracle))*) end + diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/linker.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Compute_Oracle/linker.ML Mon Jul 09 17:36:25 2007 +0200 @@ -0,0 +1,392 @@ +(* Title: Tools/Compute_Oracle/Linker.ML + ID: $$ + Author: Steven Obua + + Linker.ML solves the problem that the computing oracle does not instantiate polymorphic rules. + By going through the PCompute interface, all possible instantiations are resolved by compiling new programs, if necessary. + The obvious disadvantage of this approach is that in the worst case for each new term to be rewritten, a new program may be compiled. +*) + +(* + Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n, + and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m + + Find all substitutions S such that + a) the domain of S is tvars (t_1, ..., t_n) + b) there are indices i_1, ..., i_k, and j_1, ..., j_k with + 1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k + 2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n) +*) +signature LINKER = +sig + exception Link of string + + datatype constant = Constant of bool * string * typ + val constant_of : term -> constant + + type instances + type subst = Type.tyenv + + val empty : constant list -> instances + val typ_of_constant : constant -> typ + val add_instances : Type.tsig -> instances -> constant list -> subst list * instances + val substs_of : instances -> subst list + val is_polymorphic : constant -> bool + val distinct_constants : constant list -> constant list + val collect_consts : term list -> constant list +end + +structure Linker : LINKER = struct + +exception Link of string; + +type subst = Type.tyenv + +datatype constant = Constant of bool * string * typ +fun constant_of (Const (name, ty)) = Constant (false, name, ty) + | constant_of (Free (name, ty)) = Constant (true, name, ty) + | constant_of _ = raise Link "constant_of" + +fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL) +fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term.typ_ord) (((x1,x2),x3), ((y1,y2),y3)) +fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2)) + + +structure Consttab = TableFun(type key = constant val ord = constant_ord); +structure ConsttabModTy = TableFun(type key = constant val ord = constant_modty_ord); + +fun typ_of_constant (Constant (_, _, ty)) = ty + +val empty_subst = (Vartab.empty : Type.tyenv) + +fun merge_subst (A:Type.tyenv) (B:Type.tyenv) = + SOME (Vartab.fold (fn (v, t) => + fn tab => + (case Vartab.lookup tab v of + NONE => Vartab.update (v, t) tab + | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B) + handle Type.TYPE_MATCH => NONE + +fun subst_ord (A:Type.tyenv, B:Type.tyenv) = + (list_ord (prod_ord Term.fast_indexname_ord (prod_ord Term.sort_ord Term.typ_ord))) (Vartab.dest A, Vartab.dest B) + +structure Substtab = TableFun(type key = Type.tyenv val ord = subst_ord); + +val substtab_union = Substtab.fold Substtab.update +fun substtab_unions [] = Substtab.empty + | substtab_unions [c] = c + | substtab_unions (c::cs) = substtab_union c (substtab_unions cs) + +datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table + +fun is_polymorphic (Constant (_, _, ty)) = not (null (typ_tvars ty)) + +fun distinct_constants cs = + Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty) + +fun empty cs = + let + val cs = distinct_constants (filter is_polymorphic cs) + val old_cs = cs +(* fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (typ_tvars ty) tab + val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty)) + fun tvars_of ty = collect_tvars ty Typtab.empty + val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs + + fun tyunion A B = + Typtab.fold + (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab) + A B + + fun is_essential A B = + Typtab.fold + (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1)) + A false + + fun add_minimal (c', tvs') (tvs, cs) = + let + val tvs = tyunion tvs' tvs + val cs = (c', tvs')::cs + in + if forall (fn (c',tvs') => is_essential tvs' tvs) cs then + SOME (tvs, cs) + else + NONE + end + + fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count) + + fun generate_minimal_subsets subsets [] = subsets + | generate_minimal_subsets subsets (c::cs) = + let + val subsets' = map_filter (add_minimal c) subsets + in + generate_minimal_subsets (subsets@subsets') cs + end*) + + val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*) + + val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty) + + in + Instances ( + fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty, + Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants), + minimal_subsets, Substtab.empty) + end + +local +fun calc ctab substtab [] = substtab + | calc ctab substtab (c::cs) = + let + val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c))) + fun merge_substs substtab subst = + Substtab.fold (fn (s,_) => + fn tab => + (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab)) + substtab Substtab.empty + val substtab = substtab_unions (map (merge_substs substtab) csubsts) + in + calc ctab substtab cs + end +in +fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs +end + +fun add_instances tsig (Instances (cfilter, ctab,minsets,substs)) cs = + let +(* val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*) + fun calc_instantiations (constant as Constant (free, name, ty)) instantiations = + Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) => + fn instantiations => + if free <> free' orelse name <> name' then + instantiations + else case Consttab.lookup insttab constant of + SOME _ => instantiations + | NONE => ((constant', (constant, Type.typ_match tsig (ty', ty) empty_subst))::instantiations + handle TYPE_MATCH => instantiations)) + ctab instantiations + val instantiations = fold calc_instantiations cs [] + (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*) + fun update_ctab (constant', entry) ctab = + (case Consttab.lookup ctab constant' of + NONE => raise Link "internal error: update_ctab" + | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab) + val ctab = fold update_ctab instantiations ctab + val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs) + minsets Substtab.empty + val (added_substs, substs) = + Substtab.fold (fn (ns, _) => + fn (added, substtab) => + (case Substtab.lookup substs ns of + NONE => (ns::added, Substtab.update (ns, ()) substtab) + | SOME () => (added, substtab))) + new_substs ([], substs) + in + (added_substs, Instances (cfilter, ctab, minsets, substs)) + end + + +fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs + +local + fun get_thm thmname = PureThy.get_thm (theory "Main") (Name thmname) + val eq_th = get_thm "HOL.eq_reflection" +in + fun eq_to_meta th = (eq_th OF [th] handle _ => th) +end + + +local + +fun collect (Var x) tab = tab + | collect (Bound _) tab = tab + | collect (a $ b) tab = collect b (collect a tab) + | collect (Abs (_, _, body)) tab = collect body tab + | collect t tab = Consttab.update (constant_of t, ()) tab + +in + fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty) +end + +end + +signature PCOMPUTE = +sig + + type pcomputer + + val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer + +(* val add_thms : pcomputer -> thm list -> bool*) + + val add_instances : pcomputer -> Linker.constant list -> bool + + val rewrite : pcomputer -> cterm list -> thm list + +end + +structure PCompute : PCOMPUTE = struct + +exception PCompute of string + +datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list + +datatype pcomputer = PComputer of Compute.machine * theory_ref * Compute.computer ref * theorem list ref + +(*fun collect_consts (Var x) = [] + | collect_consts (Bound _) = [] + | collect_consts (a $ b) = (collect_consts a)@(collect_consts b) + | collect_consts (Abs (_, _, body)) = collect_consts body + | collect_consts t = [Linker.constant_of t]*) + +fun collect_consts_of_thm th = + let + val th = prop_of th + val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th) + val (left, right) = Logic.dest_equals th + in + (Linker.collect_consts [left], Linker.collect_consts (right::prems)) + end + +fun create_theorem th = +let + val (left, right) = collect_consts_of_thm th + val polycs = filter Linker.is_polymorphic left + val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty + fun check_const (c::cs) cs' = + let + val tvars = typ_tvars (Linker.typ_of_constant c) + val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false + in + if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side" + else + if null (tvars) then + check_const cs (c::cs') + else + check_const cs cs' + end + | check_const [] cs' = cs' + val monocs = check_const right [] +in + if null (polycs) then + (monocs, MonoThm th) + else + (monocs, PolyThm (th, Linker.empty polycs, [])) +end + +fun create_computer machine thy ths = + let + fun add (MonoThm th) ths = th::ths + | add (PolyThm (_, _, ths')) ths = ths'@ths + val ths = fold_rev add ths [] + in + Compute.make machine thy ths + end + +fun conv_subst thy (subst : Type.tyenv) = + map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst) + +fun add_monos thy monocs ths = + let + val tsig = Sign.tsig_of thy + val changed = ref false + fun add monocs (th as (MonoThm _)) = ([], th) + | add monocs (PolyThm (th, instances, instanceths)) = + let + val (newsubsts, instances) = Linker.add_instances tsig instances monocs + val _ = if not (null newsubsts) then changed := true else () + val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts +(* val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*) + val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths [] + in + (newmonos, PolyThm (th, instances, instanceths@newths)) + end + fun step monocs ths = + fold_rev (fn th => + fn (newmonos, ths) => + let val (newmonos', th') = add monocs th in + (newmonos'@newmonos, th'::ths) + end) + ths ([], []) + fun loop monocs ths = + let val (monocs', ths') = step monocs ths in + if null (monocs') then + ths' + else + loop monocs' ths' + end + val result = loop monocs ths + in + (!changed, result) + end + +datatype cthm = ComputeThm of term list * sort list * term + +fun thm2cthm th = + let + val {hyps, prop, shyps, ...} = Thm.rep_thm th + in + ComputeThm (hyps, shyps, prop) + end + +val cthm_ord' = prod_ord (prod_ord (list_ord Term.term_ord) (list_ord Term.sort_ord)) Term.term_ord + +fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2)) + +structure CThmtab = TableFun (type key = cthm val ord = cthm_ord) + +fun remove_duplicates ths = + let + val counter = ref 0 + val tab = ref (CThmtab.empty : unit CThmtab.table) + val thstab = ref (Inttab.empty : thm Inttab.table) + fun update th = + let + val key = thm2cthm th + in + case CThmtab.lookup (!tab) key of + NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1) + | _ => () + end + val _ = map update ths + in + map snd (Inttab.dest (!thstab)) + end + + +fun make machine thy ths cs = + let + val ths = remove_duplicates ths + val (monocs, ths) = fold_rev (fn th => + fn (monocs, ths) => + let val (m, t) = create_theorem th in + (m@monocs, t::ths) + end) + ths (cs, []) + val (_, ths) = add_monos thy monocs ths + in + PComputer (machine, Theory.self_ref thy, ref (create_computer machine thy ths), ref ths) + end + +fun add_instances (PComputer (machine, thyref, rcomputer, rths)) cs = + let + val thy = Theory.deref thyref + val (changed, ths) = add_monos thy cs (!rths) + in + if changed then + (rcomputer := create_computer machine thy ths; + rths := ths; + true) + else + false + end + +fun rewrite (pc as PComputer (_, _, rcomputer, _)) cts = + let + val _ = map (fn ct => add_instances pc (Linker.collect_consts [term_of ct])) cts + in + map (fn ct => Compute.rewrite (!rcomputer) ct) cts + end + +end \ No newline at end of file