src/Tools/Compute_Oracle/linker.ML
author wenzelm
Fri Jul 17 23:11:40 2009 +0200 (2009-07-17)
changeset 32035 8e77b6a250d5
parent 31971 8c1b845ed105
child 32740 9dd0a2f83429
permissions -rw-r--r--
tuned/modernized Envir.subst_XXX;
     1 (*  Title:      Tools/Compute_Oracle/linker.ML
     2     Author:     Steven Obua
     3 
     4 Linker.ML solves the problem that the computing oracle does not
     5 instantiate polymorphic rules. By going through the PCompute interface,
     6 all possible instantiations are resolved by compiling new programs, if
     7 necessary. The obvious disadvantage of this approach is that in the
     8 worst case for each new term to be rewritten, a new program may be
     9 compiled.
    10 *)
    11 
    12 (*
    13    Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n,
    14    and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m
    15 
    16    Find all substitutions S such that
    17    a) the domain of S is tvars (t_1, ..., t_n)
    18    b) there are indices i_1, ..., i_k, and j_1, ..., j_k with
    19       1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k
    20       2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n)
    21 *)
    22 signature LINKER =
    23 sig
    24     exception Link of string
    25 
    26     datatype constant = Constant of bool * string * typ
    27     val constant_of : term -> constant
    28 
    29     type instances
    30     type subst = Type.tyenv
    31 
    32     val empty : constant list -> instances
    33     val typ_of_constant : constant -> typ
    34     val add_instances : theory -> instances -> constant list -> subst list * instances
    35     val substs_of : instances -> subst list
    36     val is_polymorphic : constant -> bool
    37     val distinct_constants : constant list -> constant list
    38     val collect_consts : term list -> constant list
    39 end
    40 
    41 structure Linker : LINKER = struct
    42 
    43 exception Link of string;
    44 
    45 type subst = Type.tyenv
    46 
    47 datatype constant = Constant of bool * string * typ
    48 fun constant_of (Const (name, ty)) = Constant (false, name, ty)
    49   | constant_of (Free (name, ty)) = Constant (true, name, ty)
    50   | constant_of _ = raise Link "constant_of"
    51 
    52 fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL)
    53 fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) TermOrd.typ_ord) (((x1,x2),x3), ((y1,y2),y3))
    54 fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2))
    55 
    56 
    57 structure Consttab = Table(type key = constant val ord = constant_ord);
    58 structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord);
    59 
    60 fun typ_of_constant (Constant (_, _, ty)) = ty
    61 
    62 val empty_subst = (Vartab.empty : Type.tyenv)
    63 
    64 fun merge_subst (A:Type.tyenv) (B:Type.tyenv) =
    65     SOME (Vartab.fold (fn (v, t) =>
    66                        fn tab =>
    67                           (case Vartab.lookup tab v of
    68                                NONE => Vartab.update (v, t) tab
    69                              | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B)
    70     handle Type.TYPE_MATCH => NONE
    71 
    72 fun subst_ord (A:Type.tyenv, B:Type.tyenv) =
    73     (list_ord (prod_ord TermOrd.fast_indexname_ord (prod_ord TermOrd.sort_ord TermOrd.typ_ord))) (Vartab.dest A, Vartab.dest B)
    74 
    75 structure Substtab = Table(type key = Type.tyenv val ord = subst_ord);
    76 
    77 fun substtab_union c = Substtab.fold Substtab.update c
    78 fun substtab_unions [] = Substtab.empty
    79   | substtab_unions [c] = c
    80   | substtab_unions (c::cs) = substtab_union c (substtab_unions cs)
    81 
    82 datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table
    83 
    84 fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty []))
    85 
    86 fun distinct_constants cs =
    87     Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty)
    88 
    89 fun empty cs =
    90     let
    91         val cs = distinct_constants (filter is_polymorphic cs)
    92         val old_cs = cs
    93 (*      fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (OldTerm.typ_tvars ty) tab
    94         val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty))
    95         fun tvars_of ty = collect_tvars ty Typtab.empty
    96         val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs
    97 
    98         fun tyunion A B =
    99             Typtab.fold
   100                 (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab)
   101                 A B
   102 
   103         fun is_essential A B =
   104             Typtab.fold
   105             (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1))
   106             A false
   107 
   108         fun add_minimal (c', tvs') (tvs, cs) =
   109             let
   110                 val tvs = tyunion tvs' tvs
   111                 val cs = (c', tvs')::cs
   112             in
   113                 if forall (fn (c',tvs') => is_essential tvs' tvs) cs then
   114                     SOME (tvs, cs)
   115                 else
   116                     NONE
   117             end
   118 
   119         fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count)
   120 
   121         fun generate_minimal_subsets subsets [] = subsets
   122           | generate_minimal_subsets subsets (c::cs) =
   123             let
   124                 val subsets' = map_filter (add_minimal c) subsets
   125             in
   126                 generate_minimal_subsets (subsets@subsets') cs
   127             end*)
   128 
   129         val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*)
   130 
   131         val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty)
   132 
   133     in
   134         Instances (
   135         fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty,
   136         Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants),
   137         minimal_subsets, Substtab.empty)
   138     end
   139 
   140 local
   141 fun calc ctab substtab [] = substtab
   142   | calc ctab substtab (c::cs) =
   143     let
   144         val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c)))
   145         fun merge_substs substtab subst =
   146             Substtab.fold (fn (s,_) =>
   147                            fn tab =>
   148                               (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab))
   149                           substtab Substtab.empty
   150         val substtab = substtab_unions (map (merge_substs substtab) csubsts)
   151     in
   152         calc ctab substtab cs
   153     end
   154 in
   155 fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs
   156 end
   157 
   158 fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs =
   159     let
   160 (*      val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*)
   161         fun calc_instantiations (constant as Constant (free, name, ty)) instantiations =
   162             Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) =>
   163                            fn instantiations =>
   164                               if free <> free' orelse name <> name' then
   165                                   instantiations
   166                               else case Consttab.lookup insttab constant of
   167                                        SOME _ => instantiations
   168                                      | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations
   169                                                 handle TYPE_MATCH => instantiations))
   170                           ctab instantiations
   171         val instantiations = fold calc_instantiations cs []
   172         (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*)
   173         fun update_ctab (constant', entry) ctab =
   174             (case Consttab.lookup ctab constant' of
   175                  NONE => raise Link "internal error: update_ctab"
   176                | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab)
   177         val ctab = fold update_ctab instantiations ctab
   178         val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs)
   179                               minsets Substtab.empty
   180         val (added_substs, substs) =
   181             Substtab.fold (fn (ns, _) =>
   182                            fn (added, substtab) =>
   183                               (case Substtab.lookup substs ns of
   184                                    NONE => (ns::added, Substtab.update (ns, ()) substtab)
   185                                  | SOME () => (added, substtab)))
   186                           new_substs ([], substs)
   187     in
   188         (added_substs, Instances (cfilter, ctab, minsets, substs))
   189     end
   190 
   191 fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs
   192 
   193 local
   194     fun get_thm thmname = PureThy.get_thm (theory "Main") thmname
   195     val eq_th = get_thm "HOL.eq_reflection"
   196 in
   197   fun eq_to_meta th = (eq_th OF [th] handle THM _ => th)
   198 end
   199 
   200 
   201 local
   202 
   203 fun collect (Var x) tab = tab
   204   | collect (Bound _) tab = tab
   205   | collect (a $ b) tab = collect b (collect a tab)
   206   | collect (Abs (_, _, body)) tab = collect body tab
   207   | collect t tab = Consttab.update (constant_of t, ()) tab
   208 
   209 in
   210   fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty)
   211 end
   212 
   213 end
   214 
   215 signature PCOMPUTE =
   216 sig
   217     type pcomputer
   218 
   219     val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
   220     val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
   221     
   222     val add_instances : pcomputer -> Linker.constant list -> bool 
   223     val add_instances' : pcomputer -> term list -> bool
   224 
   225     val rewrite : pcomputer -> cterm list -> thm list
   226     val simplify : pcomputer -> Compute.theorem -> thm
   227 
   228     val make_theorem : pcomputer -> thm -> string list -> Compute.theorem
   229     val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem
   230     val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem
   231     val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem 
   232 
   233 end
   234 
   235 structure PCompute : PCOMPUTE = struct
   236 
   237 exception PCompute of string
   238 
   239 datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
   240 datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
   241 
   242 datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref * pattern list ref 
   243 
   244 (*fun collect_consts (Var x) = []
   245   | collect_consts (Bound _) = []
   246   | collect_consts (a $ b) = (collect_consts a)@(collect_consts b)
   247   | collect_consts (Abs (_, _, body)) = collect_consts body
   248   | collect_consts t = [Linker.constant_of t]*)
   249 
   250 fun computer_of (PComputer (_,computer,_,_)) = computer
   251 
   252 fun collect_consts_of_thm th = 
   253     let
   254         val th = prop_of th
   255         val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th)
   256         val (left, right) = Logic.dest_equals th
   257     in
   258         (Linker.collect_consts [left], Linker.collect_consts (right::prems))
   259     end
   260 
   261 fun create_theorem th =
   262 let
   263     val (left, right) = collect_consts_of_thm th
   264     val polycs = filter Linker.is_polymorphic left
   265     val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (OldTerm.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty
   266     fun check_const (c::cs) cs' =
   267         let
   268             val tvars = OldTerm.typ_tvars (Linker.typ_of_constant c)
   269             val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false
   270         in
   271             if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side"
   272             else
   273                 if null (tvars) then
   274                     check_const cs (c::cs')
   275                 else
   276                     check_const cs cs'
   277         end
   278       | check_const [] cs' = cs'
   279     val monocs = check_const right []
   280 in
   281     if null (polycs) then
   282         (monocs, MonoThm th)
   283     else
   284         (monocs, PolyThm (th, Linker.empty polycs, []))
   285 end
   286 
   287 fun create_pattern pat = 
   288 let
   289     val cs = Linker.collect_consts [pat]
   290     val polycs = filter Linker.is_polymorphic cs
   291 in
   292     if null (polycs) then
   293 	MonoPattern pat
   294     else
   295 	PolyPattern (pat, Linker.empty polycs, [])
   296 end
   297 	     
   298 fun create_computer machine thy pats ths =
   299     let
   300         fun add (MonoThm th) ths = th::ths
   301           | add (PolyThm (_, _, ths')) ths = ths'@ths
   302 	fun addpat (MonoPattern p) pats = p::pats
   303 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   304         val ths = fold_rev add ths []
   305 	val pats = fold_rev addpat pats []
   306     in
   307         Compute.make_with_cache machine thy pats ths
   308     end
   309 
   310 fun update_computer computer pats ths = 
   311     let
   312 	fun add (MonoThm th) ths = th::ths
   313 	  | add (PolyThm (_, _, ths')) ths = ths'@ths
   314 	fun addpat (MonoPattern p) pats = p::pats
   315 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   316 	val ths = fold_rev add ths []
   317 	val pats = fold_rev addpat pats []
   318     in
   319 	Compute.update_with_cache computer pats ths
   320     end
   321 
   322 fun conv_subst thy (subst : Type.tyenv) =
   323     map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
   324 
   325 fun add_monos thy monocs pats ths =
   326     let
   327         val changed = ref false
   328         fun add monocs (th as (MonoThm _)) = ([], th)
   329           | add monocs (PolyThm (th, instances, instanceths)) =
   330             let
   331                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
   332                 val _ = if not (null newsubsts) then changed := true else ()
   333                 val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts
   334 (*              val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*)
   335                 val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths []
   336             in
   337                 (newmonos, PolyThm (th, instances, instanceths@newths))
   338             end
   339 	fun addpats monocs (pat as (MonoPattern _)) = pat
   340 	  | addpats monocs (PolyPattern (p, instances, instancepats)) =
   341 	    let
   342 		val (newsubsts, instances) = Linker.add_instances thy instances monocs
   343 		val _ = if not (null newsubsts) then changed := true else ()
   344 		val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts
   345 	    in
   346 		PolyPattern (p, instances, instancepats@newpats)
   347 	    end 
   348         fun step monocs ths =
   349             fold_rev (fn th =>
   350                       fn (newmonos, ths) =>
   351                          let 
   352 			     val (newmonos', th') = add monocs th 
   353 			 in
   354                              (newmonos'@newmonos, th'::ths)
   355                          end)
   356                      ths ([], [])
   357         fun loop monocs pats ths =
   358             let 
   359 		val (monocs', ths') = step monocs ths 
   360 		val pats' = map (addpats monocs) pats
   361 	    in
   362                 if null (monocs') then
   363                     (pats', ths')
   364                 else
   365                     loop monocs' pats' ths'
   366             end
   367         val result = loop monocs pats ths
   368     in
   369         (!changed, result)
   370     end
   371 
   372 datatype cthm = ComputeThm of term list * sort list * term
   373 
   374 fun thm2cthm th =
   375     let
   376         val {hyps, prop, shyps, ...} = Thm.rep_thm th
   377     in
   378         ComputeThm (hyps, shyps, prop)
   379     end
   380 
   381 val cthm_ord' = prod_ord (prod_ord (list_ord TermOrd.term_ord) (list_ord TermOrd.sort_ord)) TermOrd.term_ord
   382 
   383 fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2))
   384 
   385 structure CThmtab = Table(type key = cthm val ord = cthm_ord)
   386 
   387 fun remove_duplicates ths =
   388     let
   389         val counter = ref 0
   390         val tab = ref (CThmtab.empty : unit CThmtab.table)
   391         val thstab = ref (Inttab.empty : thm Inttab.table)
   392         fun update th =
   393             let
   394                 val key = thm2cthm th
   395             in
   396                 case CThmtab.lookup (!tab) key of
   397                     NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1)
   398                   | _ => ()
   399             end
   400         val _ = map update ths
   401     in
   402         map snd (Inttab.dest (!thstab))
   403     end
   404 
   405 fun make_with_cache machine thy pats ths cs =
   406     let
   407 	val ths = remove_duplicates ths
   408 	val (monocs, ths) = fold_rev (fn th => 
   409 				      fn (monocs, ths) => 
   410 					 let val (m, t) = create_theorem th in 
   411 					     (m@monocs, t::ths)
   412 					 end)
   413 				     ths (cs, [])
   414 	val pats = map create_pattern pats
   415 	val (_, (pats, ths)) = add_monos thy monocs pats ths
   416 	val computer = create_computer machine thy pats ths
   417     in
   418 	PComputer (Theory.check_thy thy, computer, ref ths, ref pats)
   419     end
   420 
   421 fun make machine thy ths cs = make_with_cache machine thy [] ths cs
   422 
   423 fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = 
   424     let
   425         val thy = Theory.deref thyref
   426         val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
   427     in
   428 	if changed then
   429 	    (update_computer computer pats ths;
   430 	     rths := ths;
   431 	     rpats := pats;
   432 	     true)
   433 	else
   434 	    false
   435 
   436     end
   437 
   438 fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts)
   439 
   440 fun rewrite pc cts =
   441     let
   442 	val _ = add_instances' pc (map term_of cts)
   443 	val computer = (computer_of pc)
   444     in
   445 	map (fn ct => Compute.rewrite computer ct) cts
   446     end
   447 
   448 fun simplify pc th = Compute.simplify (computer_of pc) th
   449 
   450 fun make_theorem pc th vars = 
   451     let
   452 	val _ = add_instances' pc [prop_of th]
   453 
   454     in
   455 	Compute.make_theorem (computer_of pc) th vars
   456     end
   457 
   458 fun instantiate pc insts th = 
   459     let
   460 	val _ = add_instances' pc (map (term_of o snd) insts)
   461     in
   462 	Compute.instantiate (computer_of pc) insts th
   463     end
   464 
   465 fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th
   466 
   467 fun modus_ponens pc prem_no th' th =
   468     let
   469 	val _ = add_instances' pc [prop_of th']
   470     in
   471 	Compute.modus_ponens (computer_of pc) prem_no th' th
   472     end    
   473 		 							      			    
   474 
   475 end