diff -r 91d06b04951f -r 84b5c89b8b49 src/Tools/Compute_Oracle/compute.ML --- a/src/Tools/Compute_Oracle/compute.ML Mon Jul 09 11:44:23 2007 +0200 +++ b/src/Tools/Compute_Oracle/compute.ML Mon Jul 09 17:36:25 2007 +0200 @@ -1,4 +1,4 @@ -(* Title: Tools/Compute_Oracle/compute.ML +(* Title: Pure/Tools/compute.ML ID: $Id$ Author: Steven Obua *) @@ -7,272 +7,356 @@ type computer - exception Make of string + datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML - val basic_make : theory -> thm list -> computer - val make : theory -> thm list -> computer + exception Make of string + val make : machine -> theory -> thm list -> computer + exception Compute of string val compute : computer -> (int -> string) -> cterm -> term val theory_of : computer -> theory + val hyps_of : computer -> term list + val shyps_of : computer -> sort list - val default_naming: int -> string - val oracle_fn: theory -> computer * (int -> string) * cterm -> term + val rewrite_param : computer -> (int -> string) -> cterm -> thm + val rewrite : computer -> cterm -> thm + + val discard : computer -> unit + + val setup : theory -> theory + end -structure Compute: COMPUTE = struct +structure Compute :> COMPUTE = struct + +datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML + +(* Terms are mapped to integer codes *) +structure Encode :> +sig + type encoding + val empty : encoding + val insert : term -> encoding -> int * encoding + val lookup_code : term -> encoding -> int option + val lookup_term : int -> encoding -> term option + val remove_code : int -> encoding -> encoding + val remove_term : term -> encoding -> encoding + val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a +end += +struct + +type encoding = int * (int Termtab.table) * (term Inttab.table) + +val empty = (0, Termtab.empty, Inttab.empty) + +fun insert t (e as (count, term2int, int2term)) = + (case Termtab.lookup term2int t of + NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term)) + | SOME code => (code, e)) + +fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t + +fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c + +fun remove_code c (e as (count, term2int, int2term)) = + (case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun remove_term t (e as (count, term2int, int2term)) = + (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term)) + +fun fold f (_, term2int, _) = Termtab.fold f term2int + +end + exception Make of string; - -fun is_mono_typ (Type (_, list)) = forall is_mono_typ list - | is_mono_typ _ = false - -fun is_mono_term (Const (_, t)) = is_mono_typ t - | is_mono_term (Var (_, t)) = is_mono_typ t - | is_mono_term (Free (_, t)) = is_mono_typ t - | is_mono_term (Bound _) = true - | is_mono_term (a $ b) = is_mono_term a andalso is_mono_term b - | is_mono_term (Abs (_, ty, t)) = is_mono_typ ty andalso is_mono_term t - -structure AMTermTab = TableFun (type key = AbstractMachine.term val ord = AM_Util.term_ord) - -fun add x y = x + y : int; -fun inc x = x + 1; - -exception Mono of term; +exception Compute of string; -val remove_types = - let - fun remove_types_var table invtable ccount vcount ldepth t = - (case Termtab.lookup table t of - NONE => - let - val a = AbstractMachine.Var vcount - in - (Termtab.update (t, a) table, - AMTermTab.update (a, t) invtable, - ccount, - inc vcount, - AbstractMachine.Var (add vcount ldepth)) - end - | SOME (AbstractMachine.Var v) => - (table, invtable, ccount, vcount, AbstractMachine.Var (add v ldepth)) - | SOME _ => sys_error "remove_types_var: lookup should be a var") - - fun remove_types_const table invtable ccount vcount ldepth t = - (case Termtab.lookup table t of - NONE => - let - val a = AbstractMachine.Const ccount - in - (Termtab.update (t, a) table, - AMTermTab.update (a, t) invtable, - inc ccount, - vcount, - a) - end - | SOME (c as AbstractMachine.Const _) => - (table, invtable, ccount, vcount, c) - | SOME _ => sys_error "remove_types_const: lookup should be a const") +local + fun make_constant t ty encoding = + let + val (code, encoding) = Encode.insert t encoding + in + (encoding, AbstractMachine.Const code) + end +in - fun remove_types table invtable ccount vcount ldepth t = - case t of - Var (_, ty) => - if is_mono_typ ty then remove_types_var table invtable ccount vcount ldepth t - else raise (Mono t) - | Free (_, ty) => - if is_mono_typ ty then remove_types_var table invtable ccount vcount ldepth t - else raise (Mono t) - | Const (_, ty) => - if is_mono_typ ty then remove_types_const table invtable ccount vcount ldepth t - else raise (Mono t) - | Abs (_, ty, t') => - if is_mono_typ ty then - let - val (table, invtable, ccount, vcount, t') = - remove_types table invtable ccount vcount (inc ldepth) t' - in - (table, invtable, ccount, vcount, AbstractMachine.Abs t') - end - else - raise (Mono t) - | a $ b => - let - val (table, invtable, ccount, vcount, a) = - remove_types table invtable ccount vcount ldepth a - val (table, invtable, ccount, vcount, b) = - remove_types table invtable ccount vcount ldepth b - in - (table, invtable, ccount, vcount, AbstractMachine.App (a,b)) - end - | Bound b => (table, invtable, ccount, vcount, AbstractMachine.Var b) - in - fn (table, invtable, ccount, vcount) => remove_types table invtable ccount vcount 0 - end - -fun infer_types naming = +fun remove_types encoding t = + case t of + Var (_, ty) => make_constant t ty encoding + | Free (_, ty) => make_constant t ty encoding + | Const (_, ty) => make_constant t ty encoding + | Abs (_, ty, t') => + let val (encoding, t'') = remove_types encoding t' in + (encoding, AbstractMachine.Abs t'') + end + | a $ b => + let + val (encoding, a) = remove_types encoding a + val (encoding, b) = remove_types encoding b + in + (encoding, AbstractMachine.App (a,b)) + end + | Bound b => (encoding, AbstractMachine.Var b) +end + +local + fun type_of (Free (_, ty)) = ty + | type_of (Const (_, ty)) = ty + | type_of (Var (_, ty)) = ty + | type_of _ = sys_error "infer_types: type_of error" +in +fun infer_types naming encoding = let - fun infer_types invtable ldepth bounds ty (AbstractMachine.Var v) = - if v < ldepth then (Bound v, List.nth (bounds, v)) else - (case AMTermTab.lookup invtable (AbstractMachine.Var (v-ldepth)) of - SOME (t as Var (_, ty)) => (t, ty) - | SOME (t as Free (_, ty)) => (t, ty) - | _ => sys_error "infer_types: lookup should deliver Var or Free") - | infer_types invtable ldepth _ ty (c as AbstractMachine.Const _) = - (case AMTermTab.lookup invtable c of - SOME (c as Const (_, ty)) => (c, ty) - | _ => sys_error "infer_types: lookup should deliver Const") - | infer_types invtable ldepth bounds (n,ty) (AbstractMachine.App (a, b)) = - let - val (a, aty) = infer_types invtable ldepth bounds (n+1, ty) a - val (adom, arange) = + fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v)) + | infer_types _ bounds _ (AbstractMachine.Const code) = + let + val c = the (Encode.lookup_term code encoding) + in + (c, type_of c) + end + | infer_types level bounds _ (AbstractMachine.App (a, b)) = + let + val (a, aty) = infer_types level bounds NONE a + val (adom, arange) = case aty of Type ("fun", [dom, range]) => (dom, range) | _ => sys_error "infer_types: function type expected" - val (b, bty) = infer_types invtable ldepth bounds (0, adom) b - in - (a $ b, arange) - end - | infer_types invtable ldepth bounds (0, ty as Type ("fun", [dom, range])) - (AbstractMachine.Abs m) = + val (b, bty) = infer_types level bounds (SOME adom) b + in + (a $ b, arange) + end + | infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) = let - val (m, _) = infer_types invtable (ldepth+1) (dom::bounds) (0, range) m + val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m in - (Abs (naming ldepth, dom, m), ty) + (Abs (naming level, dom, m), ty) end - | infer_types invtable ldepth bounds ty (AbstractMachine.Abs m) = - sys_error "infer_types: cannot infer type of abstraction" + | infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction" - fun infer invtable ty term = + fun infer ty term = let - val (term', _) = infer_types invtable 0 [] (0, ty) term + val (term', _) = infer_types 0 [] (SOME ty) term in term' end in infer end +end -datatype computer = - Computer of theory_ref * - (AbstractMachine.term Termtab.table * term AMTermTab.table * int * AbstractMachine.program) +datatype prog = + ProgBarras of AM_Interpreter.program + | ProgBarrasC of AM_Compiler.program + | ProgHaskell of AM_GHC.program + | ProgSML of AM_SML.program -fun basic_make thy raw_ths = +structure Sorttab = TableFun(type key = sort val ord = Term.sort_ord) + +datatype computer = Computer of theory_ref * Encode.encoding * term list * unit Sorttab.table * prog + +datatype cthm = ComputeThm of term list * sort list * term + +fun thm2cthm th = let - val ths = map (Thm.transfer thy) raw_ths; + val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th + val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else () + in + ComputeThm (hyps, shyps, prop) + end - fun thm2rule table invtable ccount th = - let - val prop = Thm.plain_prop_of th - handle THM _ => raise (Make "theorems must be plain propositions") - val (a, b) = Logic.dest_equals prop - handle TERM _ => raise (Make "theorems must be meta-level equations") +fun make machine thy raw_ths = + let + fun transfer (x:thm) = Thm.transfer thy x + val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths - val (table, invtable, ccount, vcount, prop) = - remove_types (table, invtable, ccount, 0) (a$b) - handle Mono _ => raise (Make "no type variables allowed") - val (left, right) = - (case prop of AbstractMachine.App x => x | _ => - sys_error "make: remove_types should deliver application") + fun thm2rule (encoding, hyptable, shyptable) th = + let + val (ComputeThm (hyps, shyps, prop)) = th + val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable + val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable + val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop) + val (a, b) = Logic.dest_equals prop + handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)") + val a = Envir.eta_contract a + val b = Envir.eta_contract b + val prems = map Envir.eta_contract prems - fun make_pattern table invtable n vars (var as AbstractMachine.Var v) = + val (encoding, left) = remove_types encoding a + val (encoding, right) = remove_types encoding b + fun remove_types_of_guard encoding g = + (let + val (t1, t2) = Logic.dest_equals g + val (encoding, t1) = remove_types encoding t1 + val (encoding, t2) = remove_types encoding t2 + in + (encoding, AbstractMachine.Guard (t1, t2)) + end handle TERM _ => raise (Make "guards must be meta-level equations")) + val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, []) + + fun make_pattern encoding n vars (var as AbstractMachine.Abs _) = + raise (Make "no lambda abstractions allowed in pattern") + | make_pattern encoding n vars (var as AbstractMachine.Var _) = + raise (Make "no bound variables allowed in pattern") + | make_pattern encoding n vars (AbstractMachine.Const code) = + (case the (Encode.lookup_term code encoding) of + Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar) + handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed")) + | _ => (n, vars, AbstractMachine.PConst (code, []))) + | make_pattern encoding n vars (AbstractMachine.App (a, b)) = let - val var' = the (AMTermTab.lookup invtable var) - val table = Termtab.delete var' table - val invtable = AMTermTab.delete var invtable - val vars = Inttab.update_new (v, n) vars handle Inttab.DUP _ => - raise (Make "no duplicate variable in pattern allowed") - in - (table, invtable, n+1, vars, AbstractMachine.PVar) - end - | make_pattern table invtable n vars (AbstractMachine.Abs _) = - raise (Make "no lambda abstractions allowed in pattern") - | make_pattern table invtable n vars (AbstractMachine.Const c) = - (table, invtable, n, vars, AbstractMachine.PConst (c, [])) - | make_pattern table invtable n vars (AbstractMachine.App (a, b)) = - let - val (table, invtable, n, vars, pa) = - make_pattern table invtable n vars a - val (table, invtable, n, vars, pb) = - make_pattern table invtable n vars b + val (n, vars, pa) = make_pattern encoding n vars a + val (n, vars, pb) = make_pattern encoding n vars b in case pa of AbstractMachine.PVar => raise (Make "patterns may not start with a variable") | AbstractMachine.PConst (c, args) => - (table, invtable, n, vars, AbstractMachine.PConst (c, args@[pb])) + (n, vars, AbstractMachine.PConst (c, args@[pb])) end - val (table, invtable, vcount, vars, pattern) = - make_pattern table invtable 0 Inttab.empty left + (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule. + As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore + this check can be left out. *) + + val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left val _ = (case pattern of - AbstractMachine.PVar => + AbstractMachine.PVar => raise (Make "patterns may not start with a variable") - | _ => ()) - - (* at this point, there shouldn't be any variables - left in table or invtable, only constants *) + (* | AbstractMachine.PConst (_, []) => + (print th; raise (Make "no parameter rewrite found"))*) + | _ => ()) (* finally, provide a function for renaming the - pattern bound variables on the right hand side *) + pattern bound variables on the right hand side *) - fun rename ldepth vars (var as AbstractMachine.Var v) = - if v < ldepth then var - else (case Inttab.lookup vars (v - ldepth) of - NONE => raise (Make "new variable on right hand side") - | SOME n => AbstractMachine.Var ((vcount-n-1)+ldepth)) - | rename ldepth vars (c as AbstractMachine.Const _) = c - | rename ldepth vars (AbstractMachine.App (a, b)) = - AbstractMachine.App (rename ldepth vars a, rename ldepth vars b) - | rename ldepth vars (AbstractMachine.Abs m) = - AbstractMachine.Abs (rename (ldepth+1) vars m) - + fun rename level vars (var as AbstractMachine.Var _) = var + | rename level vars (c as AbstractMachine.Const code) = + (case Inttab.lookup vars code of + NONE => c + | SOME n => AbstractMachine.Var (vcount-n-1+level)) + | rename level vars (AbstractMachine.App (a, b)) = + AbstractMachine.App (rename level vars a, rename level vars b) + | rename level vars (AbstractMachine.Abs m) = + AbstractMachine.Abs (rename (level+1) vars m) + + fun rename_guard (AbstractMachine.Guard (a,b)) = + AbstractMachine.Guard (rename 0 vars a, rename 0 vars b) in - (table, invtable, ccount, (pattern, rename 0 vars right)) + ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right)) end - val (table, invtable, ccount, rules) = - fold_rev (fn th => fn (table, invtable, ccount, rules) => + val ((encoding, hyptable, shyptable), rules) = + fold_rev (fn th => fn (encoding_hyptable, rules) => let - val (table, invtable, ccount, rule) = - thm2rule table invtable ccount th - in (table, invtable, ccount, rule::rules) end) - ths (Termtab.empty, AMTermTab.empty, 0, []) + val (encoding_hyptable, rule) = thm2rule encoding_hyptable th + in (encoding_hyptable, rule::rules) end) + ths ((Encode.empty, Termtab.empty, Sorttab.empty), []) - val prog = AbstractMachine.compile rules + val prog = + case machine of + BARRAS => ProgBarras (AM_Interpreter.compile rules) + | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules) + | HASKELL => ProgHaskell (AM_GHC.compile rules) + | SML => ProgSML (AM_SML.compile rules) - in Computer (Theory.self_ref thy, (table, invtable, ccount, prog)) end +(* val _ = print (Encode.fold (fn x => fn s => x::s) encoding [])*) + + fun has_witness s = not (null (Sign.witness_sorts thy [] [s])) -fun make thy ths = + val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable + + in Computer (Theory.self_ref thy, encoding, Termtab.keys hyptable, shyptable, prog) end + +(*fun timeit f = let - val (_, {mk_rews={mk=mk,mk_eq_True=emk, ...},...}) = rep_ss (simpset_of thy) - fun mk_eq_True th = (case emk th of NONE => [th] | SOME th' => [th, th']) + val t1 = Time.toMicroseconds (Time.now ()) + val x = f () + val t2 = Time.toMicroseconds (Time.now ()) + val _ = writeln ("### time = "^(Real.toString ((Real.fromLargeInt t2 - Real.fromLargeInt t1)/(1000000.0)))^"s") in - basic_make thy (maps mk (maps mk_eq_True ths)) - end + x + end*) + +fun report s f = f () (*writeln s; timeit f*) -fun compute (Computer r) naming ct = +fun compute (Computer (rthy, encoding, hyps, shyptable, prog)) naming ct = let + fun run (ProgBarras p) = AM_Interpreter.run p + | run (ProgBarrasC p) = AM_Compiler.run p + | run (ProgHaskell p) = AM_GHC.run p + | run (ProgSML p) = AM_SML.run p val {t=t, T=ty, thy=ctthy, ...} = rep_cterm ct - val (rthy, (table, invtable, ccount, prog)) = r val thy = Theory.merge (Theory.deref rthy, ctthy) - val (table, invtable, ccount, vcount, t) = remove_types (table, invtable, ccount, 0) t - val t = AbstractMachine.run prog t - val t = infer_types naming invtable ty t + val (encoding, t) = report "remove_types" (fn () => remove_types encoding t) + val t = report "run" (fn () => run prog t) + val t = report "infer_types" (fn () => infer_types naming encoding ty t) in t end -fun theory_of (Computer (rthy, _)) = Theory.deref rthy +fun discard (Computer (rthy, encoding, hyps, shyptable, prog)) = + (case prog of + ProgBarras p => AM_Interpreter.discard p + | ProgBarrasC p => AM_Compiler.discard p + | ProgHaskell p => AM_GHC.discard p + | ProgSML p => AM_SML.discard p) + +fun theory_of (Computer (rthy, _, _,_,_)) = Theory.deref rthy +fun hyps_of (Computer (_, _, hyps, _, _)) = hyps +fun shyps_of (Computer (_, _, _, shyptable, _)) = Sorttab.keys (shyptable) +fun shyptab_of (Computer (_, _, _, shyptable, _)) = shyptable fun default_naming i = "v_" ^ Int.toString i +exception Param of computer * (int -> string) * cterm; -fun oracle_fn thy (r, naming, ct) = +fun rewrite_param r n ct = + let + val thy = theory_of_cterm ct + val th = timeit (fn () => invoke_oracle_i thy "Compute_Oracle.compute" (thy, Param (r, n, ct))) + val hyps = map (fn h => assume (cterm_of thy h)) (hyps_of r) + in + fold (fn h => fn p => implies_elim p h) hyps th + end + +(*fun rewrite_param r n ct = + let + val hyps = hyps_of r + val shyps = shyps_of r + val thy = theory_of_cterm ct + val _ = Theory.assert_super (theory_of r) thy + val t' = timeit (fn () => compute r n ct) + val eq = Logic.mk_equals (term_of ct, t') + in + Thm.unchecked_oracle thy "Compute.compute" (eq, hyps, shyps) + end*) + +fun rewrite r ct = rewrite_param r default_naming ct + +(* theory setup *) + +fun compute_oracle (thy, Param (r, naming, ct)) = let val _ = Theory.assert_super (theory_of r) thy val t' = compute r naming ct + val eq = Logic.mk_equals (term_of ct, t') + val hyps = hyps_of r + val shyptab = shyptab_of r + fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab + fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab + val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (eq::hyps) shyptab) + val _ = if not (null shyps) then raise Compute ("dangling sort hypotheses: "^(makestring shyps)) else () in - Logic.mk_equals (term_of ct, t') + fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps eq end + | compute_oracle _ = raise Match + + +val setup = (fn thy => (writeln "install oracle"; Theory.add_oracle ("compute", compute_oracle) thy)) + +(*val _ = Context.add_setup (Theory.add_oracle ("compute", compute_oracle))*) end +