# HG changeset patch # User obua # Date 1196700455 -3600 # Node ID e123c81257a5aadf0796452571728ddfbd72e863 # Parent 8570745cb40bb6e5937cf5670af511e7554bfe46 improvements diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/am.ML --- a/src/Tools/Compute_Oracle/am.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/am.ML Mon Dec 03 17:47:35 2007 +0100 @@ -13,7 +13,7 @@ (* The de-Bruijn index 0 occurring on the right hand side refers to the LAST pattern variable, when traversing the pattern from left to right, 1 to the second last, and so on. *) -val compile : (guard list * pattern * term) list -> program +val compile : pattern list -> (int -> int option) -> (guard list * pattern * term) list -> program val discard : program -> unit diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/am_compiler.ML --- a/src/Tools/Compute_Oracle/am_compiler.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/am_compiler.ML Mon Dec 03 17:47:35 2007 +0100 @@ -192,7 +192,7 @@ | SOME r => (compiled_rewriter := NONE; r) end -fun compile eqs = +fun compile cache_patterns const_arity eqs = let val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () val eqs = map (fn (a,b,c) => (b,c)) eqs diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/am_ghc.ML --- a/src/Tools/Compute_Oracle/am_ghc.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/am_ghc.ML Mon Dec 03 17:47:35 2007 +0100 @@ -219,7 +219,7 @@ fun fileExists name = ((OS.FileSys.fileSize name; true) handle OS.SysErr _ => false) -fun compile eqs = +fun compile cache_patterns const_arity eqs = let val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () val eqs = map (fn (a,b,c) => (b,c)) eqs diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/am_interpreter.ML --- a/src/Tools/Compute_Oracle/am_interpreter.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/am_interpreter.ML Mon Dec 03 17:47:35 2007 +0100 @@ -19,7 +19,7 @@ structure prog_struct = TableFun(type key = int*int val ord = prod_ord int_ord int_ord); -datatype program = Program of ((pattern * closure) list) prog_struct.table +datatype program = Program of ((pattern * closure * (closure*closure) list) list) prog_struct.table datatype stack = SEmpty | SAppL of closure * stack | SAppR of closure * stack | SAbs of stack @@ -101,32 +101,19 @@ | check_freevars free (Abs m) = check_freevars (free+1) m | check_freevars free (Computed t) = check_freevars free t -fun compile eqs = +fun compile cache_patterns const_arity eqs = let - val _ = if exists (fn (a,b,c) => not (null a)) eqs then raise Compile ("cannot deal with guards") else () - val eqs = map (fn (a,b,c) => (b,c)) eqs - fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") - val eqs = map (fn (p, r) => (check (p,r); (pattern_key p, (p, clos_of_term r)))) eqs + fun check p r = if check_freevars p r then () else raise Compile ("unbound variables in rule") + fun check_guard p (Guard (a,b)) = (check p a; check p b) + fun clos_of_guard (Guard (a,b)) = (clos_of_term a, clos_of_term b) + val eqs = map (fn (guards, p, r) => let val pcount = count_patternvars p val _ = map (check_guard pcount) (guards) val _ = check pcount r in + (pattern_key p, (p, clos_of_term r, map clos_of_guard guards)) end) eqs fun merge (k, a) table = prog_struct.update (k, case prog_struct.lookup table k of NONE => [a] | SOME l => a::l) table val p = fold merge eqs prog_struct.empty in Program p end -fun match_rules n [] clos = NONE - | match_rules n ((p,eq)::rs) clos = - case pattern_match [] p clos of - NONE => match_rules (n+1) rs clos - | SOME args => SOME (Closure (args, eq)) - -fun match_closure (Program prog) clos = - case len_head_of_closure 0 clos of - (len, CConst c) => - (case prog_struct.lookup prog (c, len) of - NONE => NONE - | SOME rules => match_rules 0 rules clos) - | _ => NONE - type state = bool * program * stack * closure @@ -158,7 +145,21 @@ | NONE => proj_S (!s) end -fun weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a)) +fun match_rules prog n [] clos = NONE + | match_rules prog n ((p,eq,guards)::rs) clos = + case pattern_match [] p clos of + NONE => match_rules prog (n+1) rs clos + | SOME args => if forall (guard_checks prog args) guards then SOME (Closure (args, eq)) else match_rules prog (n+1) rs clos +and guard_checks prog args (a,b) = (simp prog (Closure (args, a)) = simp prog (Closure (args, b))) +and match_closure (p as (Program prog)) clos = + case len_head_of_closure 0 clos of + (len, CConst c) => + (case prog_struct.lookup prog (c, len) of + NONE => NONE + | SOME rules => match_rules p 0 rules clos) + | _ => NONE + +and weak_reduce (false, prog, stack, Closure (e, CApp (a, b))) = Continue (false, prog, SAppL (Closure (e, b), stack), Closure (e, a)) | weak_reduce (false, prog, SAppL (b, stack), Closure (e, CAbs m)) = Continue (false, prog, stack, Closure (b::e, m)) | weak_reduce (false, prog, stack, Closure (e, CVar n)) = Continue (false, prog, stack, case List.nth (e, n) of CDummy => CVar n | r => r) | weak_reduce (false, prog, stack, Closure (e, c as CConst _)) = Continue (false, prog, stack, c) @@ -170,7 +171,7 @@ | weak_reduce (true, prog, s as (SAppL (b, stack)), a) = Continue (false, prog, SAppR (a, stack), b) | weak_reduce (true, prog, stack, c) = Stop (stack, c) -fun strong_reduce (false, prog, stack, Closure (e, CAbs m)) = +and strong_reduce (false, prog, stack, Closure (e, CAbs m)) = (let val (stack', wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure (CDummy::e, m)) in @@ -185,6 +186,18 @@ | strong_reduce (true, prog, SAppR (a, stack), b) = Continue (true, prog, stack, CApp (a, b)) | strong_reduce (true, prog, stack, clos) = Stop (stack, clos) +and simp prog t = + (let + val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, t) + in + case stack of + SEmpty => (case do_reduction strong_reduce (false, prog, SEmpty, wnf) of + (SEmpty, snf) => snf + | _ => raise (Run "internal error in run: strong failed")) + | _ => raise (Run "internal error in run: weak failed") + end handle InterruptedExecution state => resolve state) + + fun run prog t = (let val (stack, wnf) = do_reduction weak_reduce (false, prog, SEmpty, Closure ([], clos_of_term t)) diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/am_sml.ML --- a/src/Tools/Compute_Oracle/am_sml.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/am_sml.ML Mon Dec 03 17:47:35 2007 +0100 @@ -13,12 +13,15 @@ val save_result : (string * term) -> unit val set_compiled_rewriter : (term -> term) -> unit val list_nth : 'a list * int -> 'a + val dump_output : (string option) ref end structure AM_SML : AM_SML = struct open AbstractMachine; +val dump_output = ref NONE + type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) val saved_result = ref (NONE:(string*term)option) @@ -155,8 +158,8 @@ val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty fun arity_of c = the (Inttab.lookup arity c) fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) - fun adjust_pattern PVar = PVar - | adjust_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else C + fun test_pattern PVar = () + | test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") | adjust_rule (rule as (prems, p as PConst (c, args),t)) = @@ -165,13 +168,13 @@ fun check_fv t = check_freevars patternvars_counted t val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () - val args = map adjust_pattern args + val _ = map test_pattern args val len = length args val arity = arity_of c val lift = nlift 0 - fun adjust_tm n t = if n=0 then t else adjust_tm (n-1) (App (t, Var (n-1))) - fun adjust_term n t = adjust_tm n (lift n t) - fun adjust_guard n (Guard (a,b)) = Guard (adjust_term n a, adjust_term n b) + fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) + fun adjust_term n t = addapps_tm n (lift n t) + fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) in if len = arity then rule @@ -179,8 +182,7 @@ (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) else (raise Compile "internal error in adjust_rule") end - fun beta_guard (Guard (a,b)) = Guard (beta a, beta b) - fun beta_rule (prems, p, t) = ((map beta_guard prems, p, beta t) handle Match => raise Compile "beta_rule") + fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") in (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) end @@ -493,13 +495,13 @@ fun use_source src = use_text "" Output.ml_output false src -fun compile eqs = +fun compile cache_patterns const_arity eqs = let val guid = get_guid () val code = Real.toString (random ()) val module = "AMSML_"^guid val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs - (*val _ = writeTextFile "Gencode.ML" source*) + val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source val _ = compiled_rewriter := NONE val _ = use_source source in diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/compute.ML --- a/src/Tools/Compute_Oracle/compute.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/compute.ML Mon Dec 03 17:47:35 2007 +0100 @@ -15,10 +15,12 @@ exception Make of string val make : machine -> theory -> thm list -> computer + val make_with_cache : machine -> theory -> term list -> thm list -> computer val theory_of : computer -> theory val hyps_of : computer -> term list val shyps_of : computer -> sort list (* ! *) val update : computer -> thm list -> unit + (* ! *) val update_with_cache : computer -> term list -> thm list -> unit (* ! *) val discard : computer -> unit (* ! *) val set_naming : computer -> naming -> unit @@ -35,16 +37,12 @@ val setup_compute : theory -> theory - val print_encoding : bool ref - end structure Compute :> COMPUTE = struct open Report; -val print_encoding = ref false - datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML (* Terms are mapped to integer codes *) @@ -57,7 +55,7 @@ val lookup_term : int -> encoding -> term option val remove_code : int -> encoding -> encoding val remove_term : term -> encoding -> encoding - val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a + val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a end = struct @@ -81,7 +79,7 @@ fun remove_term t (e as (count, term2int, int2term)) = (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term)) -fun fold f (_, term2int, _) = Termtab.fold f term2int +fun fold f (_, term2int, _) = Termtab.fold f term2int end @@ -206,11 +204,32 @@ ComputeThm (hyps, shyps, prop) end -fun make_internal machine thy stamp encoding raw_ths = +fun make_internal machine thy stamp encoding cache_pattern_terms raw_ths = let fun transfer (x:thm) = Thm.transfer thy x val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths + fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = + raise (Make "no lambda abstractions allowed in pattern") + | make_pattern encoding n vars (var as AbstractMachine.Var _) = + raise (Make "no bound variables allowed in pattern") + | make_pattern encoding n vars (AbstractMachine.Const code) = + (case the (Encode.lookup_term code encoding) of + Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) + handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) + | _ => (n, vars, AbstractMachine.PConst (code, []))) + | make_pattern encoding n vars (AbstractMachine.App (a, b)) = + let + val (n, vars, pa) = make_pattern encoding n vars a + val (n, vars, pb) = make_pattern encoding n vars b + in + case pa of + AbstractMachine.PVar => + raise (Make "patterns may not start with a variable") + | AbstractMachine.PConst (c, args) => + (n, vars, AbstractMachine.PConst (c, args@[pb])) + end + fun thm2rule (encoding, hyptable, shyptable) th = let val (ComputeThm (hyps, shyps, prop)) = th @@ -234,27 +253,6 @@ (encoding, AbstractMachine.Guard (t1, t2)) end handle TERM _ => raise (Make "guards must be meta-level equations")) val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, []) - - fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = - raise (Make "no lambda abstractions allowed in pattern") - | make_pattern encoding n vars (var as AbstractMachine.Var _) = - raise (Make "no bound variables allowed in pattern") - | make_pattern encoding n vars (AbstractMachine.Const code) = - (case the (Encode.lookup_term code encoding) of - Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) - handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) - | _ => (n, vars, AbstractMachine.PConst (code, []))) - | make_pattern encoding n vars (AbstractMachine.App (a, b)) = - let - val (n, vars, pa) = make_pattern encoding n vars a - val (n, vars, pb) = make_pattern encoding n vars b - in - case pa of - AbstractMachine.PVar => - raise (Make "patterns may not start with a variable") - | AbstractMachine.PConst (c, args) => - (n, vars, AbstractMachine.PConst (c, args@[pb])) - end (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule. As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore @@ -294,12 +292,32 @@ in (encoding_hyptable, rule::rules) end) ths ((encoding, Termtab.empty, Sorttab.empty), []) + fun make_cache_pattern t (encoding, cache_patterns) = + let + val (encoding, a) = remove_types encoding t + val (_,_,p) = make_pattern encoding 0 Inttab.empty a + in + (encoding, p::cache_patterns) + end + + val (encoding, cache_patterns) = fold_rev make_cache_pattern cache_pattern_terms (encoding, []) + + fun arity (Type ("fun", [a,b])) = 1 + arity b + | arity _ = 0 + + fun make_arity (Const (s, _), i) tab = + (Inttab.update (i, arity (Sign.the_const_type thy s)) tab handle TYPE _ => tab) + | make_arity _ tab = tab + + val const_arity_tab = Encode.fold make_arity encoding Inttab.empty + fun const_arity x = Inttab.lookup const_arity_tab x + val prog = case machine of - BARRAS => ProgBarras (AM_Interpreter.compile rules) - | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules) - | HASKELL => ProgHaskell (AM_GHC.compile rules) - | SML => ProgSML (AM_SML.compile rules) + BARRAS => ProgBarras (AM_Interpreter.compile cache_patterns const_arity rules) + | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile cache_patterns const_arity rules) + | HASKELL => ProgHaskell (AM_GHC.compile cache_patterns const_arity rules) + | SML => ProgSML (AM_SML.compile cache_patterns const_arity rules) fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) @@ -307,17 +325,21 @@ in (Theory.check_thy thy, encoding, Termtab.keys hyptable, shyptable, prog, stamp, default_naming) end -fun make machine thy raw_thms = Computer (ref (SOME (make_internal machine thy (ref ()) Encode.empty raw_thms))) +fun make_with_cache machine thy cache_patterns raw_thms = Computer (ref (SOME (make_internal machine thy (ref ()) Encode.empty cache_patterns raw_thms))) -fun update computer raw_thms = +fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms + +fun update_with_cache computer cache_patterns raw_thms = let val c = make_internal (machine_of_prog (prog_of computer)) (theory_of computer) (stamp_of computer) - (encoding_of computer) raw_thms + (encoding_of computer) cache_patterns raw_thms val _ = (ref_of computer) := SOME c in () end +fun update computer raw_thms = update_with_cache computer [] raw_thms + fun discard computer = let val _ = diff -r 8570745cb40b -r e123c81257a5 src/Tools/Compute_Oracle/linker.ML --- a/src/Tools/Compute_Oracle/linker.ML Mon Dec 03 16:04:17 2007 +0100 +++ b/src/Tools/Compute_Oracle/linker.ML Mon Dec 03 17:47:35 2007 +0100 @@ -218,6 +218,7 @@ type pcomputer val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer + val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer val add_instances : pcomputer -> Linker.constant list -> bool val add_instances' : pcomputer -> term list -> bool @@ -237,8 +238,9 @@ exception PCompute of string datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list +datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list -datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref +datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref * pattern list ref (*fun collect_consts (Var x) = [] | collect_consts (Bound _) = [] @@ -246,7 +248,7 @@ | collect_consts (Abs (_, _, body)) = collect_consts body | collect_consts t = [Linker.constant_of t]*) -fun computer_of (PComputer (_,computer,_)) = computer +fun computer_of (PComputer (_,computer,_,_)) = computer fun collect_consts_of_thm th = let @@ -283,28 +285,45 @@ (monocs, PolyThm (th, Linker.empty polycs, [])) end -fun create_computer machine thy ths = +fun create_pattern pat = +let + val cs = Linker.collect_consts [pat] + val polycs = filter Linker.is_polymorphic cs +in + if null (polycs) then + MonoPattern pat + else + PolyPattern (pat, Linker.empty polycs, []) +end + +fun create_computer machine thy pats ths = let fun add (MonoThm th) ths = th::ths | add (PolyThm (_, _, ths')) ths = ths'@ths + fun addpat (MonoPattern p) pats = p::pats + | addpat (PolyPattern (_, _, ps)) pats = ps@pats val ths = fold_rev add ths [] + val pats = fold_rev addpat pats [] in - Compute.make machine thy ths + Compute.make_with_cache machine thy pats ths end -fun update_computer computer ths = +fun update_computer computer pats ths = let fun add (MonoThm th) ths = th::ths | add (PolyThm (_, _, ths')) ths = ths'@ths + fun addpat (MonoPattern p) pats = p::pats + | addpat (PolyPattern (_, _, ps)) pats = ps@pats val ths = fold_rev add ths [] + val pats = fold_rev addpat pats [] in - Compute.update computer ths + Compute.update_with_cache computer pats ths end fun conv_subst thy (subst : Type.tyenv) = map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst) -fun add_monos thy monocs ths = +fun add_monos thy monocs pats ths = let val changed = ref false fun add monocs (th as (MonoThm _)) = ([], th) @@ -318,21 +337,35 @@ in (newmonos, PolyThm (th, instances, instanceths@newths)) end + fun addpats monocs (pat as (MonoPattern _)) = pat + | addpats monocs (PolyPattern (p, instances, instancepats)) = + let + val (newsubsts, instances) = Linker.add_instances thy instances monocs + val _ = if not (null newsubsts) then changed := true else () + val newpats = map (fn subst => Envir.subst_TVars subst p) newsubsts + in + PolyPattern (p, instances, instancepats@newpats) + end fun step monocs ths = fold_rev (fn th => fn (newmonos, ths) => - let val (newmonos', th') = add monocs th in + let + val (newmonos', th') = add monocs th + in (newmonos'@newmonos, th'::ths) end) ths ([], []) - fun loop monocs ths = - let val (monocs', ths') = step monocs ths in + fun loop monocs pats ths = + let + val (monocs', ths') = step monocs ths + val pats' = map (addpats monocs) pats + in if null (monocs') then - ths' + (pats', ths') else - loop monocs' ths' + loop monocs' pats' ths' end - val result = loop monocs ths + val result = loop monocs pats ths in (!changed, result) end @@ -370,7 +403,7 @@ map snd (Inttab.dest (!thstab)) end -fun make machine thy ths cs = +fun make_with_cache machine thy pats ths cs = let val ths = remove_duplicates ths val (monocs, ths) = fold_rev (fn th => @@ -379,20 +412,24 @@ (m@monocs, t::ths) end) ths (cs, []) - val (_, ths) = add_monos thy monocs ths - val computer = create_computer machine thy ths + val pats = map create_pattern pats + val (_, (pats, ths)) = add_monos thy monocs pats ths + val computer = create_computer machine thy pats ths in - PComputer (Theory.check_thy thy, computer, ref ths) + PComputer (Theory.check_thy thy, computer, ref ths, ref pats) end -fun add_instances (PComputer (thyref, computer, rths)) cs = +fun make machine thy ths cs = make_with_cache machine thy [] ths cs + +fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = let val thy = Theory.deref thyref - val (changed, ths) = add_monos thy cs (!rths) + val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths) in if changed then - (update_computer computer ths; + (update_computer computer pats ths; rths := ths; + rpats := pats; true) else false