src/HOL/Tools/Predicate_Compile/predicate_compile_pred.ML
author wenzelm
Sat Jul 25 23:41:53 2015 +0200 (2015-07-25)
changeset 60781 2da59cdf531c
parent 60752 b48830b670a1
child 61268 abe08fb15a12
permissions -rw-r--r--
updated to infer_instantiate;
tuned;
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_pred.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Preprocessing definitions of predicates to introduction rules.
     5 *)
     6 
     7 signature PREDICATE_COMPILE_PRED =
     8 sig
     9   (* preprocesses an equation to a set of intro rules; defines new constants *)
    10   val preprocess : Predicate_Compile_Aux.options -> (string * thm list) -> theory
    11     -> ((string * thm list) list * theory) 
    12   val flat_higher_order_arguments : ((string * thm list) list * theory)
    13     -> ((string * thm list) list * ((string * thm list) list * theory))
    14 end;
    15 
    16 
    17 structure Predicate_Compile_Pred : PREDICATE_COMPILE_PRED =
    18 struct
    19 
    20 open Predicate_Compile_Aux
    21 
    22 fun is_compound ((Const (@{const_name Not}, _)) $ _) =
    23       error "is_compound: Negation should not occur; preprocessing is defect"
    24   | is_compound ((Const (@{const_name Ex}, _)) $ _) = true
    25   | is_compound ((Const (@{const_name HOL.disj}, _)) $ _ $ _) = true
    26   | is_compound ((Const (@{const_name HOL.conj}, _)) $ _ $ _) =
    27       error "is_compound: Conjunction should not occur; preprocessing is defect"
    28   | is_compound _ = false
    29 
    30 fun try_destruct_case thy names atom =
    31   (case find_split_thm thy (fst (strip_comb atom)) of
    32     NONE => NONE
    33   | SOME raw_split_thm =>
    34     let
    35       val split_thm = prepare_split_thm (Proof_Context.init_global thy) raw_split_thm
    36       (* TODO: contextify things - this line is to unvarify the split_thm *)
    37       (*val ((_, [isplit_thm]), _) =
    38         Variable.import true [split_thm] (Proof_Context.init_global thy)*)
    39       val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm)
    40       val (_, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
    41       val atom' = case_betapply thy atom
    42       val subst = Pattern.match thy (split_t, atom') (Vartab.empty, Vartab.empty)
    43       val names' = Term.add_free_names atom' names
    44       fun mk_subst_rhs assm =
    45         let
    46           val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
    47           val var_names = Name.variant_list names' (map fst vTs)
    48           val vars = map Free (var_names ~~ (map snd vTs))
    49           val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
    50           fun partition_prem_subst prem =
    51             (case HOLogic.dest_eq (HOLogic.dest_Trueprop prem) of
    52               (Free (x, T), r) => (NONE, SOME ((x, T), r))
    53             | _ => (SOME prem, NONE))
    54           fun partition f xs =
    55             let
    56               fun partition' acc1 acc2 [] = (rev acc1, rev acc2)
    57                 | partition' acc1 acc2 (x :: xs) =
    58                   let
    59                     val (y, z) = f x
    60                     val acc1' = (case y of NONE => acc1 | SOME y' => y' :: acc1)
    61                     val acc2' = (case z of NONE => acc2 | SOME z' => z' :: acc2)
    62                   in partition' acc1' acc2' xs end
    63             in partition' [] [] xs end
    64           val (prems'', subst) = partition partition_prem_subst prems'
    65           val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
    66           val pre_rhs =
    67             fold (curry HOLogic.mk_conj) (map HOLogic.dest_Trueprop prems'') inner_t
    68           val rhs = Envir.expand_term_frees subst pre_rhs
    69         in
    70           (case try_destruct_case thy (var_names @ names') rhs of
    71             NONE => [(subst, rhs)]
    72           | SOME (_, srs) => map (fn (subst', rhs') => (subst @ subst', rhs')) srs)
    73         end
    74      in SOME (atom', maps mk_subst_rhs assms) end)
    75      
    76 fun flatten constname atom (defs, thy) =
    77   if is_compound atom then
    78     let
    79       val atom = Envir.beta_norm (Envir.eta_long [] atom)
    80       val constname =
    81         singleton (Name.variant_list (map (Long_Name.base_name o fst) defs))
    82           ((Long_Name.base_name constname) ^ "_aux")
    83       val full_constname = Sign.full_bname thy constname
    84       val (params, args) = List.partition (is_predT o fastype_of)
    85         (map Free (Term.add_frees atom []))
    86       val constT = map fastype_of (params @ args) ---> HOLogic.boolT
    87       val lhs = list_comb (Const (full_constname, constT), params @ args)
    88       val def = Logic.mk_equals (lhs, atom)
    89       val ([definition], thy') = thy
    90         |> Sign.add_consts [(Binding.name constname, constT, NoSyn)]
    91         |> Global_Theory.add_defs false [((Binding.name (Thm.def_name constname), def), [])]
    92     in
    93       (lhs, ((full_constname, [definition]) :: defs, thy'))
    94     end
    95   else
    96     (case (fst (strip_comb atom)) of
    97       (Const (@{const_name If}, _)) =>
    98         let
    99           val if_beta = @{lemma "(if c then x else y) z = (if c then x z else y z)" by simp}
   100           val atom' = Raw_Simplifier.rewrite_term thy
   101             (map (fn th => th RS @{thm eq_reflection}) [@{thm if_bool_eq_disj}, if_beta]) [] atom
   102           val _ = @{assert} (not (atom = atom'))
   103         in
   104           flatten constname atom' (defs, thy)
   105         end
   106     | _ =>
   107         (case try_destruct_case thy [] atom of
   108           NONE => (atom, (defs, thy))
   109         | SOME (atom', srs) =>
   110             let      
   111               val frees = map Free (Term.add_frees atom' [])
   112               val constname =
   113                 singleton (Name.variant_list (map (Long_Name.base_name o fst) defs))
   114                   ((Long_Name.base_name constname) ^ "_aux")
   115               val full_constname = Sign.full_bname thy constname
   116               val constT = map fastype_of frees ---> HOLogic.boolT
   117               val lhs = list_comb (Const (full_constname, constT), frees)
   118               fun mk_def (subst, rhs) =
   119                 Logic.mk_equals (fold Envir.expand_term_frees (map single subst) lhs, rhs)
   120               val new_defs = map mk_def srs
   121               val (definition, thy') = thy
   122               |> Sign.add_consts [(Binding.name constname, constT, NoSyn)]
   123               |> fold_map Specification.axiom  (* FIXME !?!?!?! *)
   124                 (map_index (fn (i, t) =>
   125                   ((Binding.name (constname ^ "_def" ^ string_of_int i), []), t)) new_defs)
   126             in
   127               (lhs, ((full_constname, map Drule.export_without_context definition) :: defs, thy'))
   128             end))
   129 
   130 fun flatten_intros constname intros thy =
   131   let
   132     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
   133     val ((_, intros), ctxt') = Variable.import true intros ctxt
   134     val (intros', (local_defs, thy')) = (fold_map o fold_map_atoms)
   135       (flatten constname) (map Thm.prop_of intros) ([], thy)
   136     val ctxt'' = Proof_Context.transfer thy' ctxt'
   137     val intros'' =
   138       map (fn t => Goal.prove ctxt'' [] [] t (fn _ => ALLGOALS (Skip_Proof.cheat_tac ctxt''))) intros'
   139       |> Variable.export ctxt'' ctxt
   140   in
   141     (intros'', (local_defs, thy'))
   142   end
   143 
   144 fun introrulify ctxt ths = 
   145   let
   146     val ((_, ths'), ctxt') = Variable.import true ths ctxt
   147     fun introrulify' th =
   148       let
   149         val (lhs, rhs) = Logic.dest_equals (Thm.prop_of th)
   150         val frees = Term.add_free_names rhs []
   151         val disjuncts = HOLogic.dest_disj rhs
   152         val nctxt = Name.make_context frees
   153         fun mk_introrule t =
   154           let
   155             val ((ps, t'), _) = focus_ex t nctxt
   156             val prems = map HOLogic.mk_Trueprop (HOLogic.dest_conj t')
   157           in
   158             (ps, Logic.list_implies (prems, HOLogic.mk_Trueprop lhs))
   159           end
   160         val Var (x, _) =
   161           (the_single o snd o strip_comb o HOLogic.dest_Trueprop o fst o
   162             Logic.dest_implies o Thm.prop_of) @{thm exI}
   163         fun prove_introrule (index, (ps, introrule)) =
   164           let
   165             val tac = Simplifier.simp_tac (put_simpset HOL_basic_ss ctxt' addsimps [th]) 1
   166               THEN Inductive.select_disj_tac ctxt' (length disjuncts) (index + 1) 1
   167               THEN (EVERY (map (fn y =>
   168                 resolve_tac ctxt'
   169                   [infer_instantiate ctxt' [(x, Thm.cterm_of ctxt' (Free y))] @{thm exI}] 1) ps))
   170               THEN REPEAT_DETERM (resolve_tac ctxt' @{thms conjI} 1 THEN assume_tac ctxt' 1)
   171               THEN TRY (assume_tac ctxt' 1)
   172           in
   173             Goal.prove ctxt' (map fst ps) [] introrule (fn _ => tac)
   174           end
   175       in
   176         map_index prove_introrule (map mk_introrule disjuncts)
   177       end
   178   in maps introrulify' ths' |> Variable.export ctxt' ctxt end
   179 
   180 fun rewrite ctxt =
   181   Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm Ball_def}, @{thm Bex_def}])
   182   #> Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm all_not_ex}])
   183   #> Conv.fconv_rule (nnf_conv ctxt)
   184   #> Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm ex_disj_distrib}])
   185 
   186 fun rewrite_intros ctxt =
   187   Simplifier.full_simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm all_not_ex}])
   188   #> Simplifier.full_simplify
   189     (put_simpset HOL_basic_ss ctxt
   190       addsimps (tl @{thms bool_simps}) addsimps @{thms nnf_simps})
   191   #> split_conjuncts_in_assms ctxt
   192 
   193 fun print_specs options thy msg ths =
   194   if show_intermediate_results options then
   195     (tracing (msg); tracing (commas (map (Display.string_of_thm_global thy) ths)))
   196   else
   197     ()
   198 
   199 fun preprocess options (constname, specs) thy =
   200 (*  case Predicate_Compile_Data.processed_specs thy constname of
   201     SOME specss => (specss, thy)
   202   | NONE =>*)
   203     let
   204       val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
   205       val intros =
   206         if forall is_pred_equation specs then 
   207           map (split_conjuncts_in_assms ctxt) (introrulify ctxt (map (rewrite ctxt) specs))
   208         else if forall (is_intro constname) specs then
   209           map (rewrite_intros ctxt) specs
   210         else
   211           error ("unexpected specification for constant " ^ quote constname ^ ":\n"
   212             ^ commas (map (quote o Display.string_of_thm_global thy) specs))
   213       val if_beta = @{lemma "(if c then x else y) z = (if c then x z else y z)" by simp}
   214       val intros = map (rewrite_rule ctxt [if_beta RS @{thm eq_reflection}]) intros
   215       val _ = print_specs options thy "normalized intros" intros
   216       (*val intros = maps (split_cases thy) intros*)
   217       val (intros', (local_defs, thy')) = flatten_intros constname intros thy
   218       val (intross, thy'') = fold_map (preprocess options) local_defs thy'
   219       val full_spec = (constname, intros') :: flat intross
   220       (*val thy''' = Predicate_Compile_Data.store_processed_specs (constname, full_spec) thy''*)
   221     in
   222       (full_spec, thy'')
   223     end;
   224 
   225 fun flat_higher_order_arguments (intross, thy) =
   226   let
   227     fun process constname atom (new_defs, thy) =
   228       let
   229         val (pred, args) = strip_comb atom
   230         fun replace_abs_arg (abs_arg as Abs _ ) (new_defs, thy) =
   231           let
   232             val vars = map Var (Term.add_vars abs_arg [])
   233             val abs_arg' = Logic.unvarify_global abs_arg
   234             val frees = map Free (Term.add_frees abs_arg' [])
   235             val constname =
   236               singleton (Name.variant_list (map (Long_Name.base_name o fst) new_defs))
   237                 ((Long_Name.base_name constname) ^ "_hoaux")
   238             val full_constname = Sign.full_bname thy constname
   239             val constT = map fastype_of frees ---> (fastype_of abs_arg')
   240             val const = Const (full_constname, constT)
   241             val lhs = list_comb (const, frees)
   242             val def = Logic.mk_equals (lhs, abs_arg')
   243             val ([definition], thy') = thy
   244               |> Sign.add_consts [(Binding.name constname, constT, NoSyn)]
   245               |> Global_Theory.add_defs false [((Binding.name (Thm.def_name constname), def), [])]
   246           in
   247             (list_comb (Logic.varify_global const, vars),
   248               ((full_constname, [definition])::new_defs, thy'))
   249           end
   250         | replace_abs_arg arg (new_defs, thy) =
   251           if is_some (try HOLogic.dest_prodT (fastype_of arg)) then
   252             (case try HOLogic.dest_prod arg of
   253               SOME (t1, t2) =>
   254                 (new_defs, thy)
   255                 |> process constname t1 
   256                 ||>> process constname t2
   257                 |>> HOLogic.mk_prod
   258             | NONE =>
   259               (warning ("Replacing higher order arguments " ^
   260                 "is not applied in an undestructable product type"); (arg, (new_defs, thy))))
   261           else if (is_predT (fastype_of arg)) then
   262             process constname arg (new_defs, thy)
   263           else
   264             (arg, (new_defs, thy))
   265 
   266         val (args', (new_defs', thy')) = fold_map replace_abs_arg
   267           (map Envir.beta_eta_contract args) (new_defs, thy)
   268       in
   269         (list_comb (pred, args'), (new_defs', thy'))
   270       end
   271     fun flat_intro intro (new_defs, thy) =
   272       let
   273         val constname = fst (dest_Const (fst (strip_comb
   274           (HOLogic.dest_Trueprop (Logic.strip_imp_concl (Thm.prop_of intro))))))
   275         val (intro_ts, (new_defs, thy)) =
   276           fold_map_atoms (process constname) (Thm.prop_of intro) (new_defs, thy)
   277         val th = Skip_Proof.make_thm thy intro_ts
   278       in
   279         (th, (new_defs, thy))
   280       end
   281     fun fold_map_spec f [] s = ([], s)
   282       | fold_map_spec f ((c, ths) :: specs) s =
   283         let
   284           val (ths', s') = f ths s
   285           val (specs', s'') = fold_map_spec f specs s'
   286         in ((c, ths') :: specs', s'') end
   287     val (intross', (new_defs, thy')) = fold_map_spec (fold_map flat_intro) intross ([], thy)
   288   in
   289     (intross', (new_defs, thy'))
   290   end
   291 
   292 end