more flexible parsing (towards type class support)
authorblanchet
Tue, 31 May 2016 10:53:11 +0200
changeset 63188 38d6aabec460
parent 63187 da1cd3ce80c2
child 63189 d5974697765b
more flexible parsing (towards type class support)
src/HOL/Tools/BNF/bnf_gfp_grec_sugar.ML
--- a/src/HOL/Tools/BNF/bnf_gfp_grec_sugar.ML	Tue May 31 10:53:10 2016 +0200
+++ b/src/HOL/Tools/BNF/bnf_gfp_grec_sugar.ML	Tue May 31 10:53:11 2016 +0200
@@ -793,7 +793,7 @@
 fun parse_corec_equation ctxt fun_frees eq =
   let
     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop (drop_all eq))
-      handle TERM _ => error "Expected equation";
+      handle TERM _ => error "Expected HOL equation";
 
     val _ = check_corec_equation ctxt fun_frees (lhs, rhs);
 
@@ -1347,12 +1347,13 @@
     fun maybe_const_type ctxt (s, T) =
       Sign.const_type (Proof_Context.theory_of ctxt) s |> the_default T;
 
-    fun massage_polymorphic_const explore (params as {bound_Us, ...}) t =
+    fun massage_const polymorphic explore (params as {bound_Us, ...}) t =
       let val (fun_t, arg_ts) = strip_comb t in
         (case fun_t of
           Const (fun_x as (s, fun_T)) =>
-          let val general_T = maybe_const_type lthy fun_x in
-            if contains_res_T (body_type general_T) orelse is_constant t then
+          let val general_T = if polymorphic then maybe_const_type lthy fun_x else fun_T in
+            if fun_t aconv friend_tm orelse contains_res_T (body_type general_T) orelse
+               is_constant t then
               explore params t
             else
               let
@@ -1437,12 +1438,12 @@
                     val fun_t' = Const (s, fun_U);
                     val t' = build_function_after_encapsulation fun_t fun_t' params arg_ts arg_ts';
                   in
-                    (case try type_of1 (bound_Us, t') of
-                      SOME _ =>
+                    if can type_of1 (bound_Us, t') then
                       (if fun_T = fun_U orelse is_special_parametric_const (s, fun_T) then ()
                        else add_parametric_const s general_T fun_T fun_U;
                        t')
-                    | NONE => explore params t)
+                    else
+                      explore params t
                   end
                 | NONE => explore params t)
               end
@@ -1453,7 +1454,7 @@
     fun massage_rho explore =
       massage_star [massage_let, massage_if explore_cond, massage_case, massage_fun, massage_comp,
           massage_map, massage_ctr, massage_sel, massage_disc, massage_equality,
-          massage_polymorphic_const]
+          massage_const false, massage_const true]
         explore
     and massage_case explore (params as {bound_Ts, bound_Us, ...}) t =
       (case strip_comb t of
@@ -2240,20 +2241,23 @@
 
 fun friend_of_corec_cmd ((raw_fun_name, raw_fun_T_opt), raw_eq) lthy =
   let
-    val Const (fun_name, default_fun_T0) =
+    val Const (fun_name, _) =
       Proof_Context.read_const {proper = true, strict = false} lthy raw_fun_name;
-    val fun_T =
-      (case raw_fun_T_opt of
-        SOME raw_T => Syntax.read_typ lthy raw_T
-      | NONE => singleton (freeze_types lthy []) default_fun_T0);
+
+    val fake_lthy = lthy
+      |> (case raw_fun_T_opt of
+           SOME raw_T =>
+           Proof_Context.add_const_constraint (fun_name, SOME (Syntax.read_typ lthy raw_T))
+         | NONE => I);
 
-    val fun_t = Const (fun_name, fun_T);
     val fun_b = Binding.name (Long_Name.base_name fun_name);
+    val code_goal = Syntax.read_prop fake_lthy raw_eq;
 
-    val fake_lthy = lthy |> Proof_Context.add_const_constraint (fun_name, SOME fun_T)
-      handle TYPE (msg, _, _) => error msg;
-
-    val code_goal = Syntax.read_prop fake_lthy raw_eq;
+    val fun_T =
+      (case code_goal of
+        @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ t $ _) => fastype_of (head_of t)
+      | _ => error "Expected HOL equation");
+    val fun_t = Const (fun_name, fun_T);
 
     val (arg_Ts, res_T as Type (fpT_name, _)) = strip_type fun_T;
 
@@ -2279,7 +2283,7 @@
 
         val fun_free = Free (Binding.name_of fun_b, fun_T);
 
-        fun freeze_fun (t as Const (s, _)) = if s = fun_name then fun_free else t
+        fun freeze_fun (t as Const (s, T)) = if s = fun_name andalso T = fun_T then fun_free else t
           | freeze_fun t = t;
 
         val eq = Term.map_aterms freeze_fun code_goal;