src/Tools/Compute_Oracle/linker.ML
author wenzelm
Thu Mar 20 00:20:44 2008 +0100 (2008-03-20)
changeset 26343 0dd2eab7b296
parent 26336 a0e2b706ce73
child 29269 5c25a2012975
permissions -rw-r--r--
simplified get_thm(s): back to plain name argument;
     1 (*  Title:      Tools/Compute_Oracle/linker.ML
     2     ID:         $Id$
     3     Author:     Steven Obua
     4 
     5 Linker.ML solves the problem that the computing oracle does not
     6 instantiate polymorphic rules. By going through the PCompute interface,
     7 all possible instantiations are resolved by compiling new programs, if
     8 necessary. The obvious disadvantage of this approach is that in the
     9 worst case for each new term to be rewritten, a new program may be
    10 compiled.
    11 *)
    12 
    13 (*
    14    Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n,
    15    and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m
    16 
    17    Find all substitutions S such that
    18    a) the domain of S is tvars (t_1, ..., t_n)
    19    b) there are indices i_1, ..., i_k, and j_1, ..., j_k with
    20       1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k
    21       2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n)
    22 *)
    23 signature LINKER =
    24 sig
    25     exception Link of string
    26 
    27     datatype constant = Constant of bool * string * typ
    28     val constant_of : term -> constant
    29 
    30     type instances
    31     type subst = Type.tyenv
    32 
    33     val empty : constant list -> instances
    34     val typ_of_constant : constant -> typ
    35     val add_instances : theory -> instances -> constant list -> subst list * instances
    36     val substs_of : instances -> subst list
    37     val is_polymorphic : constant -> bool
    38     val distinct_constants : constant list -> constant list
    39     val collect_consts : term list -> constant list
    40 end
    41 
    42 structure Linker : LINKER = struct
    43 
    44 exception Link of string;
    45 
    46 type subst = Type.tyenv
    47 
    48 datatype constant = Constant of bool * string * typ
    49 fun constant_of (Const (name, ty)) = Constant (false, name, ty)
    50   | constant_of (Free (name, ty)) = Constant (true, name, ty)
    51   | constant_of _ = raise Link "constant_of"
    52 
    53 fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL)
    54 fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term.typ_ord) (((x1,x2),x3), ((y1,y2),y3))
    55 fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2))
    56 
    57 
    58 structure Consttab = TableFun(type key = constant val ord = constant_ord);
    59 structure ConsttabModTy = TableFun(type key = constant val ord = constant_modty_ord);
    60 
    61 fun typ_of_constant (Constant (_, _, ty)) = ty
    62 
    63 val empty_subst = (Vartab.empty : Type.tyenv)
    64 
    65 fun merge_subst (A:Type.tyenv) (B:Type.tyenv) =
    66     SOME (Vartab.fold (fn (v, t) =>
    67                        fn tab =>
    68                           (case Vartab.lookup tab v of
    69                                NONE => Vartab.update (v, t) tab
    70                              | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B)
    71     handle Type.TYPE_MATCH => NONE
    72 
    73 fun subst_ord (A:Type.tyenv, B:Type.tyenv) =
    74     (list_ord (prod_ord Term.fast_indexname_ord (prod_ord Term.sort_ord Term.typ_ord))) (Vartab.dest A, Vartab.dest B)
    75 
    76 structure Substtab = TableFun(type key = Type.tyenv val ord = subst_ord);
    77 
    78 fun substtab_union c = Substtab.fold Substtab.update c
    79 fun substtab_unions [] = Substtab.empty
    80   | substtab_unions [c] = c
    81   | substtab_unions (c::cs) = substtab_union c (substtab_unions cs)
    82 
    83 datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table
    84 
    85 fun is_polymorphic (Constant (_, _, ty)) = not (null (typ_tvars ty))
    86 
    87 fun distinct_constants cs =
    88     Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty)
    89 
    90 fun empty cs =
    91     let
    92         val cs = distinct_constants (filter is_polymorphic cs)
    93         val old_cs = cs
    94 (*      fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (typ_tvars ty) tab
    95         val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty))
    96         fun tvars_of ty = collect_tvars ty Typtab.empty
    97         val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs
    98 
    99         fun tyunion A B =
   100             Typtab.fold
   101                 (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab)
   102                 A B
   103 
   104         fun is_essential A B =
   105             Typtab.fold
   106             (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1))
   107             A false
   108 
   109         fun add_minimal (c', tvs') (tvs, cs) =
   110             let
   111                 val tvs = tyunion tvs' tvs
   112                 val cs = (c', tvs')::cs
   113             in
   114                 if forall (fn (c',tvs') => is_essential tvs' tvs) cs then
   115                     SOME (tvs, cs)
   116                 else
   117                     NONE
   118             end
   119 
   120         fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count)
   121 
   122         fun generate_minimal_subsets subsets [] = subsets
   123           | generate_minimal_subsets subsets (c::cs) =
   124             let
   125                 val subsets' = map_filter (add_minimal c) subsets
   126             in
   127                 generate_minimal_subsets (subsets@subsets') cs
   128             end*)
   129 
   130         val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*)
   131 
   132         val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty)
   133 
   134     in
   135         Instances (
   136         fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty,
   137         Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants),
   138         minimal_subsets, Substtab.empty)
   139     end
   140 
   141 local
   142 fun calc ctab substtab [] = substtab
   143   | calc ctab substtab (c::cs) =
   144     let
   145         val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c)))
   146         fun merge_substs substtab subst =
   147             Substtab.fold (fn (s,_) =>
   148                            fn tab =>
   149                               (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab))
   150                           substtab Substtab.empty
   151         val substtab = substtab_unions (map (merge_substs substtab) csubsts)
   152     in
   153         calc ctab substtab cs
   154     end
   155 in
   156 fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs
   157 end
   158 
   159 fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs =
   160     let
   161 (*      val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*)
   162         fun calc_instantiations (constant as Constant (free, name, ty)) instantiations =
   163             Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) =>
   164                            fn instantiations =>
   165                               if free <> free' orelse name <> name' then
   166                                   instantiations
   167                               else case Consttab.lookup insttab constant of
   168                                        SOME _ => instantiations
   169                                      | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations
   170                                                 handle TYPE_MATCH => instantiations))
   171                           ctab instantiations
   172         val instantiations = fold calc_instantiations cs []
   173         (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*)
   174         fun update_ctab (constant', entry) ctab =
   175             (case Consttab.lookup ctab constant' of
   176                  NONE => raise Link "internal error: update_ctab"
   177                | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab)
   178         val ctab = fold update_ctab instantiations ctab
   179         val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs)
   180                               minsets Substtab.empty
   181         val (added_substs, substs) =
   182             Substtab.fold (fn (ns, _) =>
   183                            fn (added, substtab) =>
   184                               (case Substtab.lookup substs ns of
   185                                    NONE => (ns::added, Substtab.update (ns, ()) substtab)
   186                                  | SOME () => (added, substtab)))
   187                           new_substs ([], substs)
   188     in
   189         (added_substs, Instances (cfilter, ctab, minsets, substs))
   190     end
   191 
   192 fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs
   193 
   194 local
   195     fun get_thm thmname = PureThy.get_thm (theory "Main") thmname
   196     val eq_th = get_thm "HOL.eq_reflection"
   197 in
   198   fun eq_to_meta th = (eq_th OF [th] handle THM _ => th)
   199 end
   200 
   201 
   202 local
   203 
   204 fun collect (Var x) tab = tab
   205   | collect (Bound _) tab = tab
   206   | collect (a $ b) tab = collect b (collect a tab)
   207   | collect (Abs (_, _, body)) tab = collect body tab
   208   | collect t tab = Consttab.update (constant_of t, ()) tab
   209 
   210 in
   211   fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty)
   212 end
   213 
   214 end
   215 
   216 signature PCOMPUTE =
   217 sig
   218     type pcomputer
   219 
   220     val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
   221     val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
   222     
   223     val add_instances : pcomputer -> Linker.constant list -> bool 
   224     val add_instances' : pcomputer -> term list -> bool
   225 
   226     val rewrite : pcomputer -> cterm list -> thm list
   227     val simplify : pcomputer -> Compute.theorem -> thm
   228 
   229     val make_theorem : pcomputer -> thm -> string list -> Compute.theorem
   230     val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem
   231     val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem
   232     val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem 
   233 
   234 end
   235 
   236 structure PCompute : PCOMPUTE = struct
   237 
   238 exception PCompute of string
   239 
   240 datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
   241 datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
   242 
   243 datatype pcomputer = PComputer of theory_ref * Compute.computer * theorem list ref * pattern list ref 
   244 
   245 (*fun collect_consts (Var x) = []
   246   | collect_consts (Bound _) = []
   247   | collect_consts (a $ b) = (collect_consts a)@(collect_consts b)
   248   | collect_consts (Abs (_, _, body)) = collect_consts body
   249   | collect_consts t = [Linker.constant_of t]*)
   250 
   251 fun computer_of (PComputer (_,computer,_,_)) = computer
   252 
   253 fun collect_consts_of_thm th = 
   254     let
   255         val th = prop_of th
   256         val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th)
   257         val (left, right) = Logic.dest_equals th
   258     in
   259         (Linker.collect_consts [left], Linker.collect_consts (right::prems))
   260     end
   261 
   262 fun create_theorem th =
   263 let
   264     val (left, right) = collect_consts_of_thm th
   265     val polycs = filter Linker.is_polymorphic left
   266     val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty
   267     fun check_const (c::cs) cs' =
   268         let
   269             val tvars = typ_tvars (Linker.typ_of_constant c)
   270             val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false
   271         in
   272             if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side"
   273             else
   274                 if null (tvars) then
   275                     check_const cs (c::cs')
   276                 else
   277                     check_const cs cs'
   278         end
   279       | check_const [] cs' = cs'
   280     val monocs = check_const right []
   281 in
   282     if null (polycs) then
   283         (monocs, MonoThm th)
   284     else
   285         (monocs, PolyThm (th, Linker.empty polycs, []))
   286 end
   287 
   288 fun create_pattern pat = 
   289 let
   290     val cs = Linker.collect_consts [pat]
   291     val polycs = filter Linker.is_polymorphic cs
   292 in
   293     if null (polycs) then
   294 	MonoPattern pat
   295     else
   296 	PolyPattern (pat, Linker.empty polycs, [])
   297 end
   298 	     
   299 fun create_computer machine thy pats ths =
   300     let
   301         fun add (MonoThm th) ths = th::ths
   302           | add (PolyThm (_, _, ths')) ths = ths'@ths
   303 	fun addpat (MonoPattern p) pats = p::pats
   304 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   305         val ths = fold_rev add ths []
   306 	val pats = fold_rev addpat pats []
   307     in
   308         Compute.make_with_cache machine thy pats ths
   309     end
   310 
   311 fun update_computer computer pats ths = 
   312     let
   313 	fun add (MonoThm th) ths = th::ths
   314 	  | add (PolyThm (_, _, ths')) ths = ths'@ths
   315 	fun addpat (MonoPattern p) pats = p::pats
   316 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   317 	val ths = fold_rev add ths []
   318 	val pats = fold_rev addpat pats []
   319     in
   320 	Compute.update_with_cache computer pats ths
   321     end
   322 
   323 fun conv_subst thy (subst : Type.tyenv) =
   324     map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
   325 
   326 fun add_monos thy monocs pats ths =
   327     let
   328         val changed = ref false
   329         fun add monocs (th as (MonoThm _)) = ([], th)
   330           | add monocs (PolyThm (th, instances, instanceths)) =
   331             let
   332                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
   333                 val _ = if not (null newsubsts) then changed := true else ()
   334                 val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts
   335 (*              val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*)
   336                 val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths []
   337             in
   338                 (newmonos, PolyThm (th, instances, instanceths@newths))
   339             end
   340 	fun addpats monocs (pat as (MonoPattern _)) = pat
   341 	  | addpats monocs (PolyPattern (p, instances, instancepats)) =
   342 	    let
   343 		val (newsubsts, instances) = Linker.add_instances thy instances monocs
   344 		val _ = if not (null newsubsts) then changed := true else ()
   345 		val newpats = map (fn subst => Envir.subst_TVars subst p) newsubsts
   346 	    in
   347 		PolyPattern (p, instances, instancepats@newpats)
   348 	    end 
   349         fun step monocs ths =
   350             fold_rev (fn th =>
   351                       fn (newmonos, ths) =>
   352                          let 
   353 			     val (newmonos', th') = add monocs th 
   354 			 in
   355                              (newmonos'@newmonos, th'::ths)
   356                          end)
   357                      ths ([], [])
   358         fun loop monocs pats ths =
   359             let 
   360 		val (monocs', ths') = step monocs ths 
   361 		val pats' = map (addpats monocs) pats
   362 	    in
   363                 if null (monocs') then
   364                     (pats', ths')
   365                 else
   366                     loop monocs' pats' ths'
   367             end
   368         val result = loop monocs pats ths
   369     in
   370         (!changed, result)
   371     end
   372 
   373 datatype cthm = ComputeThm of term list * sort list * term
   374 
   375 fun thm2cthm th =
   376     let
   377         val {hyps, prop, shyps, ...} = Thm.rep_thm th
   378     in
   379         ComputeThm (hyps, shyps, prop)
   380     end
   381 
   382 val cthm_ord' = prod_ord (prod_ord (list_ord Term.term_ord) (list_ord Term.sort_ord)) Term.term_ord
   383 
   384 fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2))
   385 
   386 structure CThmtab = TableFun (type key = cthm val ord = cthm_ord)
   387 
   388 fun remove_duplicates ths =
   389     let
   390         val counter = ref 0
   391         val tab = ref (CThmtab.empty : unit CThmtab.table)
   392         val thstab = ref (Inttab.empty : thm Inttab.table)
   393         fun update th =
   394             let
   395                 val key = thm2cthm th
   396             in
   397                 case CThmtab.lookup (!tab) key of
   398                     NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1)
   399                   | _ => ()
   400             end
   401         val _ = map update ths
   402     in
   403         map snd (Inttab.dest (!thstab))
   404     end
   405 
   406 fun make_with_cache machine thy pats ths cs =
   407     let
   408 	val ths = remove_duplicates ths
   409 	val (monocs, ths) = fold_rev (fn th => 
   410 				      fn (monocs, ths) => 
   411 					 let val (m, t) = create_theorem th in 
   412 					     (m@monocs, t::ths)
   413 					 end)
   414 				     ths (cs, [])
   415 	val pats = map create_pattern pats
   416 	val (_, (pats, ths)) = add_monos thy monocs pats ths
   417 	val computer = create_computer machine thy pats ths
   418     in
   419 	PComputer (Theory.check_thy thy, computer, ref ths, ref pats)
   420     end
   421 
   422 fun make machine thy ths cs = make_with_cache machine thy [] ths cs
   423 
   424 fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = 
   425     let
   426         val thy = Theory.deref thyref
   427         val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
   428     in
   429 	if changed then
   430 	    (update_computer computer pats ths;
   431 	     rths := ths;
   432 	     rpats := pats;
   433 	     true)
   434 	else
   435 	    false
   436 
   437     end
   438 
   439 fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts)
   440 
   441 fun rewrite pc cts =
   442     let
   443 	val _ = add_instances' pc (map term_of cts)
   444 	val computer = (computer_of pc)
   445     in
   446 	map (fn ct => Compute.rewrite computer ct) cts
   447     end
   448 
   449 fun simplify pc th = Compute.simplify (computer_of pc) th
   450 
   451 fun make_theorem pc th vars = 
   452     let
   453 	val _ = add_instances' pc [prop_of th]
   454 
   455     in
   456 	Compute.make_theorem (computer_of pc) th vars
   457     end
   458 
   459 fun instantiate pc insts th = 
   460     let
   461 	val _ = add_instances' pc (map (term_of o snd) insts)
   462     in
   463 	Compute.instantiate (computer_of pc) insts th
   464     end
   465 
   466 fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th
   467 
   468 fun modus_ponens pc prem_no th' th =
   469     let
   470 	val _ = add_instances' pc [prop_of th']
   471     in
   472 	Compute.modus_ponens (computer_of pc) prem_no th' th
   473     end    
   474 		 							      			    
   475 
   476 end