src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
changeset 55440 721b4561007a
parent 55399 5c8e91f884af
parent 55437 3fd63b92ea3b
child 56239 17df7145a871
equal deleted inserted replaced
55428:0ab52bf7b5e6 55440:721b4561007a
     4 Deriving specialised predicates and their intro rules
     4 Deriving specialised predicates and their intro rules
     5 *)
     5 *)
     6 
     6 
     7 signature PREDICATE_COMPILE_SPECIALISATION =
     7 signature PREDICATE_COMPILE_SPECIALISATION =
     8 sig
     8 sig
     9   val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory
     9   val find_specialisations : string list -> (string * thm list) list ->
       
    10     theory -> (string * thm list) list * theory
    10 end;
    11 end;
    11 
    12 
    12 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
    13 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
    13 struct
    14 struct
    14 
    15 
    15 open Predicate_Compile_Aux;
    16 open Predicate_Compile_Aux;
    16 
    17 
    17 (* table of specialisations *)
    18 (* table of specialisations *)
    18 structure Specialisations = Theory_Data
    19 structure Specialisations = Theory_Data
    19 (
    20 (
    20   type T = (term * term) Item_Net.T;
    21   type T = (term * term) Item_Net.T
    21   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
    22   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst)
    22   val extend = I;
    23   val extend = I
    23   val merge = Item_Net.merge;
    24   val merge = Item_Net.merge
    24 )
    25 )
    25 
    26 
    26 fun specialisation_of thy atom =
    27 fun specialisation_of thy atom =
    27   Item_Net.retrieve (Specialisations.get thy) atom
    28   Item_Net.retrieve (Specialisations.get thy) atom
    28 
    29 
    29 fun import (_, intros) args ctxt =
    30 fun import (_, intros) args ctxt =
    30   let
    31   let
    31     val ((_, intros'), ctxt') = Variable.importT intros ctxt
    32     val ((_, intros'), ctxt') = Variable.importT intros ctxt
    32     val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
    33     val pred' =
       
    34       fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
    33     val Ts = binder_types (fastype_of pred')
    35     val Ts = binder_types (fastype_of pred')
    34     val argTs = map fastype_of args
    36     val argTs = map fastype_of args
    35     val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
    37     val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
    36     val args' = map (Envir.subst_term_types Tsubst) args
    38     val args' = map (Envir.subst_term_types Tsubst) args
    37   in
    39   in
    40 
    42 
    41 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*)
    43 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*)
    42 fun is_nontrivial_constrt thy t =
    44 fun is_nontrivial_constrt thy t =
    43   let
    45   let
    44     val cnstrs = get_constrs thy
    46     val cnstrs = get_constrs thy
    45     fun check t = (case strip_comb t of
    47     fun check t =
       
    48       (case strip_comb t of
    46         (Var _, []) => (true, true)
    49         (Var _, []) => (true, true)
    47       | (Free _, []) => (true, true)
    50       | (Free _, []) => (true, true)
    48       | (Const (@{const_name Pair}, _), ts) =>
    51       | (Const (@{const_name Pair}, _), ts) =>
    49         pairself (forall I) (split_list (map check ts))
    52         pairself (forall I) (split_list (map check ts))
    50       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
    53       | (Const (s, T), ts) =>
       
    54           (case (AList.lookup (op =) cnstrs s, body_type T) of
    51             (SOME (i, Tname), Type (Tname', _)) => (false,
    55             (SOME (i, Tname), Type (Tname', _)) => (false,
    52               length ts = i andalso Tname = Tname' andalso forall (snd o check) ts)
    56               length ts = i andalso Tname = Tname' andalso forall (snd o check) ts)
    53           | _ => (false, false))
    57           | _ => (false, false))
    54       | _ => (false, false))
    58       | _ => (false, false))
    55   in check t = (false, true) end;
    59   in check t = (false, true) end
    56 
    60 
    57 fun specialise_intros black_list (pred, intros) pats thy =
    61 fun specialise_intros black_list (pred, intros) pats thy =
    58   let
    62   let
    59     val ctxt = Proof_Context.init_global thy
    63     val ctxt = Proof_Context.init_global thy
    60     val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
    64     val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
    87         val intro = Logic.list_implies (prems, concl)
    91         val intro = Logic.list_implies (prems, concl)
    88       in
    92       in
    89         SOME intro
    93         SOME intro
    90       end handle Pattern.Unif => NONE)
    94       end handle Pattern.Unif => NONE)
    91     val specialised_intros_t = map_filter I (map specialise_intro intros)
    95     val specialised_intros_t = map_filter I (map specialise_intro intros)
    92     val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
    96     val thy' =
       
    97       Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
    93     val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
    98     val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
    94     val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
    99     val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
    95     val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
   100     val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
    96       [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
   101       [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
    97     val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
   102     val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
   121       in
   126       in
   122         (free :: ts', free_names'')
   127         (free :: ts', free_names'')
   123       end
   128       end
   124     and restrict_pattern' thy [] free_names = ([], free_names)
   129     and restrict_pattern' thy [] free_names = ([], free_names)
   125       | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
   130       | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
   126       let
   131           let
   127         val (ts', free_names') = restrict_pattern' thy Tts free_names
   132             val (ts', free_names') = restrict_pattern' thy Tts free_names
   128       in
   133           in
   129         (Free (x, T) :: ts', free_names')
   134             (Free (x, T) :: ts', free_names')
   130       end
   135           end
   131       | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
   136       | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
   132         replace_term_and_restrict thy T t Tts free_names
   137           replace_term_and_restrict thy T t Tts free_names
   133       | restrict_pattern' thy ((T as Type (Tcon, _), t) :: Tts) free_names =
   138       | restrict_pattern' thy ((T as Type (Tcon, _), t) :: Tts) free_names =
   134         case Ctr_Sugar.ctr_sugar_of ctxt Tcon of
   139         case Ctr_Sugar.ctr_sugar_of ctxt Tcon of
   135           NONE => replace_term_and_restrict thy T t Tts free_names
   140           NONE => replace_term_and_restrict thy T t Tts free_names
   136         | SOME {ctrs, ...} => (case strip_comb t of
   141         | SOME {ctrs, ...} =>
   137           (Const (s, _), ats) =>
   142           (case strip_comb t of
   138           (case AList.lookup (op =) (map_filter (try dest_Const) ctrs) s of
   143             (Const (s, _), ats) =>
   139             SOME constr_T =>
   144               (case AList.lookup (op =) (map_filter (try dest_Const) ctrs) s of
   140               let
   145                 SOME constr_T =>
   141                 val (Ts', T') = strip_type constr_T
   146                   let
   142                 val Tsubst = Type.raw_match (T', T) Vartab.empty
   147                     val (Ts', T') = strip_type constr_T
   143                 val Ts = map (Envir.subst_type Tsubst) Ts'
   148                     val Tsubst = Type.raw_match (T', T) Vartab.empty
   144                 val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
   149                     val Ts = map (Envir.subst_type Tsubst) Ts'
   145                 val (ats', ts') = chop (length ats) bts'
   150                     val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
   146               in
   151                     val (ats', ts') = chop (length ats) bts'
   147                 (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
   152                   in
   148               end
   153                     (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
   149             | NONE => replace_term_and_restrict thy T t Tts free_names))
   154                   end
       
   155               | NONE => replace_term_and_restrict thy T t Tts free_names))
   150     fun restrict_pattern thy Ts args =
   156     fun restrict_pattern thy Ts args =
   151       let
   157       let
   152         val args = map Logic.unvarify_global args
   158         val args = map Logic.unvarify_global args
   153         val Ts = map Logic.unvarifyT_global Ts
   159         val Ts = map Logic.unvarifyT_global Ts
   154         val free_names = fold Term.add_free_names args []
   160         val free_names = fold Term.add_free_names args []
   155         val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
   161         val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
   156       in map Logic.varify_global pat end
   162       in map Logic.varify_global pat end
   157     fun detect' atom thy =
   163     fun detect' atom thy =
   158       case strip_comb atom of
   164       (case strip_comb atom of
   159         (pred as Const (pred_name, _), args) =>
   165         (pred as Const (pred_name, _), args) =>
   160           let
   166           let
   161           val Ts = binder_types (Sign.the_const_type thy pred_name)
   167             val Ts = binder_types (Sign.the_const_type thy pred_name)
   162           val pats = restrict_pattern thy Ts args
   168             val pats = restrict_pattern thy Ts args
   163         in
   169           in
   164           if (exists (is_nontrivial_constrt thy) pats)
   170             if (exists (is_nontrivial_constrt thy) pats)
   165             orelse (has_duplicates (op =) (fold add_vars pats [])) then
   171               orelse (has_duplicates (op =) (fold add_vars pats [])) then
   166             let
   172               let
   167               val thy' =
   173                 val thy' =
   168                 case specialisation_of thy atom of
   174                   (case specialisation_of thy atom of
   169                   [] =>
   175                     [] =>
   170                     if member (op =) ((map fst specs) @ black_list) pred_name then
   176                       if member (op =) ((map fst specs) @ black_list) pred_name then
   171                       thy
   177                         thy
   172                     else
   178                       else
   173                       (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of
   179                         (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of
   174                         NONE => thy
   180                           NONE => thy
   175                       | SOME [] => thy
   181                         | SOME [] => thy
   176                       | SOME intros =>
   182                         | SOME intros =>
   177                           specialise_intros ((map fst specs) @ (pred_name :: black_list))
   183                             specialise_intros ((map fst specs) @ (pred_name :: black_list))
   178                             (pred, intros) pats thy)
   184                               (pred, intros) pats thy)
   179                   | _ :: _ => thy
   185                   | _ :: _ => thy)
   180                 val atom' =
   186                 val atom' =
   181                   case specialisation_of thy' atom of
   187                   (case specialisation_of thy' atom of
   182                     [] => atom
   188                     [] => atom
   183                   | (t, specialised_t) :: _ =>
   189                   | (t, specialised_t) :: _ =>
   184                     let
   190                     let
   185                       val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
   191                       val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
   186                     in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom
   192                     in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom)
   187                     (*FIXME: this exception could be caught earlier in specialisation_of *)
   193                     (*FIXME: this exception could be handled earlier in specialisation_of *)
   188             in
   194               in
   189               (atom', thy')
   195                 (atom', thy')
   190             end
   196               end
   191           else (atom, thy)
   197             else (atom, thy)
   192         end
   198           end
   193       | _ => (atom, thy)
   199       | _ => (atom, thy))
   194     fun specialise' (constname, intros) thy =
   200     fun specialise' (constname, intros) thy =
   195       let
   201       let
   196         (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
   202         (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
   197         val intros = Drule.zero_var_indexes_list intros
   203         val intros = Drule.zero_var_indexes_list intros
   198         val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
   204         val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
   201       end
   207       end
   202   in
   208   in
   203     fold_map specialise' specs thy
   209     fold_map specialise' specs thy
   204   end
   210   end
   205 
   211 
   206 end;
   212 end