src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 58211 c1f3fa32d322
parent 58189 9d714be4f028
child 58213 6411ac1ef04d
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:01 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:01 2014 +0200
@@ -43,15 +43,152 @@
 
 open Ctr_Sugar
 open BNF_Util
+open BNF_Tactics
 open BNF_FP_Util
 open BNF_FP_Def_Sugar
 open BNF_FP_N2M_Sugar
 open BNF_LFP
 
-val compatN = "compat_";
+val compat_N = "compat_";
+val rec_fun_N = "rec_fun_";
 
 datatype nesting_preference = Keep_Nesting | Unfold_Nesting;
 
+fun mk_fun_rec_rhs ctxt fpTs Cs (recs as rec1 :: _) =
+  let
+    fun repair_rec_arg_args [] [] = []
+      | repair_rec_arg_args ((g_T as Type (@{type_name fun}, _)) :: g_Ts) (g :: gs) =
+        let
+          val (x_Ts, body_T) = strip_type g_T;
+        in
+          (case try HOLogic.dest_prodT body_T of
+            NONE => [g]
+          | SOME (fst_T, _) =>
+            if member (op =) fpTs fst_T then
+              let val (xs, _) = mk_Frees "x" x_Ts ctxt in
+                map (fn mk_proj => fold_rev Term.lambda xs (mk_proj (Term.list_comb (g, xs))))
+                  [HOLogic.mk_fst, HOLogic.mk_snd]
+              end
+            else
+              [g])
+          :: repair_rec_arg_args g_Ts gs
+        end
+      | repair_rec_arg_args (g_T :: g_Ts) (g :: gs) =
+        if member (op =) fpTs g_T then
+          let
+            val j = find_index (member (op =) Cs) g_Ts;
+            val h = nth gs j;
+            val g_Ts' = nth_drop j g_Ts;
+            val gs' = nth_drop j gs;
+          in
+            [g, h] :: repair_rec_arg_args g_Ts' gs'
+          end
+        else
+          [g] :: repair_rec_arg_args g_Ts gs;
+
+    fun repair_back_rec_arg f_T f' =
+      let
+        val g_Ts = Term.binder_types f_T;
+        val (gs, _) = mk_Frees "g" g_Ts ctxt;
+      in
+        fold_rev Term.lambda gs (Term.list_comb (f',
+          flat_rec_arg_args (repair_rec_arg_args g_Ts gs)))
+      end;
+
+    val f_Ts = binder_fun_types (fastype_of rec1);
+    val (fs', _) = mk_Frees "f" (replicate (length f_Ts) Term.dummyT) ctxt;
+
+    fun mk_rec' recx =
+      fold_rev Term.lambda fs' (Term.list_comb (recx, map2 repair_back_rec_arg f_Ts fs'))
+      |> Syntax.check_term ctxt;
+  in
+    map mk_rec' recs
+  end;
+
+fun define_fun_recs fpTs Cs recs lthy =
+  let
+    val b_names = Name.variant_list [] (map base_name_of_typ fpTs);
+
+    fun mk_binding b_name =
+      Binding.qualify true (compat_N ^ b_name)
+        (Binding.prefix_name rec_fun_N (Binding.name b_name));
+
+    val bs = map mk_binding b_names;
+    val rhss = mk_fun_rec_rhs lthy fpTs Cs recs;
+  in
+    fold_map3 (define_co_rec_as Least_FP Cs) fpTs bs rhss lthy
+  end;
+
+fun mk_fun_rec_thmss ctxt rec0_thmss (recs as rec1 :: _) rec_defs =
+  let
+    val f_Ts = binder_fun_types (fastype_of rec1);
+    val (fs, _) = mk_Frees "f" f_Ts ctxt;
+    val frecs = map (fn recx => Term.list_comb (recx, fs)) recs;
+
+    fun mk_ctrs_of (Type (T_name, As)) =
+      map (mk_ctr As) (#ctrs (the (ctr_sugar_of ctxt T_name)));
+
+    val fpTs = map (domain_type o body_fun_type o fastype_of) recs;
+    val fpTs_frecs = fpTs ~~ frecs;
+    val ctrss = map mk_ctrs_of fpTs;
+    val fss = unflat ctrss fs;
+
+    fun mk_rec_call g n (Type (@{type_name fun}, [dom_T, ran_T])) =
+        Abs (Name.uu, dom_T, mk_rec_call g (n + 1) ran_T)
+      | mk_rec_call g n fpT =
+        let
+          val frec = the (AList.lookup (op =) fpTs_frecs fpT);
+          val xg = Term.list_comb (g, map Bound (n - 1 downto 0));
+        in frec $ xg end;
+
+    fun mk_rec_arg_arg g_T g =
+      g :: (if exists_subtype_in fpTs g_T then [mk_rec_call g 0 g_T] else []);
+
+    fun mk_goal frec ctr f =
+      let
+        val g_Ts = binder_types (fastype_of ctr);
+        val (gs, _) = mk_Frees "g" g_Ts ctxt;
+        val gctr = Term.list_comb (ctr, gs);
+        val fgs = flat_rec_arg_args (map2 mk_rec_arg_arg g_Ts gs);
+      in
+        fold_rev (fold_rev Logic.all) [fs, gs]
+          (mk_Trueprop_eq (frec $ gctr, Term.list_comb (f, fgs)))
+      end;
+
+    fun mk_goals ctrs fs frec = map2 (mk_goal frec) ctrs fs;
+
+    val goalss = map3 mk_goals ctrss fss frecs;
+
+    fun tac ctxt =
+      unfold_thms_tac ctxt (@{thms o_apply fst_conv snd_conv} @ rec_defs @ flat rec0_thmss) THEN
+      HEADGOAL (rtac refl);
+
+    fun prove goal =
+      Goal.prove_sorry ctxt [] [] goal (tac o #context)
+      |> Thm.close_derivation;
+  in
+    map (map prove) goalss
+  end;
+
+fun define_fun_rec_derive_thms induct inducts recs0 rec_thmss fpTs lthy =
+  let
+    val thy = Proof_Context.theory_of lthy;
+
+    (* imperfect: will not yield the expected theorem for functions taking a large number of
+       arguments *)
+    val repair_induct = unfold_thms lthy @{thms all_mem_range};
+
+    val induct' = repair_induct induct;
+    val inducts' = map repair_induct inducts;
+
+    val Cs = map ((fn TVar ((s, _), S) => TFree (s, S)) o body_type o fastype_of) recs0;
+    val recs = map2 (mk_co_rec thy Least_FP Cs) fpTs recs0;
+    val ((recs', rec'_defs), lthy') = define_fun_recs fpTs Cs recs lthy |>> split_list;
+    val rec'_thmss = mk_fun_rec_thmss lthy' rec_thmss recs' rec'_defs;
+  in
+    ((induct', inducts', recs', rec'_thmss), lthy')
+  end;
+
 fun reindex_desc desc =
   let
     val kks = map fst desc;
@@ -130,10 +267,10 @@
 
     val dest_dtyp = Old_Datatype_Aux.typ_of_dtyp descr;
 
-    val Ts = Old_Datatype_Aux.get_rec_types descr;
-    val nn = length Ts;
+    val fpTs' = Old_Datatype_Aux.get_rec_types descr;
+    val nn = length fpTs';
 
-    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
+    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) fpTs';
     val ctr_Tsss = map (map (map dest_dtyp o snd) o #3 o snd) descr;
     val kkssss =
       map (map (map (fn Old_Datatype_Aux.DtRec kk => [kk] | _ => []) o snd) o #3 o snd) descr;
@@ -146,34 +283,47 @@
     val callssss =
       map2 (map2 (map2 (fn ctr_T => map (apply_comps (num_binder_types ctr_T))))) ctr_Tsss kkssss;
 
-    val b_names = Name.variant_list [] (map base_name_of_typ Ts);
-    val compat_b_names = map (prefix compatN) b_names;
+    val b_names = Name.variant_list [] (map base_name_of_typ fpTs');
+    val compat_b_names = map (prefix compat_N) b_names;
     val compat_bs = map Binding.name compat_b_names;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy') =
       if nn > nn_fp then
-        mutualize_fp_sugars Least_FP cliques compat_bs Ts callers callssss fp_sugars0 lthy
+        mutualize_fp_sugars Least_FP cliques compat_bs fpTs' callers callssss fp_sugars0 lthy
       else
         ((fp_sugars0, (NONE, NONE)), lthy);
 
-    val recs = map (fst o dest_Const o #co_rec) fp_sugars;
-    val rec_thms = maps #co_rec_thms fp_sugars;
-
     val {common_co_inducts = [induct], ...} :: _ = fp_sugars;
     val inducts = map (the_single o #co_inducts) fp_sugars;
 
+    val recs = map #co_rec fp_sugars;
+    val rec_thmss = map #co_rec_thms fp_sugars;
+
+    fun is_nested_rec_type (Type (@{type_name fun}, [_, T])) = member (op =) fpTs' (body_type T)
+      | is_nested_rec_type _ = false;
+
+    val ((induct', inducts', recs', rec'_thmss), lthy'') =
+      if nesting_pref = Unfold_Nesting andalso
+         exists (exists (exists is_nested_rec_type)) ctr_Tsss then
+        define_fun_rec_derive_thms induct inducts recs rec_thmss fpTs' lthy'
+      else
+        ((induct, inducts, recs, rec_thmss), lthy');
+
+    val rec'_names = map (fst o dest_Const) recs';
+    val rec'_thms = flat rec'_thmss;
+
     fun mk_info (kk, {T = Type (T_name0, _), ctr_sugar = {casex, exhaust, nchotomy, injects,
         distincts, case_thms, case_cong, case_cong_weak, split, split_asm, ...}, ...} : fp_sugar) =
       (T_name0,
-       {index = kk, descr = descr, inject = injects, distinct = distincts, induct = induct,
-        inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
-        rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
+       {index = kk, descr = descr, inject = injects, distinct = distincts, induct = induct',
+        inducts = inducts', exhaust = exhaust, nchotomy = nchotomy, rec_names = rec'_names,
+        rec_rewrites = rec'_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
         case_cong = case_cong, case_cong_weak = case_cong_weak, split = split,
         split_asm = split_asm});
 
     val infos = map_index mk_info (take nn_fp fp_sugars);
   in
-    (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy')
+    (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy'')
   end;
 
 fun infos_of_new_datatype_mutual_cluster lthy fpT_name =
@@ -298,7 +448,7 @@
         NONE => []
       | SOME ((induct_thms, induct_thm, induct_attrs), (rec_thmss, _)) =>
         let
-          val common_name = compatN ^ mk_common_name b_names;
+          val common_name = compat_N ^ mk_common_name b_names;
 
           val common_notes =
             (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])