src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53358 b46e6cd75dc6
parent 53357 46b0c7a08af7
child 53360 7ffc4a746a73
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sun Sep 01 10:45:54 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sun Sep 01 14:00:05 2013 +0200
@@ -28,11 +28,15 @@
 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
 
+fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
+
 val free_name = try (fn Free (v, _) => v);
 val const_name = try (fn Const (v, _) => v);
+val undef_const = Const (@{const_name undefined}, dummyT);
 
-fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
-fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else
+fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
+  |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
+fun abs_tuple t = if t = undef_const then t else
   strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
 
 val simp_attrs = @{attributes [simp]};
@@ -53,9 +57,6 @@
   user_eqn: term
 };
 
-fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
-  |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
-
 fun dissect_eqn lthy fun_names eqn' =
   let
     val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
@@ -106,20 +107,21 @@
      user_eqn = eqn'}
   end;
 
-fun rewrite_map_arg funs_data get_indices rec_type res_type =
+fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
   let
-    val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data));
-    val fun_name = #fun_name fun_data;
-    val ctr_pos = length (#left_args fun_data);
-
     val pT = HOLogic.mk_prodT (rec_type, res_type);
 
     val maybe_suc = Option.map (fn x => x + 1);
     fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
       | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
       | subst d t =
-        let val (u, vs) = strip_comb t in
-          if free_name u = SOME fun_name then
+        let
+          val (u, vs) = strip_comb t;
+          val maybe_fun_name_ctr_pos =
+            find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
+          val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
+        in
+          if is_some maybe_fun_name_ctr_pos then
             if d = SOME ~1 andalso length vs = ctr_pos then
               list_comb (permute_args ctr_pos (snd_const pT), vs)
             else if length vs > ctr_pos andalso is_some d
@@ -136,42 +138,40 @@
     subst (SOME ~1)
   end;
 
-(* FIXME get rid of funs_data or get_indices *)
-fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
+fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
   let
-    val contains_fun = not o null o get_indices;
     fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
-      | subst bound_Ts (t as g $ y) =
+      | subst bound_Ts (t as g' $ y) =
         let
-          val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data;
           val maybe_direct_y' = AList.lookup (op =) direct_calls y;
           val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
-          val (g_head, g_args) = strip_comb g;
+          val (g, g_args) = strip_comb g';
+          val maybe_ctr_pos =
+            try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
+          val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
+            primrec_error_eqn "too few arguments in recursive call" t;
         in
-          if not is_ctr_arg then
-            pairself (subst bound_Ts) (g, y) |> (op $)
-          else if contains_fun g_head then
-            (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ =>
-              if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*)
-                orelse primrec_error_eqn "too few arguments in recursive call" t;
-            list_comb (the maybe_direct_y', g_args))
+          if not (member (op =) ctr_args y) then
+            pairself (subst bound_Ts) (g', y) |> (op $)
+          else if is_some maybe_ctr_pos then
+            list_comb (the maybe_direct_y', g_args)
           else if is_some maybe_indirect_y' then
-            (if contains_fun g then t else y)
-            |> massage_indirect_rec_call lthy contains_fun
-              (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y')
-            |> (if contains_fun g then I else curry (op $) g)
+            (if has_call g' then t else y)
+            |> massage_indirect_rec_call lthy has_call
+              (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
+            |> (if has_call g' then I else curry (op $) g')
           else
             t
         end
       | subst _ t = t
   in
     subst [] t
-    |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *)
-      primrec_error_eqn "recursive call not directly applied to constructor argument" t); u))
+    |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *)
+      primrec_error_eqn "recursive call not directly applied to constructor argument" t)
   end;
 
-fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data =
-  if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else
+fun build_rec_arg lthy funs_data has_call ctr_spec maybe_eqn_data =
+  if is_none maybe_eqn_data then undef_const else
     let
       val eqn_data = the maybe_eqn_data;
       val t = #rhs_term eqn_data;
@@ -215,13 +215,15 @@
       val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
 
       val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
+      val fun_name_ctr_pos_list =
+        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
     in
       t
-      |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls
+      |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
       |> fold_rev absfree abstractions
     end;
 
-fun build_defs lthy bs mxs funs_data rec_specs get_indices =
+fun build_defs lthy bs mxs funs_data rec_specs has_call =
   let
     val n_funs = length funs_data;
 
@@ -239,7 +241,7 @@
     val recs = take n_funs rec_specs |> map #recx;
     val rec_args = ctr_spec_eqn_data_list
       |> sort ((op <) o pairself (#offset o fst) |> make_ord)
-      |> map (uncurry (build_rec_arg lthy get_indices funs_data) o apsnd (try the_single));
+      |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
     val ctr_poss = map (fn x =>
       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
         primrec_error ("inconstant constructor pattern position for function " ^
@@ -253,7 +255,7 @@
     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   end;
 
-fun find_rec_calls get_indices eqn_data =
+fun find_rec_calls has_call eqn_data =
   let
     fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
       | find (t as _ $ _) ctr_arg =
@@ -265,7 +267,7 @@
             find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
           else
             let val (f, args) = chop n args' |>> curry list_comb f' in
-              if exists_subterm (not o null o get_indices) f then
+              if has_call f then
                 f :: maps (fn x => find x ctr_arg) args
               else
                 find f ctr_arg @ maps (fn x => find x ctr_arg) args
@@ -288,16 +290,16 @@
       |> map (fn (x, y) => the_single y handle List.Empty =>
           primrec_error ("missing equations for function " ^ quote x));
 
-    fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
-      |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
-      |> map_filter I;
-
+    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val arg_Ts = map (#rec_type o hd) funs_data;
     val res_Ts = map (#res_type o hd) funs_data;
     val callssss = funs_data
       |> map (partition_eq ((op =) o pairself #ctr))
-      |> map (maps (map_filter (find_rec_calls get_indices)));
+      |> map (maps (map_filter (find_rec_calls has_call)));
 
+    fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
+      |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
+      |> map_filter I;
     val ((nontriv, rec_specs, _, induct_thm, induct_thms), lthy') =
       rec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
 
@@ -308,7 +310,7 @@
         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
           " is not a constructor in left-hand side") user_eqn) eqns_data end;
 
-    val defs = build_defs lthy' bs mxs funs_data rec_specs get_indices;
+    val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
 
     fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
         lthy =
@@ -561,7 +563,7 @@
         0 upto length ctr_specs - 1
         |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
           |> Option.map #cond
-          |> the_default (Const (@{const_name undefined}, dummyT)))
+          |> the_default undef_const)
         |> fst o split_last;
   in
     (* FIXME: deal with #preds above *)
@@ -633,7 +635,7 @@
     val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
     val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @
       map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0;
-    val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT))
+    val corec_args = replicate n_args undef_const
       |> fold2 build_corec_args_discs disc_eqnss ctr_specss
       |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;