various refactoring;
authorpanny
Wed, 04 Sep 2013 02:11:50 +0200
changeset 53401 2101a97e6220
parent 53395 a1a78a271682
child 53402 50cc036f1522
various refactoring; handle self-mappings; handle range types containing function types;
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 03 21:46:42 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 04 02:11:50 2013 +0200
@@ -36,8 +36,7 @@
 
 fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
   |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
-fun abs_tuple t = if t = undef_const then t else
-  strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
+val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
 
 val simp_attrs = @{attributes [simp]};
 
@@ -107,7 +106,7 @@
      user_eqn = eqn'}
   end;
 
-fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
+fun rewrite_map_arg get_ctr_pos rec_type res_type =
   let
     val pT = HOLogic.mk_prodT (rec_type, res_type);
 
@@ -117,11 +116,9 @@
       | subst d t =
         let
           val (u, vs) = strip_comb t;
-          val maybe_fun_name_ctr_pos =
-            find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
-          val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
+          val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
         in
-          if is_some maybe_fun_name_ctr_pos then
+          if ctr_pos >= 0 then
             if d = SOME ~1 andalso length vs = ctr_pos then
               list_comb (permute_args ctr_pos (snd_const pT), vs)
             else if length vs > ctr_pos andalso is_some d
@@ -138,7 +135,7 @@
     subst (SOME ~1)
   end;
 
-fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
+fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
   let
     fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
       | subst bound_Ts (t as g' $ y) =
@@ -146,19 +143,18 @@
           val maybe_direct_y' = AList.lookup (op =) direct_calls y;
           val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
           val (g, g_args) = strip_comb g';
-          val maybe_ctr_pos =
-            try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
-          val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
+          val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
+          val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
             primrec_error_eqn "too few arguments in recursive call" t;
         in
           if not (member (op =) ctr_args y) then
             pairself (subst bound_Ts) (g', y) |> (op $)
-          else if is_some maybe_ctr_pos then
+          else if ctr_pos >= 0 then
             list_comb (the maybe_direct_y', g_args)
           else if is_some maybe_indirect_y' then
             (if has_call g' then t else y)
             |> massage_indirect_rec_call lthy has_call
-              (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
+              (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
             |> (if has_call g' then I else curry (op $) g')
           else
             t
@@ -211,16 +207,17 @@
             nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
           indirect_calls';
 
+      val fun_name_ctr_pos_list =
+        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
+      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
       val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
       val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
 
-      val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
-      val fun_name_ctr_pos_list =
-        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
+      val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
     in
       t
-      |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
-      |> fold_rev absfree abstractions
+      |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
+      |> fold_rev lambda abstractions
     end;
 
 fun build_defs lthy bs mxs funs_data rec_specs has_call =
@@ -372,15 +369,16 @@
 
 type co_eqn_data_disc = {
   fun_name: string,
+  fun_args: term list,
   ctr_no: int, (*###*)
   cond: term,
   user_eqn: term
 };
 type co_eqn_data_sel = {
   fun_name: string,
+  fun_args: term list,
   ctr: term,
   sel: term,
-  fun_args: term list,
   rhs_term: term,
   user_eqn: term
 };
@@ -388,11 +386,10 @@
   Disc of co_eqn_data_disc |
   Sel of co_eqn_data_sel;
 
-fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
+fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
   let
     fun find_subterm p = let (* FIXME \<exists>? *)
-      fun f (t as u $ v) =
-        fold_rev (curry merge_options) [if p t then SOME t else NONE, f u, f v] NONE
+      fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
         | f t = if p t then SOME t else NONE
       in f end;
 
@@ -406,9 +403,8 @@
 
     val discs = #ctr_specs corec_spec |> map #disc;
     val ctrs = #ctr_specs corec_spec |> map #ctr;
-    val n_ctrs = length ctrs;
     val not_disc = head_of imp_rhs = @{term Not};
-    val _ = not_disc andalso n_ctrs <> 2 andalso
+    val _ = not_disc andalso length ctrs <> 2 andalso
       primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" imp_rhs;
     val disc = find_subterm (member (op =) discs o head_of) imp_rhs;
     val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
@@ -428,32 +424,28 @@
     val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
     val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
     val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
-    val matched_conds = filter (equal fun_name o fst) matched_conds_ps |> map snd;
-    val imp_lhs = mk_conjs imp_lhs';
+    val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs;
+    val imp_lhs = mk_conjs imp_lhs'
+      |> incr_boundvars (length fun_args)
+      |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0))
     val cond =
       if catch_all then
-        if null matched_conds then fold_rev absfree (map dest_Free fun_args) @{const True} else
-          (strip_abs_vars (hd matched_conds),
-            mk_disjs (map strip_abs_body matched_conds) |> HOLogic.mk_not)
-          |-> fold_rev (fn (v, T) => fn u => Abs (v, T, u))
+        matched_cond |> HOLogic.mk_not
       else if sequential then
-        HOLogic.mk_conj (HOLogic.mk_not (mk_disjs (map strip_abs_body matched_conds)), imp_lhs)
-        |> fold_rev absfree (map dest_Free fun_args)
+        HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs)
       else
-        imp_lhs |> fold_rev absfree (map dest_Free fun_args);
-    val matched_cond =
-      if sequential then fold_rev absfree (map dest_Free fun_args) imp_lhs else cond;
+        imp_lhs;
 
-    val matched_conds_ps' = if catch_all
-      then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
-      else (fun_name, matched_cond) :: matched_conds_ps;
+    val matched_conds' =
+      (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds;
   in
     (Disc {
       fun_name = fun_name,
+      fun_args = fun_args,
       ctr_no = ctr_no,
       cond = cond,
       user_eqn = eqn'
-    }, matched_conds_ps')
+    }, matched_conds')
   end;
 
 fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn =
@@ -473,15 +465,15 @@
   in
     Sel {
       fun_name = fun_name,
+      fun_args = fun_args,
       ctr = #ctr ctr_spec,
       sel = sel,
-      fun_args = fun_args,
       rhs_term = rhs,
       user_eqn = eqn'
     }
   end;
 
-fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
+fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
   let 
     val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
     val fun_name = head_of lhs |> fst o dest_Free;
@@ -491,10 +483,10 @@
       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
 
     val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
-    val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
-      then (NONE, matched_conds_ps)
+    val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1
+      then (NONE, matched_conds)
       else apfst SOME (co_dissect_eqn_disc
-          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
+          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds);
 
     val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
@@ -506,10 +498,10 @@
     val eqns_data_sel =
       map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
   in
-    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
+    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds')
   end;
 
-fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
+fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds =
   let
     val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
         strip_qnt_body @{const_name all} eqn')
@@ -531,65 +523,68 @@
     if member (op =) discs head orelse
       is_some maybe_rhs andalso
         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
-      co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
+      co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
       |>> single
     else if member (op =) sels head then
-      ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds_ps)
+      ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds)
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
-      co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
+      co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
     else
       primrec_error_eqn "malformed function equation" eqn
   end;
 
 fun build_corec_args_discs disc_eqns ctr_specs =
-  let
-    val conds = map #cond disc_eqns;
-    val args' =
-      if length ctr_specs = 1 then []
-      else if length disc_eqns = length ctr_specs then
-        fst (split_last conds)
-      else if length disc_eqns = length ctr_specs - 1 then
-        let val n = 0 upto length ctr_specs - 1
-            |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
-          if n = length ctr_specs - 1 then
-            conds
-          else
-            split_last conds
-            ||> (fn t => fold_rev absfree (strip_abs_vars t) (strip_abs_body t |> HOLogic.mk_not))
-            |>> chop n
-            |> (fn ((l, r), x) => l @ (x :: r))
-        end
-      else
-        0 upto length ctr_specs - 1
-        |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
-          |> Option.map #cond
-          |> the_default undef_const)
-        |> fst o split_last;
-  in
-    (* FIXME: deal with #preds above *)
-    fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
-  end;
+  if null disc_eqns then I else
+    let
+      val conds = map #cond disc_eqns;
+      val fun_args = #fun_args (hd disc_eqns);
+      val args =
+        if length ctr_specs = 1 then []
+        else if length disc_eqns = length ctr_specs then
+          fst (split_last conds)
+        else if length disc_eqns = length ctr_specs - 1 then
+          let val n = 0 upto length ctr_specs - 1
+              |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
+            if n = length ctr_specs - 1 then
+              conds
+            else
+              split_last conds
+              ||> HOLogic.mk_not
+              |>> chop n
+              |> (fn ((l, r), x) => l @ (x :: r))
+          end
+        else
+          0 upto length ctr_specs - 1
+          |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
+            |> Option.map #cond
+            |> the_default undef_const)
+          |> fst o split_last;
+    in
+      (* FIXME deal with #preds above *)
+      (map_filter #pred ctr_specs, args)
+      |-> fold2 (fn idx => fn t => nth_map idx
+        (K (subst_bounds (List.rev fun_args, t)
+          |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args))))
+    end;
 
 fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
-  |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn))
+  |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn))
   |> the_default ([], undef_const)
