src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53341 63015d035301
parent 53335 585b2fee55e5
child 53350 17632ef6cfe8
child 53352 43a1cc050943
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 00:40:21 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 18:18:33 2013 +0200
@@ -29,6 +29,8 @@
 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
 
 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
+fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else
+  strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
 
 val simp_attrs = @{attributes [simp]};
 
@@ -398,23 +400,23 @@
 
 (* Primcorec *)
 
-type co_eqn_data_dtr_disc = {
+type co_eqn_data_disc = {
   fun_name: string,
-  ctr_no: int,
+  ctr_no: int, (*###*)
   cond: term,
   user_eqn: term
 };
-type co_eqn_data_dtr_sel = {
+type co_eqn_data_sel = {
   fun_name: string,
-  ctr_no: int,
-  sel_no: int,
+  ctr: term,
+  sel: term,
   fun_args: term list,
   rhs_term: term,
   user_eqn: term
 };
 datatype co_eqn_data =
-  Dtr_Disc of co_eqn_data_dtr_disc |
-  Dtr_Sel of co_eqn_data_dtr_sel
+  Disc of co_eqn_data_disc |
+  Sel of co_eqn_data_sel;
 
 fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
   let
@@ -476,7 +478,7 @@
       then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
       else (fun_name, matched_cond) :: matched_conds_ps;
   in
-    (Dtr_Disc {
+    (Disc {
       fun_name = fun_name,
       ctr_no = ctr_no,
       cond = cond,
@@ -495,15 +497,14 @@
         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
     val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name)
       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
-    val ((ctr_spec, ctr_no), sel) = #ctr_specs corec_spec
+    val (ctr_spec, sel) = #ctr_specs corec_spec
       |> the o get_index (try (the o find_first (equal sel) o #sels))
-      |>> `(nth (#ctr_specs corec_spec));
-    val sel_no = find_index (equal sel) (#sels ctr_spec);
+      |>> nth (#ctr_specs corec_spec);
   in
-    Dtr_Sel {
+    Sel {
       fun_name = fun_name,
-      ctr_no = ctr_no,
-      sel_no = sel_no,
+      ctr = #ctr ctr_spec,
+      sel = sel,
       fun_args = fun_args,
       rhs_term = rhs,
       user_eqn = eqn'
@@ -518,21 +519,24 @@
     val (ctr, ctr_args) = strip_comb rhs;
     val ctr_spec = the (find_first (equal ctr o #ctr) (#ctr_specs corec_spec))
       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
+
     val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
-    val (eqn_data_disc, matched_conds_ps') = co_dissect_eqn_disc
-        sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps;
+    val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
+      then (NONE, matched_conds_ps)
+      else apfst SOME (co_dissect_eqn_disc
+          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
 
     val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
 
 val _ = warning ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
- Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> " ^
+ (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
 
     val eqns_data_sel =
-      map (co_dissect_eqn_sel fun_name_corec_spec_list @{const True}(*###*)) sel_imp_rhss;
+      map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
   in
-    (eqn_data_disc :: eqns_data_sel, matched_conds_ps')
+    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
   end;
 
 fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
@@ -540,9 +544,8 @@
     val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
         strip_qnt_body @{const_name all} eqn')
         handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
-    val (imp_lhs', imp_rhs) =
-      (map HOLogic.dest_Trueprop (Logic.strip_imp_prems eqn),
-       HOLogic.dest_Trueprop (Logic.strip_imp_concl eqn));
+    val (imp_lhs', imp_rhs) = Logic.strip_horn eqn
+      |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
 
     val head = imp_rhs
       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
@@ -568,10 +571,10 @@
       primrec_error_eqn "malformed function equation" eqn
   end;
 
-fun build_corec_args_discs ctr_specs disc_eqns =
+fun build_corec_args_discs disc_eqns ctr_specs =
   let
     val conds = map #cond disc_eqns;
-    val args =
+    val args' =
       if length ctr_specs = 1 then []
       else if length disc_eqns = length ctr_specs then
         fst (split_last conds)
@@ -592,33 +595,54 @@
           |> Option.map #cond
           |> the_default (Const (@{const_name undefined}, dummyT)))
         |> fst o split_last;
-    fun finish t =
-      let val n = length (fastype_of t |> binder_types) in
-        if t = Const (@{const_name undefined}, dummyT) then t
-        else if n = 0 then Abs (Name.uu_, @{typ unit}, t)
-        else if n = 1 then t
-        else Const (@{const_name prod_case}, dummyT) $ t
-      end;
   in
-    map finish args
+    (* FIXME: deal with #preds above *)
+    fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
   end;
 
-fun build_corec_args_sel sel_eqns ctr_spec =
-  let
-    (* FIXME *)
-    val n_args = fold (curry (op +)) (map (fn Direct_Corec _ => 3 | _ => 1) (#calls ctr_spec)) 0;
-  in
-    replicate n_args (Const (@{const_name undefined}, dummyT))
+fun build_corec_args_sel all_sel_eqns ctr_spec =
+  let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
+    if null sel_eqns then I else
+      let
+        val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
+
+val _ = warning ("sels / calls:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map ((op ^) o
+ apfst (Syntax.string_of_term @{context}) o apsnd (curry (op ^) " / " o @{make_string}))
+  (sel_call_list)));
+
+        (* FIXME get rid of dummy_no_calls' *)
+        val dummy_no_calls' = map_filter (try (apsnd (fn Dummy_No_Corec n => n))) sel_call_list;
+        val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
+        val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
+        val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
+
+        fun build_arg_no_call sel = find_first (equal sel o #sel) sel_eqns |> #rhs_term o the;
+        fun build_arg_direct_call sel = primrec_error "not implemented yet";
+        fun build_arg_indirect_call sel = primrec_error "not implemented yet";
+
+        val update_args = I
+          #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
+            (build_arg_no_call sel |> K)) no_calls'
+          #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
+            (build_arg_indirect_call sel |> K)) indirect_calls'
+          #> fold (fn (sel, (q_idx, g_idx, h_idx)) =>
+            let val (q, g, h) = build_arg_indirect_call sel in
+              nth_map q_idx (K q) o nth_map g_idx (K g) o nth_map h_idx (K h) end) direct_calls';
+  
+        val arg_idxs = maps (fn (_, (x, y, z)) => [x, y, z]) direct_calls' @
+            maps (map snd) [dummy_no_calls', no_calls', indirect_calls'];
+        val abs_args = fold (fn idx => nth_map idx
+          (abs_tuple o fold_rev absfree (sel_eqns |> #fun_args o hd |> map dest_Free))) arg_idxs;
+      in
+        abs_args o update_args
+      end
   end;
 
 fun co_build_defs lthy sequential bs arg_Tss fun_name_corec_spec_list eqns_data =
   let
     val fun_names = map Binding.name_of bs;
 
-(*    fun group _ [] = [] (* FIXME \<exists>? *)
-      | group eq (x :: xs) =
-        let val (xs', ys) = List.partition (eq x) xs in (x :: xs') :: group eq ys end;*)
-    val disc_eqnss = map_filter (try (fn Dtr_Disc x => x)) eqns_data
+    val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
       |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
       |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
@@ -630,20 +654,20 @@
 
 val _ = warning ("disc_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} disc_eqnss));
 
-    val sel_eqnss = map_filter (try (fn Dtr_Sel x => x)) eqns_data
+    val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
       |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
-      |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
+      |> map (flat o snd);
 
 val _ = warning ("sel_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} sel_eqnss));
 
-    fun splice (xs :: xss) (ys :: yss) = xs @ ys @ splice xss yss (* FIXME \<exists>? *)
-      | splice xss yss = flat xss @ flat yss;
     val corecs = map (#corec o snd) fun_name_corec_spec_list;
-    val corec_args = (map snd fun_name_corec_spec_list ~~ disc_eqnss ~~ sel_eqnss)
-      |> maps (fn (({ctr_specs, ...}, disc_eqns), sel_eqns) =>
-        splice (build_corec_args_discs ctr_specs disc_eqns |> map single)
-          (map (build_corec_args_sel sel_eqns) ctr_specs));
+    val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
+    val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @
+      map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0;
+    val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT))
+      |> fold2 build_corec_args_discs disc_eqnss ctr_specss
+      |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;
 
 val _ = warning ("corecursor arguments:\n    \<cdot> " ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));