# HG changeset patch # User wenzelm # Date 1279719876 -7200 # Node ID d836595703372f2077e14d1dff2f3000abe1e8b7 # Parent c7ce7685e087d53ab879c2302557d967034f3fb1 moved src/Tools/Compute_Oracle to src/HOL/Matrix/Compute_Oracle -- it actually depends on HOL anyway; diff -r c7ce7685e087 -r d83659570337 src/HOL/IsaMakefile --- a/src/HOL/IsaMakefile Wed Jul 21 15:31:38 2010 +0200 +++ b/src/HOL/IsaMakefile Wed Jul 21 15:44:36 2010 +0200 @@ -1036,21 +1036,17 @@ HOL-Matrix: HOL $(LOG)/HOL-Matrix.gz -$(LOG)/HOL-Matrix.gz: $(OUT)/HOL \ - $(SRC)/Tools/Compute_Oracle/Compute_Oracle.thy \ - $(SRC)/Tools/Compute_Oracle/am_compiler.ML \ - $(SRC)/Tools/Compute_Oracle/am_interpreter.ML \ - $(SRC)/Tools/Compute_Oracle/am.ML \ - $(SRC)/Tools/Compute_Oracle/linker.ML \ - $(SRC)/Tools/Compute_Oracle/am_ghc.ML \ - $(SRC)/Tools/Compute_Oracle/am_sml.ML \ - $(SRC)/Tools/Compute_Oracle/compute.ML Matrix/ComputeFloat.thy \ - Matrix/ComputeHOL.thy Matrix/ComputeNumeral.thy Tools/float_arith.ML \ - Matrix/Matrix.thy Matrix/SparseMatrix.thy Matrix/LP.thy \ - Matrix/document/root.tex Matrix/ROOT.ML Matrix/Cplex.thy \ +$(LOG)/HOL-Matrix.gz: $(OUT)/HOL Matrix/ComputeFloat.thy \ + Matrix/ComputeHOL.thy Matrix/ComputeNumeral.thy \ + Matrix/Compute_Oracle/Compute_Oracle.thy Matrix/Compute_Oracle/am.ML \ + Matrix/Compute_Oracle/am_compiler.ML Matrix/Compute_Oracle/am_ghc.ML \ + Matrix/Compute_Oracle/am_interpreter.ML \ + Matrix/Compute_Oracle/am_sml.ML Matrix/Compute_Oracle/compute.ML \ + Matrix/Compute_Oracle/linker.ML Matrix/Cplex.thy \ Matrix/CplexMatrixConverter.ML Matrix/Cplex_tools.ML \ - Matrix/FloatSparseMatrixBuilder.ML Matrix/fspmlp.ML \ - Matrix/matrixlp.ML + Matrix/FloatSparseMatrixBuilder.ML Matrix/LP.thy Matrix/Matrix.thy \ + Matrix/ROOT.ML Matrix/SparseMatrix.thy Matrix/document/root.tex \ + Matrix/fspmlp.ML Matrix/matrixlp.ML Tools/float_arith.ML @$(ISABELLE_TOOL) usedir -g true $(OUT)/HOL Matrix diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/ComputeHOL.thy --- a/src/HOL/Matrix/ComputeHOL.thy Wed Jul 21 15:31:38 2010 +0200 +++ b/src/HOL/Matrix/ComputeHOL.thy Wed Jul 21 15:44:36 2010 +0200 @@ -1,5 +1,5 @@ theory ComputeHOL -imports Complex_Main "~~/src/Tools/Compute_Oracle/Compute_Oracle" +imports Complex_Main "Compute_Oracle/Compute_Oracle" begin lemma Trueprop_eq_eq: "Trueprop X == (X == True)" by (simp add: atomize_eq) diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/Compute_Oracle.thy --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/Compute_Oracle.thy Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,11 @@ +(* Title: Tools/Compute_Oracle/Compute_Oracle.thy + Author: Steven Obua, TU Munich + +Steven Obua's evaluator. +*) + +theory Compute_Oracle imports HOL +uses "am.ML" "am_compiler.ML" "am_interpreter.ML" "am_ghc.ML" "am_sml.ML" "report.ML" "compute.ML" "linker.ML" +begin + +end \ No newline at end of file diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/am.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/am.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,75 @@ +signature ABSTRACT_MACHINE = +sig + +datatype term = Var of int | Const of int | App of term * term | Abs of term | Computed of term + +datatype pattern = PVar | PConst of int * (pattern list) + +datatype guard = Guard of term * term + +type program + +exception Compile of string; + +(* The de-Bruijn index 0 occurring on the right hand side refers to the LAST pattern variable, when traversing the pattern from left to right, + 1 to the second last, and so on. *) +val compile : pattern list -> (int -> int option) -> (guard list * pattern * term) list -> program + +val discard : program -> unit + +exception Run of string; +val run : program -> term -> term + +(* Utilities *) + +val check_freevars : int -> term -> bool +val forall_consts : (int -> bool) -> term -> bool +val closed : term -> bool +val erase_Computed : term -> term + +end + +structure AbstractMachine : ABSTRACT_MACHINE = +struct + +datatype term = Var of int | Const of int | App of term * term | Abs of term | Computed of term + +datatype pattern = PVar | PConst of int * (pattern list) + +datatype guard = Guard of term * term + +type program = unit + +exception Compile of string; + +fun erase_Computed (Computed t) = erase_Computed t + | erase_Computed (App (t1, t2)) = App (erase_Computed t1, erase_Computed t2) + | erase_Computed (Abs t) = Abs (erase_Computed t) + | erase_Computed t = t + +(*Returns true iff at most 0 .. (free-1) occur unbound. therefore + check_freevars 0 t iff t is closed*) +fun check_freevars free (Var x) = x < free + | check_freevars free (Const c) = true + | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v + | check_freevars free (Abs m) = check_freevars (free+1) m + | check_freevars free (Computed t) = check_freevars free t + +fun forall_consts pred (Const c) = pred c + | forall_consts pred (Var x) = true + | forall_consts pred (App (u,v)) = forall_consts pred u + andalso forall_consts pred v + | forall_consts pred (Abs m) = forall_consts pred m + | forall_consts pred (Computed t) = forall_consts pred t + +fun closed t = check_freevars 0 t + +fun compile _ = raise Compile "abstract machine stub" + +fun discard _ = raise Compile "abstract machine stub" + +exception Run of string; + +fun run p t = raise Run "abstract machine stub" + +end diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/am_compiler.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/am_compiler.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,211 @@ +(* Title: Tools/Compute_Oracle/am_compiler.ML + Author: Steven Obua +*) + +signature COMPILING_AM = +sig + include ABSTRACT_MACHINE + + val set_compiled_rewriter : (term -> term) -> unit + val list_nth : 'a list * int -> 'a + val list_map : ('a -> 'b) -> 'a list -> 'b list +end + +structure AM_Compiler : COMPILING_AM = struct + +val list_nth = List.nth; +val list_map = map; + +open AbstractMachine; + +val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option) + +fun set_compiled_rewriter r = (compiled_rewriter := SOME r) + +type program = (term -> term) + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun print_rule (p, t) = + let + fun str x = Int.toString x + fun print_pattern n PVar = (n+1, "x"^(str n)) + | print_pattern n (PConst (c, [])) = (n, "c"^(str c)) + | print_pattern n (PConst (c, args)) = + let + val h = print_pattern n (PConst (c,[])) + in + print_pattern_list h args + end + and print_pattern_list r [] = r + | print_pattern_list (n, p) (t::ts) = + let + val (n, t) = print_pattern n t + in + print_pattern_list (n, "App ("^p^", "^t^")") ts + end + + val (n, pattern) = print_pattern 0 p + val pattern = + if exists_string Symbol.is_ascii_blank pattern then "(" ^ pattern ^")" + else pattern + + fun print_term d (Var x) = (*if x < d then "Var "^(str x) else "x"^(str (n-(x-d)-1))*) + "Var " ^ str x + | print_term d (Const c) = "c" ^ str c + | print_term d (App (a,b)) = "App (" ^ print_term d a ^ ", " ^ print_term d b ^ ")" + | print_term d (Abs c) = "Abs (" ^ print_term (d + 1) c ^ ")" + | print_term d (Computed c) = print_term d c + + fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1)) + + val term = print_term 0 t + val term = + if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")" + else "Closure ([], "^term^")" + + in + " | weak_reduce (false, stack, "^pattern^") = Continue (false, stack, "^term^")" + end + +fun constants_of PVar = [] + | constants_of (PConst (c, ps)) = c :: maps constants_of ps + +fun constants_of_term (Var _) = [] + | constants_of_term (Abs m) = constants_of_term m + | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b) + | constants_of_term (Const c) = [c] + | constants_of_term (Computed c) = constants_of_term c + +fun load_rules sname name prog = + let + val buffer = Unsynchronized.ref "" + fun write s = (buffer := (!buffer)^s) + fun writeln s = (write s; write "\n") + fun writelist [] = () + | writelist (s::ss) = (writeln s; writelist ss) + fun str i = Int.toString i + val _ = writelist [ + "structure "^name^" = struct", + "", + "datatype term = Dummy | App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"] + val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog) + val _ = map (fn x => write (" | c"^(str x))) constants + val _ = writelist [ + "", + "datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack", + "", + "type state = bool * stack * term", + "", + "datatype loopstate = Continue of state | Stop of stack * term", + "", + "fun proj_C (Continue s) = s", + " | proj_C _ = raise Match", + "", + "fun proj_S (Stop s) = s", + " | proj_S _ = raise Match", + "", + "fun cont (Continue _) = true", + " | cont _ = false", + "", + "fun do_reduction reduce p =", + " let", + " val s = Unsynchronized.ref (Continue p)", + " val _ = while cont (!s) do (s := reduce (proj_C (!s)))", + " in", + " proj_S (!s)", + " end", + ""] + + val _ = writelist [ + "fun weak_reduce (false, stack, Closure (e, App (a, b))) = Continue (false, SAppL (Closure (e, b), stack), Closure (e, a))", + " | weak_reduce (false, SAppL (b, stack), Closure (e, Abs m)) = Continue (false, stack, Closure (b::e, m))", + " | weak_reduce (false, stack, c as Closure (e, Abs m)) = Continue (true, stack, c)", + " | weak_reduce (false, stack, Closure (e, Var n)) = Continue (false, stack, case "^sname^".list_nth (e, n) of Dummy => Var n | r => r)", + " | weak_reduce (false, stack, Closure (e, c)) = Continue (false, stack, c)"] + val _ = writelist (map print_rule prog) + val _ = writelist [ + " | weak_reduce (false, stack, clos) = Continue (true, stack, clos)", + " | weak_reduce (true, SAppR (a, stack), b) = Continue (false, stack, App (a,b))", + " | weak_reduce (true, s as (SAppL (b, stack)), a) = Continue (false, SAppR (a, stack), b)", + " | weak_reduce (true, stack, c) = Stop (stack, c)", + "", + "fun strong_reduce (false, stack, Closure (e, Abs m)) =", + " let", + " val (stack', wnf) = do_reduction weak_reduce (false, SEmpty, Closure (Dummy::e, m))", + " in", + " case stack' of", + " SEmpty => Continue (false, SAbs stack, wnf)", + " | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")", + " end", + " | strong_reduce (false, stack, clos as (App (u, v))) = Continue (false, SAppL (v, stack), u)", + " | strong_reduce (false, stack, clos) = Continue (true, stack, clos)", + " | strong_reduce (true, SAbs stack, m) = Continue (false, stack, Abs m)", + " | strong_reduce (true, SAppL (b, stack), a) = Continue (false, SAppR (a, stack), b)", + " | strong_reduce (true, SAppR (a, stack), b) = Continue (true, stack, App (a, b))", + " | strong_reduce (true, stack, clos) = Stop (stack, clos)", + ""] + + val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)" + val _ = writelist [ + "fun importTerm ("^sname^".Var x) = Var x", + " | importTerm ("^sname^".Const c) = "^ic, + " | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)", + " | importTerm ("^sname^".Abs m) = Abs (importTerm m)", + ""] + + fun ec c = " | exportTerm c"^(str c)^" = "^sname^".Const "^(str c) + val _ = writelist [ + "fun exportTerm (Var x) = "^sname^".Var x", + " | exportTerm (Const c) = "^sname^".Const c", + " | exportTerm (App (a,b)) = "^sname^".App (exportTerm a, exportTerm b)", + " | exportTerm (Abs m) = "^sname^".Abs (exportTerm m)", + " | exportTerm (Closure (closlist, clos)) = raise ("^sname^".Run \"internal error, cannot export Closure\")", + " | exportTerm Dummy = raise ("^sname^".Run \"internal error, cannot export Dummy\")"] + val _ = writelist (map ec constants) + + val _ = writelist [ + "", + "fun rewrite t = ", + " let", + " val (stack, wnf) = do_reduction weak_reduce (false, SEmpty, Closure ([], importTerm t))", + " in", + " case stack of ", + " SEmpty => (case do_reduction strong_reduce (false, SEmpty, wnf) of", + " (SEmpty, snf) => exportTerm snf", + " | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))", + " | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))", + " end", + "", + "val _ = "^sname^".set_compiled_rewriter rewrite", + "", + "end;"] + + in + compiled_rewriter := NONE; + use_text ML_Env.local_context (1, "") false (!buffer); + case !compiled_rewriter of + NONE => raise (Compile "cannot communicate with compiled function") + | SOME r => (compiled_rewriter := NONE; r) + end + +fun compile cache_patterns const_arity eqs = + let + val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () + val eqs = map (fn (a,b,c) => (b,c)) eqs + fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") + val _ = map (fn (p, r) => + (check (p, r); + case p of PVar => raise (Compile "pattern is just a variable") | _ => ())) eqs + in + load_rules "AM_Compiler" "AM_compiled_code" eqs + end + +fun run prog t = (prog t) + +fun discard p = () + +end + diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/am_ghc.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/am_ghc.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,325 @@ +(* Title: Tools/Compute_Oracle/am_ghc.ML + Author: Steven Obua +*) + +structure AM_GHC : ABSTRACT_MACHINE = struct + +open AbstractMachine; + +type program = string * string * (int Inttab.table) + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun update_arity arity code a = + (case Inttab.lookup arity code of + NONE => Inttab.update_new (code, a) arity + | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) + +(* We have to find out the maximal arity of each constant *) +fun collect_pattern_arity PVar arity = arity + | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) + +local +fun collect applevel (Var _) arity = arity + | collect applevel (Const c) arity = update_arity arity c applevel + | collect applevel (Abs m) arity = collect 0 m arity + | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) +in +fun collect_term_arity t arity = collect 0 t arity +end + +fun nlift level n (Var m) = if m < level then Var m else Var (m+n) + | nlift level n (Const c) = Const c + | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) + | nlift level n (Abs b) = Abs (nlift (level+1) n b) + +fun rep n x = if n = 0 then [] else x::(rep (n-1) x) + +fun adjust_rules rules = + let + val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty + fun arity_of c = the (Inttab.lookup arity c) + fun adjust_pattern PVar = PVar + | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C + fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable") + | adjust_rule (rule as (p as PConst (c, args),t)) = + let + val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () + val args = map adjust_pattern args + val len = length args + val arity = arity_of c + fun lift level n (Var m) = if m < level then Var m else Var (m+n) + | lift level n (Const c) = Const c + | lift level n (App (a,b)) = App (lift level n a, lift level n b) + | lift level n (Abs b) = Abs (lift (level+1) n b) + val lift = lift 0 + fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) + in + if len = arity then + rule + else if arity >= len then + (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t)) + else (raise Compile "internal error in adjust_rule") + end + in + (arity, map adjust_rule rules) + end + +fun print_term arity_of n = +let + fun str x = string_of_int x + fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s + + fun print_apps d f [] = f + | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args + and print_call d (App (a, b)) args = print_call d a (b::args) + | print_call d (Const c) args = + (case arity_of c of + NONE => print_apps d ("Const "^(str c)) args + | SOME a => + let + val len = length args + in + if a <= len then + let + val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a)))) + in + print_apps d s (List.drop (args, a)) + end + else + let + fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1))) + fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) + fun append_args [] t = t + | append_args (c::cs) t = append_args cs (App (t, c)) + in + print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) + end + end) + | print_call d t args = print_apps d (print_term d t) args + and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1)) + | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")" + | print_term d t = print_call d t [] +in + print_term 0 +end + +fun print_rule arity_of (p, t) = + let + fun str x = Int.toString x + fun print_pattern top n PVar = (n+1, "x"^(str n)) + | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)) + | print_pattern top n (PConst (c, args)) = + let + val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args + in + (n, if top then s else "("^s^")") + end + and print_pattern_list r [] = r + | print_pattern_list (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list (n, p^" "^t) ts + end + val (n, pattern) = print_pattern true 0 p + in + pattern^" = "^(print_term arity_of n t) + end + +fun group_rules rules = + let + fun add_rule (r as (PConst (c,_), _)) groups = + let + val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) + in + Inttab.update (c, r::rs) groups + end + | add_rule _ _ = raise Compile "internal error group_rules" + in + fold_rev add_rule rules Inttab.empty + end + +fun haskell_prog name rules = + let + val buffer = Unsynchronized.ref "" + fun write s = (buffer := (!buffer)^s) + fun writeln s = (write s; write "\n") + fun writelist [] = () + | writelist (s::ss) = (writeln s; writelist ss) + fun str i = Int.toString i + val (arity, rules) = adjust_rules rules + val rules = group_rules rules + val constants = Inttab.keys arity + fun arity_of c = Inttab.lookup arity c + fun rep_str s n = implode (rep n s) + fun indexed s n = s^(str n) + fun section n = if n = 0 then [] else (section (n-1))@[n-1] + fun make_show c = + let + val args = section (the (arity_of c)) + in + " show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = " + ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args)) + end + fun default_case c = + let + val args = implode (map (indexed " x") (section (the (arity_of c)))) + in + (indexed "c" c)^args^" = "^(indexed "C" c)^args + end + val _ = writelist [ + "module "^name^" where", + "", + "data Term = Const Integer | App Term Term | Abs (Term -> Term)", + " "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)), + "", + "instance Show Term where"] + val _ = writelist (map make_show constants) + val _ = writelist [ + " show (Const c) = \"c\"++(show c)", + " show (App a b) = \"A\"++(show a)++(show b)", + " show (Abs _) = \"L\"", + ""] + val _ = writelist [ + "app (Abs a) b = a b", + "app a b = App a b", + "", + "calc s c = writeFile s (show c)", + ""] + fun list_group c = (writelist (case Inttab.lookup rules c of + NONE => [default_case c, ""] + | SOME (rs as ((PConst (_, []), _)::rs')) => + if not (null rs') then raise Compile "multiple declaration of constant" + else (map (print_rule arity_of) rs) @ [""] + | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""])) + val _ = map list_group constants + in + (arity, !buffer) + end + +val guid_counter = Unsynchronized.ref 0 +fun get_guid () = + let + val c = !guid_counter + val _ = guid_counter := !guid_counter + 1 + in + (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) + end + +fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s]))); +fun wrap s = "\""^s^"\"" + +fun writeTextFile name s = File.write (Path.explode name) s + +val ghc = Unsynchronized.ref (case getenv "GHC_PATH" of "" => "ghc" | s => s) + +fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) + +fun compile cache_patterns const_arity eqs = + let + val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () + val eqs = map (fn (a,b,c) => (b,c)) eqs + val guid = get_guid () + val module = "AMGHC_Prog_"^guid + val (arity, source) = haskell_prog module eqs + val module_file = tmp_file (module^".hs") + val object_file = tmp_file (module^".o") + val _ = writeTextFile module_file source + val _ = bash ((!ghc)^" -c "^module_file) + val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else () + in + (guid, module_file, arity) + end + +fun readResultFile name = File.read (Path.explode name) + +fun parse_result arity_of result = + let + val result = String.explode result + fun shift NONE x = SOME x + | shift (SOME y) x = SOME (y*10 + x) + fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest + | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest + | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest + | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest + | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest + | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest + | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest + | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest + | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest + | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest + | parse_int' x rest = (x, rest) + fun parse_int rest = parse_int' NONE rest + + fun parse (#"C"::rest) = + (case parse_int rest of + (SOME c, rest) => + let + val (args, rest) = parse_list (the (arity_of c)) rest + fun app_args [] t = t + | app_args (x::xs) t = app_args xs (App (t, x)) + in + (app_args args (Const c), rest) + end + | (NONE, rest) => raise Run "parse C") + | parse (#"c"::rest) = + (case parse_int rest of + (SOME c, rest) => (Const c, rest) + | _ => raise Run "parse c") + | parse (#"A"::rest) = + let + val (a, rest) = parse rest + val (b, rest) = parse rest + in + (App (a,b), rest) + end + | parse (#"L"::rest) = raise Run "there may be no abstraction in the result" + | parse _ = raise Run "invalid result" + and parse_list n rest = + if n = 0 then + ([], rest) + else + let + val (x, rest) = parse rest + val (xs, rest) = parse_list (n-1) rest + in + (x::xs, rest) + end + val (parsed, rest) = parse result + fun is_blank (#" "::rest) = is_blank rest + | is_blank (#"\n"::rest) = is_blank rest + | is_blank [] = true + | is_blank _ = false + in + if is_blank rest then parsed else raise Run "non-blank suffix in result file" + end + +fun run (guid, module_file, arity) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun arity_of c = Inttab.lookup arity c + val callguid = get_guid() + val module = "AMGHC_Prog_"^guid + val call = module^"_Call_"^callguid + val result_file = tmp_file (module^"_Result_"^callguid^".txt") + val call_file = tmp_file (call^".hs") + val term = print_term arity_of 0 t + val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")" + val _ = writeTextFile call_file call_source + val _ = bash ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file) + val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')") + val t' = parse_result arity_of result + val _ = OS.FileSys.remove call_file + val _ = OS.FileSys.remove result_file + in + t' + end + + +fun discard _ = () + +end + diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/am_interpreter.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/am_interpreter.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,213 @@ +(* Title: Tools/Compute_Oracle/am_interpreter.ML + Author: Steven Obua +*) + +signature AM_BARRAS = +sig + include ABSTRACT_MACHINE + val max_reductions : int option Unsynchronized.ref +end + +structure AM_Interpreter : AM_BARRAS = struct + +open AbstractMachine; + +datatype closure = CDummy | CVar of int | CConst of int + | CApp of closure * closure | CAbs of closure + | Closure of (closure list) * closure + +structure prog_struct = Table(type key = int*int val ord = prod_ord int_ord int_ord); + +datatype program = Program of ((pattern * closure * (closure*closure) list) list) prog_struct.table + +datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack + +fun clos_of_term (Var x) = CVar x + | clos_of_term (Const c) = CConst c + | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v) + | clos_of_term (Abs u) = CAbs (clos_of_term u) + | clos_of_term (Computed t) = clos_of_term t + +fun term_of_clos (CVar x) = Var x + | term_of_clos (CConst c) = Const c + | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v) + | term_of_clos (CAbs u) = Abs (term_of_clos u) + | term_of_clos (Closure (e, u)) = raise (Run "internal error: closure in normalized term found") + | term_of_clos CDummy = raise (Run "internal error: dummy in normalized term found") + +fun resolve_closure closures (CVar x) = (case List.nth (closures, x) of CDummy => CVar x | r => r) + | resolve_closure closures (CConst c) = CConst c + | resolve_closure closures (CApp (u, v)) = CApp (resolve_closure closures u, resolve_closure closures v) + | resolve_closure closures (CAbs u) = CAbs (resolve_closure (CDummy::closures) u) + | resolve_closure closures (CDummy) = raise (Run "internal error: resolve_closure applied to CDummy") + | resolve_closure closures (Closure (e, u)) = resolve_closure e u + +fun resolve_closure' c = resolve_closure [] c + +fun resolve_stack tm SEmpty = tm + | resolve_stack tm (SAppL (c, s)) = resolve_stack (CApp (tm, resolve_closure' c)) s + | resolve_stack tm (SAppR (c, s)) = resolve_stack (CApp (resolve_closure' c, tm)) s + | resolve_stack tm (SAbs s) = resolve_stack (CAbs tm) s + +fun resolve (stack, closure) = + let + val _ = writeln "start resolving" + val t = resolve_stack (resolve_closure' closure) stack + val _ = writeln "finished resolving" + in + t + end + +fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a + | strip_closure args x = (x, args) + +fun len_head_of_closure n (CApp (a,b)) = len_head_of_closure (n+1) a + | len_head_of_closure n x = (n, x) + + +(* earlier occurrence of PVar corresponds to higher de Bruijn index *) +fun pattern_match args PVar clos = SOME (clos::args) + | pattern_match args (PConst (c, patterns)) clos = + let + val (f, closargs) = strip_closure [] clos + in + case f of + CConst d => + if c = d then + pattern_match_list args patterns closargs + else + NONE + | _ => NONE + end +and pattern_match_list args [] [] = SOME args + | pattern_match_list args (p::ps) (c::cs) = + (case pattern_match args p c of + NONE => NONE + | SOME args => pattern_match_list args ps cs) + | pattern_match_list _ _ _ = NONE + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun pattern_key (PConst (c, ps)) = (c, length ps) + | pattern_key _ = raise (Compile "pattern reduces to variable") + +(*Returns true iff at most 0 .. (free-1) occur unbound. therefore + check_freevars 0 t iff t is closed*) +fun check_freevars free (Var x) = x < free + | check_freevars free (Const c) = true + | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v + | check_freevars free (Abs m) = check_freevars (free+1) m + | check_freevars free (Computed t) = check_freevars free t + +fun compile cache_patterns const_arity eqs = + let + fun check p r = if check_freevars p r then () else raise Compile ("unbound variables in rule") + fun check_guard p (Guard (a,b)) = (check p a; check p b) + fun clos_of_guard (Guard (a,b)) = (clos_of_term a, clos_of_term b) + val eqs = map (fn (guards, p, r) => let val pcount = count_patternvars p val _ = map (check_guard pcount) (guards) val _ = check pcount r in + (pattern_key p, (p, clos_of_term r, map clos_of_guard guards)) end) eqs + fun merge (k, a) table = prog_struct.update (k, case prog_struct.lookup table k of NONE => [a] | SOME l => a::l) table + val p = fold merge eqs prog_struct.empty + in + Program p + end + + +type state = bool * program * stack * closure + +datatype loopstate = Continue of state | Stop of stack * closure + +fun proj_C (Continue s) = s + | proj_C _ = raise Match + +exception InterruptedExecution of stack * closure + +fun proj_S (Stop s) = s + | proj_S (Continue (_,_,s,c)) = (s,c) + +fun cont (Continue _) = true + | cont _ = false + +val max_reductions = Unsynchronized.ref (NONE : int option) + +fun do_reduction reduce p = + let + val s = Unsynchronized.ref (Continue p) + val counter = Unsynchronized.ref 0 + val _ = case !max_reductions of + NONE => while cont (!s) do (s := reduce (proj_C (!s))) + | SOME m => while cont (!s) andalso (!counter < m) do (s := reduce (proj_C (!s)); counter := (!counter) + 1) + in + case !max_reductions of + SOME m => if !counter >= m then raise InterruptedExecution (proj_S (!s)) else proj_S (!s) + | NONE => proj_S (!s) + end + +fun match_rules prog n [] clos = NONE + | match_rules prog n ((p,eq,guards)::rs) clos = + case pattern_match [] p clos of + NONE => match_rules prog (n+1) rs clos + | SOME args => if forall (guard_checks prog args) guards then SOME (Closure (args, eq)) else match_rules prog (n+1) rs clos +and guard_checks prog args (a,b) = (simp prog (Closure (args, a)) = simp prog (Closure (args, b))) +and match_closure (p as (Program prog)) clos = + case len_head_of_closure 0 clos of + (len, CConst c) => + (case prog_struct.lookup prog (c, len) of + NONE => NONE + | SOME rules => match_rules p 0 rules clos) + | _ => NONE + +and weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a)) + | weak_reduce (false, prog, SAppL (b, stack), Closure (e, CAbs m)) = Continue (false, prog, stack, Closure (b::e, m)) + | weak_reduce (false, prog, stack, Closure (e, CVar n)) = Continue (false, prog, stack, case List.nth (e, n) of CDummy => CVar n | r => r) + | weak_reduce (false, prog, stack, Closure (e, c as CConst _)) = Continue (false, prog, stack, c) + | weak_reduce (false, prog, stack, clos) = + (case match_closure prog clos of + NONE => Continue (true, prog, stack, clos) + | SOME r => Continue (false, prog, stack, r)) + | weak_reduce (true, prog, SAppR (a, stack), b) = Continue (false, prog, stack, CApp (a,b)) + | weak_reduce (true, prog, s as (SAppL (b, stack)), a) = Continue (false, prog, SAppR (a, stack), b) + | weak_reduce (true, prog, stack, c) = Stop (stack, c) + +and strong_reduce (false, prog, stack, Closure (e, CAbs m)) = + (let + val (stack', wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure (CDummy::e, m)) + in + case stack' of + SEmpty => Continue (false, prog, SAbs stack, wnf) + | _ => raise (Run "internal error in strong: weak failed") + end handle InterruptedExecution state => raise InterruptedExecution (stack, resolve state)) + | strong_reduce (false, prog, stack, clos as (CApp (u, v))) = Continue (false, prog, SAppL (v, stack), u) + | strong_reduce (false, prog, stack, clos) = Continue (true, prog, stack, clos) + | strong_reduce (true, prog, SAbs stack, m) = Continue (false, prog, stack, CAbs m) + | strong_reduce (true, prog, SAppL (b, stack), a) = Continue (false, prog, SAppR (a, stack), b) + | strong_reduce (true, prog, SAppR (a, stack), b) = Continue (true, prog, stack, CApp (a, b)) + | strong_reduce (true, prog, stack, clos) = Stop (stack, clos) + +and simp prog t = + (let + val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, t) + in + case stack of + SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of + (SEmpty, snf) => snf + | _ => raise (Run "internal error in run: strong failed")) + | _ => raise (Run "internal error in run: weak failed") + end handle InterruptedExecution state => resolve state) + + +fun run prog t = + (let + val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure ([], clos_of_term t)) + in + case stack of + SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of + (SEmpty, snf) => term_of_clos snf + | _ => raise (Run "internal error in run: strong failed")) + | _ => raise (Run "internal error in run: weak failed") + end handle InterruptedExecution state => term_of_clos (resolve state)) + +fun discard p = () + +end diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/am_sml.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/am_sml.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,548 @@ +(* Title: Tools/Compute_Oracle/am_sml.ML + Author: Steven Obua + +TODO: "parameterless rewrite cannot be used in pattern": In a lot of +cases it CAN be used, and these cases should be handled +properly; right now, all cases raise an exception. +*) + +signature AM_SML = +sig + include ABSTRACT_MACHINE + val save_result : (string * term) -> unit + val set_compiled_rewriter : (term -> term) -> unit + val list_nth : 'a list * int -> 'a + val dump_output : (string option) Unsynchronized.ref +end + +structure AM_SML : AM_SML = struct + +open AbstractMachine; + +val dump_output = Unsynchronized.ref (NONE: string option) + +type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) + +val saved_result = Unsynchronized.ref (NONE:(string*term)option) + +fun save_result r = (saved_result := SOME r) +fun clear_result () = (saved_result := NONE) + +val list_nth = List.nth + +(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*) + +val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option) + +fun set_compiled_rewriter r = (compiled_rewriter := SOME r) + +fun count_patternvars PVar = 1 + | count_patternvars (PConst (_, ps)) = + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps + +fun update_arity arity code a = + (case Inttab.lookup arity code of + NONE => Inttab.update_new (code, a) arity + | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) + +(* We have to find out the maximal arity of each constant *) +fun collect_pattern_arity PVar arity = arity + | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) + +(* We also need to find out the maximal toplevel arity of each function constant *) +fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" + | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) + +local +fun collect applevel (Var _) arity = arity + | collect applevel (Const c) arity = update_arity arity c applevel + | collect applevel (Abs m) arity = collect 0 m arity + | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) +in +fun collect_term_arity t arity = collect 0 t arity +end + +fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) + + +fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) + +fun beta (Const c) = Const c + | beta (Var i) = Var i + | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) + | beta (App (a, b)) = + (case beta a of + Abs m => beta (App (Abs m, b)) + | a => App (a, beta b)) + | beta (Abs m) = Abs (beta m) + | beta (Computed t) = Computed t +and subst x (Const c) t = Const c + | subst x (Var i) t = if i = x then t else Var i + | subst x (App (a,b)) t = App (subst x a t, subst x b t) + | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) +and lift level (Const c) = Const c + | lift level (App (a,b)) = App (lift level a, lift level b) + | lift level (Var i) = if i < level then Var i else Var (i+1) + | lift level (Abs m) = Abs (lift (level + 1) m) +and unlift level (Const c) = Const c + | unlift level (App (a, b)) = App (unlift level a, unlift level b) + | unlift level (Abs m) = Abs (unlift (level+1) m) + | unlift level (Var i) = if i < level then Var i else Var (i-1) + +fun nlift level n (Var m) = if m < level then Var m else Var (m+n) + | nlift level n (Const c) = Const c + | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) + | nlift level n (Abs b) = Abs (nlift (level+1) n b) + +fun subst_const (c, t) (Const c') = if c = c' then t else Const c' + | subst_const _ (Var i) = Var i + | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) + | subst_const ct (Abs m) = Abs (subst_const ct m) + +(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) +fun inline_rules rules = + let + fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b + | term_contains_const c (Abs m) = term_contains_const c m + | term_contains_const c (Var i) = false + | term_contains_const c (Const c') = (c = c') + fun find_rewrite [] = NONE + | find_rewrite ((prems, PConst (c, []), r) :: _) = + if check_freevars 0 r then + if term_contains_const c r then + raise Compile "parameterless rewrite is caught in cycle" + else if not (null prems) then + raise Compile "parameterless rewrite may not be guarded" + else + SOME (c, r) + else raise Compile "unbound variable on right hand side or guards of rule" + | find_rewrite (_ :: rules) = find_rewrite rules + fun remove_rewrite (c,r) [] = [] + | remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = + (if c = c' then + if null args andalso r = r' andalso null (prems') then + remove_rewrite cr rules + else raise Compile "incompatible parameterless rewrites found" + else + rule :: (remove_rewrite cr rules)) + | remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs) + fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args) + | pattern_contains_const c (PVar) = false + fun inline_rewrite (ct as (c, _)) (prems, p, r) = + if pattern_contains_const c p then + raise Compile "parameterless rewrite cannot be used in pattern" + else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) + fun inline inlined rules = + (case find_rewrite rules of + NONE => (Inttab.make inlined, rules) + | SOME ct => + let + val rules = map (inline_rewrite ct) (remove_rewrite ct rules) + val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined) + in + inline inlined rules + end) + in + inline [] rules + end + + +(* + Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. + Also beta reduce the adjusted right hand side of a rule. +*) +fun adjust_rules rules = + let + val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty + val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty + fun arity_of c = the (Inttab.lookup arity c) + fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) + fun test_pattern PVar = () + | test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) + fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") + | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") + | adjust_rule (rule as (prems, p as PConst (c, args),t)) = + let + val patternvars_counted = count_patternvars p + fun check_fv t = check_freevars patternvars_counted t + val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () + val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () + val _ = map test_pattern args + val len = length args + val arity = arity_of c + val lift = nlift 0 + fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) + fun adjust_term n t = addapps_tm n (lift n t) + fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) + in + if len = arity then + rule + else if arity >= len then + (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) + else (raise Compile "internal error in adjust_rule") + end + fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") + in + (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) + end + +fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = +let + fun str x = string_of_int x + fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s + val module_prefix = (case module of NONE => "" | SOME s => s^".") + fun print_apps d f [] = f + | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args + and print_call d (App (a, b)) args = print_call d a (b::args) + | print_call d (Const c) args = + (case arity_of c of + NONE => print_apps d (module_prefix^"Const "^(str c)) args + | SOME 0 => module_prefix^"C"^(str c) + | SOME a => + let + val len = length args + in + if a <= len then + let + val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) + val _ = if strict_a > a then raise Compile "strict" else () + val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) + val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) + in + print_apps d s (List.drop (args, a)) + end + else + let + fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) + fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) + fun append_args [] t = t + | append_args (c::cs) t = append_args cs (App (t, c)) + in + print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) + end + end) + | print_call d t args = print_apps d (print_term d t) args + and print_term d (Var x) = + if x < d then + "b"^(str (d-x-1)) + else + let + val n = pattern_var_count - (x-d) - 1 + val x = "x"^(str n) + in + if n < pattern_var_count - pattern_lazy_var_count then + x + else + "("^x^" ())" + end + | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" + | print_term d t = print_call d t [] +in + print_term 0 +end + +fun section n = if n = 0 then [] else (section (n-1))@[n-1] + +fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = + let + fun str x = Int.toString x + fun print_pattern top n PVar = (n+1, "x"^(str n)) + | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) + | print_pattern top n (PConst (c, args)) = + let + val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") + val (n, s) = print_pattern_list 0 top (n, f) args + in + (n, s) + end + and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") + | print_pattern_list' counter top (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts + end + and print_pattern_list counter top (n, p) (t::ts) = + let + val (n, t) = print_pattern false n t + in + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts + end + val c = (case p of PConst (c, _) => c | _ => raise Match) + val (n, pattern) = print_pattern true 0 p + val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) + fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm + fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" + val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))) + fun print_guards t [] = print_tm t + | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch + in + (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) + end + +fun group_rules rules = + let + fun add_rule (r as (_, PConst (c,_), _)) groups = + let + val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) + in + Inttab.update (c, r::rs) groups + end + | add_rule _ _ = raise Compile "internal error group_rules" + in + fold_rev add_rule rules Inttab.empty + end + +fun sml_prog name code rules = + let + val buffer = Unsynchronized.ref "" + fun write s = (buffer := (!buffer)^s) + fun writeln s = (write s; write "\n") + fun writelist [] = () + | writelist (s::ss) = (writeln s; writelist ss) + fun str i = Int.toString i + val (inlinetab, rules) = inline_rules rules + val (arity, toplevel_arity, rules) = adjust_rules rules + val rules = group_rules rules + val constants = Inttab.keys arity + fun arity_of c = Inttab.lookup arity c + fun toplevel_arity_of c = Inttab.lookup toplevel_arity c + fun rep_str s n = implode (rep n s) + fun indexed s n = s^(str n) + fun string_of_tuple [] = "" + | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")" + fun string_of_args [] = "" + | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs)) + fun default_case gnum c = + let + val leftargs = implode (map (indexed " x") (section (the (arity_of c)))) + val rightargs = section (the (arity_of c)) + val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) + val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs + val right = (indexed "C" c)^" "^(string_of_tuple xs) + val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")" + val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right + in + (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right + end + + fun eval_rules c = + let + val arity = the (arity_of c) + val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) + fun eval_rule n = + let + val sc = string_of_int c + val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) + fun arg i = + let + val x = indexed "x" i + val x = if i < n then "(eval bounds "^x^")" else x + val x = if i < strict_arity then x else "(fn () => "^x^")" + in + x + end + val right = "c"^sc^" "^(string_of_args (map arg (section arity))) + val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right + val right = if arity > 0 then right else "C"^sc + in + " | eval bounds ("^left^") = "^right + end + in + map eval_rule (rev (section (arity + 1))) + end + + fun convert_computed_rules (c: int) : string list = + let + val arity = the (arity_of c) + fun eval_rule () = + let + val sc = string_of_int c + val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc) + fun arg i = "(convert_computed "^(indexed "x" i)^")" + val right = "C"^sc^" "^(string_of_tuple (map arg (section arity))) + val right = if arity > 0 then right else "C"^sc + in + " | convert_computed ("^left^") = "^right + end + in + [eval_rule ()] + end + + fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" + val _ = writelist [ + "structure "^name^" = struct", + "", + "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", + " "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), + ""] + fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" + fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ + (case the (arity_of c) of + 0 => "true" + | n => + let + val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) + val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) + in + eq^(implode eqs) + end) + val _ = writelist [ + "fun term_eq (Const c1) (Const c2) = (c1 = c2)", + " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] + val _ = writelist (map make_term_eq constants) + val _ = writelist [ + " | term_eq _ _ = false", + "" + ] + val _ = writelist [ + "fun app (Abs a) b = a b", + " | app a b = App (a, b)", + ""] + fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) + fun writefundecl [] = () + | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) + fun list_group c = (case Inttab.lookup rules c of + NONE => [defcase 0 c] + | SOME rs => + let + val rs = + fold + (fn r => + fn rs => + let + val (gnum, l, rs) = + (case rs of + [] => (0, [], []) + | (gnum, l)::rs => (gnum, l, rs)) + val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r + in + if gnum' = gnum then + (gnum, r::l)::rs + else + let + val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))) + fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args + val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') + in + (gnum', [])::(gnum, s::r::l)::rs + end + end) + rs [] + val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) + in + rev (map (fn z => rev (snd z)) rs) + end) + val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) + val _ = writelist [ + "fun convert (Const i) = AM_SML.Const i", + " | convert (App (a, b)) = AM_SML.App (convert a, convert b)", + " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] + fun make_convert c = + let + val args = map (indexed "a") (section (the (arity_of c))) + val leftargs = + case args of + [] => "" + | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")" + val args = map (indexed "convert a") (section (the (arity_of c))) + val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) + in + " | convert (C"^(str c)^" "^leftargs^") = "^right + end + val _ = writelist (map make_convert constants) + val _ = writelist [ + "", + "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"", + " | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""] + val _ = map (writelist o convert_computed_rules) constants + val _ = writelist [ + " | convert_computed (AbstractMachine.Const c) = Const c", + " | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)", + " | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] + val _ = writelist [ + "", + "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", + " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] + val _ = map (writelist o eval_rules) constants + val _ = writelist [ + " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", + " | eval bounds (AbstractMachine.Const c) = Const c", + " | eval bounds (AbstractMachine.Computed t) = convert_computed t"] + val _ = writelist [ + "", + "fun export term = AM_SML.save_result (\""^code^"\", convert term)", + "", + "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))", + "", + "end"] + in + (arity, toplevel_arity, inlinetab, !buffer) + end + +val guid_counter = Unsynchronized.ref 0 +fun get_guid () = + let + val c = !guid_counter + val _ = guid_counter := !guid_counter + 1 + in + (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) + end + + +fun writeTextFile name s = File.write (Path.explode name) s + +fun use_source src = use_text ML_Env.local_context (1, "") false src + +fun compile cache_patterns const_arity eqs = + let + val guid = get_guid () + val code = Real.toString (random ()) + val module = "AMSML_"^guid + val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs + val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source + val _ = compiled_rewriter := NONE + val _ = use_source source + in + case !compiled_rewriter of + NONE => raise Compile "broken link to compiled function" + | SOME f => (module, code, arity, toplevel_arity, inlinetab, f) + end + + +fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) + | inline (Var i) = Var i + | inline (App (a, b)) = App (inline a, inline b) + | inline (Abs m) = Abs (inline m) + val t = beta (inline t) + fun arity_of c = Inttab.lookup arity c + fun toplevel_arity_of c = Inttab.lookup toplevel_arity c + val term = print_term NONE arity_of toplevel_arity_of 0 0 t + val source = "local open "^module^" in val _ = export ("^term^") end" + val _ = writeTextFile "Gencode_call.ML" source + val _ = clear_result () + val _ = use_source source + in + case !saved_result of + NONE => raise Run "broken link to compiled code" + | SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked") + end + +fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = + let + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") + fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) + | inline (Var i) = Var i + | inline (App (a, b)) = App (inline a, inline b) + | inline (Abs m) = Abs (inline m) + | inline (Computed t) = Computed t + in + compiled_fun (beta (inline t)) + end + +fun discard p = () + +end diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/compute.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/compute.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,683 @@ +(* Title: Tools/Compute_Oracle/compute.ML + Author: Steven Obua +*) + +signature COMPUTE = sig + + type computer + type theorem + type naming = int -> string + + datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML + + (* Functions designated with a ! in front of them actually update the computer parameter *) + + exception Make of string + val make : machine -> theory -> thm list -> computer + val make_with_cache : machine -> theory -> term list -> thm list -> computer + val theory_of : computer -> theory + val hyps_of : computer -> term list + val shyps_of : computer -> sort list + (* ! *) val update : computer -> thm list -> unit + (* ! *) val update_with_cache : computer -> term list -> thm list -> unit + (* ! *) val discard : computer -> unit + + (* ! *) val set_naming : computer -> naming -> unit + val naming_of : computer -> naming + + exception Compute of string + val simplify : computer -> theorem -> thm + val rewrite : computer -> cterm -> thm + + val make_theorem : computer -> thm -> string list -> theorem + (* ! *) val instantiate : computer -> (string * cterm) list -> theorem -> theorem + (* ! *) val evaluate_prem : computer -> int -> theorem -> theorem + (* ! *) val modus_ponens : computer -> int -> thm -> theorem -> theorem + +end + +structure Compute :> COMPUTE = struct + +open Report; + +datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML + +(* Terms are mapped to integer codes *) +structure Encode :> +sig + type encoding + val empty : encoding + val insert : term -> encoding -> int * encoding + val lookup_code : term -> encoding -> int option + val lookup_term : int -> encoding -> term option + val remove_code : int -> encoding -> encoding + val remove_term : term -> encoding -> encoding + val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a +end += +struct + +type encoding = int * (int Termtab.table) * (term Inttab.table) + +val empty = (0, Termtab.empty, Inttab.empty) + +fun insert t (e as (count, term2int, int2term)) = + (case Termtab.lookup term2int t of + NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term)) + | SOME code => (code, e)) + +fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t + +fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c + +fun remove_code c (e as (count, term2int, int2term)) = + (case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun remove_term t (e as (count, term2int, int2term)) = + (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun fold f (_, term2int, _) = Termtab.fold f term2int + +end + +exception Make of string; +exception Compute of string; + +local + fun make_constant t ty encoding = + let + val (code, encoding) = Encode.insert t encoding + in + (encoding, AbstractMachine.Const code) + end +in + +fun remove_types encoding t = + case t of + Var (_, ty) => make_constant t ty encoding + | Free (_, ty) => make_constant t ty encoding + | Const (_, ty) => make_constant t ty encoding + | Abs (_, ty, t') => + let val (encoding, t'') = remove_types encoding t' in + (encoding, AbstractMachine.Abs t'') + end + | a $ b => + let + val (encoding, a) = remove_types encoding a + val (encoding, b) = remove_types encoding b + in + (encoding, AbstractMachine.App (a,b)) + end + | Bound b => (encoding, AbstractMachine.Var b) +end + +local + fun type_of (Free (_, ty)) = ty + | type_of (Const (_, ty)) = ty + | type_of (Var (_, ty)) = ty + | type_of _ = sys_error "infer_types: type_of error" +in +fun infer_types naming encoding = + let + fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v)) + | infer_types _ bounds _ (AbstractMachine.Const code) = + let + val c = the (Encode.lookup_term code encoding) + in + (c, type_of c) + end + | infer_types level bounds _ (AbstractMachine.App (a, b)) = + let + val (a, aty) = infer_types level bounds NONE a + val (adom, arange) = + case aty of + Type ("fun", [dom, range]) => (dom, range) + | _ => sys_error "infer_types: function type expected" + val (b, bty) = infer_types level bounds (SOME adom) b + in + (a $ b, arange) + end + | infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) = + let + val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m + in + (Abs (naming level, dom, m), ty) + end + | infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction" + + fun infer ty term = + let + val (term', _) = infer_types 0 [] (SOME ty) term + in + term' + end + in + infer + end +end + +datatype prog = + ProgBarras of AM_Interpreter.program + | ProgBarrasC of AM_Compiler.program + | ProgHaskell of AM_GHC.program + | ProgSML of AM_SML.program + +fun machine_of_prog (ProgBarras _) = BARRAS + | machine_of_prog (ProgBarrasC _) = BARRAS_COMPILED + | machine_of_prog (ProgHaskell _) = HASKELL + | machine_of_prog (ProgSML _) = SML + +type naming = int -> string + +fun default_naming i = "v_" ^ Int.toString i + +datatype computer = Computer of + (theory_ref * Encode.encoding * term list * unit Sorttab.table * prog * unit Unsynchronized.ref * naming) + option Unsynchronized.ref + +fun theory_of (Computer (Unsynchronized.ref (SOME (rthy,_,_,_,_,_,_)))) = Theory.deref rthy +fun hyps_of (Computer (Unsynchronized.ref (SOME (_,_,hyps,_,_,_,_)))) = hyps +fun shyps_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = Sorttab.keys (shyptable) +fun shyptab_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = shyptable +fun stamp_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,stamp,_)))) = stamp +fun prog_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,prog,_,_)))) = prog +fun encoding_of (Computer (Unsynchronized.ref (SOME (_,encoding,_,_,_,_,_)))) = encoding +fun set_encoding (Computer (r as Unsynchronized.ref (SOME (p1,encoding,p2,p3,p4,p5,p6)))) encoding' = + (r := SOME (p1,encoding',p2,p3,p4,p5,p6)) +fun naming_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,_,n)))) = n +fun set_naming (Computer (r as Unsynchronized.ref (SOME (p1,p2,p3,p4,p5,p6,naming)))) naming'= + (r := SOME (p1,p2,p3,p4,p5,p6,naming')) + +fun ref_of (Computer r) = r + +datatype cthm = ComputeThm of term list * sort list * term + +fun thm2cthm th = + let + val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th + val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else () + in + ComputeThm (hyps, shyps, prop) + end + +fun make_internal machine thy stamp encoding cache_pattern_terms raw_ths = + let + fun transfer (x:thm) = Thm.transfer thy x + val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths + + fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = + raise (Make "no lambda abstractions allowed in pattern") + | make_pattern encoding n vars (var as AbstractMachine.Var _) = + raise (Make "no bound variables allowed in pattern") + | make_pattern encoding n vars (AbstractMachine.Const code) = + (case the (Encode.lookup_term code encoding) of + Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) + handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) + | _ => (n, vars, AbstractMachine.PConst (code, []))) + | make_pattern encoding n vars (AbstractMachine.App (a, b)) = + let + val (n, vars, pa) = make_pattern encoding n vars a + val (n, vars, pb) = make_pattern encoding n vars b + in + case pa of + AbstractMachine.PVar => + raise (Make "patterns may not start with a variable") + | AbstractMachine.PConst (c, args) => + (n, vars, AbstractMachine.PConst (c, args@[pb])) + end + + fun thm2rule (encoding, hyptable, shyptable) th = + let + val (ComputeThm (hyps, shyps, prop)) = th + val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable + val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable + val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop) + val (a, b) = Logic.dest_equals prop + handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)") + val a = Envir.eta_contract a + val b = Envir.eta_contract b + val prems = map Envir.eta_contract prems + + val (encoding, left) = remove_types encoding a + val (encoding, right) = remove_types encoding b + fun remove_types_of_guard encoding g = + (let + val (t1, t2) = Logic.dest_equals g + val (encoding, t1) = remove_types encoding t1 + val (encoding, t2) = remove_types encoding t2 + in + (encoding, AbstractMachine.Guard (t1, t2)) + end handle TERM _ => raise (Make "guards must be meta-level equations")) + val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, []) + + (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule. + As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore + this check can be left out. *) + + val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left + val _ = (case pattern of + AbstractMachine.PVar => + raise (Make "patterns may not start with a variable") + (* | AbstractMachine.PConst (_, []) => + (print th; raise (Make "no parameter rewrite found"))*) + | _ => ()) + + (* finally, provide a function for renaming the + pattern bound variables on the right hand side *) + + fun rename level vars (var as AbstractMachine.Var _) = var + | rename level vars (c as AbstractMachine.Const code) = + (case Inttab.lookup vars code of + NONE => c + | SOME n => AbstractMachine.Var (vcount-n-1+level)) + | rename level vars (AbstractMachine.App (a, b)) = + AbstractMachine.App (rename level vars a, rename level vars b) + | rename level vars (AbstractMachine.Abs m) = + AbstractMachine.Abs (rename (level+1) vars m) + + fun rename_guard (AbstractMachine.Guard (a,b)) = + AbstractMachine.Guard (rename 0 vars a, rename 0 vars b) + in + ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right)) + end + + val ((encoding, hyptable, shyptable), rules) = + fold_rev (fn th => fn (encoding_hyptable, rules) => + let + val (encoding_hyptable, rule) = thm2rule encoding_hyptable th + in (encoding_hyptable, rule::rules) end) + ths ((encoding, Termtab.empty, Sorttab.empty), []) + + fun make_cache_pattern t (encoding, cache_patterns) = + let + val (encoding, a) = remove_types encoding t + val (_,_,p) = make_pattern encoding 0 Inttab.empty a + in + (encoding, p::cache_patterns) + end + + val (encoding, cache_patterns) = fold_rev make_cache_pattern cache_pattern_terms (encoding, []) + + fun arity (Type ("fun", [a,b])) = 1 + arity b + | arity _ = 0 + + fun make_arity (Const (s, _), i) tab = + (Inttab.update (i, arity (Sign.the_const_type thy s)) tab handle TYPE _ => tab) + | make_arity _ tab = tab + + val const_arity_tab = Encode.fold make_arity encoding Inttab.empty + fun const_arity x = Inttab.lookup const_arity_tab x + + val prog = + case machine of + BARRAS => ProgBarras (AM_Interpreter.compile cache_patterns const_arity rules) + | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile cache_patterns const_arity rules) + | HASKELL => ProgHaskell (AM_GHC.compile cache_patterns const_arity rules) + | SML => ProgSML (AM_SML.compile cache_patterns const_arity rules) + + fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) + + val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable + + in (Theory.check_thy thy, encoding, Termtab.keys hyptable, shyptable, prog, stamp, default_naming) end + +fun make_with_cache machine thy cache_patterns raw_thms = + Computer (Unsynchronized.ref (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty cache_patterns raw_thms))) + +fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms + +fun update_with_cache computer cache_patterns raw_thms = + let + val c = make_internal (machine_of_prog (prog_of computer)) (theory_of computer) (stamp_of computer) + (encoding_of computer) cache_patterns raw_thms + val _ = (ref_of computer) := SOME c + in + () + end + +fun update computer raw_thms = update_with_cache computer [] raw_thms + +fun discard computer = + let + val _ = + case prog_of computer of + ProgBarras p => AM_Interpreter.discard p + | ProgBarrasC p => AM_Compiler.discard p + | ProgHaskell p => AM_GHC.discard p + | ProgSML p => AM_SML.discard p + val _ = (ref_of computer) := NONE + in + () + end + +fun runprog (ProgBarras p) = AM_Interpreter.run p + | runprog (ProgBarrasC p) = AM_Compiler.run p + | runprog (ProgHaskell p) = AM_GHC.run p + | runprog (ProgSML p) = AM_SML.run p + +(* ------------------------------------------------------------------------------------- *) +(* An oracle for exporting theorems; must only be accessible from inside this structure! *) +(* ------------------------------------------------------------------------------------- *) + +fun merge_hyps hyps1 hyps2 = +let + fun add hyps tab = fold (fn h => fn tab => Termtab.update (h, ()) tab) hyps tab +in + Termtab.keys (add hyps2 (add hyps1 Termtab.empty)) +end + +fun add_shyps shyps tab = fold (fn h => fn tab => Sorttab.update (h, ()) tab) shyps tab + +fun merge_shyps shyps1 shyps2 = Sorttab.keys (add_shyps shyps2 (add_shyps shyps1 Sorttab.empty)) + +val (_, export_oracle) = Context.>>> (Context.map_theory_result + (Thm.add_oracle (Binding.name "compute", fn (thy, hyps, shyps, prop) => + let + val shyptab = add_shyps shyps Sorttab.empty + fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab + fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab + fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) + val shyptab = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptab))) shyptab + val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (prop::hyps) shyptab) + val _ = + if not (null shyps) then + raise Compute ("dangling sort hypotheses: " ^ + commas (map (Syntax.string_of_sort_global thy) shyps)) + else () + in + Thm.cterm_of thy (fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps prop) + end))); + +fun export_thm thy hyps shyps prop = + let + val th = export_oracle (thy, hyps, shyps, prop) + val hyps = map (fn h => Thm.assume (cterm_of thy h)) hyps + in + fold (fn h => fn p => Thm.implies_elim p h) hyps th + end + +(* --------- Rewrite ----------- *) + +fun rewrite computer ct = + let + val thy = Thm.theory_of_cterm ct + val {t=t',T=ty,...} = rep_cterm ct + val _ = Theory.assert_super (theory_of computer) thy + val naming = naming_of computer + val (encoding, t) = remove_types (encoding_of computer) t' + (*val _ = if (!print_encoding) then writeln (makestring ("encoding: ",Encode.fold (fn x => fn s => x::s) encoding [])) else ()*) + val t = runprog (prog_of computer) t + val t = infer_types naming encoding ty t + val eq = Logic.mk_equals (t', t) + in + export_thm thy (hyps_of computer) (Sorttab.keys (shyptab_of computer)) eq + end + +(* --------- Simplify ------------ *) + +datatype prem = EqPrem of AbstractMachine.term * AbstractMachine.term * Term.typ * int + | Prem of AbstractMachine.term +datatype theorem = Theorem of theory_ref * unit Unsynchronized.ref * (int * typ) Symtab.table * (AbstractMachine.term option) Inttab.table + * prem list * AbstractMachine.term * term list * sort list + + +exception ParamSimplify of computer * theorem + +fun make_theorem computer th vars = +let + val _ = Theory.assert_super (theory_of computer) (theory_of_thm th) + + val (ComputeThm (hyps, shyps, prop)) = thm2cthm th + + val encoding = encoding_of computer + + (* variables in the theorem are identified upfront *) + fun collect_vars (Abs (_, _, t)) tab = collect_vars t tab + | collect_vars (a $ b) tab = collect_vars b (collect_vars a tab) + | collect_vars (Const _) tab = tab + | collect_vars (Free _) tab = tab + | collect_vars (Var ((s, i), ty)) tab = + if List.find (fn x => x=s) vars = NONE then + tab + else + (case Symtab.lookup tab s of + SOME ((s',i'),ty') => + if s' <> s orelse i' <> i orelse ty <> ty' then + raise Compute ("make_theorem: variable name '"^s^"' is not unique") + else + tab + | NONE => Symtab.update (s, ((s, i), ty)) tab) + val vartab = collect_vars prop Symtab.empty + fun encodevar (s, t as (_, ty)) (encoding, tab) = + let + val (x, encoding) = Encode.insert (Var t) encoding + in + (encoding, Symtab.update (s, (x, ty)) tab) + end + val (encoding, vartab) = Symtab.fold encodevar vartab (encoding, Symtab.empty) + val varsubst = Inttab.make (map (fn (s, (x, _)) => (x, NONE)) (Symtab.dest vartab)) + + (* make the premises and the conclusion *) + fun mk_prem encoding t = + (let + val (a, b) = Logic.dest_equals t + val ty = type_of a + val (encoding, a) = remove_types encoding a + val (encoding, b) = remove_types encoding b + val (eq, encoding) = Encode.insert (Const ("==", ty --> ty --> @{typ "prop"})) encoding + in + (encoding, EqPrem (a, b, ty, eq)) + end handle TERM _ => let val (encoding, t) = remove_types encoding t in (encoding, Prem t) end) + val (encoding, prems) = + (fold_rev (fn t => fn (encoding, l) => + case mk_prem encoding t of + (encoding, t) => (encoding, t::l)) (Logic.strip_imp_prems prop) (encoding, [])) + val (encoding, concl) = remove_types encoding (Logic.strip_imp_concl prop) + val _ = set_encoding computer encoding +in + Theorem (Theory.check_thy (theory_of_thm th), stamp_of computer, vartab, varsubst, + prems, concl, hyps, shyps) +end + +fun theory_of_theorem (Theorem (rthy,_,_,_,_,_,_,_)) = Theory.deref rthy +fun update_theory thy (Theorem (_,p0,p1,p2,p3,p4,p5,p6)) = + Theorem (Theory.check_thy thy,p0,p1,p2,p3,p4,p5,p6) +fun stamp_of_theorem (Theorem (_,s, _, _, _, _, _, _)) = s +fun vartab_of_theorem (Theorem (_,_,vt,_,_,_,_,_)) = vt +fun varsubst_of_theorem (Theorem (_,_,_,vs,_,_,_,_)) = vs +fun update_varsubst vs (Theorem (p0,p1,p2,_,p3,p4,p5,p6)) = Theorem (p0,p1,p2,vs,p3,p4,p5,p6) +fun prems_of_theorem (Theorem (_,_,_,_,prems,_,_,_)) = prems +fun update_prems prems (Theorem (p0,p1,p2,p3,_,p4,p5,p6)) = Theorem (p0,p1,p2,p3,prems,p4,p5,p6) +fun concl_of_theorem (Theorem (_,_,_,_,_,concl,_,_)) = concl +fun hyps_of_theorem (Theorem (_,_,_,_,_,_,hyps,_)) = hyps +fun update_hyps hyps (Theorem (p0,p1,p2,p3,p4,p5,_,p6)) = Theorem (p0,p1,p2,p3,p4,p5,hyps,p6) +fun shyps_of_theorem (Theorem (_,_,_,_,_,_,_,shyps)) = shyps +fun update_shyps shyps (Theorem (p0,p1,p2,p3,p4,p5,p6,_)) = Theorem (p0,p1,p2,p3,p4,p5,p6,shyps) + +fun check_compatible computer th s = + if stamp_of computer <> stamp_of_theorem th then + raise Compute (s^": computer and theorem are incompatible") + else () + +fun instantiate computer insts th = +let + val _ = check_compatible computer th + + val thy = theory_of computer + + val vartab = vartab_of_theorem th + + fun rewrite computer t = + let + val naming = naming_of computer + val (encoding, t) = remove_types (encoding_of computer) t + val t = runprog (prog_of computer) t + val _ = set_encoding computer encoding + in + t + end + + fun assert_varfree vs t = + if AbstractMachine.forall_consts (fn x => Inttab.lookup vs x = NONE) t then + () + else + raise Compute "instantiate: assert_varfree failed" + + fun assert_closed t = + if AbstractMachine.closed t then + () + else + raise Compute "instantiate: not a closed term" + + fun compute_inst (s, ct) vs = + let + val _ = Theory.assert_super (theory_of_cterm ct) thy + val ty = typ_of (ctyp_of_term ct) + in + (case Symtab.lookup vartab s of + NONE => raise Compute ("instantiate: variable '"^s^"' not found in theorem") + | SOME (x, ty') => + (case Inttab.lookup vs x of + SOME (SOME _) => raise Compute ("instantiate: variable '"^s^"' has already been instantiated") + | SOME NONE => + if ty <> ty' then + raise Compute ("instantiate: wrong type for variable '"^s^"'") + else + let + val t = rewrite computer (term_of ct) + val _ = assert_varfree vs t + val _ = assert_closed t + in + Inttab.update (x, SOME t) vs + end + | NONE => raise Compute "instantiate: internal error")) + end + + val vs = fold compute_inst insts (varsubst_of_theorem th) +in + update_varsubst vs th +end + +fun match_aterms subst = + let + exception no_match + open AbstractMachine + fun match subst (b as (Const c)) a = + if a = b then subst + else + (case Inttab.lookup subst c of + SOME (SOME a') => if a=a' then subst else raise no_match + | SOME NONE => if AbstractMachine.closed a then + Inttab.update (c, SOME a) subst + else raise no_match + | NONE => raise no_match) + | match subst (b as (Var _)) a = if a=b then subst else raise no_match + | match subst (App (u, v)) (App (u', v')) = match (match subst u u') v v' + | match subst (Abs u) (Abs u') = match subst u u' + | match subst _ _ = raise no_match + in + fn b => fn a => (SOME (match subst b a) handle no_match => NONE) + end + +fun apply_subst vars_allowed subst = + let + open AbstractMachine + fun app (t as (Const c)) = + (case Inttab.lookup subst c of + NONE => t + | SOME (SOME t) => Computed t + | SOME NONE => if vars_allowed then t else raise Compute "apply_subst: no vars allowed") + | app (t as (Var _)) = t + | app (App (u, v)) = App (app u, app v) + | app (Abs m) = Abs (app m) + in + app + end + +fun splicein n l L = List.take (L, n) @ l @ List.drop (L, n+1) + +fun evaluate_prem computer prem_no th = +let + val _ = check_compatible computer th + val prems = prems_of_theorem th + val varsubst = varsubst_of_theorem th + fun run vars_allowed t = + runprog (prog_of computer) (apply_subst vars_allowed varsubst t) +in + case List.nth (prems, prem_no) of + Prem _ => raise Compute "evaluate_prem: no equality premise" + | EqPrem (a, b, ty, _) => + let + val a' = run false a + val b' = run true b + in + case match_aterms varsubst b' a' of + NONE => + let + fun mk s = Syntax.string_of_term_global Pure.thy + (infer_types (naming_of computer) (encoding_of computer) ty s) + val left = "computed left side: "^(mk a') + val right = "computed right side: "^(mk b') + in + raise Compute ("evaluate_prem: cannot assign computed left to right hand side\n"^left^"\n"^right^"\n") + end + | SOME varsubst => + update_prems (splicein prem_no [] prems) (update_varsubst varsubst th) + end +end + +fun prem2term (Prem t) = t + | prem2term (EqPrem (a,b,_,eq)) = + AbstractMachine.App (AbstractMachine.App (AbstractMachine.Const eq, a), b) + +fun modus_ponens computer prem_no th' th = +let + val _ = check_compatible computer th + val thy = + let + val thy1 = theory_of_theorem th + val thy2 = theory_of_thm th' + in + if Theory.subthy (thy1, thy2) then thy2 + else if Theory.subthy (thy2, thy1) then thy1 else + raise Compute "modus_ponens: theorems are not compatible with each other" + end + val th' = make_theorem computer th' [] + val varsubst = varsubst_of_theorem th + fun run vars_allowed t = + runprog (prog_of computer) (apply_subst vars_allowed varsubst t) + val prems = prems_of_theorem th + val prem = run true (prem2term (List.nth (prems, prem_no))) + val concl = run false (concl_of_theorem th') +in + case match_aterms varsubst prem concl of + NONE => raise Compute "modus_ponens: conclusion does not match premise" + | SOME varsubst => + let + val th = update_varsubst varsubst th + val th = update_prems (splicein prem_no (prems_of_theorem th') prems) th + val th = update_hyps (merge_hyps (hyps_of_theorem th) (hyps_of_theorem th')) th + val th = update_shyps (merge_shyps (shyps_of_theorem th) (shyps_of_theorem th')) th + in + update_theory thy th + end +end + +fun simplify computer th = +let + val _ = check_compatible computer th + val varsubst = varsubst_of_theorem th + val encoding = encoding_of computer + val naming = naming_of computer + fun infer t = infer_types naming encoding @{typ "prop"} t + fun run t = infer (runprog (prog_of computer) (apply_subst true varsubst t)) + fun runprem p = run (prem2term p) + val prop = Logic.list_implies (map runprem (prems_of_theorem th), run (concl_of_theorem th)) + val hyps = merge_hyps (hyps_of computer) (hyps_of_theorem th) + val shyps = merge_shyps (shyps_of_theorem th) (Sorttab.keys (shyptab_of computer)) +in + export_thm (theory_of_theorem th) hyps shyps prop +end + +end + diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/linker.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/linker.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,472 @@ +(* Title: Tools/Compute_Oracle/linker.ML + Author: Steven Obua + +This module solves the problem that the computing oracle does not +instantiate polymorphic rules. By going through the PCompute +interface, all possible instantiations are resolved by compiling new +programs, if necessary. The obvious disadvantage of this approach is +that in the worst case for each new term to be rewritten, a new +program may be compiled. +*) + +(* + Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n, + and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m + + Find all substitutions S such that + a) the domain of S is tvars (t_1, ..., t_n) + b) there are indices i_1, ..., i_k, and j_1, ..., j_k with + 1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k + 2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n) +*) +signature LINKER = +sig + exception Link of string + + datatype constant = Constant of bool * string * typ + val constant_of : term -> constant + + type instances + type subst = Type.tyenv + + val empty : constant list -> instances + val typ_of_constant : constant -> typ + val add_instances : theory -> instances -> constant list -> subst list * instances + val substs_of : instances -> subst list + val is_polymorphic : constant -> bool + val distinct_constants : constant list -> constant list + val collect_consts : term list -> constant list +end + +structure Linker : LINKER = struct + +exception Link of string; + +type subst = Type.tyenv + +datatype constant = Constant of bool * string * typ +fun constant_of (Const (name, ty)) = Constant (false, name, ty) + | constant_of (Free (name, ty)) = Constant (true, name, ty) + | constant_of _ = raise Link "constant_of" + +fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL) +fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term_Ord.typ_ord) (((x1,x2),x3), ((y1,y2),y3)) +fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2)) + + +structure Consttab = Table(type key = constant val ord = constant_ord); +structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord); + +fun typ_of_constant (Constant (_, _, ty)) = ty + +val empty_subst = (Vartab.empty : Type.tyenv) + +fun merge_subst (A:Type.tyenv) (B:Type.tyenv) = + SOME (Vartab.fold (fn (v, t) => + fn tab => + (case Vartab.lookup tab v of + NONE => Vartab.update (v, t) tab + | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B) + handle Type.TYPE_MATCH => NONE + +fun subst_ord (A:Type.tyenv, B:Type.tyenv) = + (list_ord (prod_ord Term_Ord.fast_indexname_ord (prod_ord Term_Ord.sort_ord Term_Ord.typ_ord))) (Vartab.dest A, Vartab.dest B) + +structure Substtab = Table(type key = Type.tyenv val ord = subst_ord); + +fun substtab_union c = Substtab.fold Substtab.update c +fun substtab_unions [] = Substtab.empty + | substtab_unions [c] = c + | substtab_unions (c::cs) = substtab_union c (substtab_unions cs) + +datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table + +fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty [])) + +fun distinct_constants cs = + Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty) + +fun empty cs = + let + val cs = distinct_constants (filter is_polymorphic cs) + val old_cs = cs +(* fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (OldTerm.typ_tvars ty) tab + val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty)) + fun tvars_of ty = collect_tvars ty Typtab.empty + val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs + + fun tyunion A B = + Typtab.fold + (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab) + A B + + fun is_essential A B = + Typtab.fold + (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1)) + A false + + fun add_minimal (c', tvs') (tvs, cs) = + let + val tvs = tyunion tvs' tvs + val cs = (c', tvs')::cs + in + if forall (fn (c',tvs') => is_essential tvs' tvs) cs then + SOME (tvs, cs) + else + NONE + end + + fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count) + + fun generate_minimal_subsets subsets [] = subsets + | generate_minimal_subsets subsets (c::cs) = + let + val subsets' = map_filter (add_minimal c) subsets + in + generate_minimal_subsets (subsets@subsets') cs + end*) + + val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*) + + val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty) + + in + Instances ( + fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty, + Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants), + minimal_subsets, Substtab.empty) + end + +local +fun calc ctab substtab [] = substtab + | calc ctab substtab (c::cs) = + let + val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c))) + fun merge_substs substtab subst = + Substtab.fold (fn (s,_) => + fn tab => + (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab)) + substtab Substtab.empty + val substtab = substtab_unions (map (merge_substs substtab) csubsts) + in + calc ctab substtab cs + end +in +fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs +end + +fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs = + let +(* val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*) + fun calc_instantiations (constant as Constant (free, name, ty)) instantiations = + Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) => + fn instantiations => + if free <> free' orelse name <> name' then + instantiations + else case Consttab.lookup insttab constant of + SOME _ => instantiations + | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations + handle TYPE_MATCH => instantiations)) + ctab instantiations + val instantiations = fold calc_instantiations cs [] + (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*) + fun update_ctab (constant', entry) ctab = + (case Consttab.lookup ctab constant' of + NONE => raise Link "internal error: update_ctab" + | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab) + val ctab = fold update_ctab instantiations ctab + val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs) + minsets Substtab.empty + val (added_substs, substs) = + Substtab.fold (fn (ns, _) => + fn (added, substtab) => + (case Substtab.lookup substs ns of + NONE => (ns::added, Substtab.update (ns, ()) substtab) + | SOME () => (added, substtab))) + new_substs ([], substs) + in + (added_substs, Instances (cfilter, ctab, minsets, substs)) + end + +fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs + +fun eq_to_meta th = (@{thm HOL.eq_reflection} OF [th] handle THM _ => th) + + +local + +fun collect (Var x) tab = tab + | collect (Bound _) tab = tab + | collect (a $ b) tab = collect b (collect a tab) + | collect (Abs (_, _, body)) tab = collect body tab + | collect t tab = Consttab.update (constant_of t, ()) tab + +in + fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty) +end + +end + +signature PCOMPUTE = +sig + type pcomputer + + val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer + val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer + + val add_instances : pcomputer -> Linker.constant list -> bool + val add_instances' : pcomputer -> term list -> bool + + val rewrite : pcomputer -> cterm list -> thm list + val simplify : pcomputer -> Compute.theorem -> thm + + val make_theorem : pcomputer -> thm -> string list -> Compute.theorem + val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem + val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem + val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem + +end + +structure PCompute : PCOMPUTE = struct + +exception PCompute of string + +datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list +datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list + +datatype pcomputer = + PComputer of theory_ref * Compute.computer * theorem list Unsynchronized.ref * + pattern list Unsynchronized.ref + +(*fun collect_consts (Var x) = [] + | collect_consts (Bound _) = [] + | collect_consts (a $ b) = (collect_consts a)@(collect_consts b) + | collect_consts (Abs (_, _, body)) = collect_consts body + | collect_consts t = [Linker.constant_of t]*) + +fun computer_of (PComputer (_,computer,_,_)) = computer + +fun collect_consts_of_thm th = + let + val th = prop_of th + val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th) + val (left, right) = Logic.dest_equals th + in + (Linker.collect_consts [left], Linker.collect_consts (right::prems)) + end + +fun create_theorem th = +let + val (left, right) = collect_consts_of_thm th + val polycs = filter Linker.is_polymorphic left + val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (OldTerm.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty + fun check_const (c::cs) cs' = + let + val tvars = OldTerm.typ_tvars (Linker.typ_of_constant c) + val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false + in + if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side" + else + if null (tvars) then + check_const cs (c::cs') + else + check_const cs cs' + end + | check_const [] cs' = cs' + val monocs = check_const right [] +in + if null (polycs) then + (monocs, MonoThm th) + else + (monocs, PolyThm (th, Linker.empty polycs, [])) +end + +fun create_pattern pat = +let + val cs = Linker.collect_consts [pat] + val polycs = filter Linker.is_polymorphic cs +in + if null (polycs) then + MonoPattern pat + else + PolyPattern (pat, Linker.empty polycs, []) +end + +fun create_computer machine thy pats ths = + let + fun add (MonoThm th) ths = th::ths + | add (PolyThm (_, _, ths')) ths = ths'@ths + fun addpat (MonoPattern p) pats = p::pats + | addpat (PolyPattern (_, _, ps)) pats = ps@pats + val ths = fold_rev add ths [] + val pats = fold_rev addpat pats [] + in + Compute.make_with_cache machine thy pats ths + end + +fun update_computer computer pats ths = + let + fun add (MonoThm th) ths = th::ths + | add (PolyThm (_, _, ths')) ths = ths'@ths + fun addpat (MonoPattern p) pats = p::pats + | addpat (PolyPattern (_, _, ps)) pats = ps@pats + val ths = fold_rev add ths [] + val pats = fold_rev addpat pats [] + in + Compute.update_with_cache computer pats ths + end + +fun conv_subst thy (subst : Type.tyenv) = + map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst) + +fun add_monos thy monocs pats ths = + let + val changed = Unsynchronized.ref false + fun add monocs (th as (MonoThm _)) = ([], th) + | add monocs (PolyThm (th, instances, instanceths)) = + let + val (newsubsts, instances) = Linker.add_instances thy instances monocs + val _ = if not (null newsubsts) then changed := true else () + val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts +(* val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*) + val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths [] + in + (newmonos, PolyThm (th, instances, instanceths@newths)) + end + fun addpats monocs (pat as (MonoPattern _)) = pat + | addpats monocs (PolyPattern (p, instances, instancepats)) = + let + val (newsubsts, instances) = Linker.add_instances thy instances monocs + val _ = if not (null newsubsts) then changed := true else () + val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts + in + PolyPattern (p, instances, instancepats@newpats) + end + fun step monocs ths = + fold_rev (fn th => + fn (newmonos, ths) => + let + val (newmonos', th') = add monocs th + in + (newmonos'@newmonos, th'::ths) + end) + ths ([], []) + fun loop monocs pats ths = + let + val (monocs', ths') = step monocs ths + val pats' = map (addpats monocs) pats + in + if null (monocs') then + (pats', ths') + else + loop monocs' pats' ths' + end + val result = loop monocs pats ths + in + (!changed, result) + end + +datatype cthm = ComputeThm of term list * sort list * term + +fun thm2cthm th = + let + val {hyps, prop, shyps, ...} = Thm.rep_thm th + in + ComputeThm (hyps, shyps, prop) + end + +val cthm_ord' = prod_ord (prod_ord (list_ord Term_Ord.term_ord) (list_ord Term_Ord.sort_ord)) Term_Ord.term_ord + +fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2)) + +structure CThmtab = Table(type key = cthm val ord = cthm_ord) + +fun remove_duplicates ths = + let + val counter = Unsynchronized.ref 0 + val tab = Unsynchronized.ref (CThmtab.empty : unit CThmtab.table) + val thstab = Unsynchronized.ref (Inttab.empty : thm Inttab.table) + fun update th = + let + val key = thm2cthm th + in + case CThmtab.lookup (!tab) key of + NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1) + | _ => () + end + val _ = map update ths + in + map snd (Inttab.dest (!thstab)) + end + +fun make_with_cache machine thy pats ths cs = + let + val ths = remove_duplicates ths + val (monocs, ths) = fold_rev (fn th => + fn (monocs, ths) => + let val (m, t) = create_theorem th in + (m@monocs, t::ths) + end) + ths (cs, []) + val pats = map create_pattern pats + val (_, (pats, ths)) = add_monos thy monocs pats ths + val computer = create_computer machine thy pats ths + in + PComputer (Theory.check_thy thy, computer, Unsynchronized.ref ths, Unsynchronized.ref pats) + end + +fun make machine thy ths cs = make_with_cache machine thy [] ths cs + +fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = + let + val thy = Theory.deref thyref + val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths) + in + if changed then + (update_computer computer pats ths; + rths := ths; + rpats := pats; + true) + else + false + + end + +fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts) + +fun rewrite pc cts = + let + val _ = add_instances' pc (map term_of cts) + val computer = (computer_of pc) + in + map (fn ct => Compute.rewrite computer ct) cts + end + +fun simplify pc th = Compute.simplify (computer_of pc) th + +fun make_theorem pc th vars = + let + val _ = add_instances' pc [prop_of th] + + in + Compute.make_theorem (computer_of pc) th vars + end + +fun instantiate pc insts th = + let + val _ = add_instances' pc (map (term_of o snd) insts) + in + Compute.instantiate (computer_of pc) insts th + end + +fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th + +fun modus_ponens pc prem_no th' th = + let + val _ = add_instances' pc [prop_of th'] + in + Compute.modus_ponens (computer_of pc) prem_no th' th + end + + +end diff -r c7ce7685e087 -r d83659570337 src/HOL/Matrix/Compute_Oracle/report.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Matrix/Compute_Oracle/report.ML Wed Jul 21 15:44:36 2010 +0200 @@ -0,0 +1,33 @@ +structure Report = +struct + +local + + val report_depth = Unsynchronized.ref 0 + fun space n = if n <= 0 then "" else (space (n-1))^" " + fun report_space () = space (!report_depth) + +in + +fun timeit f = + let + val t1 = start_timing () + val x = f () + val t2 = #message (end_timing t1) + val _ = writeln ((report_space ()) ^ "--> "^t2) + in + x + end + +fun report s f = +let + val _ = writeln ((report_space ())^s) + val _ = report_depth := !report_depth + 1 + val x = timeit f + val _ = report_depth := !report_depth - 1 +in + x +end + +end +end \ No newline at end of file diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/Compute_Oracle.thy --- a/src/Tools/Compute_Oracle/Compute_Oracle.thy Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,11 +0,0 @@ -(* Title: Tools/Compute_Oracle/Compute_Oracle.thy - Author: Steven Obua, TU Munich - -Steven Obua's evaluator. -*) - -theory Compute_Oracle imports HOL -uses "am.ML" "am_compiler.ML" "am_interpreter.ML" "am_ghc.ML" "am_sml.ML" "report.ML" "compute.ML" "linker.ML" -begin - -end \ No newline at end of file diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/am.ML --- a/src/Tools/Compute_Oracle/am.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,75 +0,0 @@ -signature ABSTRACT_MACHINE = -sig - -datatype term = Var of int | Const of int | App of term * term | Abs of term | Computed of term - -datatype pattern = PVar | PConst of int * (pattern list) - -datatype guard = Guard of term * term - -type program - -exception Compile of string; - -(* The de-Bruijn index 0 occurring on the right hand side refers to the LAST pattern variable, when traversing the pattern from left to right, - 1 to the second last, and so on. *) -val compile : pattern list -> (int -> int option) -> (guard list * pattern * term) list -> program - -val discard : program -> unit - -exception Run of string; -val run : program -> term -> term - -(* Utilities *) - -val check_freevars : int -> term -> bool -val forall_consts : (int -> bool) -> term -> bool -val closed : term -> bool -val erase_Computed : term -> term - -end - -structure AbstractMachine : ABSTRACT_MACHINE = -struct - -datatype term = Var of int | Const of int | App of term * term | Abs of term | Computed of term - -datatype pattern = PVar | PConst of int * (pattern list) - -datatype guard = Guard of term * term - -type program = unit - -exception Compile of string; - -fun erase_Computed (Computed t) = erase_Computed t - | erase_Computed (App (t1, t2)) = App (erase_Computed t1, erase_Computed t2) - | erase_Computed (Abs t) = Abs (erase_Computed t) - | erase_Computed t = t - -(*Returns true iff at most 0 .. (free-1) occur unbound. therefore - check_freevars 0 t iff t is closed*) -fun check_freevars free (Var x) = x < free - | check_freevars free (Const c) = true - | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v - | check_freevars free (Abs m) = check_freevars (free+1) m - | check_freevars free (Computed t) = check_freevars free t - -fun forall_consts pred (Const c) = pred c - | forall_consts pred (Var x) = true - | forall_consts pred (App (u,v)) = forall_consts pred u - andalso forall_consts pred v - | forall_consts pred (Abs m) = forall_consts pred m - | forall_consts pred (Computed t) = forall_consts pred t - -fun closed t = check_freevars 0 t - -fun compile _ = raise Compile "abstract machine stub" - -fun discard _ = raise Compile "abstract machine stub" - -exception Run of string; - -fun run p t = raise Run "abstract machine stub" - -end diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/am_compiler.ML --- a/src/Tools/Compute_Oracle/am_compiler.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,211 +0,0 @@ -(* Title: Tools/Compute_Oracle/am_compiler.ML - Author: Steven Obua -*) - -signature COMPILING_AM = -sig - include ABSTRACT_MACHINE - - val set_compiled_rewriter : (term -> term) -> unit - val list_nth : 'a list * int -> 'a - val list_map : ('a -> 'b) -> 'a list -> 'b list -end - -structure AM_Compiler : COMPILING_AM = struct - -val list_nth = List.nth; -val list_map = map; - -open AbstractMachine; - -val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option) - -fun set_compiled_rewriter r = (compiled_rewriter := SOME r) - -type program = (term -> term) - -fun count_patternvars PVar = 1 - | count_patternvars (PConst (_, ps)) = - List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps - -fun print_rule (p, t) = - let - fun str x = Int.toString x - fun print_pattern n PVar = (n+1, "x"^(str n)) - | print_pattern n (PConst (c, [])) = (n, "c"^(str c)) - | print_pattern n (PConst (c, args)) = - let - val h = print_pattern n (PConst (c,[])) - in - print_pattern_list h args - end - and print_pattern_list r [] = r - | print_pattern_list (n, p) (t::ts) = - let - val (n, t) = print_pattern n t - in - print_pattern_list (n, "App ("^p^", "^t^")") ts - end - - val (n, pattern) = print_pattern 0 p - val pattern = - if exists_string Symbol.is_ascii_blank pattern then "(" ^ pattern ^")" - else pattern - - fun print_term d (Var x) = (*if x < d then "Var "^(str x) else "x"^(str (n-(x-d)-1))*) - "Var " ^ str x - | print_term d (Const c) = "c" ^ str c - | print_term d (App (a,b)) = "App (" ^ print_term d a ^ ", " ^ print_term d b ^ ")" - | print_term d (Abs c) = "Abs (" ^ print_term (d + 1) c ^ ")" - | print_term d (Computed c) = print_term d c - - fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1)) - - val term = print_term 0 t - val term = - if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")" - else "Closure ([], "^term^")" - - in - " | weak_reduce (false, stack, "^pattern^") = Continue (false, stack, "^term^")" - end - -fun constants_of PVar = [] - | constants_of (PConst (c, ps)) = c :: maps constants_of ps - -fun constants_of_term (Var _) = [] - | constants_of_term (Abs m) = constants_of_term m - | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b) - | constants_of_term (Const c) = [c] - | constants_of_term (Computed c) = constants_of_term c - -fun load_rules sname name prog = - let - val buffer = Unsynchronized.ref "" - fun write s = (buffer := (!buffer)^s) - fun writeln s = (write s; write "\n") - fun writelist [] = () - | writelist (s::ss) = (writeln s; writelist ss) - fun str i = Int.toString i - val _ = writelist [ - "structure "^name^" = struct", - "", - "datatype term = Dummy | App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"] - val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog) - val _ = map (fn x => write (" | c"^(str x))) constants - val _ = writelist [ - "", - "datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack", - "", - "type state = bool * stack * term", - "", - "datatype loopstate = Continue of state | Stop of stack * term", - "", - "fun proj_C (Continue s) = s", - " | proj_C _ = raise Match", - "", - "fun proj_S (Stop s) = s", - " | proj_S _ = raise Match", - "", - "fun cont (Continue _) = true", - " | cont _ = false", - "", - "fun do_reduction reduce p =", - " let", - " val s = Unsynchronized.ref (Continue p)", - " val _ = while cont (!s) do (s := reduce (proj_C (!s)))", - " in", - " proj_S (!s)", - " end", - ""] - - val _ = writelist [ - "fun weak_reduce (false, stack, Closure (e, App (a, b))) = Continue (false, SAppL (Closure (e, b), stack), Closure (e, a))", - " | weak_reduce (false, SAppL (b, stack), Closure (e, Abs m)) = Continue (false, stack, Closure (b::e, m))", - " | weak_reduce (false, stack, c as Closure (e, Abs m)) = Continue (true, stack, c)", - " | weak_reduce (false, stack, Closure (e, Var n)) = Continue (false, stack, case "^sname^".list_nth (e, n) of Dummy => Var n | r => r)", - " | weak_reduce (false, stack, Closure (e, c)) = Continue (false, stack, c)"] - val _ = writelist (map print_rule prog) - val _ = writelist [ - " | weak_reduce (false, stack, clos) = Continue (true, stack, clos)", - " | weak_reduce (true, SAppR (a, stack), b) = Continue (false, stack, App (a,b))", - " | weak_reduce (true, s as (SAppL (b, stack)), a) = Continue (false, SAppR (a, stack), b)", - " | weak_reduce (true, stack, c) = Stop (stack, c)", - "", - "fun strong_reduce (false, stack, Closure (e, Abs m)) =", - " let", - " val (stack', wnf) = do_reduction weak_reduce (false, SEmpty, Closure (Dummy::e, m))", - " in", - " case stack' of", - " SEmpty => Continue (false, SAbs stack, wnf)", - " | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")", - " end", - " | strong_reduce (false, stack, clos as (App (u, v))) = Continue (false, SAppL (v, stack), u)", - " | strong_reduce (false, stack, clos) = Continue (true, stack, clos)", - " | strong_reduce (true, SAbs stack, m) = Continue (false, stack, Abs m)", - " | strong_reduce (true, SAppL (b, stack), a) = Continue (false, SAppR (a, stack), b)", - " | strong_reduce (true, SAppR (a, stack), b) = Continue (true, stack, App (a, b))", - " | strong_reduce (true, stack, clos) = Stop (stack, clos)", - ""] - - val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)" - val _ = writelist [ - "fun importTerm ("^sname^".Var x) = Var x", - " | importTerm ("^sname^".Const c) = "^ic, - " | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)", - " | importTerm ("^sname^".Abs m) = Abs (importTerm m)", - ""] - - fun ec c = " | exportTerm c"^(str c)^" = "^sname^".Const "^(str c) - val _ = writelist [ - "fun exportTerm (Var x) = "^sname^".Var x", - " | exportTerm (Const c) = "^sname^".Const c", - " | exportTerm (App (a,b)) = "^sname^".App (exportTerm a, exportTerm b)", - " | exportTerm (Abs m) = "^sname^".Abs (exportTerm m)", - " | exportTerm (Closure (closlist, clos)) = raise ("^sname^".Run \"internal error, cannot export Closure\")", - " | exportTerm Dummy = raise ("^sname^".Run \"internal error, cannot export Dummy\")"] - val _ = writelist (map ec constants) - - val _ = writelist [ - "", - "fun rewrite t = ", - " let", - " val (stack, wnf) = do_reduction weak_reduce (false, SEmpty, Closure ([], importTerm t))", - " in", - " case stack of ", - " SEmpty => (case do_reduction strong_reduce (false, SEmpty, wnf) of", - " (SEmpty, snf) => exportTerm snf", - " | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))", - " | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))", - " end", - "", - "val _ = "^sname^".set_compiled_rewriter rewrite", - "", - "end;"] - - in - compiled_rewriter := NONE; - use_text ML_Env.local_context (1, "") false (!buffer); - case !compiled_rewriter of - NONE => raise (Compile "cannot communicate with compiled function") - | SOME r => (compiled_rewriter := NONE; r) - end - -fun compile cache_patterns const_arity eqs = - let - val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () - val eqs = map (fn (a,b,c) => (b,c)) eqs - fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") - val _ = map (fn (p, r) => - (check (p, r); - case p of PVar => raise (Compile "pattern is just a variable") | _ => ())) eqs - in - load_rules "AM_Compiler" "AM_compiled_code" eqs - end - -fun run prog t = (prog t) - -fun discard p = () - -end - diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/am_ghc.ML --- a/src/Tools/Compute_Oracle/am_ghc.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,325 +0,0 @@ -(* Title: Tools/Compute_Oracle/am_ghc.ML - Author: Steven Obua -*) - -structure AM_GHC : ABSTRACT_MACHINE = struct - -open AbstractMachine; - -type program = string * string * (int Inttab.table) - -fun count_patternvars PVar = 1 - | count_patternvars (PConst (_, ps)) = - List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps - -fun update_arity arity code a = - (case Inttab.lookup arity code of - NONE => Inttab.update_new (code, a) arity - | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) - -(* We have to find out the maximal arity of each constant *) -fun collect_pattern_arity PVar arity = arity - | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) - -local -fun collect applevel (Var _) arity = arity - | collect applevel (Const c) arity = update_arity arity c applevel - | collect applevel (Abs m) arity = collect 0 m arity - | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) -in -fun collect_term_arity t arity = collect 0 t arity -end - -fun nlift level n (Var m) = if m < level then Var m else Var (m+n) - | nlift level n (Const c) = Const c - | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) - | nlift level n (Abs b) = Abs (nlift (level+1) n b) - -fun rep n x = if n = 0 then [] else x::(rep (n-1) x) - -fun adjust_rules rules = - let - val arity = fold (fn (p, t) => fn arity => collect_term_arity t (collect_pattern_arity p arity)) rules Inttab.empty - fun arity_of c = the (Inttab.lookup arity c) - fun adjust_pattern PVar = PVar - | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C - fun adjust_rule (PVar, t) = raise Compile ("pattern may not be a variable") - | adjust_rule (rule as (p as PConst (c, args),t)) = - let - val _ = if not (check_freevars (count_patternvars p) t) then raise Compile ("unbound variables on right hand side") else () - val args = map adjust_pattern args - val len = length args - val arity = arity_of c - fun lift level n (Var m) = if m < level then Var m else Var (m+n) - | lift level n (Const c) = Const c - | lift level n (App (a,b)) = App (lift level n a, lift level n b) - | lift level n (Abs b) = Abs (lift (level+1) n b) - val lift = lift 0 - fun adjust_term n t = if n=0 then t else adjust_term (n-1) (App (t, Var (n-1))) - in - if len = arity then - rule - else if arity >= len then - (PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) (lift (arity-len) t)) - else (raise Compile "internal error in adjust_rule") - end - in - (arity, map adjust_rule rules) - end - -fun print_term arity_of n = -let - fun str x = string_of_int x - fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s - - fun print_apps d f [] = f - | print_apps d f (a::args) = print_apps d ("app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args - and print_call d (App (a, b)) args = print_call d a (b::args) - | print_call d (Const c) args = - (case arity_of c of - NONE => print_apps d ("Const "^(str c)) args - | SOME a => - let - val len = length args - in - if a <= len then - let - val s = "c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, a)))) - in - print_apps d s (List.drop (args, a)) - end - else - let - fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n-1))) - fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) - fun append_args [] t = t - | append_args (c::cs) t = append_args cs (App (t, c)) - in - print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) - end - end) - | print_call d t args = print_apps d (print_term d t) args - and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else "x"^(str (n-(x-d)-1)) - | print_term d (Abs c) = "Abs (\\b"^(str d)^" -> "^(print_term (d + 1) c)^")" - | print_term d t = print_call d t [] -in - print_term 0 -end - -fun print_rule arity_of (p, t) = - let - fun str x = Int.toString x - fun print_pattern top n PVar = (n+1, "x"^(str n)) - | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)) - | print_pattern top n (PConst (c, args)) = - let - val (n,s) = print_pattern_list (n, (if top then "c" else "C")^(str c)) args - in - (n, if top then s else "("^s^")") - end - and print_pattern_list r [] = r - | print_pattern_list (n, p) (t::ts) = - let - val (n, t) = print_pattern false n t - in - print_pattern_list (n, p^" "^t) ts - end - val (n, pattern) = print_pattern true 0 p - in - pattern^" = "^(print_term arity_of n t) - end - -fun group_rules rules = - let - fun add_rule (r as (PConst (c,_), _)) groups = - let - val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) - in - Inttab.update (c, r::rs) groups - end - | add_rule _ _ = raise Compile "internal error group_rules" - in - fold_rev add_rule rules Inttab.empty - end - -fun haskell_prog name rules = - let - val buffer = Unsynchronized.ref "" - fun write s = (buffer := (!buffer)^s) - fun writeln s = (write s; write "\n") - fun writelist [] = () - | writelist (s::ss) = (writeln s; writelist ss) - fun str i = Int.toString i - val (arity, rules) = adjust_rules rules - val rules = group_rules rules - val constants = Inttab.keys arity - fun arity_of c = Inttab.lookup arity c - fun rep_str s n = implode (rep n s) - fun indexed s n = s^(str n) - fun section n = if n = 0 then [] else (section (n-1))@[n-1] - fun make_show c = - let - val args = section (the (arity_of c)) - in - " show ("^(indexed "C" c)^(implode (map (indexed " a") args))^") = " - ^"\""^(indexed "C" c)^"\""^(implode (map (fn a => "++(show "^(indexed "a" a)^")") args)) - end - fun default_case c = - let - val args = implode (map (indexed " x") (section (the (arity_of c)))) - in - (indexed "c" c)^args^" = "^(indexed "C" c)^args - end - val _ = writelist [ - "module "^name^" where", - "", - "data Term = Const Integer | App Term Term | Abs (Term -> Term)", - " "^(implode (map (fn c => " | C"^(str c)^(rep_str " Term" (the (arity_of c)))) constants)), - "", - "instance Show Term where"] - val _ = writelist (map make_show constants) - val _ = writelist [ - " show (Const c) = \"c\"++(show c)", - " show (App a b) = \"A\"++(show a)++(show b)", - " show (Abs _) = \"L\"", - ""] - val _ = writelist [ - "app (Abs a) b = a b", - "app a b = App a b", - "", - "calc s c = writeFile s (show c)", - ""] - fun list_group c = (writelist (case Inttab.lookup rules c of - NONE => [default_case c, ""] - | SOME (rs as ((PConst (_, []), _)::rs')) => - if not (null rs') then raise Compile "multiple declaration of constant" - else (map (print_rule arity_of) rs) @ [""] - | SOME rs => (map (print_rule arity_of) rs) @ [default_case c, ""])) - val _ = map list_group constants - in - (arity, !buffer) - end - -val guid_counter = Unsynchronized.ref 0 -fun get_guid () = - let - val c = !guid_counter - val _ = guid_counter := !guid_counter + 1 - in - (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) - end - -fun tmp_file s = Path.implode (Path.expand (File.tmp_path (Path.make [s]))); -fun wrap s = "\""^s^"\"" - -fun writeTextFile name s = File.write (Path.explode name) s - -val ghc = Unsynchronized.ref (case getenv "GHC_PATH" of "" => "ghc" | s => s) - -fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) - -fun compile cache_patterns const_arity eqs = - let - val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () - val eqs = map (fn (a,b,c) => (b,c)) eqs - val guid = get_guid () - val module = "AMGHC_Prog_"^guid - val (arity, source) = haskell_prog module eqs - val module_file = tmp_file (module^".hs") - val object_file = tmp_file (module^".o") - val _ = writeTextFile module_file source - val _ = bash ((!ghc)^" -c "^module_file) - val _ = if not (fileExists object_file) then raise Compile ("Failure compiling haskell code (GHC_PATH = '"^(!ghc)^"')") else () - in - (guid, module_file, arity) - end - -fun readResultFile name = File.read (Path.explode name) - -fun parse_result arity_of result = - let - val result = String.explode result - fun shift NONE x = SOME x - | shift (SOME y) x = SOME (y*10 + x) - fun parse_int' x (#"0"::rest) = parse_int' (shift x 0) rest - | parse_int' x (#"1"::rest) = parse_int' (shift x 1) rest - | parse_int' x (#"2"::rest) = parse_int' (shift x 2) rest - | parse_int' x (#"3"::rest) = parse_int' (shift x 3) rest - | parse_int' x (#"4"::rest) = parse_int' (shift x 4) rest - | parse_int' x (#"5"::rest) = parse_int' (shift x 5) rest - | parse_int' x (#"6"::rest) = parse_int' (shift x 6) rest - | parse_int' x (#"7"::rest) = parse_int' (shift x 7) rest - | parse_int' x (#"8"::rest) = parse_int' (shift x 8) rest - | parse_int' x (#"9"::rest) = parse_int' (shift x 9) rest - | parse_int' x rest = (x, rest) - fun parse_int rest = parse_int' NONE rest - - fun parse (#"C"::rest) = - (case parse_int rest of - (SOME c, rest) => - let - val (args, rest) = parse_list (the (arity_of c)) rest - fun app_args [] t = t - | app_args (x::xs) t = app_args xs (App (t, x)) - in - (app_args args (Const c), rest) - end - | (NONE, rest) => raise Run "parse C") - | parse (#"c"::rest) = - (case parse_int rest of - (SOME c, rest) => (Const c, rest) - | _ => raise Run "parse c") - | parse (#"A"::rest) = - let - val (a, rest) = parse rest - val (b, rest) = parse rest - in - (App (a,b), rest) - end - | parse (#"L"::rest) = raise Run "there may be no abstraction in the result" - | parse _ = raise Run "invalid result" - and parse_list n rest = - if n = 0 then - ([], rest) - else - let - val (x, rest) = parse rest - val (xs, rest) = parse_list (n-1) rest - in - (x::xs, rest) - end - val (parsed, rest) = parse result - fun is_blank (#" "::rest) = is_blank rest - | is_blank (#"\n"::rest) = is_blank rest - | is_blank [] = true - | is_blank _ = false - in - if is_blank rest then parsed else raise Run "non-blank suffix in result file" - end - -fun run (guid, module_file, arity) t = - let - val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") - fun arity_of c = Inttab.lookup arity c - val callguid = get_guid() - val module = "AMGHC_Prog_"^guid - val call = module^"_Call_"^callguid - val result_file = tmp_file (module^"_Result_"^callguid^".txt") - val call_file = tmp_file (call^".hs") - val term = print_term arity_of 0 t - val call_source = "module "^call^" where\n\nimport "^module^"\n\ncall = "^module^".calc \""^result_file^"\" ("^term^")" - val _ = writeTextFile call_file call_source - val _ = bash ((!ghc)^" -e \""^call^".call\" "^module_file^" "^call_file) - val result = readResultFile result_file handle IO.Io _ => raise Run ("Failure running haskell compiler (GHC_PATH = '"^(!ghc)^"')") - val t' = parse_result arity_of result - val _ = OS.FileSys.remove call_file - val _ = OS.FileSys.remove result_file - in - t' - end - - -fun discard _ = () - -end - diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/am_interpreter.ML --- a/src/Tools/Compute_Oracle/am_interpreter.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,213 +0,0 @@ -(* Title: Tools/Compute_Oracle/am_interpreter.ML - Author: Steven Obua -*) - -signature AM_BARRAS = -sig - include ABSTRACT_MACHINE - val max_reductions : int option Unsynchronized.ref -end - -structure AM_Interpreter : AM_BARRAS = struct - -open AbstractMachine; - -datatype closure = CDummy | CVar of int | CConst of int - | CApp of closure * closure | CAbs of closure - | Closure of (closure list) * closure - -structure prog_struct = Table(type key = int*int val ord = prod_ord int_ord int_ord); - -datatype program = Program of ((pattern * closure * (closure*closure) list) list) prog_struct.table - -datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack - -fun clos_of_term (Var x) = CVar x - | clos_of_term (Const c) = CConst c - | clos_of_term (App (u, v)) = CApp (clos_of_term u, clos_of_term v) - | clos_of_term (Abs u) = CAbs (clos_of_term u) - | clos_of_term (Computed t) = clos_of_term t - -fun term_of_clos (CVar x) = Var x - | term_of_clos (CConst c) = Const c - | term_of_clos (CApp (u, v)) = App (term_of_clos u, term_of_clos v) - | term_of_clos (CAbs u) = Abs (term_of_clos u) - | term_of_clos (Closure (e, u)) = raise (Run "internal error: closure in normalized term found") - | term_of_clos CDummy = raise (Run "internal error: dummy in normalized term found") - -fun resolve_closure closures (CVar x) = (case List.nth (closures, x) of CDummy => CVar x | r => r) - | resolve_closure closures (CConst c) = CConst c - | resolve_closure closures (CApp (u, v)) = CApp (resolve_closure closures u, resolve_closure closures v) - | resolve_closure closures (CAbs u) = CAbs (resolve_closure (CDummy::closures) u) - | resolve_closure closures (CDummy) = raise (Run "internal error: resolve_closure applied to CDummy") - | resolve_closure closures (Closure (e, u)) = resolve_closure e u - -fun resolve_closure' c = resolve_closure [] c - -fun resolve_stack tm SEmpty = tm - | resolve_stack tm (SAppL (c, s)) = resolve_stack (CApp (tm, resolve_closure' c)) s - | resolve_stack tm (SAppR (c, s)) = resolve_stack (CApp (resolve_closure' c, tm)) s - | resolve_stack tm (SAbs s) = resolve_stack (CAbs tm) s - -fun resolve (stack, closure) = - let - val _ = writeln "start resolving" - val t = resolve_stack (resolve_closure' closure) stack - val _ = writeln "finished resolving" - in - t - end - -fun strip_closure args (CApp (a,b)) = strip_closure (b::args) a - | strip_closure args x = (x, args) - -fun len_head_of_closure n (CApp (a,b)) = len_head_of_closure (n+1) a - | len_head_of_closure n x = (n, x) - - -(* earlier occurrence of PVar corresponds to higher de Bruijn index *) -fun pattern_match args PVar clos = SOME (clos::args) - | pattern_match args (PConst (c, patterns)) clos = - let - val (f, closargs) = strip_closure [] clos - in - case f of - CConst d => - if c = d then - pattern_match_list args patterns closargs - else - NONE - | _ => NONE - end -and pattern_match_list args [] [] = SOME args - | pattern_match_list args (p::ps) (c::cs) = - (case pattern_match args p c of - NONE => NONE - | SOME args => pattern_match_list args ps cs) - | pattern_match_list _ _ _ = NONE - -fun count_patternvars PVar = 1 - | count_patternvars (PConst (_, ps)) = List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps - -fun pattern_key (PConst (c, ps)) = (c, length ps) - | pattern_key _ = raise (Compile "pattern reduces to variable") - -(*Returns true iff at most 0 .. (free-1) occur unbound. therefore - check_freevars 0 t iff t is closed*) -fun check_freevars free (Var x) = x < free - | check_freevars free (Const c) = true - | check_freevars free (App (u, v)) = check_freevars free u andalso check_freevars free v - | check_freevars free (Abs m) = check_freevars (free+1) m - | check_freevars free (Computed t) = check_freevars free t - -fun compile cache_patterns const_arity eqs = - let - fun check p r = if check_freevars p r then () else raise Compile ("unbound variables in rule") - fun check_guard p (Guard (a,b)) = (check p a; check p b) - fun clos_of_guard (Guard (a,b)) = (clos_of_term a, clos_of_term b) - val eqs = map (fn (guards, p, r) => let val pcount = count_patternvars p val _ = map (check_guard pcount) (guards) val _ = check pcount r in - (pattern_key p, (p, clos_of_term r, map clos_of_guard guards)) end) eqs - fun merge (k, a) table = prog_struct.update (k, case prog_struct.lookup table k of NONE => [a] | SOME l => a::l) table - val p = fold merge eqs prog_struct.empty - in - Program p - end - - -type state = bool * program * stack * closure - -datatype loopstate = Continue of state | Stop of stack * closure - -fun proj_C (Continue s) = s - | proj_C _ = raise Match - -exception InterruptedExecution of stack * closure - -fun proj_S (Stop s) = s - | proj_S (Continue (_,_,s,c)) = (s,c) - -fun cont (Continue _) = true - | cont _ = false - -val max_reductions = Unsynchronized.ref (NONE : int option) - -fun do_reduction reduce p = - let - val s = Unsynchronized.ref (Continue p) - val counter = Unsynchronized.ref 0 - val _ = case !max_reductions of - NONE => while cont (!s) do (s := reduce (proj_C (!s))) - | SOME m => while cont (!s) andalso (!counter < m) do (s := reduce (proj_C (!s)); counter := (!counter) + 1) - in - case !max_reductions of - SOME m => if !counter >= m then raise InterruptedExecution (proj_S (!s)) else proj_S (!s) - | NONE => proj_S (!s) - end - -fun match_rules prog n [] clos = NONE - | match_rules prog n ((p,eq,guards)::rs) clos = - case pattern_match [] p clos of - NONE => match_rules prog (n+1) rs clos - | SOME args => if forall (guard_checks prog args) guards then SOME (Closure (args, eq)) else match_rules prog (n+1) rs clos -and guard_checks prog args (a,b) = (simp prog (Closure (args, a)) = simp prog (Closure (args, b))) -and match_closure (p as (Program prog)) clos = - case len_head_of_closure 0 clos of - (len, CConst c) => - (case prog_struct.lookup prog (c, len) of - NONE => NONE - | SOME rules => match_rules p 0 rules clos) - | _ => NONE - -and weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a)) - | weak_reduce (false, prog, SAppL (b, stack), Closure (e, CAbs m)) = Continue (false, prog, stack, Closure (b::e, m)) - | weak_reduce (false, prog, stack, Closure (e, CVar n)) = Continue (false, prog, stack, case List.nth (e, n) of CDummy => CVar n | r => r) - | weak_reduce (false, prog, stack, Closure (e, c as CConst _)) = Continue (false, prog, stack, c) - | weak_reduce (false, prog, stack, clos) = - (case match_closure prog clos of - NONE => Continue (true, prog, stack, clos) - | SOME r => Continue (false, prog, stack, r)) - | weak_reduce (true, prog, SAppR (a, stack), b) = Continue (false, prog, stack, CApp (a,b)) - | weak_reduce (true, prog, s as (SAppL (b, stack)), a) = Continue (false, prog, SAppR (a, stack), b) - | weak_reduce (true, prog, stack, c) = Stop (stack, c) - -and strong_reduce (false, prog, stack, Closure (e, CAbs m)) = - (let - val (stack', wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure (CDummy::e, m)) - in - case stack' of - SEmpty => Continue (false, prog, SAbs stack, wnf) - | _ => raise (Run "internal error in strong: weak failed") - end handle InterruptedExecution state => raise InterruptedExecution (stack, resolve state)) - | strong_reduce (false, prog, stack, clos as (CApp (u, v))) = Continue (false, prog, SAppL (v, stack), u) - | strong_reduce (false, prog, stack, clos) = Continue (true, prog, stack, clos) - | strong_reduce (true, prog, SAbs stack, m) = Continue (false, prog, stack, CAbs m) - | strong_reduce (true, prog, SAppL (b, stack), a) = Continue (false, prog, SAppR (a, stack), b) - | strong_reduce (true, prog, SAppR (a, stack), b) = Continue (true, prog, stack, CApp (a, b)) - | strong_reduce (true, prog, stack, clos) = Stop (stack, clos) - -and simp prog t = - (let - val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, t) - in - case stack of - SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of - (SEmpty, snf) => snf - | _ => raise (Run "internal error in run: strong failed")) - | _ => raise (Run "internal error in run: weak failed") - end handle InterruptedExecution state => resolve state) - - -fun run prog t = - (let - val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure ([], clos_of_term t)) - in - case stack of - SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of - (SEmpty, snf) => term_of_clos snf - | _ => raise (Run "internal error in run: strong failed")) - | _ => raise (Run "internal error in run: weak failed") - end handle InterruptedExecution state => term_of_clos (resolve state)) - -fun discard p = () - -end diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/am_sml.ML --- a/src/Tools/Compute_Oracle/am_sml.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,548 +0,0 @@ -(* Title: Tools/Compute_Oracle/am_sml.ML - Author: Steven Obua - -TODO: "parameterless rewrite cannot be used in pattern": In a lot of -cases it CAN be used, and these cases should be handled -properly; right now, all cases raise an exception. -*) - -signature AM_SML = -sig - include ABSTRACT_MACHINE - val save_result : (string * term) -> unit - val set_compiled_rewriter : (term -> term) -> unit - val list_nth : 'a list * int -> 'a - val dump_output : (string option) Unsynchronized.ref -end - -structure AM_SML : AM_SML = struct - -open AbstractMachine; - -val dump_output = Unsynchronized.ref (NONE: string option) - -type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) - -val saved_result = Unsynchronized.ref (NONE:(string*term)option) - -fun save_result r = (saved_result := SOME r) -fun clear_result () = (saved_result := NONE) - -val list_nth = List.nth - -(*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*) - -val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option) - -fun set_compiled_rewriter r = (compiled_rewriter := SOME r) - -fun count_patternvars PVar = 1 - | count_patternvars (PConst (_, ps)) = - List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps - -fun update_arity arity code a = - (case Inttab.lookup arity code of - NONE => Inttab.update_new (code, a) arity - | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) - -(* We have to find out the maximal arity of each constant *) -fun collect_pattern_arity PVar arity = arity - | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) - -(* We also need to find out the maximal toplevel arity of each function constant *) -fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" - | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) - -local -fun collect applevel (Var _) arity = arity - | collect applevel (Const c) arity = update_arity arity c applevel - | collect applevel (Abs m) arity = collect 0 m arity - | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) -in -fun collect_term_arity t arity = collect 0 t arity -end - -fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) - - -fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) - -fun beta (Const c) = Const c - | beta (Var i) = Var i - | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) - | beta (App (a, b)) = - (case beta a of - Abs m => beta (App (Abs m, b)) - | a => App (a, beta b)) - | beta (Abs m) = Abs (beta m) - | beta (Computed t) = Computed t -and subst x (Const c) t = Const c - | subst x (Var i) t = if i = x then t else Var i - | subst x (App (a,b)) t = App (subst x a t, subst x b t) - | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) -and lift level (Const c) = Const c - | lift level (App (a,b)) = App (lift level a, lift level b) - | lift level (Var i) = if i < level then Var i else Var (i+1) - | lift level (Abs m) = Abs (lift (level + 1) m) -and unlift level (Const c) = Const c - | unlift level (App (a, b)) = App (unlift level a, unlift level b) - | unlift level (Abs m) = Abs (unlift (level+1) m) - | unlift level (Var i) = if i < level then Var i else Var (i-1) - -fun nlift level n (Var m) = if m < level then Var m else Var (m+n) - | nlift level n (Const c) = Const c - | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) - | nlift level n (Abs b) = Abs (nlift (level+1) n b) - -fun subst_const (c, t) (Const c') = if c = c' then t else Const c' - | subst_const _ (Var i) = Var i - | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) - | subst_const ct (Abs m) = Abs (subst_const ct m) - -(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) -fun inline_rules rules = - let - fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b - | term_contains_const c (Abs m) = term_contains_const c m - | term_contains_const c (Var i) = false - | term_contains_const c (Const c') = (c = c') - fun find_rewrite [] = NONE - | find_rewrite ((prems, PConst (c, []), r) :: _) = - if check_freevars 0 r then - if term_contains_const c r then - raise Compile "parameterless rewrite is caught in cycle" - else if not (null prems) then - raise Compile "parameterless rewrite may not be guarded" - else - SOME (c, r) - else raise Compile "unbound variable on right hand side or guards of rule" - | find_rewrite (_ :: rules) = find_rewrite rules - fun remove_rewrite (c,r) [] = [] - | remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = - (if c = c' then - if null args andalso r = r' andalso null (prems') then - remove_rewrite cr rules - else raise Compile "incompatible parameterless rewrites found" - else - rule :: (remove_rewrite cr rules)) - | remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs) - fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args) - | pattern_contains_const c (PVar) = false - fun inline_rewrite (ct as (c, _)) (prems, p, r) = - if pattern_contains_const c p then - raise Compile "parameterless rewrite cannot be used in pattern" - else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) - fun inline inlined rules = - (case find_rewrite rules of - NONE => (Inttab.make inlined, rules) - | SOME ct => - let - val rules = map (inline_rewrite ct) (remove_rewrite ct rules) - val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined) - in - inline inlined rules - end) - in - inline [] rules - end - - -(* - Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. - Also beta reduce the adjusted right hand side of a rule. -*) -fun adjust_rules rules = - let - val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty - val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty - fun arity_of c = the (Inttab.lookup arity c) - fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) - fun test_pattern PVar = () - | test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) - fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") - | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") - | adjust_rule (rule as (prems, p as PConst (c, args),t)) = - let - val patternvars_counted = count_patternvars p - fun check_fv t = check_freevars patternvars_counted t - val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () - val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () - val _ = map test_pattern args - val len = length args - val arity = arity_of c - val lift = nlift 0 - fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) - fun adjust_term n t = addapps_tm n (lift n t) - fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) - in - if len = arity then - rule - else if arity >= len then - (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) - else (raise Compile "internal error in adjust_rule") - end - fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") - in - (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) - end - -fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = -let - fun str x = string_of_int x - fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s - val module_prefix = (case module of NONE => "" | SOME s => s^".") - fun print_apps d f [] = f - | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args - and print_call d (App (a, b)) args = print_call d a (b::args) - | print_call d (Const c) args = - (case arity_of c of - NONE => print_apps d (module_prefix^"Const "^(str c)) args - | SOME 0 => module_prefix^"C"^(str c) - | SOME a => - let - val len = length args - in - if a <= len then - let - val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) - val _ = if strict_a > a then raise Compile "strict" else () - val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) - val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) - in - print_apps d s (List.drop (args, a)) - end - else - let - fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) - fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) - fun append_args [] t = t - | append_args (c::cs) t = append_args cs (App (t, c)) - in - print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) - end - end) - | print_call d t args = print_apps d (print_term d t) args - and print_term d (Var x) = - if x < d then - "b"^(str (d-x-1)) - else - let - val n = pattern_var_count - (x-d) - 1 - val x = "x"^(str n) - in - if n < pattern_var_count - pattern_lazy_var_count then - x - else - "("^x^" ())" - end - | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" - | print_term d t = print_call d t [] -in - print_term 0 -end - -fun section n = if n = 0 then [] else (section (n-1))@[n-1] - -fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = - let - fun str x = Int.toString x - fun print_pattern top n PVar = (n+1, "x"^(str n)) - | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) - | print_pattern top n (PConst (c, args)) = - let - val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") - val (n, s) = print_pattern_list 0 top (n, f) args - in - (n, s) - end - and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") - | print_pattern_list' counter top (n, p) (t::ts) = - let - val (n, t) = print_pattern false n t - in - print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts - end - and print_pattern_list counter top (n, p) (t::ts) = - let - val (n, t) = print_pattern false n t - in - print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts - end - val c = (case p of PConst (c, _) => c | _ => raise Match) - val (n, pattern) = print_pattern true 0 p - val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) - fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm - fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" - val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))) - fun print_guards t [] = print_tm t - | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch - in - (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) - end - -fun group_rules rules = - let - fun add_rule (r as (_, PConst (c,_), _)) groups = - let - val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) - in - Inttab.update (c, r::rs) groups - end - | add_rule _ _ = raise Compile "internal error group_rules" - in - fold_rev add_rule rules Inttab.empty - end - -fun sml_prog name code rules = - let - val buffer = Unsynchronized.ref "" - fun write s = (buffer := (!buffer)^s) - fun writeln s = (write s; write "\n") - fun writelist [] = () - | writelist (s::ss) = (writeln s; writelist ss) - fun str i = Int.toString i - val (inlinetab, rules) = inline_rules rules - val (arity, toplevel_arity, rules) = adjust_rules rules - val rules = group_rules rules - val constants = Inttab.keys arity - fun arity_of c = Inttab.lookup arity c - fun toplevel_arity_of c = Inttab.lookup toplevel_arity c - fun rep_str s n = implode (rep n s) - fun indexed s n = s^(str n) - fun string_of_tuple [] = "" - | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")" - fun string_of_args [] = "" - | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs)) - fun default_case gnum c = - let - val leftargs = implode (map (indexed " x") (section (the (arity_of c)))) - val rightargs = section (the (arity_of c)) - val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) - val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs - val right = (indexed "C" c)^" "^(string_of_tuple xs) - val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")" - val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right - in - (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right - end - - fun eval_rules c = - let - val arity = the (arity_of c) - val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) - fun eval_rule n = - let - val sc = string_of_int c - val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) - fun arg i = - let - val x = indexed "x" i - val x = if i < n then "(eval bounds "^x^")" else x - val x = if i < strict_arity then x else "(fn () => "^x^")" - in - x - end - val right = "c"^sc^" "^(string_of_args (map arg (section arity))) - val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right - val right = if arity > 0 then right else "C"^sc - in - " | eval bounds ("^left^") = "^right - end - in - map eval_rule (rev (section (arity + 1))) - end - - fun convert_computed_rules (c: int) : string list = - let - val arity = the (arity_of c) - fun eval_rule () = - let - val sc = string_of_int c - val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc) - fun arg i = "(convert_computed "^(indexed "x" i)^")" - val right = "C"^sc^" "^(string_of_tuple (map arg (section arity))) - val right = if arity > 0 then right else "C"^sc - in - " | convert_computed ("^left^") = "^right - end - in - [eval_rule ()] - end - - fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" - val _ = writelist [ - "structure "^name^" = struct", - "", - "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", - " "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), - ""] - fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" - fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ - (case the (arity_of c) of - 0 => "true" - | n => - let - val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) - val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) - in - eq^(implode eqs) - end) - val _ = writelist [ - "fun term_eq (Const c1) (Const c2) = (c1 = c2)", - " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] - val _ = writelist (map make_term_eq constants) - val _ = writelist [ - " | term_eq _ _ = false", - "" - ] - val _ = writelist [ - "fun app (Abs a) b = a b", - " | app a b = App (a, b)", - ""] - fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) - fun writefundecl [] = () - | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) - fun list_group c = (case Inttab.lookup rules c of - NONE => [defcase 0 c] - | SOME rs => - let - val rs = - fold - (fn r => - fn rs => - let - val (gnum, l, rs) = - (case rs of - [] => (0, [], []) - | (gnum, l)::rs => (gnum, l, rs)) - val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r - in - if gnum' = gnum then - (gnum, r::l)::rs - else - let - val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))) - fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args - val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') - in - (gnum', [])::(gnum, s::r::l)::rs - end - end) - rs [] - val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) - in - rev (map (fn z => rev (snd z)) rs) - end) - val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) - val _ = writelist [ - "fun convert (Const i) = AM_SML.Const i", - " | convert (App (a, b)) = AM_SML.App (convert a, convert b)", - " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] - fun make_convert c = - let - val args = map (indexed "a") (section (the (arity_of c))) - val leftargs = - case args of - [] => "" - | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")" - val args = map (indexed "convert a") (section (the (arity_of c))) - val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) - in - " | convert (C"^(str c)^" "^leftargs^") = "^right - end - val _ = writelist (map make_convert constants) - val _ = writelist [ - "", - "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"", - " | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""] - val _ = map (writelist o convert_computed_rules) constants - val _ = writelist [ - " | convert_computed (AbstractMachine.Const c) = Const c", - " | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)", - " | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] - val _ = writelist [ - "", - "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", - " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] - val _ = map (writelist o eval_rules) constants - val _ = writelist [ - " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", - " | eval bounds (AbstractMachine.Const c) = Const c", - " | eval bounds (AbstractMachine.Computed t) = convert_computed t"] - val _ = writelist [ - "", - "fun export term = AM_SML.save_result (\""^code^"\", convert term)", - "", - "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))", - "", - "end"] - in - (arity, toplevel_arity, inlinetab, !buffer) - end - -val guid_counter = Unsynchronized.ref 0 -fun get_guid () = - let - val c = !guid_counter - val _ = guid_counter := !guid_counter + 1 - in - (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) - end - - -fun writeTextFile name s = File.write (Path.explode name) s - -fun use_source src = use_text ML_Env.local_context (1, "") false src - -fun compile cache_patterns const_arity eqs = - let - val guid = get_guid () - val code = Real.toString (random ()) - val module = "AMSML_"^guid - val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs - val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source - val _ = compiled_rewriter := NONE - val _ = use_source source - in - case !compiled_rewriter of - NONE => raise Compile "broken link to compiled function" - | SOME f => (module, code, arity, toplevel_arity, inlinetab, f) - end - - -fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = - let - val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") - fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) - | inline (Var i) = Var i - | inline (App (a, b)) = App (inline a, inline b) - | inline (Abs m) = Abs (inline m) - val t = beta (inline t) - fun arity_of c = Inttab.lookup arity c - fun toplevel_arity_of c = Inttab.lookup toplevel_arity c - val term = print_term NONE arity_of toplevel_arity_of 0 0 t - val source = "local open "^module^" in val _ = export ("^term^") end" - val _ = writeTextFile "Gencode_call.ML" source - val _ = clear_result () - val _ = use_source source - in - case !saved_result of - NONE => raise Run "broken link to compiled code" - | SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked") - end - -fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = - let - val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") - fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) - | inline (Var i) = Var i - | inline (App (a, b)) = App (inline a, inline b) - | inline (Abs m) = Abs (inline m) - | inline (Computed t) = Computed t - in - compiled_fun (beta (inline t)) - end - -fun discard p = () - -end diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/compute.ML --- a/src/Tools/Compute_Oracle/compute.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,683 +0,0 @@ -(* Title: Tools/Compute_Oracle/compute.ML - Author: Steven Obua -*) - -signature COMPUTE = sig - - type computer - type theorem - type naming = int -> string - - datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML - - (* Functions designated with a ! in front of them actually update the computer parameter *) - - exception Make of string - val make : machine -> theory -> thm list -> computer - val make_with_cache : machine -> theory -> term list -> thm list -> computer - val theory_of : computer -> theory - val hyps_of : computer -> term list - val shyps_of : computer -> sort list - (* ! *) val update : computer -> thm list -> unit - (* ! *) val update_with_cache : computer -> term list -> thm list -> unit - (* ! *) val discard : computer -> unit - - (* ! *) val set_naming : computer -> naming -> unit - val naming_of : computer -> naming - - exception Compute of string - val simplify : computer -> theorem -> thm - val rewrite : computer -> cterm -> thm - - val make_theorem : computer -> thm -> string list -> theorem - (* ! *) val instantiate : computer -> (string * cterm) list -> theorem -> theorem - (* ! *) val evaluate_prem : computer -> int -> theorem -> theorem - (* ! *) val modus_ponens : computer -> int -> thm -> theorem -> theorem - -end - -structure Compute :> COMPUTE = struct - -open Report; - -datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML - -(* Terms are mapped to integer codes *) -structure Encode :> -sig - type encoding - val empty : encoding - val insert : term -> encoding -> int * encoding - val lookup_code : term -> encoding -> int option - val lookup_term : int -> encoding -> term option - val remove_code : int -> encoding -> encoding - val remove_term : term -> encoding -> encoding - val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a -end -= -struct - -type encoding = int * (int Termtab.table) * (term Inttab.table) - -val empty = (0, Termtab.empty, Inttab.empty) - -fun insert t (e as (count, term2int, int2term)) = - (case Termtab.lookup term2int t of - NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term)) - | SOME code => (code, e)) - -fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t - -fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c - -fun remove_code c (e as (count, term2int, int2term)) = - (case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term)) - -fun remove_term t (e as (count, term2int, int2term)) = - (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term)) - -fun fold f (_, term2int, _) = Termtab.fold f term2int - -end - -exception Make of string; -exception Compute of string; - -local - fun make_constant t ty encoding = - let - val (code, encoding) = Encode.insert t encoding - in - (encoding, AbstractMachine.Const code) - end -in - -fun remove_types encoding t = - case t of - Var (_, ty) => make_constant t ty encoding - | Free (_, ty) => make_constant t ty encoding - | Const (_, ty) => make_constant t ty encoding - | Abs (_, ty, t') => - let val (encoding, t'') = remove_types encoding t' in - (encoding, AbstractMachine.Abs t'') - end - | a $ b => - let - val (encoding, a) = remove_types encoding a - val (encoding, b) = remove_types encoding b - in - (encoding, AbstractMachine.App (a,b)) - end - | Bound b => (encoding, AbstractMachine.Var b) -end - -local - fun type_of (Free (_, ty)) = ty - | type_of (Const (_, ty)) = ty - | type_of (Var (_, ty)) = ty - | type_of _ = sys_error "infer_types: type_of error" -in -fun infer_types naming encoding = - let - fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v)) - | infer_types _ bounds _ (AbstractMachine.Const code) = - let - val c = the (Encode.lookup_term code encoding) - in - (c, type_of c) - end - | infer_types level bounds _ (AbstractMachine.App (a, b)) = - let - val (a, aty) = infer_types level bounds NONE a - val (adom, arange) = - case aty of - Type ("fun", [dom, range]) => (dom, range) - | _ => sys_error "infer_types: function type expected" - val (b, bty) = infer_types level bounds (SOME adom) b - in - (a $ b, arange) - end - | infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) = - let - val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m - in - (Abs (naming level, dom, m), ty) - end - | infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction" - - fun infer ty term = - let - val (term', _) = infer_types 0 [] (SOME ty) term - in - term' - end - in - infer - end -end - -datatype prog = - ProgBarras of AM_Interpreter.program - | ProgBarrasC of AM_Compiler.program - | ProgHaskell of AM_GHC.program - | ProgSML of AM_SML.program - -fun machine_of_prog (ProgBarras _) = BARRAS - | machine_of_prog (ProgBarrasC _) = BARRAS_COMPILED - | machine_of_prog (ProgHaskell _) = HASKELL - | machine_of_prog (ProgSML _) = SML - -type naming = int -> string - -fun default_naming i = "v_" ^ Int.toString i - -datatype computer = Computer of - (theory_ref * Encode.encoding * term list * unit Sorttab.table * prog * unit Unsynchronized.ref * naming) - option Unsynchronized.ref - -fun theory_of (Computer (Unsynchronized.ref (SOME (rthy,_,_,_,_,_,_)))) = Theory.deref rthy -fun hyps_of (Computer (Unsynchronized.ref (SOME (_,_,hyps,_,_,_,_)))) = hyps -fun shyps_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = Sorttab.keys (shyptable) -fun shyptab_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = shyptable -fun stamp_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,stamp,_)))) = stamp -fun prog_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,prog,_,_)))) = prog -fun encoding_of (Computer (Unsynchronized.ref (SOME (_,encoding,_,_,_,_,_)))) = encoding -fun set_encoding (Computer (r as Unsynchronized.ref (SOME (p1,encoding,p2,p3,p4,p5,p6)))) encoding' = - (r := SOME (p1,encoding',p2,p3,p4,p5,p6)) -fun naming_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,_,n)))) = n -fun set_naming (Computer (r as Unsynchronized.ref (SOME (p1,p2,p3,p4,p5,p6,naming)))) naming'= - (r := SOME (p1,p2,p3,p4,p5,p6,naming')) - -fun ref_of (Computer r) = r - -datatype cthm = ComputeThm of term list * sort list * term - -fun thm2cthm th = - let - val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th - val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else () - in - ComputeThm (hyps, shyps, prop) - end - -fun make_internal machine thy stamp encoding cache_pattern_terms raw_ths = - let - fun transfer (x:thm) = Thm.transfer thy x - val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths - - fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = - raise (Make "no lambda abstractions allowed in pattern") - | make_pattern encoding n vars (var as AbstractMachine.Var _) = - raise (Make "no bound variables allowed in pattern") - | make_pattern encoding n vars (AbstractMachine.Const code) = - (case the (Encode.lookup_term code encoding) of - Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) - handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) - | _ => (n, vars, AbstractMachine.PConst (code, []))) - | make_pattern encoding n vars (AbstractMachine.App (a, b)) = - let - val (n, vars, pa) = make_pattern encoding n vars a - val (n, vars, pb) = make_pattern encoding n vars b - in - case pa of - AbstractMachine.PVar => - raise (Make "patterns may not start with a variable") - | AbstractMachine.PConst (c, args) => - (n, vars, AbstractMachine.PConst (c, args@[pb])) - end - - fun thm2rule (encoding, hyptable, shyptable) th = - let - val (ComputeThm (hyps, shyps, prop)) = th - val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable - val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable - val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop) - val (a, b) = Logic.dest_equals prop - handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)") - val a = Envir.eta_contract a - val b = Envir.eta_contract b - val prems = map Envir.eta_contract prems - - val (encoding, left) = remove_types encoding a - val (encoding, right) = remove_types encoding b - fun remove_types_of_guard encoding g = - (let - val (t1, t2) = Logic.dest_equals g - val (encoding, t1) = remove_types encoding t1 - val (encoding, t2) = remove_types encoding t2 - in - (encoding, AbstractMachine.Guard (t1, t2)) - end handle TERM _ => raise (Make "guards must be meta-level equations")) - val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, []) - - (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule. - As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore - this check can be left out. *) - - val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left - val _ = (case pattern of - AbstractMachine.PVar => - raise (Make "patterns may not start with a variable") - (* | AbstractMachine.PConst (_, []) => - (print th; raise (Make "no parameter rewrite found"))*) - | _ => ()) - - (* finally, provide a function for renaming the - pattern bound variables on the right hand side *) - - fun rename level vars (var as AbstractMachine.Var _) = var - | rename level vars (c as AbstractMachine.Const code) = - (case Inttab.lookup vars code of - NONE => c - | SOME n => AbstractMachine.Var (vcount-n-1+level)) - | rename level vars (AbstractMachine.App (a, b)) = - AbstractMachine.App (rename level vars a, rename level vars b) - | rename level vars (AbstractMachine.Abs m) = - AbstractMachine.Abs (rename (level+1) vars m) - - fun rename_guard (AbstractMachine.Guard (a,b)) = - AbstractMachine.Guard (rename 0 vars a, rename 0 vars b) - in - ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right)) - end - - val ((encoding, hyptable, shyptable), rules) = - fold_rev (fn th => fn (encoding_hyptable, rules) => - let - val (encoding_hyptable, rule) = thm2rule encoding_hyptable th - in (encoding_hyptable, rule::rules) end) - ths ((encoding, Termtab.empty, Sorttab.empty), []) - - fun make_cache_pattern t (encoding, cache_patterns) = - let - val (encoding, a) = remove_types encoding t - val (_,_,p) = make_pattern encoding 0 Inttab.empty a - in - (encoding, p::cache_patterns) - end - - val (encoding, cache_patterns) = fold_rev make_cache_pattern cache_pattern_terms (encoding, []) - - fun arity (Type ("fun", [a,b])) = 1 + arity b - | arity _ = 0 - - fun make_arity (Const (s, _), i) tab = - (Inttab.update (i, arity (Sign.the_const_type thy s)) tab handle TYPE _ => tab) - | make_arity _ tab = tab - - val const_arity_tab = Encode.fold make_arity encoding Inttab.empty - fun const_arity x = Inttab.lookup const_arity_tab x - - val prog = - case machine of - BARRAS => ProgBarras (AM_Interpreter.compile cache_patterns const_arity rules) - | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile cache_patterns const_arity rules) - | HASKELL => ProgHaskell (AM_GHC.compile cache_patterns const_arity rules) - | SML => ProgSML (AM_SML.compile cache_patterns const_arity rules) - - fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) - - val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable - - in (Theory.check_thy thy, encoding, Termtab.keys hyptable, shyptable, prog, stamp, default_naming) end - -fun make_with_cache machine thy cache_patterns raw_thms = - Computer (Unsynchronized.ref (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty cache_patterns raw_thms))) - -fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms - -fun update_with_cache computer cache_patterns raw_thms = - let - val c = make_internal (machine_of_prog (prog_of computer)) (theory_of computer) (stamp_of computer) - (encoding_of computer) cache_patterns raw_thms - val _ = (ref_of computer) := SOME c - in - () - end - -fun update computer raw_thms = update_with_cache computer [] raw_thms - -fun discard computer = - let - val _ = - case prog_of computer of - ProgBarras p => AM_Interpreter.discard p - | ProgBarrasC p => AM_Compiler.discard p - | ProgHaskell p => AM_GHC.discard p - | ProgSML p => AM_SML.discard p - val _ = (ref_of computer) := NONE - in - () - end - -fun runprog (ProgBarras p) = AM_Interpreter.run p - | runprog (ProgBarrasC p) = AM_Compiler.run p - | runprog (ProgHaskell p) = AM_GHC.run p - | runprog (ProgSML p) = AM_SML.run p - -(* ------------------------------------------------------------------------------------- *) -(* An oracle for exporting theorems; must only be accessible from inside this structure! *) -(* ------------------------------------------------------------------------------------- *) - -fun merge_hyps hyps1 hyps2 = -let - fun add hyps tab = fold (fn h => fn tab => Termtab.update (h, ()) tab) hyps tab -in - Termtab.keys (add hyps2 (add hyps1 Termtab.empty)) -end - -fun add_shyps shyps tab = fold (fn h => fn tab => Sorttab.update (h, ()) tab) shyps tab - -fun merge_shyps shyps1 shyps2 = Sorttab.keys (add_shyps shyps2 (add_shyps shyps1 Sorttab.empty)) - -val (_, export_oracle) = Context.>>> (Context.map_theory_result - (Thm.add_oracle (Binding.name "compute", fn (thy, hyps, shyps, prop) => - let - val shyptab = add_shyps shyps Sorttab.empty - fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab - fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab - fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) - val shyptab = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptab))) shyptab - val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (prop::hyps) shyptab) - val _ = - if not (null shyps) then - raise Compute ("dangling sort hypotheses: " ^ - commas (map (Syntax.string_of_sort_global thy) shyps)) - else () - in - Thm.cterm_of thy (fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps prop) - end))); - -fun export_thm thy hyps shyps prop = - let - val th = export_oracle (thy, hyps, shyps, prop) - val hyps = map (fn h => Thm.assume (cterm_of thy h)) hyps - in - fold (fn h => fn p => Thm.implies_elim p h) hyps th - end - -(* --------- Rewrite ----------- *) - -fun rewrite computer ct = - let - val thy = Thm.theory_of_cterm ct - val {t=t',T=ty,...} = rep_cterm ct - val _ = Theory.assert_super (theory_of computer) thy - val naming = naming_of computer - val (encoding, t) = remove_types (encoding_of computer) t' - (*val _ = if (!print_encoding) then writeln (makestring ("encoding: ",Encode.fold (fn x => fn s => x::s) encoding [])) else ()*) - val t = runprog (prog_of computer) t - val t = infer_types naming encoding ty t - val eq = Logic.mk_equals (t', t) - in - export_thm thy (hyps_of computer) (Sorttab.keys (shyptab_of computer)) eq - end - -(* --------- Simplify ------------ *) - -datatype prem = EqPrem of AbstractMachine.term * AbstractMachine.term * Term.typ * int - | Prem of AbstractMachine.term -datatype theorem = Theorem of theory_ref * unit Unsynchronized.ref * (int * typ) Symtab.table * (AbstractMachine.term option) Inttab.table - * prem list * AbstractMachine.term * term list * sort list - - -exception ParamSimplify of computer * theorem - -fun make_theorem computer th vars = -let - val _ = Theory.assert_super (theory_of computer) (theory_of_thm th) - - val (ComputeThm (hyps, shyps, prop)) = thm2cthm th - - val encoding = encoding_of computer - - (* variables in the theorem are identified upfront *) - fun collect_vars (Abs (_, _, t)) tab = collect_vars t tab - | collect_vars (a $ b) tab = collect_vars b (collect_vars a tab) - | collect_vars (Const _) tab = tab - | collect_vars (Free _) tab = tab - | collect_vars (Var ((s, i), ty)) tab = - if List.find (fn x => x=s) vars = NONE then - tab - else - (case Symtab.lookup tab s of - SOME ((s',i'),ty') => - if s' <> s orelse i' <> i orelse ty <> ty' then - raise Compute ("make_theorem: variable name '"^s^"' is not unique") - else - tab - | NONE => Symtab.update (s, ((s, i), ty)) tab) - val vartab = collect_vars prop Symtab.empty - fun encodevar (s, t as (_, ty)) (encoding, tab) = - let - val (x, encoding) = Encode.insert (Var t) encoding - in - (encoding, Symtab.update (s, (x, ty)) tab) - end - val (encoding, vartab) = Symtab.fold encodevar vartab (encoding, Symtab.empty) - val varsubst = Inttab.make (map (fn (s, (x, _)) => (x, NONE)) (Symtab.dest vartab)) - - (* make the premises and the conclusion *) - fun mk_prem encoding t = - (let - val (a, b) = Logic.dest_equals t - val ty = type_of a - val (encoding, a) = remove_types encoding a - val (encoding, b) = remove_types encoding b - val (eq, encoding) = Encode.insert (Const ("==", ty --> ty --> @{typ "prop"})) encoding - in - (encoding, EqPrem (a, b, ty, eq)) - end handle TERM _ => let val (encoding, t) = remove_types encoding t in (encoding, Prem t) end) - val (encoding, prems) = - (fold_rev (fn t => fn (encoding, l) => - case mk_prem encoding t of - (encoding, t) => (encoding, t::l)) (Logic.strip_imp_prems prop) (encoding, [])) - val (encoding, concl) = remove_types encoding (Logic.strip_imp_concl prop) - val _ = set_encoding computer encoding -in - Theorem (Theory.check_thy (theory_of_thm th), stamp_of computer, vartab, varsubst, - prems, concl, hyps, shyps) -end - -fun theory_of_theorem (Theorem (rthy,_,_,_,_,_,_,_)) = Theory.deref rthy -fun update_theory thy (Theorem (_,p0,p1,p2,p3,p4,p5,p6)) = - Theorem (Theory.check_thy thy,p0,p1,p2,p3,p4,p5,p6) -fun stamp_of_theorem (Theorem (_,s, _, _, _, _, _, _)) = s -fun vartab_of_theorem (Theorem (_,_,vt,_,_,_,_,_)) = vt -fun varsubst_of_theorem (Theorem (_,_,_,vs,_,_,_,_)) = vs -fun update_varsubst vs (Theorem (p0,p1,p2,_,p3,p4,p5,p6)) = Theorem (p0,p1,p2,vs,p3,p4,p5,p6) -fun prems_of_theorem (Theorem (_,_,_,_,prems,_,_,_)) = prems -fun update_prems prems (Theorem (p0,p1,p2,p3,_,p4,p5,p6)) = Theorem (p0,p1,p2,p3,prems,p4,p5,p6) -fun concl_of_theorem (Theorem (_,_,_,_,_,concl,_,_)) = concl -fun hyps_of_theorem (Theorem (_,_,_,_,_,_,hyps,_)) = hyps -fun update_hyps hyps (Theorem (p0,p1,p2,p3,p4,p5,_,p6)) = Theorem (p0,p1,p2,p3,p4,p5,hyps,p6) -fun shyps_of_theorem (Theorem (_,_,_,_,_,_,_,shyps)) = shyps -fun update_shyps shyps (Theorem (p0,p1,p2,p3,p4,p5,p6,_)) = Theorem (p0,p1,p2,p3,p4,p5,p6,shyps) - -fun check_compatible computer th s = - if stamp_of computer <> stamp_of_theorem th then - raise Compute (s^": computer and theorem are incompatible") - else () - -fun instantiate computer insts th = -let - val _ = check_compatible computer th - - val thy = theory_of computer - - val vartab = vartab_of_theorem th - - fun rewrite computer t = - let - val naming = naming_of computer - val (encoding, t) = remove_types (encoding_of computer) t - val t = runprog (prog_of computer) t - val _ = set_encoding computer encoding - in - t - end - - fun assert_varfree vs t = - if AbstractMachine.forall_consts (fn x => Inttab.lookup vs x = NONE) t then - () - else - raise Compute "instantiate: assert_varfree failed" - - fun assert_closed t = - if AbstractMachine.closed t then - () - else - raise Compute "instantiate: not a closed term" - - fun compute_inst (s, ct) vs = - let - val _ = Theory.assert_super (theory_of_cterm ct) thy - val ty = typ_of (ctyp_of_term ct) - in - (case Symtab.lookup vartab s of - NONE => raise Compute ("instantiate: variable '"^s^"' not found in theorem") - | SOME (x, ty') => - (case Inttab.lookup vs x of - SOME (SOME _) => raise Compute ("instantiate: variable '"^s^"' has already been instantiated") - | SOME NONE => - if ty <> ty' then - raise Compute ("instantiate: wrong type for variable '"^s^"'") - else - let - val t = rewrite computer (term_of ct) - val _ = assert_varfree vs t - val _ = assert_closed t - in - Inttab.update (x, SOME t) vs - end - | NONE => raise Compute "instantiate: internal error")) - end - - val vs = fold compute_inst insts (varsubst_of_theorem th) -in - update_varsubst vs th -end - -fun match_aterms subst = - let - exception no_match - open AbstractMachine - fun match subst (b as (Const c)) a = - if a = b then subst - else - (case Inttab.lookup subst c of - SOME (SOME a') => if a=a' then subst else raise no_match - | SOME NONE => if AbstractMachine.closed a then - Inttab.update (c, SOME a) subst - else raise no_match - | NONE => raise no_match) - | match subst (b as (Var _)) a = if a=b then subst else raise no_match - | match subst (App (u, v)) (App (u', v')) = match (match subst u u') v v' - | match subst (Abs u) (Abs u') = match subst u u' - | match subst _ _ = raise no_match - in - fn b => fn a => (SOME (match subst b a) handle no_match => NONE) - end - -fun apply_subst vars_allowed subst = - let - open AbstractMachine - fun app (t as (Const c)) = - (case Inttab.lookup subst c of - NONE => t - | SOME (SOME t) => Computed t - | SOME NONE => if vars_allowed then t else raise Compute "apply_subst: no vars allowed") - | app (t as (Var _)) = t - | app (App (u, v)) = App (app u, app v) - | app (Abs m) = Abs (app m) - in - app - end - -fun splicein n l L = List.take (L, n) @ l @ List.drop (L, n+1) - -fun evaluate_prem computer prem_no th = -let - val _ = check_compatible computer th - val prems = prems_of_theorem th - val varsubst = varsubst_of_theorem th - fun run vars_allowed t = - runprog (prog_of computer) (apply_subst vars_allowed varsubst t) -in - case List.nth (prems, prem_no) of - Prem _ => raise Compute "evaluate_prem: no equality premise" - | EqPrem (a, b, ty, _) => - let - val a' = run false a - val b' = run true b - in - case match_aterms varsubst b' a' of - NONE => - let - fun mk s = Syntax.string_of_term_global Pure.thy - (infer_types (naming_of computer) (encoding_of computer) ty s) - val left = "computed left side: "^(mk a') - val right = "computed right side: "^(mk b') - in - raise Compute ("evaluate_prem: cannot assign computed left to right hand side\n"^left^"\n"^right^"\n") - end - | SOME varsubst => - update_prems (splicein prem_no [] prems) (update_varsubst varsubst th) - end -end - -fun prem2term (Prem t) = t - | prem2term (EqPrem (a,b,_,eq)) = - AbstractMachine.App (AbstractMachine.App (AbstractMachine.Const eq, a), b) - -fun modus_ponens computer prem_no th' th = -let - val _ = check_compatible computer th - val thy = - let - val thy1 = theory_of_theorem th - val thy2 = theory_of_thm th' - in - if Theory.subthy (thy1, thy2) then thy2 - else if Theory.subthy (thy2, thy1) then thy1 else - raise Compute "modus_ponens: theorems are not compatible with each other" - end - val th' = make_theorem computer th' [] - val varsubst = varsubst_of_theorem th - fun run vars_allowed t = - runprog (prog_of computer) (apply_subst vars_allowed varsubst t) - val prems = prems_of_theorem th - val prem = run true (prem2term (List.nth (prems, prem_no))) - val concl = run false (concl_of_theorem th') -in - case match_aterms varsubst prem concl of - NONE => raise Compute "modus_ponens: conclusion does not match premise" - | SOME varsubst => - let - val th = update_varsubst varsubst th - val th = update_prems (splicein prem_no (prems_of_theorem th') prems) th - val th = update_hyps (merge_hyps (hyps_of_theorem th) (hyps_of_theorem th')) th - val th = update_shyps (merge_shyps (shyps_of_theorem th) (shyps_of_theorem th')) th - in - update_theory thy th - end -end - -fun simplify computer th = -let - val _ = check_compatible computer th - val varsubst = varsubst_of_theorem th - val encoding = encoding_of computer - val naming = naming_of computer - fun infer t = infer_types naming encoding @{typ "prop"} t - fun run t = infer (runprog (prog_of computer) (apply_subst true varsubst t)) - fun runprem p = run (prem2term p) - val prop = Logic.list_implies (map runprem (prems_of_theorem th), run (concl_of_theorem th)) - val hyps = merge_hyps (hyps_of computer) (hyps_of_theorem th) - val shyps = merge_shyps (shyps_of_theorem th) (Sorttab.keys (shyptab_of computer)) -in - export_thm (theory_of_theorem th) hyps shyps prop -end - -end - diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/linker.ML --- a/src/Tools/Compute_Oracle/linker.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,472 +0,0 @@ -(* Title: Tools/Compute_Oracle/linker.ML - Author: Steven Obua - -This module solves the problem that the computing oracle does not -instantiate polymorphic rules. By going through the PCompute -interface, all possible instantiations are resolved by compiling new -programs, if necessary. The obvious disadvantage of this approach is -that in the worst case for each new term to be rewritten, a new -program may be compiled. -*) - -(* - Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n, - and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m - - Find all substitutions S such that - a) the domain of S is tvars (t_1, ..., t_n) - b) there are indices i_1, ..., i_k, and j_1, ..., j_k with - 1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k - 2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n) -*) -signature LINKER = -sig - exception Link of string - - datatype constant = Constant of bool * string * typ - val constant_of : term -> constant - - type instances - type subst = Type.tyenv - - val empty : constant list -> instances - val typ_of_constant : constant -> typ - val add_instances : theory -> instances -> constant list -> subst list * instances - val substs_of : instances -> subst list - val is_polymorphic : constant -> bool - val distinct_constants : constant list -> constant list - val collect_consts : term list -> constant list -end - -structure Linker : LINKER = struct - -exception Link of string; - -type subst = Type.tyenv - -datatype constant = Constant of bool * string * typ -fun constant_of (Const (name, ty)) = Constant (false, name, ty) - | constant_of (Free (name, ty)) = Constant (true, name, ty) - | constant_of _ = raise Link "constant_of" - -fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL) -fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term_Ord.typ_ord) (((x1,x2),x3), ((y1,y2),y3)) -fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2)) - - -structure Consttab = Table(type key = constant val ord = constant_ord); -structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord); - -fun typ_of_constant (Constant (_, _, ty)) = ty - -val empty_subst = (Vartab.empty : Type.tyenv) - -fun merge_subst (A:Type.tyenv) (B:Type.tyenv) = - SOME (Vartab.fold (fn (v, t) => - fn tab => - (case Vartab.lookup tab v of - NONE => Vartab.update (v, t) tab - | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B) - handle Type.TYPE_MATCH => NONE - -fun subst_ord (A:Type.tyenv, B:Type.tyenv) = - (list_ord (prod_ord Term_Ord.fast_indexname_ord (prod_ord Term_Ord.sort_ord Term_Ord.typ_ord))) (Vartab.dest A, Vartab.dest B) - -structure Substtab = Table(type key = Type.tyenv val ord = subst_ord); - -fun substtab_union c = Substtab.fold Substtab.update c -fun substtab_unions [] = Substtab.empty - | substtab_unions [c] = c - | substtab_unions (c::cs) = substtab_union c (substtab_unions cs) - -datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table - -fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty [])) - -fun distinct_constants cs = - Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty) - -fun empty cs = - let - val cs = distinct_constants (filter is_polymorphic cs) - val old_cs = cs -(* fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (OldTerm.typ_tvars ty) tab - val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty)) - fun tvars_of ty = collect_tvars ty Typtab.empty - val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs - - fun tyunion A B = - Typtab.fold - (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab) - A B - - fun is_essential A B = - Typtab.fold - (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1)) - A false - - fun add_minimal (c', tvs') (tvs, cs) = - let - val tvs = tyunion tvs' tvs - val cs = (c', tvs')::cs - in - if forall (fn (c',tvs') => is_essential tvs' tvs) cs then - SOME (tvs, cs) - else - NONE - end - - fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count) - - fun generate_minimal_subsets subsets [] = subsets - | generate_minimal_subsets subsets (c::cs) = - let - val subsets' = map_filter (add_minimal c) subsets - in - generate_minimal_subsets (subsets@subsets') cs - end*) - - val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*) - - val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty) - - in - Instances ( - fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty, - Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants), - minimal_subsets, Substtab.empty) - end - -local -fun calc ctab substtab [] = substtab - | calc ctab substtab (c::cs) = - let - val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c))) - fun merge_substs substtab subst = - Substtab.fold (fn (s,_) => - fn tab => - (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab)) - substtab Substtab.empty - val substtab = substtab_unions (map (merge_substs substtab) csubsts) - in - calc ctab substtab cs - end -in -fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs -end - -fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs = - let -(* val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*) - fun calc_instantiations (constant as Constant (free, name, ty)) instantiations = - Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) => - fn instantiations => - if free <> free' orelse name <> name' then - instantiations - else case Consttab.lookup insttab constant of - SOME _ => instantiations - | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations - handle TYPE_MATCH => instantiations)) - ctab instantiations - val instantiations = fold calc_instantiations cs [] - (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*) - fun update_ctab (constant', entry) ctab = - (case Consttab.lookup ctab constant' of - NONE => raise Link "internal error: update_ctab" - | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab) - val ctab = fold update_ctab instantiations ctab - val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs) - minsets Substtab.empty - val (added_substs, substs) = - Substtab.fold (fn (ns, _) => - fn (added, substtab) => - (case Substtab.lookup substs ns of - NONE => (ns::added, Substtab.update (ns, ()) substtab) - | SOME () => (added, substtab))) - new_substs ([], substs) - in - (added_substs, Instances (cfilter, ctab, minsets, substs)) - end - -fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs - -fun eq_to_meta th = (@{thm HOL.eq_reflection} OF [th] handle THM _ => th) - - -local - -fun collect (Var x) tab = tab - | collect (Bound _) tab = tab - | collect (a $ b) tab = collect b (collect a tab) - | collect (Abs (_, _, body)) tab = collect body tab - | collect t tab = Consttab.update (constant_of t, ()) tab - -in - fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty) -end - -end - -signature PCOMPUTE = -sig - type pcomputer - - val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer - val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer - - val add_instances : pcomputer -> Linker.constant list -> bool - val add_instances' : pcomputer -> term list -> bool - - val rewrite : pcomputer -> cterm list -> thm list - val simplify : pcomputer -> Compute.theorem -> thm - - val make_theorem : pcomputer -> thm -> string list -> Compute.theorem - val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem - val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem - val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem - -end - -structure PCompute : PCOMPUTE = struct - -exception PCompute of string - -datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list -datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list - -datatype pcomputer = - PComputer of theory_ref * Compute.computer * theorem list Unsynchronized.ref * - pattern list Unsynchronized.ref - -(*fun collect_consts (Var x) = [] - | collect_consts (Bound _) = [] - | collect_consts (a $ b) = (collect_consts a)@(collect_consts b) - | collect_consts (Abs (_, _, body)) = collect_consts body - | collect_consts t = [Linker.constant_of t]*) - -fun computer_of (PComputer (_,computer,_,_)) = computer - -fun collect_consts_of_thm th = - let - val th = prop_of th - val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th) - val (left, right) = Logic.dest_equals th - in - (Linker.collect_consts [left], Linker.collect_consts (right::prems)) - end - -fun create_theorem th = -let - val (left, right) = collect_consts_of_thm th - val polycs = filter Linker.is_polymorphic left - val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (OldTerm.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty - fun check_const (c::cs) cs' = - let - val tvars = OldTerm.typ_tvars (Linker.typ_of_constant c) - val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false - in - if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side" - else - if null (tvars) then - check_const cs (c::cs') - else - check_const cs cs' - end - | check_const [] cs' = cs' - val monocs = check_const right [] -in - if null (polycs) then - (monocs, MonoThm th) - else - (monocs, PolyThm (th, Linker.empty polycs, [])) -end - -fun create_pattern pat = -let - val cs = Linker.collect_consts [pat] - val polycs = filter Linker.is_polymorphic cs -in - if null (polycs) then - MonoPattern pat - else - PolyPattern (pat, Linker.empty polycs, []) -end - -fun create_computer machine thy pats ths = - let - fun add (MonoThm th) ths = th::ths - | add (PolyThm (_, _, ths')) ths = ths'@ths - fun addpat (MonoPattern p) pats = p::pats - | addpat (PolyPattern (_, _, ps)) pats = ps@pats - val ths = fold_rev add ths [] - val pats = fold_rev addpat pats [] - in - Compute.make_with_cache machine thy pats ths - end - -fun update_computer computer pats ths = - let - fun add (MonoThm th) ths = th::ths - | add (PolyThm (_, _, ths')) ths = ths'@ths - fun addpat (MonoPattern p) pats = p::pats - | addpat (PolyPattern (_, _, ps)) pats = ps@pats - val ths = fold_rev add ths [] - val pats = fold_rev addpat pats [] - in - Compute.update_with_cache computer pats ths - end - -fun conv_subst thy (subst : Type.tyenv) = - map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst) - -fun add_monos thy monocs pats ths = - let - val changed = Unsynchronized.ref false - fun add monocs (th as (MonoThm _)) = ([], th) - | add monocs (PolyThm (th, instances, instanceths)) = - let - val (newsubsts, instances) = Linker.add_instances thy instances monocs - val _ = if not (null newsubsts) then changed := true else () - val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts -(* val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*) - val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths [] - in - (newmonos, PolyThm (th, instances, instanceths@newths)) - end - fun addpats monocs (pat as (MonoPattern _)) = pat - | addpats monocs (PolyPattern (p, instances, instancepats)) = - let - val (newsubsts, instances) = Linker.add_instances thy instances monocs - val _ = if not (null newsubsts) then changed := true else () - val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts - in - PolyPattern (p, instances, instancepats@newpats) - end - fun step monocs ths = - fold_rev (fn th => - fn (newmonos, ths) => - let - val (newmonos', th') = add monocs th - in - (newmonos'@newmonos, th'::ths) - end) - ths ([], []) - fun loop monocs pats ths = - let - val (monocs', ths') = step monocs ths - val pats' = map (addpats monocs) pats - in - if null (monocs') then - (pats', ths') - else - loop monocs' pats' ths' - end - val result = loop monocs pats ths - in - (!changed, result) - end - -datatype cthm = ComputeThm of term list * sort list * term - -fun thm2cthm th = - let - val {hyps, prop, shyps, ...} = Thm.rep_thm th - in - ComputeThm (hyps, shyps, prop) - end - -val cthm_ord' = prod_ord (prod_ord (list_ord Term_Ord.term_ord) (list_ord Term_Ord.sort_ord)) Term_Ord.term_ord - -fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2)) - -structure CThmtab = Table(type key = cthm val ord = cthm_ord) - -fun remove_duplicates ths = - let - val counter = Unsynchronized.ref 0 - val tab = Unsynchronized.ref (CThmtab.empty : unit CThmtab.table) - val thstab = Unsynchronized.ref (Inttab.empty : thm Inttab.table) - fun update th = - let - val key = thm2cthm th - in - case CThmtab.lookup (!tab) key of - NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1) - | _ => () - end - val _ = map update ths - in - map snd (Inttab.dest (!thstab)) - end - -fun make_with_cache machine thy pats ths cs = - let - val ths = remove_duplicates ths - val (monocs, ths) = fold_rev (fn th => - fn (monocs, ths) => - let val (m, t) = create_theorem th in - (m@monocs, t::ths) - end) - ths (cs, []) - val pats = map create_pattern pats - val (_, (pats, ths)) = add_monos thy monocs pats ths - val computer = create_computer machine thy pats ths - in - PComputer (Theory.check_thy thy, computer, Unsynchronized.ref ths, Unsynchronized.ref pats) - end - -fun make machine thy ths cs = make_with_cache machine thy [] ths cs - -fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = - let - val thy = Theory.deref thyref - val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths) - in - if changed then - (update_computer computer pats ths; - rths := ths; - rpats := pats; - true) - else - false - - end - -fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts) - -fun rewrite pc cts = - let - val _ = add_instances' pc (map term_of cts) - val computer = (computer_of pc) - in - map (fn ct => Compute.rewrite computer ct) cts - end - -fun simplify pc th = Compute.simplify (computer_of pc) th - -fun make_theorem pc th vars = - let - val _ = add_instances' pc [prop_of th] - - in - Compute.make_theorem (computer_of pc) th vars - end - -fun instantiate pc insts th = - let - val _ = add_instances' pc (map (term_of o snd) insts) - in - Compute.instantiate (computer_of pc) insts th - end - -fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th - -fun modus_ponens pc prem_no th' th = - let - val _ = add_instances' pc [prop_of th'] - in - Compute.modus_ponens (computer_of pc) prem_no th' th - end - - -end diff -r c7ce7685e087 -r d83659570337 src/Tools/Compute_Oracle/report.ML --- a/src/Tools/Compute_Oracle/report.ML Wed Jul 21 15:31:38 2010 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,33 +0,0 @@ -structure Report = -struct - -local - - val report_depth = Unsynchronized.ref 0 - fun space n = if n <= 0 then "" else (space (n-1))^" " - fun report_space () = space (!report_depth) - -in - -fun timeit f = - let - val t1 = start_timing () - val x = f () - val t2 = #message (end_timing t1) - val _ = writeln ((report_space ()) ^ "--> "^t2) - in - x - end - -fun report s f = -let - val _ = writeln ((report_space ())^s) - val _ = report_depth := !report_depth + 1 - val x = timeit f - val _ = report_depth := !report_depth - 1 -in - x -end - -end -end \ No newline at end of file