simplified rewriting of map arguments
authorpanny
Sun, 01 Sep 2013 10:45:54 +0200
changeset 53357 46b0c7a08af7
parent 53356 c5a1629d8e45
child 53358 b46e6cd75dc6
simplified rewriting of map arguments
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sun Sep 01 16:38:04 2013 +1000
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sun Sep 01 10:45:54 2013 +0200
@@ -28,8 +28,11 @@
 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
 
+val free_name = try (fn Free (v, _) => v);
+val const_name = try (fn Const (v, _) => v);
+
 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
-fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else
+fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else
   strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
 
 val simp_attrs = @{attributes [simp]};
@@ -103,43 +106,34 @@
      user_eqn = eqn'}
   end;
 
-fun rewrite_map_arg funs_data get_indices y rec_type res_type =
+fun rewrite_map_arg funs_data get_indices rec_type res_type =
   let
-    val pT = HOLogic.mk_prodT (rec_type, res_type);
-    val fstx = fst_const pT;
-    val sndx = snd_const pT;
+    val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data));
+    val fun_name = #fun_name fun_data;
+    val ctr_pos = length (#left_args fun_data);
 
-    val SOME ({fun_name, left_args, ...} :: _) =
-      find_first (equal rec_type o #rec_type o hd) funs_data;
-    val ctr_pos = length left_args;
+    val pT = HOLogic.mk_prodT (rec_type, res_type);
 
-    fun subst _ d (t as Bound d') = t |> d = d' ? curry (op $) fstx
-      | subst l d (Abs (v, T, b)) = Abs (v, if d < 0 then pT else T, subst l (d + 1) b)
-      | subst l d t =
+    val maybe_suc = Option.map (fn x => x + 1);
+    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
+      | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
+      | subst d t =
         let val (u, vs) = strip_comb t in
-          if try (fst o dest_Free) u = SOME fun_name then
-            if l andalso length vs = ctr_pos then
-              list_comb (sndx |> permute_args ctr_pos, vs)
-            else if length vs <= ctr_pos then
-              primrec_error_eqn "too few arguments in recursive call" t
-            else if nth vs ctr_pos |> member (op =) [y, Bound d] then
-              list_comb (sndx $ nth vs ctr_pos, nth_drop ctr_pos vs |> map (subst false d))
+          if free_name u = SOME fun_name then
+            if d = SOME ~1 andalso length vs = ctr_pos then
+              list_comb (permute_args ctr_pos (snd_const pT), vs)
+            else if length vs > ctr_pos andalso is_some d
+                andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
+              list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
             else
-              primrec_error_eqn "recursive call not directly applied to constructor argument" t
-          else if try (fst o dest_Const) u = SOME @{const_name comp} then
-            (hd vs |> get_indices |> null orelse
-              primrec_error_eqn "recursive call not directly applied to constructor argument" t;
-            list_comb
-              (u |> map_types (strip_type #>> (fn Ts => Ts
-                   |> nth_map (length Ts - 1) (K pT)
-                   |> nth_map (length Ts - 2) (strip_type #>> nth_map 0 (K pT) #> (op --->)))
-                 #> (op --->)),
-              nth_map 1 (subst l d) vs))
+              primrec_error_eqn ("recursive call not directly applied to constructor argument") t
+          else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
+            list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
           else
-            list_comb (u, map (subst false d) vs)
+            list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
         end
   in
-    subst true ~1
+    subst (SOME ~1)
   end;
 
 (* FIXME get rid of funs_data or get_indices *)
@@ -164,7 +158,7 @@
           else if is_some maybe_indirect_y' then
             (if contains_fun g then t else y)
             |> massage_indirect_rec_call lthy contains_fun
-              (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
+              (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y')
             |> (if contains_fun g then I else curry (op $) g)
           else
             t
@@ -426,7 +420,7 @@
     val fun_args = if is_none disc
       then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd
       else the disc |> the_single o snd o strip_comb
-        |> (fn t => if try (fst o dest_Free o head_of) t = SOME fun_name
+        |> (fn t => if free_name (head_of t) = SOME fun_name
           then snd (strip_comb t) else []);
 
     val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};