simplified code; eliminated some dummyTs
authorpanny
Thu, 19 Sep 2013 16:12:43 +0200
changeset 53735 99331dac1e1c
parent 53734 7613573f023a
child 53736 82799e03fff7
simplified code; eliminated some dummyTs
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Sep 19 12:20:12 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Sep 19 16:12:43 2013 +0200
@@ -56,6 +56,8 @@
         | a n t = let val idx = find_index (equal t) vs in
             if idx < 0 then t else Bound (n + idx) end
   in a 0 end;
+fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
+fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
 
 val simp_attrs = @{attributes [simp]};
 
@@ -561,48 +563,43 @@
   |> the_default undef_const
   |> K;
 
-fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
+fun build_corec_args_direct_call lthy has_call sel_eqns sel =
   let
     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
-    fun massage rhs_term is_end t =
-      let
-        val U = range_type (fastype_of t);
-        fun rewrite t =
-          if U = @{typ bool} then (if has_call t then @{term False} else @{term True}) (* stop? *)
-          else if is_end = has_call t then undef_const
-          else if is_end then t (* end *)
-          else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
-      in
-        massage_direct_corec_call lthy has_call rewrite U rhs_term
-      end;
+    fun rewrite_q t = if has_call t then @{term False} else @{term True};
+    fun rewrite_g t = if has_call t then undef_const else t;
+    fun rewrite_h t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
+    fun massage _ NONE t = t
+      | massage f (SOME {fun_args, rhs_term, ...}) t =
+        massage_direct_corec_call lthy has_call f (range_type (fastype_of t)) rhs_term
+        |> abs_tuple fun_args;
   in
-    if is_none maybe_sel_eqn then K I else
-      abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
+    (massage rewrite_q maybe_sel_eqn,
+     massage rewrite_g maybe_sel_eqn,
+     massage rewrite_h maybe_sel_eqn)
   end;
 
 fun build_corec_arg_indirect_call lthy has_call sel_eqns sel =
   let
     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
-    fun rewrite (Abs (v, T, b)) = Abs (v, T, rewrite b)
-      | rewrite t =
+    fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
+      | rewrite bound_Ts U T (t as _ $ _) =
         let val (u, vs) = strip_comb t in
           if is_Free u andalso has_call u then
-            Const (@{const_name Inr}, dummyT) $
-              (if null vs then HOLogic.unit
-               else foldr1 (fn (x, y) => Const (@{const_name Pair}, dummyT) $ x $ y) vs)
+            Inr_const U T $ mk_tuple1 bound_Ts vs
           else if try (fst o dest_Const) u = SOME @{const_name prod_case} then
-            list_comb (u |> map_types (K dummyT), map rewrite vs)
-          else if null vs then
-            u
+            list_comb (map_types (K dummyT) u, map (rewrite bound_Ts U T) vs)
           else
-            list_comb (rewrite u, map rewrite vs)
-        end;
-    fun massage rhs_term t =
-      massage_indirect_corec_call lthy has_call (K (K rewrite)) [] (range_type (fastype_of t))
-        rhs_term;
+            list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
+        end
+      | rewrite _ U T t = if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
+    fun massage NONE t = t
+      | massage (SOME {fun_args, rhs_term, ...}) t =
+        massage_indirect_corec_call lthy has_call (rewrite []) []
+          (range_type (fastype_of t)) rhs_term
+        |> abs_tuple fun_args;
   in
-    if is_none maybe_sel_eqn then I else
-      abs_tuple (#fun_args (the maybe_sel_eqn)) o massage (#rhs_term (the maybe_sel_eqn))
+    massage maybe_sel_eqn
   end;
 
 fun build_corec_args_sel lthy has_call all_sel_eqns ctr_spec =
@@ -616,11 +613,10 @@
         val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
       in
         I
-        #> fold (fn (sel, n) => nth_map n
-          (build_corec_arg_no_call sel_eqns sel)) no_calls'
+        #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
         #> fold (fn (sel, (q, g, h)) =>
-          let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
-            nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
+          let val (fq, fg, fh) = build_corec_args_direct_call lthy has_call sel_eqns sel in
+            nth_map q fq o nth_map g fg o nth_map h fh end) direct_calls'
         #> fold (fn (sel, n) => nth_map n
           (build_corec_arg_indirect_call lthy has_call sel_eqns sel)) indirect_calls'
       end
@@ -636,10 +632,9 @@
       |> map (Const o pair @{const_name undefined})
       |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
       |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
-    fun currys Ts t = if length Ts <= 1 then t else
-      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
-        (length Ts - 1 downto 0 |> map Bound)
-      |> fold_rev (Term.abs o pair Name.uu) Ts;
+    fun currys [] t = t
+      | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
+          |> fold_rev (Term.abs o pair Name.uu) Ts;
 
 val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
@@ -792,7 +787,6 @@
 
         fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
             {ctr, disc, sels, collapse, ...} =
-let val _ = tracing ("disc = " ^ @{make_string} disc); in
           if not (exists (equal ctr o #ctr) disc_eqns)
               andalso not (exists (equal ctr o #ctr) sel_eqns)
 andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true)
@@ -804,7 +798,7 @@
           then [] else
             let
 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
-val _ = tracing (the_default "NO disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
+val _ = tracing (the_default "no disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
               val (fun_name, fun_T, fun_args, prems) =
                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
@@ -831,8 +825,6 @@
               mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
               |> K |> Goal.prove lthy [] [] t
               |> single
-(*handle ERROR x => (warning x; []))*)
-end
           end;
 
         val (disc_notes, disc_thmss) =