ported predicate compiler to 'ctr_sugar'
authorblanchet
Wed, 12 Feb 2014 08:35:56 +0100
changeset 55399 5c8e91f884af
parent 55398 67e9fdd9ae9e
child 55400 1e8dd9cd320b
ported predicate compiler to 'ctr_sugar' * * * ported predicate compiler to 'ctr_sugar', part 2
src/HOL/Tools/Predicate_Compile/code_prolog.ML
src/HOL/Tools/Predicate_Compile/mode_inference.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_data.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
--- a/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -511,13 +511,6 @@
 
 fun mk_lim_relname T = "lim_" ^  mk_relname T
 
-(* This is copied from "pat_completeness.ML" *)
-fun inst_constrs_of thy (T as Type (name, _)) =
-  map (fn (Cn,CT) =>
-    Envir.subst_term_types (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
-    (the (Datatype.get_constrs thy name))
-  | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], [])
-
 fun is_recursive_constr T (Const (constr_name, T')) = member (op =) (binder_types T') T
   
 fun mk_ground_impl ctxt limited_types (T as Type (Tcon, Targs)) (seen, constant_table) =
@@ -549,7 +542,7 @@
         in
           (clause :: flat rec_clauses, (seen', constant_table''))
         end
-      val constrs = inst_constrs_of (Proof_Context.theory_of ctxt) T
+      val constrs = Function_Lib.inst_constrs_of (Proof_Context.theory_of ctxt) T
       val constrs' = (constrs ~~ map (is_recursive_constr T) constrs)
         |> (fn cs => filter_out snd cs @ filter snd cs)
       val (clauses, constant_table') =
--- a/src/HOL/Tools/Predicate_Compile/mode_inference.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/mode_inference.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -183,19 +183,19 @@
   | collect_non_invertible_subterms ctxt t (names, eqs) =
     case (strip_comb t) of (f, args) =>
       if is_invertible_function ctxt f then
-          let
-            val (args', (names', eqs')) =
-              fold_map (collect_non_invertible_subterms ctxt) args (names, eqs)
-          in
-            (list_comb (f, args'), (names', eqs'))
-          end
-        else
-          let
-            val s = singleton (Name.variant_list names) "x"
-            val v = Free (s, fastype_of t)
-          in
-            (v, (s :: names, HOLogic.mk_eq (v, t) :: eqs))
-          end
+        let
+          val (args', (names', eqs')) =
+            fold_map (collect_non_invertible_subterms ctxt) args (names, eqs)
+        in
+          (list_comb (f, args'), (names', eqs'))
+        end
+      else
+        let
+          val s = singleton (Name.variant_list names) "x"
+          val v = Free (s, fastype_of t)
+        in
+          (v, (s :: names, HOLogic.mk_eq (v, t) :: eqs))
+        end
 (*
   if is_constrt thy t then (t, (names, eqs)) else
     let
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -47,6 +47,7 @@
   val is_pred_equation : thm -> bool
   val is_intro : string -> thm -> bool
   val is_predT : typ -> bool
+  val get_constrs : theory -> (string * (int * string)) list
   val is_constrt : theory -> term -> bool
   val is_constr : Proof.context -> string -> bool
   val strip_ex : term -> (string * typ) list * term
@@ -477,15 +478,22 @@
 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   | is_predT _ = false
 
+fun get_constrs thy =
+  let
+    val ctxt = Proof_Context.init_global thy
+  in
+    Ctr_Sugar.ctr_sugars_of ctxt
+    |> maps (map_filter (try dest_Const) o #ctrs)
+    |> map (apsnd (fn T => (BNF_Util.num_binder_types T, fst (dest_Type (body_type T)))))
+  end
+
 (*** check if a term contains only constructor functions ***)
 (* TODO: another copy in the core! *)
 (* FIXME: constructor terms are supposed to be seen in the way the code generator
   sees constructors.*)
 fun is_constrt thy =
   let
-    val cnstrs = flat (maps
-      (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
-      (Symtab.dest (Datatype.get_all thy)));
+    val cnstrs = get_constrs thy
     fun check t = (case strip_comb t of
         (Var _, []) => true
       | (Free _, []) => true
@@ -495,23 +503,6 @@
       | _ => false)
   in check end;
 
-(* returns true if t is an application of an datatype constructor *)
-(* which then consequently would be splitted *)
-(* else false *)
-(*
-fun is_constructor thy t =
-  if (is_Type (fastype_of t)) then
-    (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
-      NONE => false
-    | SOME info => (let
-      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
-      val (c, _) = strip_comb t
-      in (case c of
-        Const (name, _) => name mem_string constr_consts
-        | _ => false) end))
-  else false
-*)
-
 val is_constr = Code.is_constr o Proof_Context.theory_of;
 
 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
@@ -601,7 +592,8 @@
     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
 
-fun find_split_thm thy (Const (name, _)) = Option.map #split (Datatype.info_of_case thy name)
+fun find_split_thm thy (Const (name, _)) =
+    Option.map #split (Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy) name)
   | find_split_thm thy _ = NONE
 
 (* lifting term operations to theorems *)
@@ -880,76 +872,72 @@
 (** making case distributivity rules **)
 (*** this should be part of the datatype package ***)
 
-fun datatype_names_of_case_name thy case_name =
-  map (#1 o #2) (#descr (the (Datatype.info_of_case thy case_name)))
+fun datatype_name_of_case_name thy =
+  Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy)
+  #> the #> #ctrs #> hd #> fastype_of #> body_type #> dest_Type #> fst
 
-fun make_case_distribs case_names descr thy =
+fun make_case_comb thy Tcon =
   let
-    val case_combs = Datatype_Prop.make_case_combs case_names descr thy "f";
-    fun make comb =
-      let
-        val Type ("fun", [T, T']) = fastype_of comb;
-        val (Const (case_name, _), fs) = strip_comb comb
-        val used = Term.add_tfree_names comb []
-        val U = TFree (singleton (Name.variant_list used) "'t", HOLogic.typeS)
-        val x = Free ("x", T)
-        val f = Free ("f", T' --> U)
-        fun apply_f f' =
-          let
-            val Ts = binder_types (fastype_of f')
-            val bs = map Bound ((length Ts - 1) downto 0)
-          in
-            fold_rev absdummy Ts (f $ (list_comb (f', bs)))
-          end
-        val fs' = map apply_f fs
-        val case_c' = Const (case_name, (map fastype_of fs') @ [T] ---> U)
-      in
-        HOLogic.mk_eq (f $ (comb $ x), list_comb (case_c', fs') $ x)
-      end
+    val ctxt = Proof_Context.init_global thy
+    val SOME {casex, ...} = Ctr_Sugar.ctr_sugar_of ctxt Tcon
+    val casex' = Type.legacy_freeze casex
+    val Ts = BNF_Util.binder_fun_types (fastype_of casex')
   in
-    map make case_combs
+    list_comb (casex', map_index (fn (j, T) => Free ("f" ^ string_of_int j,  T)) Ts)
   end
 
-fun case_rewrites thy Tcon =
+fun make_case_distrib thy Tcon =
   let
-    val {descr, case_name, ...} = Datatype.the_info thy Tcon
+    val comb = make_case_comb thy Tcon;
+    val Type ("fun", [T, T']) = fastype_of comb;
+    val (Const (case_name, _), fs) = strip_comb comb
+    val used = Term.add_tfree_names comb []
+    val U = TFree (singleton (Name.variant_list used) "'t", HOLogic.typeS)
+    val x = Free ("x", T)
+    val f = Free ("f", T' --> U)
+    fun apply_f f' =
+      let
+        val Ts = binder_types (fastype_of f')
+        val bs = map Bound ((length Ts - 1) downto 0)
+      in
+        fold_rev absdummy Ts (f $ (list_comb (f', bs)))
+      end
+    val fs' = map apply_f fs
+    val case_c' = Const (case_name, (map fastype_of fs') @ [T] ---> U)
   in
-    map (Drule.export_without_context o Skip_Proof.make_thm thy o HOLogic.mk_Trueprop)
-      (make_case_distribs [case_name] [descr] thy)
+    HOLogic.mk_eq (f $ (comb $ x), list_comb (case_c', fs') $ x)
   end
 
-fun instantiated_case_rewrites thy Tcon =
+fun case_rewrite thy Tcon =
+  (Drule.export_without_context o Skip_Proof.make_thm thy o HOLogic.mk_Trueprop)
+    (make_case_distrib thy Tcon)
+
+fun instantiated_case_rewrite thy Tcon =
   let
-    val rew_ths = case_rewrites thy Tcon
+    val th = case_rewrite thy Tcon
     val ctxt = Proof_Context.init_global thy
-    fun instantiate th =
-    let
-      val f = (fst (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of th))))))
-      val Type ("fun", [uninst_T, uninst_T']) = fastype_of f
-      val ([_, tname', uname, yname], ctxt') = Variable.add_fixes ["'t", "'t'", "'u", "y"] ctxt
-      val T' = TFree (tname', HOLogic.typeS)
-      val U = TFree (uname, HOLogic.typeS)
-      val y = Free (yname, U)
-      val f' = absdummy (U --> T') (Bound 0 $ y)
-      val th' = Thm.certify_instantiate
-        ([(dest_TVar uninst_T, U --> T'), (dest_TVar uninst_T', T')],
-         [((fst (dest_Var f), (U --> T') --> T'), f')]) th
-      val [th'] = Variable.export ctxt' ctxt [th']
-   in
-     th'
-   end
- in
-   map instantiate rew_ths
- end
+    val f = (fst (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of th))))))
+    val Type ("fun", [uninst_T, uninst_T']) = fastype_of f
+    val ([_, tname', uname, yname], ctxt') = Variable.add_fixes ["'t", "'t'", "'u", "y"] ctxt
+    val T' = TFree (tname', HOLogic.typeS)
+    val U = TFree (uname, HOLogic.typeS)
+    val y = Free (yname, U)
+    val f' = absdummy (U --> T') (Bound 0 $ y)
+    val th' = Thm.certify_instantiate
+      ([(dest_TVar uninst_T, U --> T'), (dest_TVar uninst_T', T')],
+       [((fst (dest_Var f), (U --> T') --> T'), f')]) th
+    val [th'] = Variable.export ctxt' ctxt [th']
+  in
+    th'
+  end
 
 fun case_betapply thy t =
   let
     val case_name = fst (dest_Const (fst (strip_comb t)))
-    val Tcons = datatype_names_of_case_name thy case_name
-    val ths = maps (instantiated_case_rewrites thy) Tcons
+    val Tcon = datatype_name_of_case_name thy case_name
+    val th = instantiated_case_rewrite thy Tcon
   in
-    Raw_Simplifier.rewrite_term thy
-      (map (fn th => th RS @{thm eq_reflection}) ths) [] t
+    Raw_Simplifier.rewrite_term thy [th RS @{thm eq_reflection}] [] t
   end
 
 (*** conversions ***)
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -814,13 +814,14 @@
     case T of
       TFree _ => NONE
     | Type (Tcon, _) =>
-      (case Datatype.get_constrs (Proof_Context.theory_of ctxt) Tcon of
+      (case Ctr_Sugar.ctr_sugar_of ctxt Tcon of
         NONE => NONE
-      | SOME cs =>
+      | SOME {ctrs, ...} =>
         (case strip_comb t of
           (Var _, []) => NONE
         | (Free _, []) => NONE
-        | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
+        | (Const (c, T), _) =>
+          if AList.defined (op =) (map_filter (try dest_Const) ctrs) c then SOME (c, T) else NONE))
   end
 
 fun partition_clause ctxt pos moded_clauses =
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_data.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_data.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -260,8 +260,11 @@
     val ctxt = Proof_Context.init_global thy
     fun is_nondefining_const (c, _) = member (op =) logic_operator_names c
     fun has_code_pred_intros (c, _) = can (Core_Data.intros_of ctxt) c
-    fun case_consts (c, _) = is_some (Datatype.info_of_case thy c)
-    fun is_datatype_constructor (c, T) = is_some (Datatype.info_of_constr thy (c, T))
+    fun case_consts (c, _) = is_some (Ctr_Sugar.ctr_sugar_of_case ctxt c)
+    fun is_datatype_constructor (x as (_, T)) =
+      (case body_type T of
+        Type (Tcon, _) => can (Ctr_Sugar.dest_ctr ctxt Tcon) (Const x)
+      | _ => false)
     fun defiants_of specs =
       fold (Term.add_consts o prop_of) specs []
       |> filter_out is_datatype_constructor
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -46,30 +46,19 @@
 
 (* auxillary functions *)
 
-fun is_Type (Type _) = true
-  | is_Type _ = false
-
-(* returns true if t is an application of an datatype constructor *)
+(* returns true if t is an application of a datatype constructor *)
 (* which then consequently would be splitted *)
-(* else false *)
-fun is_constructor thy t =
-  if (is_Type (fastype_of t)) then
-    (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of
-      NONE => false
-    | SOME info => (let
-      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
-      val (c, _) = strip_comb t
-      in (case c of
-        Const (name, _) => member (op =) constr_consts name
-        | _ => false) end))
-  else false
+fun is_constructor ctxt t =
+  (case fastype_of t of
+    Type (s, _) => s <> @{type_name fun} andalso can (Ctr_Sugar.dest_ctr ctxt s) t
+  | _ => false);
 
 (* MAJOR FIXME:  prove_params should be simple
  - different form of introrule for parameters ? *)
 
 fun prove_param options ctxt nargs t deriv =
   let
-    val  (f, args) = strip_comb (Envir.eta_contract t)
+    val (f, args) = strip_comb (Envir.eta_contract t)
     val mode = head_mode_of deriv
     val param_derivations = param_derivations_of deriv
     val ho_args = ho_args_of mode args
@@ -139,15 +128,14 @@
 
 fun prove_match options ctxt nargs out_ts =
   let
-    val thy = Proof_Context.theory_of ctxt
     val eval_if_P =
       @{lemma "P ==> Predicate.eval x z ==> Predicate.eval (if P then x else y) z" by simp} 
     fun get_case_rewrite t =
-      if (is_constructor thy t) then
+      if is_constructor ctxt t then
         let
-          val {case_rewrites, ...} = Datatype.the_info thy (fst (dest_Type (fastype_of t)))
+          val SOME {case_thms, ...} = Ctr_Sugar.ctr_sugar_of ctxt (fst (dest_Type (fastype_of t)))
         in
-          fold (union Thm.eq_thm) (case_rewrites :: map get_case_rewrite (snd (strip_comb t))) []
+          fold (union Thm.eq_thm) (case_thms :: map get_case_rewrite (snd (strip_comb t))) []
         end
       else []
     val simprules = insert Thm.eq_thm @{thm "unit.cases"} (insert Thm.eq_thm @{thm "prod.cases"}
@@ -309,14 +297,13 @@
 
 fun prove_match2 options ctxt out_ts =
   let
-    val thy = Proof_Context.theory_of ctxt
     fun split_term_tac (Free _) = all_tac
       | split_term_tac t =
-        if (is_constructor thy t) then
+        if is_constructor ctxt t then
           let
-            val {case_rewrites, split_asm, ...} =
-              Datatype.the_info thy (fst (dest_Type (fastype_of t)))
-            val num_of_constrs = length case_rewrites
+            val SOME {case_thms, split_asm, ...} =
+              Ctr_Sugar.ctr_sugar_of ctxt (fst (dest_Type (fastype_of t)))
+            val num_of_constrs = length case_thms
             val (_, ts) = strip_comb t
           in
             print_tac options ("Term " ^ (Syntax.string_of_term ctxt t) ^ 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML	Wed Feb 12 08:35:56 2014 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML	Wed Feb 12 08:35:56 2014 +0100
@@ -41,9 +41,7 @@
 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*)
 fun is_nontrivial_constrt thy t =
   let
-    val cnstrs = flat (maps
-      (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
-      (Symtab.dest (Datatype.get_all thy)));
+    val cnstrs = get_constrs thy
     fun check t = (case strip_comb t of
         (Var _, []) => (true, true)
       | (Free _, []) => (true, true)
@@ -107,6 +105,7 @@
 
 and find_specialisations black_list specs thy =
   let
+    val ctxt = Proof_Context.init_global thy
     val add_vars = fold_aterms (fn Var v => cons v | _ => I);
     fun fresh_free T free_names =
       let
@@ -132,10 +131,11 @@
       | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
         replace_term_and_restrict thy T t Tts free_names
       | restrict_pattern' thy ((T as Type (Tcon, _), t) :: Tts) free_names =
-        case Datatype.get_constrs thy Tcon of
+        case Ctr_Sugar.ctr_sugar_of ctxt Tcon of
           NONE => replace_term_and_restrict thy T t Tts free_names
-        | SOME constrs => (case strip_comb t of
-          (Const (s, _), ats) => (case AList.lookup (op =) constrs s of
+        | SOME {ctrs, ...} => (case strip_comb t of
+          (Const (s, _), ats) =>
+          (case AList.lookup (op =) (map_filter (try dest_Const) ctrs) s of
             SOME constr_T =>
               let
                 val (Ts', T') = strip_type constr_T