-  |-> abs_tuple oo fold_rev absfree;
+  |-> abs_tuple;
 
 fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
   let
     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns
-
-    fun rewrite U T t =
+    fun rewrite is_end U T t =
       if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
-      else if T = U = has_call t then undef_const
-      else if T = U then t (* end *)
+      else if is_end = has_call t then undef_const
+      else if is_end then t (* end *)
       else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
-    fun massage rhs_term t =
-      massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term;
-    val abstract = abs_tuple oo fold_rev absfree o map dest_Free;
+    fun massage rhs_term is_end t = massage_direct_corec_call
+      lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term;
   in
-    if is_none maybe_sel_eqn then I else
-      massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn))
+    if is_none maybe_sel_eqn then K I else
+      abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
   end;
 
 fun build_corec_arg_indirect_call sel_eqns sel =
@@ -614,7 +609,7 @@
           (build_corec_arg_no_call sel_eqns sel |> K)) no_calls'
         #> fold (fn (sel, (q, g, h)) =>
           let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
-            nth_map h f o nth_map g f o nth_map q f end) direct_calls'
+            nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
         #> fold (fn (sel, n) => nth_map n
           (build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls'
       end
@@ -651,24 +646,25 @@
       |> fold2 build_corec_args_discs disc_eqnss ctr_specss
       |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
 
+    fun currys Ts t = if length Ts <= 1 then t else
+      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
+        (length Ts - 1 downto 0 |> map Bound)
+      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
+
 val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));
 
     fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
       |> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys));
     val proof_obligations = if sequential then [] else
-      maps (uneq_pairs_rev o map #cond) disc_eqnss
-      |> map (fn (x, y) => ((strip_abs_body x, strip_abs_body y), strip_abs_vars x))
-      |> map (apfst (apsnd HOLogic.mk_not #> pairself HOLogic.mk_Trueprop
-        #> apfst (curry (op $) @{const ==>}) #> (op $)))
-      |> map (fn (t, abs_vars) => fold_rev (fn (v, T) => fn u =>
-          Const (@{const_name all}, (T --> @{typ prop}) --> @{typ prop}) $
-            Abs (v, T, u)) abs_vars t);
+      maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond))) disc_eqnss
+      |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y]
+        |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args))
+        |> curry list_comb @{const ==>});
 
-    fun currys Ts t = if length Ts <= 1 then t else
-      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
-        (length Ts - 1 downto 0 |> map Bound)
-      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
+val _ = tracing ("proof obligations:\n    \<cdot> " ^
+ space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) proof_obligations));
+
   in
     map (list_comb o rpair corec_args) corecs
     |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss