--- a/src/Tools/Compute_Oracle/am.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/am.ML Mon Dec 03 17:47:35 2007 +0100
@@ -13,7 +13,7 @@
(* The de-Bruijn index 0 occurring on the right hand side refers to the LAST pattern variable, when traversing the pattern from left to right,
1 to the second last, and so on. *)
-val compile : (guard list * pattern * term) list -> program
+val compile : pattern list -> (int -> int option) -> (guard list * pattern * term) list -> program
val discard : program -> unit
--- a/src/Tools/Compute_Oracle/am_compiler.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/am_compiler.ML Mon Dec 03 17:47:35 2007 +0100
@@ -192,7 +192,7 @@
| SOME r => (compiled_rewriter := NONE; r)
end
-fun compile eqs =
+fun compile cache_patterns const_arity eqs =
let
val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
val eqs = map (fn (a,b,c) => (b,c)) eqs
--- a/src/Tools/Compute_Oracle/am_ghc.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/am_ghc.ML Mon Dec 03 17:47:35 2007 +0100
@@ -219,7 +219,7 @@
fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false)
-fun compile eqs =
+fun compile cache_patterns const_arity eqs =
let
val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
val eqs = map (fn (a,b,c) => (b,c)) eqs
--- a/src/Tools/Compute_Oracle/am_interpreter.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/am_interpreter.ML Mon Dec 03 17:47:35 2007 +0100
@@ -19,7 +19,7 @@
structure prog_struct = TableFun(type key = int*int val ord = prod_ord int_ord int_ord);
-datatype program = Program of ((pattern * closure) list) prog_struct.table
+datatype program = Program of ((pattern * closure * (closure*closure) list) list) prog_struct.table
datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack
@@ -101,32 +101,19 @@
| check_freevars free (Abs m) = check_freevars (free+1) m
| check_freevars free (Computed t) = check_freevars free t
-fun compile eqs =
+fun compile cache_patterns const_arity eqs =
let
- val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
- val eqs = map (fn (a,b,c) => (b,c)) eqs
- fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule")
- val eqs = map (fn (p, r) => (check (p,r); (pattern_key p, (p, clos_of_term r)))) eqs
+ fun check p r = if check_freevars p r then () else raise Compile ("unbound variables in rule")
+ fun check_guard p (Guard (a,b)) = (check p a; check p b)
+ fun clos_of_guard (Guard (a,b)) = (clos_of_term a, clos_of_term b)
+ val eqs = map (fn (guards, p, r) => let val pcount = count_patternvars p val _ = map (check_guard pcount) (guards) val _ = check pcount r in
+ (pattern_key p, (p, clos_of_term r, map clos_of_guard guards)) end) eqs
fun merge (k, a) table = prog_struct.update (k, case prog_struct.lookup table k of NONE => [a] | SOME l => a::l) table
val p = fold merge eqs prog_struct.empty
in
Program p
end
-fun match_rules n [] clos = NONE
- | match_rules n ((p,eq)::rs) clos =
- case pattern_match [] p clos of
- NONE => match_rules (n+1) rs clos
- | SOME args => SOME (Closure (args, eq))
-
-fun match_closure (Program prog) clos =
- case len_head_of_closure 0 clos of
- (len, CConst c) =>
- (case prog_struct.lookup prog (c, len) of
- NONE => NONE
- | SOME rules => match_rules 0 rules clos)
- | _ => NONE
-
type state = bool * program * stack * closure
@@ -158,7 +145,21 @@
| NONE => proj_S (!s)
end
-fun weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a))
+fun match_rules prog n [] clos = NONE
+ | match_rules prog n ((p,eq,guards)::rs) clos =
+ case pattern_match [] p clos of
+ NONE => match_rules prog (n+1) rs clos
+ | SOME args => if forall (guard_checks prog args) guards then SOME (Closure (args, eq)) else match_rules prog (n+1) rs clos
+and guard_checks prog args (a,b) = (simp prog (Closure (args, a)) = simp prog (Closure (args, b)))
+and match_closure (p as (Program prog)) clos =
+ case len_head_of_closure 0 clos of
+ (len, CConst c) =>
+ (case prog_struct.lookup prog (c, len) of
+ NONE => NONE
+ | SOME rules => match_rules p 0 rules clos)
+ | _ => NONE
+
+and weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a))
| weak_reduce (false, prog, SAppL (b, stack), Closure (e, CAbs m)) = Continue (false, prog, stack, Closure (b::e, m))
| weak_reduce (false, prog, stack, Closure (e, CVar n)) = Continue (false, prog, stack, case List.nth (e, n) of CDummy => CVar n | r => r)
| weak_reduce (false, prog, stack, Closure (e, c as CConst _)) = Continue (false, prog, stack, c)
@@ -170,7 +171,7 @@
| weak_reduce (true, prog, s as (SAppL (b, stack)), a) = Continue (false, prog, SAppR (a, stack), b)
| weak_reduce (true, prog, stack, c) = Stop (stack, c)
-fun strong_reduce (false, prog, stack, Closure (e, CAbs m)) =
+and strong_reduce (false, prog, stack, Closure (e, CAbs m)) =
(let
val (stack', wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure (CDummy::e, m))
in
@@ -185,6 +186,18 @@
| strong_reduce (true, prog, SAppR (a, stack), b) = Continue (true, prog, stack, CApp (a, b))
| strong_reduce (true, prog, stack, clos) = Stop (stack, clos)
+and simp prog t =
+ (let
+ val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, t)
+ in
+ case stack of
+ SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of
+ (SEmpty, snf) => snf
+ | _ => raise (Run "internal error in run: strong failed"))
+ | _ => raise (Run "internal error in run: weak failed")
+ end handle InterruptedExecution state => resolve state)
+
+
fun run prog t =
(let
val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure ([], clos_of_term t))
--- a/src/Tools/Compute_Oracle/am_sml.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/am_sml.ML Mon Dec 03 17:47:35 2007 +0100
@@ -13,12 +13,15 @@
val save_result : (string * term) -> unit
val set_compiled_rewriter : (term -> term) -> unit
val list_nth : 'a list * int -> 'a
+ val dump_output : (string option) ref
end
structure AM_SML : AM_SML = struct
open AbstractMachine;
+val dump_output = ref NONE
+
type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term)
val saved_result = ref (NONE:(string*term)option)
@@ -155,8 +158,8 @@
val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
fun arity_of c = the (Inttab.lookup arity c)
fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c)
- fun adjust_pattern PVar = PVar
- | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C
+ fun test_pattern PVar = ()
+ | test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ())
fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
| adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
| adjust_rule (rule as (prems, p as PConst (c, args),t)) =
@@ -165,13 +168,13 @@
fun check_fv t = check_freevars patternvars_counted t
val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else ()
val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else ()
- val args = map adjust_pattern args
+ val _ = map test_pattern args
val len = length args
val arity = arity_of c
val lift = nlift 0
- fun adjust_tm n t = if n=0 then t else adjust_tm (n-1) (App (t, Var (n-1)))
- fun adjust_term n t = adjust_tm n (lift n t)
- fun adjust_guard n (Guard (a,b)) = Guard (adjust_term n a, adjust_term n b)
+ fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1)))
+ fun adjust_term n t = addapps_tm n (lift n t)
+ fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b)
in
if len = arity then
rule
@@ -179,8 +182,7 @@
(map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
else (raise Compile "internal error in adjust_rule")
end
- fun beta_guard (Guard (a,b)) = Guard (beta a, beta b)
- fun beta_rule (prems, p, t) = ((map beta_guard prems, p, beta t) handle Match => raise Compile "beta_rule")
+ fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule")
in
(arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
end
@@ -493,13 +495,13 @@
fun use_source src = use_text "" Output.ml_output false src
-fun compile eqs =
+fun compile cache_patterns const_arity eqs =
let
val guid = get_guid ()
val code = Real.toString (random ())
val module = "AMSML_"^guid
val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs
- (*val _ = writeTextFile "Gencode.ML" source*)
+ val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source
val _ = compiled_rewriter := NONE
val _ = use_source source
in
--- a/src/Tools/Compute_Oracle/compute.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/compute.ML Mon Dec 03 17:47:35 2007 +0100
@@ -15,10 +15,12 @@
exception Make of string
val make : machine -> theory -> thm list -> computer
+ val make_with_cache : machine -> theory -> term list -> thm list -> computer
val theory_of : computer -> theory
val hyps_of : computer -> term list
val shyps_of : computer -> sort list
(* ! *) val update : computer -> thm list -> unit
+ (* ! *) val update_with_cache : computer -> term list -> thm list -> unit
(* ! *) val discard : computer -> unit
(* ! *) val set_naming : computer -> naming -> unit
@@ -35,16 +37,12 @@
val setup_compute : theory -> theory
- val print_encoding : bool ref
-
end
structure Compute :> COMPUTE = struct
open Report;
-val print_encoding = ref false
-
datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML
(* Terms are mapped to integer codes *)
@@ -57,7 +55,7 @@
val lookup_term : int -> encoding -> term option
val remove_code : int -> encoding -> encoding
val remove_term : term -> encoding -> encoding
- val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a
+ val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a
end
=
struct
@@ -81,7 +79,7 @@
fun remove_term t (e as (count, term2int, int2term)) =
(case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term))
-fun fold f (_, term2int, _) = Termtab.fold f term2int
+fun fold f (_, term2int, _) = Termtab.fold f term2int
end
@@ -206,11 +204,32 @@
ComputeThm (hyps, shyps, prop)
end
-fun make_internal machine thy stamp encoding raw_ths =
+fun make_internal machine thy stamp encoding cache_pattern_terms raw_ths =
let
fun transfer (x:thm) = Thm.transfer thy x
val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths
+ fun make_pattern encoding n vars (var as AbstractMachine.Abs _) =
+ raise (Make "no lambda abstractions allowed in pattern")
+ | make_pattern encoding n vars (var as AbstractMachine.Var _) =
+ raise (Make "no bound variables allowed in pattern")
+ | make_pattern encoding n vars (AbstractMachine.Const code) =
+ (case the (Encode.lookup_term code encoding) of
+ Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar)
+ handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed"))
+ | _ => (n, vars, AbstractMachine.PConst (code, [])))
+ | make_pattern encoding n vars (AbstractMachine.App (a, b)) =
+ let
+ val (n, vars, pa) = make_pattern encoding n vars a
+ val (n, vars, pb) = make_pattern encoding n vars b
+ in
+ case pa of
+ AbstractMachine.PVar =>
+ raise (Make "patterns may not start with a variable")
+ | AbstractMachine.PConst (c, args) =>
+ (n, vars, AbstractMachine.PConst (c, args@[pb]))
+ end
+
fun thm2rule (encoding, hyptable, shyptable) th =
let
val (ComputeThm (hyps, shyps, prop)) = th
@@ -234,27 +253,6 @@
(encoding, AbstractMachine.Guard (t1, t2))
end handle TERM _ => raise (Make "guards must be meta-level equations"))
val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, [])
-
- fun make_pattern encoding n vars (var as AbstractMachine.Abs _) =
- raise (Make "no lambda abstractions allowed in pattern")
- | make_pattern encoding n vars (var as AbstractMachine.Var _) =
- raise (Make "no bound variables allowed in pattern")
- | make_pattern encoding n vars (AbstractMachine.Const code) =
- (case the (Encode.lookup_term code encoding) of
- Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar)
- handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed"))
- | _ => (n, vars, AbstractMachine.PConst (code, [])))
- | make_pattern encoding n vars (AbstractMachine.App (a, b)) =
- let
- val (n, vars, pa) = make_pattern encoding n vars a
- val (n, vars, pb) = make_pattern encoding n vars b
- in
- case pa of
- AbstractMachine.PVar =>
- raise (Make "patterns may not start with a variable")
- | AbstractMachine.PConst (c, args) =>
- (n, vars, AbstractMachine.PConst (c, args@[pb]))
- end
(* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule.
As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore
@@ -294,12 +292,32 @@
in (encoding_hyptable, rule::rules) end)
ths ((encoding, Termtab.empty, Sorttab.empty), [])
+ fun make_cache_pattern t (encoding, cache_patterns) =
+ let
+ val (encoding, a) = remove_types encoding t
+ val (_,_,p) = make_pattern encoding 0 Inttab.empty a
+ in
+ (encoding, p::cache_patterns)
+ end
+
+ val (encoding, cache_patterns) = fold_rev make_cache_pattern cache_pattern_terms (encoding, [])
+
+ fun arity (Type ("fun", [a,b])) = 1 + arity b
+ | arity _ = 0
+
+ fun make_arity (Const (s, _), i) tab =
+ (Inttab.update (i, arity (Sign.the_const_type thy s)) tab handle TYPE _ => tab)
+ | make_arity _ tab = tab
+
+ val const_arity_tab = Encode.fold make_arity encoding Inttab.empty
+ fun const_arity x = Inttab.lookup const_arity_tab x
+
val prog =
case machine of
- BARRAS => ProgBarras (AM_Interpreter.compile rules)
- | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules)
- | HASKELL => ProgHaskell (AM_GHC.compile rules)
- | SML => ProgSML (AM_SML.compile rules)
+ BARRAS => ProgBarras (AM_Interpreter.compile cache_patterns const_arity rules)
+ | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile cache_patterns const_arity rules)
+ | HASKELL => ProgHaskell (AM_GHC.compile cache_patterns const_arity rules)
+ | SML => ProgSML (AM_SML.compile cache_patterns const_arity rules)
fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
@@ -307,17 +325,21 @@
in (Theory.check_thy thy, encoding, Termtab.keys hyptable, shyptable, prog, stamp, default_naming) end
-fun make machine thy raw_thms = Computer (ref (SOME (make_internal machine thy (ref ()) Encode.empty raw_thms)))
+fun make_with_cache machine thy cache_patterns raw_thms = Computer (ref (SOME (make_internal machine thy (ref ()) Encode.empty cache_patterns raw_thms)))
-fun update computer raw_thms =
+fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms
+
+fun update_with_cache computer cache_patterns raw_thms =
let
val c = make_internal (machine_of_prog (prog_of computer)) (theory_of computer) (stamp_of computer)
- (encoding_of computer) raw_thms
+ (encoding_of computer) cache_patterns raw_thms
val _ = (ref_of computer) := SOME c
in
()
end
+fun update computer raw_thms = update_with_cache computer [] raw_thms
+
fun discard computer =
let
val _ =
--- a/src/Tools/Compute_Oracle/linker.ML Mon Dec 03 16:04:17 2007 +0100
+++ b/src/Tools/Compute_Oracle/linker.ML Mon Dec 03 17:47:35 2007 +0100
@@ -218,6 +218,7 @@
type pcomputer
val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
+ val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
val add_instances : pcomputer -> Linker.constant list -> bool
val add_instances' : pcomputer -> term list -> bool
@@ -237,8 +238,9 @@
exception PCompute of string
datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
+datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
-datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref
+datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref * pattern list ref
(*fun collect_consts (Var x) = []
| collect_consts (Bound _) = []
@@ -246,7 +248,7 @@
| collect_consts (Abs (_, _, body)) = collect_consts body
| collect_consts t = [Linker.constant_of t]*)
-fun computer_of (PComputer (_,computer,_)) = computer
+fun computer_of (PComputer (_,computer,_,_)) = computer
fun collect_consts_of_thm th =
let
@@ -283,28 +285,45 @@
(monocs, PolyThm (th, Linker.empty polycs, []))
end
-fun create_computer machine thy ths =
+fun create_pattern pat =
+let
+ val cs = Linker.collect_consts [pat]
+ val polycs = filter Linker.is_polymorphic cs
+in
+ if null (polycs) then
+ MonoPattern pat
+ else
+ PolyPattern (pat, Linker.empty polycs, [])
+end
+
+fun create_computer machine thy pats ths =
let
fun add (MonoThm th) ths = th::ths
| add (PolyThm (_, _, ths')) ths = ths'@ths
+ fun addpat (MonoPattern p) pats = p::pats
+ | addpat (PolyPattern (_, _, ps)) pats = ps@pats
val ths = fold_rev add ths []
+ val pats = fold_rev addpat pats []
in
- Compute.make machine thy ths
+ Compute.make_with_cache machine thy pats ths
end
-fun update_computer computer ths =
+fun update_computer computer pats ths =
let
fun add (MonoThm th) ths = th::ths
| add (PolyThm (_, _, ths')) ths = ths'@ths
+ fun addpat (MonoPattern p) pats = p::pats
+ | addpat (PolyPattern (_, _, ps)) pats = ps@pats
val ths = fold_rev add ths []
+ val pats = fold_rev addpat pats []
in
- Compute.update computer ths
+ Compute.update_with_cache computer pats ths
end
fun conv_subst thy (subst : Type.tyenv) =
map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
-fun add_monos thy monocs ths =
+fun add_monos thy monocs pats ths =
let
val changed = ref false
fun add monocs (th as (MonoThm _)) = ([], th)
@@ -318,21 +337,35 @@
in
(newmonos, PolyThm (th, instances, instanceths@newths))
end
+ fun addpats monocs (pat as (MonoPattern _)) = pat
+ | addpats monocs (PolyPattern (p, instances, instancepats)) =
+ let
+ val (newsubsts, instances) = Linker.add_instances thy instances monocs
+ val _ = if not (null newsubsts) then changed := true else ()
+ val newpats = map (fn subst => Envir.subst_TVars subst p) newsubsts
+ in
+ PolyPattern (p, instances, instancepats@newpats)
+ end
fun step monocs ths =
fold_rev (fn th =>
fn (newmonos, ths) =>
- let val (newmonos', th') = add monocs th in
+ let
+ val (newmonos', th') = add monocs th
+ in
(newmonos'@newmonos, th'::ths)
end)
ths ([], [])
- fun loop monocs ths =
- let val (monocs', ths') = step monocs ths in
+ fun loop monocs pats ths =
+ let
+ val (monocs', ths') = step monocs ths
+ val pats' = map (addpats monocs) pats
+ in
if null (monocs') then
- ths'
+ (pats', ths')
else
- loop monocs' ths'
+ loop monocs' pats' ths'
end
- val result = loop monocs ths
+ val result = loop monocs pats ths
in
(!changed, result)
end
@@ -370,7 +403,7 @@
map snd (Inttab.dest (!thstab))
end
-fun make machine thy ths cs =
+fun make_with_cache machine thy pats ths cs =
let
val ths = remove_duplicates ths
val (monocs, ths) = fold_rev (fn th =>
@@ -379,20 +412,24 @@
(m@monocs, t::ths)
end)
ths (cs, [])
- val (_, ths) = add_monos thy monocs ths
- val computer = create_computer machine thy ths
+ val pats = map create_pattern pats
+ val (_, (pats, ths)) = add_monos thy monocs pats ths
+ val computer = create_computer machine thy pats ths
in
- PComputer (Theory.check_thy thy, computer, ref ths)
+ PComputer (Theory.check_thy thy, computer, ref ths, ref pats)
end
-fun add_instances (PComputer (thyref, computer, rths)) cs =
+fun make machine thy ths cs = make_with_cache machine thy [] ths cs
+
+fun add_instances (PComputer (thyref, computer, rths, rpats)) cs =
let
val thy = Theory.deref thyref
- val (changed, ths) = add_monos thy cs (!rths)
+ val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
in
if changed then
- (update_computer computer ths;
+ (update_computer computer pats ths;
rths := ths;
+ rpats := pats;
true)
else
false