src/HOL/Tools/Nitpick/nitpick_hol.ML
changeset 55080 b7c41accbff2
parent 55017 2df6ad1dbd66
child 55414 eab03e9cee8a
--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Jan 20 19:05:25 2014 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Jan 20 19:51:56 2014 +0100
@@ -110,7 +110,6 @@
   val strip_n_binders : int -> typ -> typ list * typ
   val nth_range_type : int -> typ -> typ
   val num_factors_in_type : typ -> int
-  val num_binder_types : typ -> int
   val curried_binder_types : typ -> typ list
   val mk_flat_tuple : typ -> term list -> term
   val dest_n_tuple : int -> term -> term list
@@ -293,6 +292,7 @@
 datatype boxability =
   InConstr | InSel | InExpr | InPair | InFunLHS | InFunRHS1 | InFunRHS2
 
+(* FIXME: Get rid of 'codatatypes' and related functionality *)
 structure Data = Generic_Data
 (
   type T = {frac_types: (string * (string * string) list) list,
@@ -543,9 +543,6 @@
 fun num_factors_in_type (Type (@{type_name prod}, [T1, T2])) =
     fold (Integer.add o num_factors_in_type) [T1, T2] 0
   | num_factors_in_type _ = 1
-fun num_binder_types (Type (@{type_name fun}, [_, T2])) =
-    1 + num_binder_types T2
-  | num_binder_types _ = 0
 val curried_binder_types = maps HOLogic.flatten_tupleT o binder_types
 fun maybe_curried_binder_types T =
   (if is_pair_type (body_type T) then binder_types else curried_binder_types) T
@@ -595,8 +592,8 @@
                  @{type_name int}, @{type_name natural}, @{type_name integer}] s orelse
   (s = @{type_name nat} andalso is_standard_datatype thy stds nat_T)
 
-fun repair_constr_type thy body_T' T =
-  varify_and_instantiate_type_global thy (body_type T) body_T' T
+fun repair_constr_type (Type (_, Ts)) T =
+  snd (dest_Const (Ctr_Sugar.mk_ctr Ts (Const (Name.uu, T))))
 
 fun register_frac_type_generic frac_s ersaetze generic =
   let
@@ -627,39 +624,41 @@
   register_ersatz_generic ersatz
 val register_ersatz_global = Context.theory_map o register_ersatz_generic
 
-fun register_codatatype_generic co_T case_name constr_xs generic =
+fun register_codatatype_generic coT case_name constr_xs generic =
   let
     val thy = Context.theory_of generic
     val {frac_types, ersatz_table, codatatypes} = Data.get generic
-    val constr_xs = map (apsnd (repair_constr_type thy co_T)) constr_xs
-    val (co_s, co_Ts) = dest_Type co_T
+    val constr_xs = map (apsnd (repair_constr_type coT)) constr_xs
+    val (co_s, coTs) = dest_Type coT
     val _ =
-      if forall is_TFree co_Ts andalso not (has_duplicates (op =) co_Ts) andalso
+      if forall is_TFree coTs andalso not (has_duplicates (op =) coTs) andalso
          co_s <> @{type_name fun} andalso
          not (is_basic_datatype thy [(NONE, true)] co_s) then
         ()
       else
-        raise TYPE ("Nitpick_HOL.register_codatatype_generic", [co_T], [])
+        raise TYPE ("Nitpick_HOL.register_codatatype_generic", [coT], [])
     val codatatypes = AList.update (op =) (co_s, (case_name, constr_xs))
                                    codatatypes
   in Data.put {frac_types = frac_types, ersatz_table = ersatz_table,
                codatatypes = codatatypes} generic end
 (* TODO: Consider morphism. *)
-fun register_codatatype co_T case_name constr_xs (_ : morphism) =
-  register_codatatype_generic co_T case_name constr_xs
+fun register_codatatype coT case_name constr_xs (_ : morphism) =
+  register_codatatype_generic coT case_name constr_xs
 val register_codatatype_global =
   Context.theory_map ooo register_codatatype_generic
 
-fun unregister_codatatype_generic co_T = register_codatatype_generic co_T "" []
+fun unregister_codatatype_generic coT = register_codatatype_generic coT "" []
 (* TODO: Consider morphism. *)
-fun unregister_codatatype co_T (_ : morphism) =
-  unregister_codatatype_generic co_T
+fun unregister_codatatype coT (_ : morphism) =
+  unregister_codatatype_generic coT
 val unregister_codatatype_global =
   Context.theory_map o unregister_codatatype_generic
 
 fun is_codatatype ctxt (Type (s, _)) =
-    s |> AList.lookup (op =) (#codatatypes (Data.get (Context.Proof ctxt)))
-      |> Option.map snd |> these |> null |> not
+    Option.map #fp (BNF_FP_Def_Sugar.fp_sugar_of ctxt s)
+        = SOME BNF_FP_Util.Greatest_FP orelse
+    not (null (these (Option.map snd (AList.lookup (op =)
+                            (#codatatypes (Data.get (Context.Proof ctxt))) s))))
   | is_codatatype _ _ = false
 fun is_registered_type ctxt T = is_frac_type ctxt T orelse is_codatatype ctxt T
 fun is_real_quot_type ctxt (Type (s, _)) =
@@ -782,15 +781,24 @@
     raise TYPE ("Nitpick_HOL.equiv_relation_for_quot_type", [T], [])
 
 fun is_coconstr ctxt (s, T) =
-  let val thy = Proof_Context.theory_of ctxt in
-    case body_type T of
-      co_T as Type (co_s, _) =>
-      let val {codatatypes, ...} = Data.get (Context.Proof ctxt) in
-        exists (fn (s', T') => s = s' andalso repair_constr_type thy co_T T' = T)
-               (AList.lookup (op =) codatatypes co_s |> Option.map snd |> these)
-      end
-    | _ => false
-  end
+  case body_type T of
+    coT as Type (co_s, _) =>
+    let
+      val ctrs1 =
+        co_s
+        |> AList.lookup (op =) (#codatatypes (Data.get (Context.Proof ctxt)))
+        |> Option.map snd |> these
+      val ctrs2 =
+        (case BNF_FP_Def_Sugar.fp_sugar_of ctxt co_s of
+           SOME (fp_sugar as {fp = BNF_FP_Util.Greatest_FP, ...}) =>
+           map dest_Const
+               (#ctrs (BNF_FP_Def_Sugar.of_fp_sugar #ctr_sugars fp_sugar))
+         | _ => [])
+    in
+      exists (fn (s', T') => s = s' andalso repair_constr_type coT T' = T)
+             (ctrs1 @ ctrs2)
+    end
+  | _ => false
 fun is_constr_like ctxt (s, T) =
   member (op =) [@{const_name FunBox}, @{const_name PairBox},
                  @{const_name Quot}, @{const_name Zero_Rep},
@@ -924,44 +932,49 @@
                               (T as Type (s, Ts)) =
     (case AList.lookup (op =) (#codatatypes (Data.get (Context.Proof ctxt)))
                        s of
-       SOME (_, xs' as (_ :: _)) => map (apsnd (repair_constr_type thy T)) xs'
+       SOME (_, xs' as (_ :: _)) => map (apsnd (repair_constr_type T)) xs'
      | _ =>
-       if is_frac_type ctxt T then
-         case typedef_info ctxt s of
-           SOME {abs_type, rep_type, Abs_name, ...} =>
-           [(Abs_name,
-             varify_and_instantiate_type ctxt abs_type T rep_type --> T)]
-         | NONE => [] (* impossible *)
-       else if is_datatype ctxt stds T then
-         case Datatype.get_info thy s of
-           SOME {index, descr, ...} =>
-           let
-             val (_, dtyps, constrs) = AList.lookup (op =) descr index |> the
-           in
-             map (apsnd (fn Us =>
-                            map (typ_of_dtyp descr (dtyps ~~ Ts)) Us ---> T))
-                 constrs
-           end
-         | NONE =>
-           if is_record_type T then
-             let
-               val s' = unsuffix Record.ext_typeN s ^ Record.extN
-               val T' = (Record.get_extT_fields thy T
-                        |> apsnd single |> uncurry append |> map snd) ---> T
-             in [(s', T')] end
-           else if is_real_quot_type ctxt T then
-             [(@{const_name Quot}, rep_type_for_quot_type ctxt T --> T)]
-           else case typedef_info ctxt s of
-             SOME {abs_type, rep_type, Abs_name, ...} =>
-             [(Abs_name,
-               varify_and_instantiate_type ctxt abs_type T rep_type --> T)]
-           | NONE =>
-             if T = @{typ ind} then
-               [dest_Const @{const Zero_Rep}, dest_Const @{const Suc_Rep}]
-             else
-               []
-       else
-         [])
+       (case BNF_FP_Def_Sugar.fp_sugar_of ctxt s of
+          SOME (fp_sugar as {fp = BNF_FP_Util.Greatest_FP, ...}) =>
+          map (apsnd (repair_constr_type T) o dest_Const)
+              (#ctrs (BNF_FP_Def_Sugar.of_fp_sugar #ctr_sugars fp_sugar))
+        | _ =>
+          if is_frac_type ctxt T then
+            case typedef_info ctxt s of
+              SOME {abs_type, rep_type, Abs_name, ...} =>
+              [(Abs_name,
+                varify_and_instantiate_type ctxt abs_type T rep_type --> T)]
+            | NONE => [] (* impossible *)
+          else if is_datatype ctxt stds T then
+            case Datatype.get_info thy s of
+              SOME {index, descr, ...} =>
+              let
+                val (_, dtyps, constrs) = AList.lookup (op =) descr index |> the
+              in
+                map (apsnd (fn Us =>
+                               map (typ_of_dtyp descr (dtyps ~~ Ts)) Us ---> T))
+                    constrs
+              end
+            | NONE =>
+              if is_record_type T then
+                let
+                  val s' = unsuffix Record.ext_typeN s ^ Record.extN
+                  val T' = (Record.get_extT_fields thy T
+                           |> apsnd single |> uncurry append |> map snd) ---> T
+                in [(s', T')] end
+              else if is_real_quot_type ctxt T then
+                [(@{const_name Quot}, rep_type_for_quot_type ctxt T --> T)]
+              else case typedef_info ctxt s of
+                SOME {abs_type, rep_type, Abs_name, ...} =>
+                [(Abs_name,
+                  varify_and_instantiate_type ctxt abs_type T rep_type --> T)]
+              | NONE =>
+                if T = @{typ ind} then
+                  [dest_Const @{const Zero_Rep}, dest_Const @{const Suc_Rep}]
+                else
+                  []
+          else
+            []))
   | uncached_datatype_constrs _ _ = []
 fun datatype_constrs (hol_ctxt as {constr_cache, ...}) T =
   case AList.lookup (op =) (!constr_cache) T of
@@ -1451,7 +1464,13 @@
                       cons (case_name, AList.lookup (op =) descr index
                                        |> the |> #3 |> length))
                 (Datatype.get_all thy) [] @
-    map (apsnd length o snd) (#codatatypes (Data.get (Context.Proof ctxt)))
+    map (apsnd length o snd) (#codatatypes (Data.get (Context.Proof ctxt))) @
+    maps (fn {fp, ctr_sugars, ...} =>
+             if fp = BNF_FP_Util.Greatest_FP then
+               map (apsnd num_binder_types o dest_Const o #casex) ctr_sugars
+             else
+               [])
+         (BNF_FP_Def_Sugar.fp_sugars_of ctxt)
   end
 
 fun fixpoint_kind_of_const thy table x =