process code-style inputs
authorpanny
Fri, 04 Oct 2013 18:27:07 +0200
changeset 54065 e30e63d05e58
parent 54064 183cfce3f827
child 54066 4a7aa85b6b47
process code-style inputs
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 04 17:00:35 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 04 18:27:07 2013 +0200
@@ -34,10 +34,10 @@
 open BNF_FP_Rec_Sugar_Util
 open BNF_FP_Rec_Sugar_Tactics
 
-val codeN = "code"
-val ctrN = "ctr"
-val discN = "disc"
-val selN = "sel"
+val codeN = "code";
+val ctrN = "ctr";
+val discN = "disc";
+val selN = "sel";
 
 val nitpick_attrs = @{attributes [nitpick_simp]};
 val simp_attrs = @{attributes [simp]};
@@ -472,7 +472,7 @@
   Disc of co_eqn_data_disc |
   Sel of co_eqn_data_sel;
 
-fun co_dissect_eqn_disc sequential fun_names (corec_specs : corec_spec list) prems' concl
+fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) prems' concl
     matchedsss =
   let
     fun find_subterm p = let (* FIXME \<exists>? *)
@@ -507,11 +507,11 @@
     val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
     val prems = map (abstract (List.rev fun_args)) prems';
     val real_prems =
-      (if catch_all orelse sequential then maps negate_disj matchedss else []) @
+      (if catch_all orelse seq then maps negate_disj matchedss else []) @
       (if catch_all then [] else prems);
 
     val matchedsss' = AList.delete (op =) fun_name matchedsss
-      |> cons (fun_name, if sequential then matchedss @ [prems] else matchedss @ [real_prems]);
+      |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
 
     val user_eqn =
       (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
@@ -560,49 +560,72 @@
     }
   end;
 
-fun co_dissect_eqn_ctr sequential fun_names (corec_specs : corec_spec list) eqn' imp_prems imp_rhs
+fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' prems concl
     matchedsss =
   let
-    val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
+    val (lhs, rhs) = HOLogic.dest_eq concl;
     val fun_name = head_of lhs |> fst o dest_Free;
     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
     val (ctr, ctr_args) = strip_comb rhs;
     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
 
-    val disc_imp_rhs = betapply (disc, lhs);
+    val disc_concl = betapply (disc, lhs);
     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
       then (NONE, matchedsss)
       else apfst SOME (co_dissect_eqn_disc
-          sequential fun_names corec_specs imp_prems disc_imp_rhs matchedsss);
+          seq fun_names corec_specs prems disc_concl matchedsss);
 
-    val sel_imp_rhss = (sels ~~ ctr_args)
+    val sel_concls = (sels ~~ ctr_args)
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
 
 (*
-val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
- (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
- space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
+val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} concl ^ "\nto\n    \<cdot> " ^
+ (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n    \<cdot> ")) "" ^
+ space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls));
 *)
 
-    val eqns_data_sel =
-      map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_imp_rhss;
+    val eqns_data_sel = map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_concls;
   in
     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   end;
 
-fun co_dissect_eqn sequential fun_names (corec_specs : corec_spec list) eqn' of_spec matchedsss =
+fun co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss =
+  let
+    val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
+    val fun_name = head_of lhs |> fst o dest_Free;
+    val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
+
+    val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
+        if member ((op =) o apsnd #ctr) ctr_specs ctr
+        then cons (ctr, cs)
+        else primrec_error_eqn "not a constructor" ctr) [] rhs' []
+      |> AList.group (op =);
+
+    val ctr_premss = map (single o mk_disjs o map mk_conjs o snd) cond_ctrs;
+    val ctr_concls = cond_ctrs |> map (fn (ctr, _) =>
+        binder_types (fastype_of ctr)
+        |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
+          if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
+        |> curry list_comb ctr
+        |> curry HOLogic.mk_eq lhs);
+  in
+    fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn') ctr_premss ctr_concls matchedsss
+  end;
+
+fun co_dissect_eqn lthy has_call seq fun_names (corec_specs : corec_spec list) eqn' of_spec
+    matchedsss =
   let
     val eqn = drop_All eqn'
       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
-    val (imp_prems, imp_rhs) = Logic.strip_horn eqn
+    val (prems, concl) = Logic.strip_horn eqn
       |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
 
-    val head = imp_rhs
+    val head = concl
       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
       |> head_of;
 
-    val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
+    val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
 
     val discs = maps #ctr_specs corec_specs |> map #disc;
     val sels = maps #ctr_specs corec_specs |> maps #sels;
@@ -611,12 +634,17 @@
     if member (op =) discs head orelse
       is_some maybe_rhs andalso
         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
-      co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
+      co_dissect_eqn_disc seq fun_names corec_specs prems concl matchedsss
       |>> single
     else if member (op =) sels head then
-      ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec imp_rhs], matchedsss)
-    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
-      co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
+      ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec concl], matchedsss)
+    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
+      member (op =) ctrs (head_of (the maybe_rhs)) then
+      co_dissect_eqn_ctr seq fun_names corec_specs eqn' prems concl matchedsss
+    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
+      null prems then
+      co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss
+      |>> flat
     else
       primrec_error_eqn "malformed function equation" eqn
   end;
@@ -646,7 +674,7 @@
       fun rewrite_g _ t = if has_call t then undef_const else t;
       fun rewrite_h bound_Ts t =
         if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
-      fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
+      fun massage f _ = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
     in
       (massage rewrite_q,
        massage rewrite_g,
@@ -763,7 +791,7 @@
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
     end;
 
-fun add_primcorec simple sequential fixes specs of_specs lthy =
+fun add_primcorec simple seq fixes specs of_specs lthy =
   let
     val (bs, mxs) = map_split (apfst fst) fixes;
     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
@@ -778,8 +806,10 @@
     val fun_names = map Binding.name_of bs;
     val corec_specs = take actual_nn corec_specs'; (*###*)
 
+    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val eqns_data =
-      fold_map2 (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) of_specs []
+      fold_map2 (co_dissect_eqn lthy has_call seq fun_names corec_specs)
+        (map snd specs) of_specs []
       |> flat o fst;
 
     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
@@ -796,14 +826,13 @@
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       |> map (flat o snd);
 
-    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val arg_Tss = map (binder_types o snd o fst) fixes;
     val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
     val (defs, exclss') =
       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
 
     fun excl_tac (c, c', a) =
-      if a orelse c = c' orelse sequential then
+      if a orelse c = c' orelse seq then
         SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
       else if simple then
         SOME (K (auto_tac lthy))