src/Tools/Compute_Oracle/linker.ML
author wenzelm
Tue Sep 29 16:24:36 2009 +0200 (2009-09-29)
changeset 32740 9dd0a2f83429
parent 32035 8e77b6a250d5
child 32960 69916a850301
permissions -rw-r--r--
explicit indication of Unsynchronized.ref;
     1 (*  Title:      Tools/Compute_Oracle/linker.ML
     2     Author:     Steven Obua
     3 
     4 Linker.ML solves the problem that the computing oracle does not
     5 instantiate polymorphic rules. By going through the PCompute interface,
     6 all possible instantiations are resolved by compiling new programs, if
     7 necessary. The obvious disadvantage of this approach is that in the
     8 worst case for each new term to be rewritten, a new program may be
     9 compiled.
    10 *)
    11 
    12 (*
    13    Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n,
    14    and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m
    15 
    16    Find all substitutions S such that
    17    a) the domain of S is tvars (t_1, ..., t_n)
    18    b) there are indices i_1, ..., i_k, and j_1, ..., j_k with
    19       1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k
    20       2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n)
    21 *)
    22 signature LINKER =
    23 sig
    24     exception Link of string
    25 
    26     datatype constant = Constant of bool * string * typ
    27     val constant_of : term -> constant
    28 
    29     type instances
    30     type subst = Type.tyenv
    31 
    32     val empty : constant list -> instances
    33     val typ_of_constant : constant -> typ
    34     val add_instances : theory -> instances -> constant list -> subst list * instances
    35     val substs_of : instances -> subst list
    36     val is_polymorphic : constant -> bool
    37     val distinct_constants : constant list -> constant list
    38     val collect_consts : term list -> constant list
    39 end
    40 
    41 structure Linker : LINKER = struct
    42 
    43 exception Link of string;
    44 
    45 type subst = Type.tyenv
    46 
    47 datatype constant = Constant of bool * string * typ
    48 fun constant_of (Const (name, ty)) = Constant (false, name, ty)
    49   | constant_of (Free (name, ty)) = Constant (true, name, ty)
    50   | constant_of _ = raise Link "constant_of"
    51 
    52 fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL)
    53 fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) TermOrd.typ_ord) (((x1,x2),x3), ((y1,y2),y3))
    54 fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2))
    55 
    56 
    57 structure Consttab = Table(type key = constant val ord = constant_ord);
    58 structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord);
    59 
    60 fun typ_of_constant (Constant (_, _, ty)) = ty
    61 
    62 val empty_subst = (Vartab.empty : Type.tyenv)
    63 
    64 fun merge_subst (A:Type.tyenv) (B:Type.tyenv) =
    65     SOME (Vartab.fold (fn (v, t) =>
    66                        fn tab =>
    67                           (case Vartab.lookup tab v of
    68                                NONE => Vartab.update (v, t) tab
    69                              | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B)
    70     handle Type.TYPE_MATCH => NONE
    71 
    72 fun subst_ord (A:Type.tyenv, B:Type.tyenv) =
    73     (list_ord (prod_ord TermOrd.fast_indexname_ord (prod_ord TermOrd.sort_ord TermOrd.typ_ord))) (Vartab.dest A, Vartab.dest B)
    74 
    75 structure Substtab = Table(type key = Type.tyenv val ord = subst_ord);
    76 
    77 fun substtab_union c = Substtab.fold Substtab.update c
    78 fun substtab_unions [] = Substtab.empty
    79   | substtab_unions [c] = c
    80   | substtab_unions (c::cs) = substtab_union c (substtab_unions cs)
    81 
    82 datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table
    83 
    84 fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty []))
    85 
    86 fun distinct_constants cs =
    87     Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty)
    88 
    89 fun empty cs =
    90     let
    91         val cs = distinct_constants (filter is_polymorphic cs)
    92         val old_cs = cs
    93 (*      fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (OldTerm.typ_tvars ty) tab
    94         val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty))
    95         fun tvars_of ty = collect_tvars ty Typtab.empty
    96         val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs
    97 
    98         fun tyunion A B =
    99             Typtab.fold
   100                 (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab)
   101                 A B
   102 
   103         fun is_essential A B =
   104             Typtab.fold
   105             (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1))
   106             A false
   107 
   108         fun add_minimal (c', tvs') (tvs, cs) =
   109             let
   110                 val tvs = tyunion tvs' tvs
   111                 val cs = (c', tvs')::cs
   112             in
   113                 if forall (fn (c',tvs') => is_essential tvs' tvs) cs then
   114                     SOME (tvs, cs)
   115                 else
   116                     NONE
   117             end
   118 
   119         fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count)
   120 
   121         fun generate_minimal_subsets subsets [] = subsets
   122           | generate_minimal_subsets subsets (c::cs) =
   123             let
   124                 val subsets' = map_filter (add_minimal c) subsets
   125             in
   126                 generate_minimal_subsets (subsets@subsets') cs
   127             end*)
   128 
   129         val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*)
   130 
   131         val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty)
   132 
   133     in
   134         Instances (
   135         fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty,
   136         Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants),
   137         minimal_subsets, Substtab.empty)
   138     end
   139 
   140 local
   141 fun calc ctab substtab [] = substtab
   142   | calc ctab substtab (c::cs) =
   143     let
   144         val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c)))
   145         fun merge_substs substtab subst =
   146             Substtab.fold (fn (s,_) =>
   147                            fn tab =>
   148                               (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab))
   149                           substtab Substtab.empty
   150         val substtab = substtab_unions (map (merge_substs substtab) csubsts)
   151     in
   152         calc ctab substtab cs
   153     end
   154 in
   155 fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs
   156 end
   157 
   158 fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs =
   159     let
   160 (*      val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*)
   161         fun calc_instantiations (constant as Constant (free, name, ty)) instantiations =
   162             Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) =>
   163                            fn instantiations =>
   164                               if free <> free' orelse name <> name' then
   165                                   instantiations
   166                               else case Consttab.lookup insttab constant of
   167                                        SOME _ => instantiations
   168                                      | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations
   169                                                 handle TYPE_MATCH => instantiations))
   170                           ctab instantiations
   171         val instantiations = fold calc_instantiations cs []
   172         (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*)
   173         fun update_ctab (constant', entry) ctab =
   174             (case Consttab.lookup ctab constant' of
   175                  NONE => raise Link "internal error: update_ctab"
   176                | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab)
   177         val ctab = fold update_ctab instantiations ctab
   178         val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs)
   179                               minsets Substtab.empty
   180         val (added_substs, substs) =
   181             Substtab.fold (fn (ns, _) =>
   182                            fn (added, substtab) =>
   183                               (case Substtab.lookup substs ns of
   184                                    NONE => (ns::added, Substtab.update (ns, ()) substtab)
   185                                  | SOME () => (added, substtab)))
   186                           new_substs ([], substs)
   187     in
   188         (added_substs, Instances (cfilter, ctab, minsets, substs))
   189     end
   190 
   191 fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs
   192 
   193 local
   194     fun get_thm thmname = PureThy.get_thm (theory "Main") thmname
   195     val eq_th = get_thm "HOL.eq_reflection"
   196 in
   197   fun eq_to_meta th = (eq_th OF [th] handle THM _ => th)
   198 end
   199 
   200 
   201 local
   202 
   203 fun collect (Var x) tab = tab
   204   | collect (Bound _) tab = tab
   205   | collect (a $ b) tab = collect b (collect a tab)
   206   | collect (Abs (_, _, body)) tab = collect body tab
   207   | collect t tab = Consttab.update (constant_of t, ()) tab
   208 
   209 in
   210   fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty)
   211 end
   212 
   213 end
   214 
   215 signature PCOMPUTE =
   216 sig
   217     type pcomputer
   218 
   219     val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
   220     val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
   221     
   222     val add_instances : pcomputer -> Linker.constant list -> bool 
   223     val add_instances' : pcomputer -> term list -> bool
   224 
   225     val rewrite : pcomputer -> cterm list -> thm list
   226     val simplify : pcomputer -> Compute.theorem -> thm
   227 
   228     val make_theorem : pcomputer -> thm -> string list -> Compute.theorem
   229     val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem
   230     val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem
   231     val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem 
   232 
   233 end
   234 
   235 structure PCompute : PCOMPUTE = struct
   236 
   237 exception PCompute of string
   238 
   239 datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
   240 datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
   241 
   242 datatype pcomputer =
   243   PComputer of theory_ref * Compute.computer * theorem list Unsynchronized.ref *
   244     pattern list Unsynchronized.ref 
   245 
   246 (*fun collect_consts (Var x) = []
   247   | collect_consts (Bound _) = []
   248   | collect_consts (a $ b) = (collect_consts a)@(collect_consts b)
   249   | collect_consts (Abs (_, _, body)) = collect_consts body
   250   | collect_consts t = [Linker.constant_of t]*)
   251 
   252 fun computer_of (PComputer (_,computer,_,_)) = computer
   253 
   254 fun collect_consts_of_thm th = 
   255     let
   256         val th = prop_of th
   257         val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th)
   258         val (left, right) = Logic.dest_equals th
   259     in
   260         (Linker.collect_consts [left], Linker.collect_consts (right::prems))
   261     end
   262 
   263 fun create_theorem th =
   264 let
   265     val (left, right) = collect_consts_of_thm th
   266     val polycs = filter Linker.is_polymorphic left
   267     val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (OldTerm.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty
   268     fun check_const (c::cs) cs' =
   269         let
   270             val tvars = OldTerm.typ_tvars (Linker.typ_of_constant c)
   271             val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false
   272         in
   273             if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side"
   274             else
   275                 if null (tvars) then
   276                     check_const cs (c::cs')
   277                 else
   278                     check_const cs cs'
   279         end
   280       | check_const [] cs' = cs'
   281     val monocs = check_const right []
   282 in
   283     if null (polycs) then
   284         (monocs, MonoThm th)
   285     else
   286         (monocs, PolyThm (th, Linker.empty polycs, []))
   287 end
   288 
   289 fun create_pattern pat = 
   290 let
   291     val cs = Linker.collect_consts [pat]
   292     val polycs = filter Linker.is_polymorphic cs
   293 in
   294     if null (polycs) then
   295 	MonoPattern pat
   296     else
   297 	PolyPattern (pat, Linker.empty polycs, [])
   298 end
   299 	     
   300 fun create_computer machine thy pats ths =
   301     let
   302         fun add (MonoThm th) ths = th::ths
   303           | add (PolyThm (_, _, ths')) ths = ths'@ths
   304 	fun addpat (MonoPattern p) pats = p::pats
   305 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   306         val ths = fold_rev add ths []
   307 	val pats = fold_rev addpat pats []
   308     in
   309         Compute.make_with_cache machine thy pats ths
   310     end
   311 
   312 fun update_computer computer pats ths = 
   313     let
   314 	fun add (MonoThm th) ths = th::ths
   315 	  | add (PolyThm (_, _, ths')) ths = ths'@ths
   316 	fun addpat (MonoPattern p) pats = p::pats
   317 	  | addpat (PolyPattern (_, _, ps)) pats = ps@pats
   318 	val ths = fold_rev add ths []
   319 	val pats = fold_rev addpat pats []
   320     in
   321 	Compute.update_with_cache computer pats ths
   322     end
   323 
   324 fun conv_subst thy (subst : Type.tyenv) =
   325     map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
   326 
   327 fun add_monos thy monocs pats ths =
   328     let
   329         val changed = Unsynchronized.ref false
   330         fun add monocs (th as (MonoThm _)) = ([], th)
   331           | add monocs (PolyThm (th, instances, instanceths)) =
   332             let
   333                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
   334                 val _ = if not (null newsubsts) then changed := true else ()
   335                 val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts
   336 (*              val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*)
   337                 val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths []
   338             in
   339                 (newmonos, PolyThm (th, instances, instanceths@newths))
   340             end
   341 	fun addpats monocs (pat as (MonoPattern _)) = pat
   342 	  | addpats monocs (PolyPattern (p, instances, instancepats)) =
   343 	    let
   344 		val (newsubsts, instances) = Linker.add_instances thy instances monocs
   345 		val _ = if not (null newsubsts) then changed := true else ()
   346 		val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts
   347 	    in
   348 		PolyPattern (p, instances, instancepats@newpats)
   349 	    end 
   350         fun step monocs ths =
   351             fold_rev (fn th =>
   352                       fn (newmonos, ths) =>
   353                          let 
   354 			     val (newmonos', th') = add monocs th 
   355 			 in
   356                              (newmonos'@newmonos, th'::ths)
   357                          end)
   358                      ths ([], [])
   359         fun loop monocs pats ths =
   360             let 
   361 		val (monocs', ths') = step monocs ths 
   362 		val pats' = map (addpats monocs) pats
   363 	    in
   364                 if null (monocs') then
   365                     (pats', ths')
   366                 else
   367                     loop monocs' pats' ths'
   368             end
   369         val result = loop monocs pats ths
   370     in
   371         (!changed, result)
   372     end
   373 
   374 datatype cthm = ComputeThm of term list * sort list * term
   375 
   376 fun thm2cthm th =
   377     let
   378         val {hyps, prop, shyps, ...} = Thm.rep_thm th
   379     in
   380         ComputeThm (hyps, shyps, prop)
   381     end
   382 
   383 val cthm_ord' = prod_ord (prod_ord (list_ord TermOrd.term_ord) (list_ord TermOrd.sort_ord)) TermOrd.term_ord
   384 
   385 fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2))
   386 
   387 structure CThmtab = Table(type key = cthm val ord = cthm_ord)
   388 
   389 fun remove_duplicates ths =
   390     let
   391         val counter = Unsynchronized.ref 0
   392         val tab = Unsynchronized.ref (CThmtab.empty : unit CThmtab.table)
   393         val thstab = Unsynchronized.ref (Inttab.empty : thm Inttab.table)
   394         fun update th =
   395             let
   396                 val key = thm2cthm th
   397             in
   398                 case CThmtab.lookup (!tab) key of
   399                     NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1)
   400                   | _ => ()
   401             end
   402         val _ = map update ths
   403     in
   404         map snd (Inttab.dest (!thstab))
   405     end
   406 
   407 fun make_with_cache machine thy pats ths cs =
   408     let
   409 	val ths = remove_duplicates ths
   410 	val (monocs, ths) = fold_rev (fn th => 
   411 				      fn (monocs, ths) => 
   412 					 let val (m, t) = create_theorem th in 
   413 					     (m@monocs, t::ths)
   414 					 end)
   415 				     ths (cs, [])
   416 	val pats = map create_pattern pats
   417 	val (_, (pats, ths)) = add_monos thy monocs pats ths
   418 	val computer = create_computer machine thy pats ths
   419     in
   420 	PComputer (Theory.check_thy thy, computer, Unsynchronized.ref ths, Unsynchronized.ref pats)
   421     end
   422 
   423 fun make machine thy ths cs = make_with_cache machine thy [] ths cs
   424 
   425 fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = 
   426     let
   427         val thy = Theory.deref thyref
   428         val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
   429     in
   430 	if changed then
   431 	    (update_computer computer pats ths;
   432 	     rths := ths;
   433 	     rpats := pats;
   434 	     true)
   435 	else
   436 	    false
   437 
   438     end
   439 
   440 fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts)
   441 
   442 fun rewrite pc cts =
   443     let
   444 	val _ = add_instances' pc (map term_of cts)
   445 	val computer = (computer_of pc)
   446     in
   447 	map (fn ct => Compute.rewrite computer ct) cts
   448     end
   449 
   450 fun simplify pc th = Compute.simplify (computer_of pc) th
   451 
   452 fun make_theorem pc th vars = 
   453     let
   454 	val _ = add_instances' pc [prop_of th]
   455 
   456     in
   457 	Compute.make_theorem (computer_of pc) th vars
   458     end
   459 
   460 fun instantiate pc insts th = 
   461     let
   462 	val _ = add_instances' pc (map (term_of o snd) insts)
   463     in
   464 	Compute.instantiate (computer_of pc) insts th
   465     end
   466 
   467 fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th
   468 
   469 fun modus_ponens pc prem_no th' th =
   470     let
   471 	val _ = add_instances' pc [prop_of th']
   472     in
   473 	Compute.modus_ponens (computer_of pc) prem_no th' th
   474     end    
   475 		 							      			    
   476 
   477 end