src/Tools/Compute_Oracle/linker.ML
changeset 37872 d83659570337
parent 37871 c7ce7685e087
child 37873 66d90b2b87bc
equal deleted inserted replaced
37871:c7ce7685e087 37872:d83659570337
     1 (*  Title:      Tools/Compute_Oracle/linker.ML
       
     2     Author:     Steven Obua
       
     3 
       
     4 This module solves the problem that the computing oracle does not
       
     5 instantiate polymorphic rules. By going through the PCompute
       
     6 interface, all possible instantiations are resolved by compiling new
       
     7 programs, if necessary. The obvious disadvantage of this approach is
       
     8 that in the worst case for each new term to be rewritten, a new
       
     9 program may be compiled.
       
    10 *)
       
    11 
       
    12 (*
       
    13    Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n,
       
    14    and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m
       
    15 
       
    16    Find all substitutions S such that
       
    17    a) the domain of S is tvars (t_1, ..., t_n)
       
    18    b) there are indices i_1, ..., i_k, and j_1, ..., j_k with
       
    19       1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k
       
    20       2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n)
       
    21 *)
       
    22 signature LINKER =
       
    23 sig
       
    24     exception Link of string
       
    25 
       
    26     datatype constant = Constant of bool * string * typ
       
    27     val constant_of : term -> constant
       
    28 
       
    29     type instances
       
    30     type subst = Type.tyenv
       
    31 
       
    32     val empty : constant list -> instances
       
    33     val typ_of_constant : constant -> typ
       
    34     val add_instances : theory -> instances -> constant list -> subst list * instances
       
    35     val substs_of : instances -> subst list
       
    36     val is_polymorphic : constant -> bool
       
    37     val distinct_constants : constant list -> constant list
       
    38     val collect_consts : term list -> constant list
       
    39 end
       
    40 
       
    41 structure Linker : LINKER = struct
       
    42 
       
    43 exception Link of string;
       
    44 
       
    45 type subst = Type.tyenv
       
    46 
       
    47 datatype constant = Constant of bool * string * typ
       
    48 fun constant_of (Const (name, ty)) = Constant (false, name, ty)
       
    49   | constant_of (Free (name, ty)) = Constant (true, name, ty)
       
    50   | constant_of _ = raise Link "constant_of"
       
    51 
       
    52 fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL)
       
    53 fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term_Ord.typ_ord) (((x1,x2),x3), ((y1,y2),y3))
       
    54 fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2))
       
    55 
       
    56 
       
    57 structure Consttab = Table(type key = constant val ord = constant_ord);
       
    58 structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord);
       
    59 
       
    60 fun typ_of_constant (Constant (_, _, ty)) = ty
       
    61 
       
    62 val empty_subst = (Vartab.empty : Type.tyenv)
       
    63 
       
    64 fun merge_subst (A:Type.tyenv) (B:Type.tyenv) =
       
    65     SOME (Vartab.fold (fn (v, t) =>
       
    66                        fn tab =>
       
    67                           (case Vartab.lookup tab v of
       
    68                                NONE => Vartab.update (v, t) tab
       
    69                              | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B)
       
    70     handle Type.TYPE_MATCH => NONE
       
    71 
       
    72 fun subst_ord (A:Type.tyenv, B:Type.tyenv) =
       
    73     (list_ord (prod_ord Term_Ord.fast_indexname_ord (prod_ord Term_Ord.sort_ord Term_Ord.typ_ord))) (Vartab.dest A, Vartab.dest B)
       
    74 
       
    75 structure Substtab = Table(type key = Type.tyenv val ord = subst_ord);
       
    76 
       
    77 fun substtab_union c = Substtab.fold Substtab.update c
       
    78 fun substtab_unions [] = Substtab.empty
       
    79   | substtab_unions [c] = c
       
    80   | substtab_unions (c::cs) = substtab_union c (substtab_unions cs)
       
    81 
       
    82 datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table
       
    83 
       
    84 fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty []))
       
    85 
       
    86 fun distinct_constants cs =
       
    87     Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty)
       
    88 
       
    89 fun empty cs =
       
    90     let
       
    91         val cs = distinct_constants (filter is_polymorphic cs)
       
    92         val old_cs = cs
       
    93 (*      fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (OldTerm.typ_tvars ty) tab
       
    94         val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty))
       
    95         fun tvars_of ty = collect_tvars ty Typtab.empty
       
    96         val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs
       
    97 
       
    98         fun tyunion A B =
       
    99             Typtab.fold
       
   100                 (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab)
       
   101                 A B
       
   102 
       
   103         fun is_essential A B =
       
   104             Typtab.fold
       
   105             (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1))
       
   106             A false
       
   107 
       
   108         fun add_minimal (c', tvs') (tvs, cs) =
       
   109             let
       
   110                 val tvs = tyunion tvs' tvs
       
   111                 val cs = (c', tvs')::cs
       
   112             in
       
   113                 if forall (fn (c',tvs') => is_essential tvs' tvs) cs then
       
   114                     SOME (tvs, cs)
       
   115                 else
       
   116                     NONE
       
   117             end
       
   118 
       
   119         fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count)
       
   120 
       
   121         fun generate_minimal_subsets subsets [] = subsets
       
   122           | generate_minimal_subsets subsets (c::cs) =
       
   123             let
       
   124                 val subsets' = map_filter (add_minimal c) subsets
       
   125             in
       
   126                 generate_minimal_subsets (subsets@subsets') cs
       
   127             end*)
       
   128 
       
   129         val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*)
       
   130 
       
   131         val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty)
       
   132 
       
   133     in
       
   134         Instances (
       
   135         fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty,
       
   136         Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants),
       
   137         minimal_subsets, Substtab.empty)
       
   138     end
       
   139 
       
   140 local
       
   141 fun calc ctab substtab [] = substtab
       
   142   | calc ctab substtab (c::cs) =
       
   143     let
       
   144         val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c)))
       
   145         fun merge_substs substtab subst =
       
   146             Substtab.fold (fn (s,_) =>
       
   147                            fn tab =>
       
   148                               (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab))
       
   149                           substtab Substtab.empty
       
   150         val substtab = substtab_unions (map (merge_substs substtab) csubsts)
       
   151     in
       
   152         calc ctab substtab cs
       
   153     end
       
   154 in
       
   155 fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs
       
   156 end
       
   157 
       
   158 fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs =
       
   159     let
       
   160 (*      val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*)
       
   161         fun calc_instantiations (constant as Constant (free, name, ty)) instantiations =
       
   162             Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) =>
       
   163                            fn instantiations =>
       
   164                               if free <> free' orelse name <> name' then
       
   165                                   instantiations
       
   166                               else case Consttab.lookup insttab constant of
       
   167                                        SOME _ => instantiations
       
   168                                      | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations
       
   169                                                 handle TYPE_MATCH => instantiations))
       
   170                           ctab instantiations
       
   171         val instantiations = fold calc_instantiations cs []
       
   172         (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*)
       
   173         fun update_ctab (constant', entry) ctab =
       
   174             (case Consttab.lookup ctab constant' of
       
   175                  NONE => raise Link "internal error: update_ctab"
       
   176                | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab)
       
   177         val ctab = fold update_ctab instantiations ctab
       
   178         val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs)
       
   179                               minsets Substtab.empty
       
   180         val (added_substs, substs) =
       
   181             Substtab.fold (fn (ns, _) =>
       
   182                            fn (added, substtab) =>
       
   183                               (case Substtab.lookup substs ns of
       
   184                                    NONE => (ns::added, Substtab.update (ns, ()) substtab)
       
   185                                  | SOME () => (added, substtab)))
       
   186                           new_substs ([], substs)
       
   187     in
       
   188         (added_substs, Instances (cfilter, ctab, minsets, substs))
       
   189     end
       
   190 
       
   191 fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs
       
   192 
       
   193 fun eq_to_meta th = (@{thm HOL.eq_reflection} OF [th] handle THM _ => th)
       
   194 
       
   195 
       
   196 local
       
   197 
       
   198 fun collect (Var x) tab = tab
       
   199   | collect (Bound _) tab = tab
       
   200   | collect (a $ b) tab = collect b (collect a tab)
       
   201   | collect (Abs (_, _, body)) tab = collect body tab
       
   202   | collect t tab = Consttab.update (constant_of t, ()) tab
       
   203 
       
   204 in
       
   205   fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty)
       
   206 end
       
   207 
       
   208 end
       
   209 
       
   210 signature PCOMPUTE =
       
   211 sig
       
   212     type pcomputer
       
   213 
       
   214     val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
       
   215     val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
       
   216     
       
   217     val add_instances : pcomputer -> Linker.constant list -> bool 
       
   218     val add_instances' : pcomputer -> term list -> bool
       
   219 
       
   220     val rewrite : pcomputer -> cterm list -> thm list
       
   221     val simplify : pcomputer -> Compute.theorem -> thm
       
   222 
       
   223     val make_theorem : pcomputer -> thm -> string list -> Compute.theorem
       
   224     val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem
       
   225     val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem
       
   226     val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem 
       
   227 
       
   228 end
       
   229 
       
   230 structure PCompute : PCOMPUTE = struct
       
   231 
       
   232 exception PCompute of string
       
   233 
       
   234 datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
       
   235 datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
       
   236 
       
   237 datatype pcomputer =
       
   238   PComputer of theory_ref * Compute.computer * theorem list Unsynchronized.ref *
       
   239     pattern list Unsynchronized.ref 
       
   240 
       
   241 (*fun collect_consts (Var x) = []
       
   242   | collect_consts (Bound _) = []
       
   243   | collect_consts (a $ b) = (collect_consts a)@(collect_consts b)
       
   244   | collect_consts (Abs (_, _, body)) = collect_consts body
       
   245   | collect_consts t = [Linker.constant_of t]*)
       
   246 
       
   247 fun computer_of (PComputer (_,computer,_,_)) = computer
       
   248 
       
   249 fun collect_consts_of_thm th = 
       
   250     let
       
   251         val th = prop_of th
       
   252         val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th)
       
   253         val (left, right) = Logic.dest_equals th
       
   254     in
       
   255         (Linker.collect_consts [left], Linker.collect_consts (right::prems))
       
   256     end
       
   257 
       
   258 fun create_theorem th =
       
   259 let
       
   260     val (left, right) = collect_consts_of_thm th
       
   261     val polycs = filter Linker.is_polymorphic left
       
   262     val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (OldTerm.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty
       
   263     fun check_const (c::cs) cs' =
       
   264         let
       
   265             val tvars = OldTerm.typ_tvars (Linker.typ_of_constant c)
       
   266             val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false
       
   267         in
       
   268             if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side"
       
   269             else
       
   270                 if null (tvars) then
       
   271                     check_const cs (c::cs')
       
   272                 else
       
   273                     check_const cs cs'
       
   274         end
       
   275       | check_const [] cs' = cs'
       
   276     val monocs = check_const right []
       
   277 in
       
   278     if null (polycs) then
       
   279         (monocs, MonoThm th)
       
   280     else
       
   281         (monocs, PolyThm (th, Linker.empty polycs, []))
       
   282 end
       
   283 
       
   284 fun create_pattern pat = 
       
   285 let
       
   286     val cs = Linker.collect_consts [pat]
       
   287     val polycs = filter Linker.is_polymorphic cs
       
   288 in
       
   289     if null (polycs) then
       
   290         MonoPattern pat
       
   291     else
       
   292         PolyPattern (pat, Linker.empty polycs, [])
       
   293 end
       
   294              
       
   295 fun create_computer machine thy pats ths =
       
   296     let
       
   297         fun add (MonoThm th) ths = th::ths
       
   298           | add (PolyThm (_, _, ths')) ths = ths'@ths
       
   299         fun addpat (MonoPattern p) pats = p::pats
       
   300           | addpat (PolyPattern (_, _, ps)) pats = ps@pats
       
   301         val ths = fold_rev add ths []
       
   302         val pats = fold_rev addpat pats []
       
   303     in
       
   304         Compute.make_with_cache machine thy pats ths
       
   305     end
       
   306 
       
   307 fun update_computer computer pats ths = 
       
   308     let
       
   309         fun add (MonoThm th) ths = th::ths
       
   310           | add (PolyThm (_, _, ths')) ths = ths'@ths
       
   311         fun addpat (MonoPattern p) pats = p::pats
       
   312           | addpat (PolyPattern (_, _, ps)) pats = ps@pats
       
   313         val ths = fold_rev add ths []
       
   314         val pats = fold_rev addpat pats []
       
   315     in
       
   316         Compute.update_with_cache computer pats ths
       
   317     end
       
   318 
       
   319 fun conv_subst thy (subst : Type.tyenv) =
       
   320     map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
       
   321 
       
   322 fun add_monos thy monocs pats ths =
       
   323     let
       
   324         val changed = Unsynchronized.ref false
       
   325         fun add monocs (th as (MonoThm _)) = ([], th)
       
   326           | add monocs (PolyThm (th, instances, instanceths)) =
       
   327             let
       
   328                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
       
   329                 val _ = if not (null newsubsts) then changed := true else ()
       
   330                 val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts
       
   331 (*              val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*)
       
   332                 val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths []
       
   333             in
       
   334                 (newmonos, PolyThm (th, instances, instanceths@newths))
       
   335             end
       
   336         fun addpats monocs (pat as (MonoPattern _)) = pat
       
   337           | addpats monocs (PolyPattern (p, instances, instancepats)) =
       
   338             let
       
   339                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
       
   340                 val _ = if not (null newsubsts) then changed := true else ()
       
   341                 val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts
       
   342             in
       
   343                 PolyPattern (p, instances, instancepats@newpats)
       
   344             end 
       
   345         fun step monocs ths =
       
   346             fold_rev (fn th =>
       
   347                       fn (newmonos, ths) =>
       
   348                          let 
       
   349                              val (newmonos', th') = add monocs th 
       
   350                          in
       
   351                              (newmonos'@newmonos, th'::ths)
       
   352                          end)
       
   353                      ths ([], [])
       
   354         fun loop monocs pats ths =
       
   355             let 
       
   356                 val (monocs', ths') = step monocs ths 
       
   357                 val pats' = map (addpats monocs) pats
       
   358             in
       
   359                 if null (monocs') then
       
   360                     (pats', ths')
       
   361                 else
       
   362                     loop monocs' pats' ths'
       
   363             end
       
   364         val result = loop monocs pats ths
       
   365     in
       
   366         (!changed, result)
       
   367     end
       
   368 
       
   369 datatype cthm = ComputeThm of term list * sort list * term
       
   370 
       
   371 fun thm2cthm th =
       
   372     let
       
   373         val {hyps, prop, shyps, ...} = Thm.rep_thm th
       
   374     in
       
   375         ComputeThm (hyps, shyps, prop)
       
   376     end
       
   377 
       
   378 val cthm_ord' = prod_ord (prod_ord (list_ord Term_Ord.term_ord) (list_ord Term_Ord.sort_ord)) Term_Ord.term_ord
       
   379 
       
   380 fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2))
       
   381 
       
   382 structure CThmtab = Table(type key = cthm val ord = cthm_ord)
       
   383 
       
   384 fun remove_duplicates ths =
       
   385     let
       
   386         val counter = Unsynchronized.ref 0
       
   387         val tab = Unsynchronized.ref (CThmtab.empty : unit CThmtab.table)
       
   388         val thstab = Unsynchronized.ref (Inttab.empty : thm Inttab.table)
       
   389         fun update th =
       
   390             let
       
   391                 val key = thm2cthm th
       
   392             in
       
   393                 case CThmtab.lookup (!tab) key of
       
   394                     NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1)
       
   395                   | _ => ()
       
   396             end
       
   397         val _ = map update ths
       
   398     in
       
   399         map snd (Inttab.dest (!thstab))
       
   400     end
       
   401 
       
   402 fun make_with_cache machine thy pats ths cs =
       
   403     let
       
   404         val ths = remove_duplicates ths
       
   405         val (monocs, ths) = fold_rev (fn th => 
       
   406                                       fn (monocs, ths) => 
       
   407                                          let val (m, t) = create_theorem th in 
       
   408                                              (m@monocs, t::ths)
       
   409                                          end)
       
   410                                      ths (cs, [])
       
   411         val pats = map create_pattern pats
       
   412         val (_, (pats, ths)) = add_monos thy monocs pats ths
       
   413         val computer = create_computer machine thy pats ths
       
   414     in
       
   415         PComputer (Theory.check_thy thy, computer, Unsynchronized.ref ths, Unsynchronized.ref pats)
       
   416     end
       
   417 
       
   418 fun make machine thy ths cs = make_with_cache machine thy [] ths cs
       
   419 
       
   420 fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = 
       
   421     let
       
   422         val thy = Theory.deref thyref
       
   423         val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
       
   424     in
       
   425         if changed then
       
   426             (update_computer computer pats ths;
       
   427              rths := ths;
       
   428              rpats := pats;
       
   429              true)
       
   430         else
       
   431             false
       
   432 
       
   433     end
       
   434 
       
   435 fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts)
       
   436 
       
   437 fun rewrite pc cts =
       
   438     let
       
   439         val _ = add_instances' pc (map term_of cts)
       
   440         val computer = (computer_of pc)
       
   441     in
       
   442         map (fn ct => Compute.rewrite computer ct) cts
       
   443     end
       
   444 
       
   445 fun simplify pc th = Compute.simplify (computer_of pc) th
       
   446 
       
   447 fun make_theorem pc th vars = 
       
   448     let
       
   449         val _ = add_instances' pc [prop_of th]
       
   450 
       
   451     in
       
   452         Compute.make_theorem (computer_of pc) th vars
       
   453     end
       
   454 
       
   455 fun instantiate pc insts th = 
       
   456     let
       
   457         val _ = add_instances' pc (map (term_of o snd) insts)
       
   458     in
       
   459         Compute.instantiate (computer_of pc) insts th
       
   460     end
       
   461 
       
   462 fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th
       
   463 
       
   464 fun modus_ponens pc prem_no th' th =
       
   465     let
       
   466         val _ = add_instances' pc [prop_of th']
       
   467     in
       
   468         Compute.modus_ponens (computer_of pc) prem_no th' th
       
   469     end    
       
   470                                                                                                     
       
   471 
       
   472 end