handle direct corecursion
authorpanny
Mon, 02 Sep 2013 15:13:00 +0200
changeset 53360 7ffc4a746a73
parent 53359 ef65d5ee60cf
child 53363 f6629734dd2b
child 53364 a4fff0c0599c
handle direct corecursion
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Mon Sep 02 11:03:02 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Mon Sep 02 15:13:00 2013 +0200
@@ -499,7 +499,7 @@
     val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
 
-val _ = warning ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
+val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
  (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
 
@@ -570,45 +570,57 @@
     fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
   end;
 
-fun build_corec_args_sel all_sel_eqns ctr_spec =
+fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
+  |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn))
+  |> the_default ([], undef_const)
+  |-> abs_tuple oo fold_rev absfree;
+
+fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
+  let
+    val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns
+
+    fun rewrite U T t =
+      if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
+      else if T = U = has_call t then undef_const
+      else if T = U then t (* end *)
+      else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
+    fun massage rhs_term t =
+      massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term;
+    val abstract = abs_tuple oo fold_rev absfree o map dest_Free;
+  in
+    if is_none maybe_sel_eqn then I else
+      massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn))
+  end;
+
+fun build_corec_arg_indirect_call sel_eqns sel =
+  primrec_error "indirect corecursion not implemented yet";
+
+fun build_corec_args_sel lthy has_call all_sel_eqns ctr_spec =
   let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
     if null sel_eqns then I else
       let
         val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
 
-val _ = warning ("sels / calls:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map ((op ^) o
+val _ = tracing ("sels / calls:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map ((op ^) o
  apfst (Syntax.string_of_term @{context}) o apsnd (curry (op ^) " / " o @{make_string}))
   (sel_call_list)));
 
-        (* FIXME get rid of dummy_no_calls' *)
-        val dummy_no_calls' = map_filter (try (apsnd (fn Dummy_No_Corec n => n))) sel_call_list;
         val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
         val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
         val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
-
-        fun build_arg_no_call sel = find_first (equal sel o #sel) sel_eqns |> #rhs_term o the;
-        fun build_arg_direct_call sel = primrec_error "not implemented yet";
-        fun build_arg_indirect_call sel = primrec_error "not implemented yet";
-
-        val update_args = I
-          #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
-            (build_arg_no_call sel |> K)) no_calls'
-          #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
-            (build_arg_indirect_call sel |> K)) indirect_calls'
-          #> fold (fn (sel, (q_idx, g_idx, h_idx)) =>
-            let val (q, g, h) = build_arg_indirect_call sel in
-              nth_map q_idx (K q) o nth_map g_idx (K g) o nth_map h_idx (K h) end) direct_calls';
-  
-        val arg_idxs = maps (fn (_, (x, y, z)) => [x, y, z]) direct_calls' @
-            maps (map snd) [dummy_no_calls', no_calls', indirect_calls'];
-        val abs_args = fold (fn idx => nth_map idx
-          (abs_tuple o fold_rev absfree (sel_eqns |> #fun_args o hd |> map dest_Free))) arg_idxs;
       in
-        abs_args o update_args
+        I
+        #> fold (fn (sel, n) => nth_map n
+          (build_corec_arg_no_call sel_eqns sel |> K)) no_calls'
+        #> fold (fn (sel, (q, g, h)) =>
+          let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
+            nth_map h f o nth_map g f o nth_map q f end) direct_calls'
+        #> fold (fn (sel, n) => nth_map n
+          (build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls'
       end
   end;
 
-fun co_build_defs lthy sequential bs mxs arg_Tss fun_name_corec_spec_list eqns_data =
+fun co_build_defs lthy sequential bs mxs has_call arg_Tss fun_name_corec_spec_list eqns_data =
   let
     val fun_names = map Binding.name_of bs;
 
@@ -622,24 +634,24 @@
         primrec_error_eqns "excess discriminator equations in definition"
           (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
 
-val _ = warning ("disc_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} disc_eqnss));
+(*val _ = tracing ("disc_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} disc_eqnss));*)
 
     val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
       |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
       |> map (flat o snd);
 
-val _ = warning ("sel_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} sel_eqnss));
+(*val _ = tracing ("sel_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} sel_eqnss));*)
 
     val corecs = map (#corec o snd) fun_name_corec_spec_list;
     val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
-    val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @
-      map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0;
-    val corec_args = replicate n_args undef_const
+    val corec_args = hd corecs
+      |> fst o split_last o binder_types o fastype_of
+      |> map (Const o pair @{const_name undefined})
       |> fold2 build_corec_args_discs disc_eqnss ctr_specss
-      |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;
+      |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
 
-val _ = warning ("corecursor arguments:\n    \<cdot> " ^
+val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));
 
     fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
@@ -685,16 +697,17 @@
     val fun_names = map Binding.name_of bs;
 
     val fun_name_corec_spec_list = (fun_names ~~ res_Ts, corec_specs)
-      |> uncurry (finds (fn ((v, T), {corec, ...}) => T = body_type (fastype_of corec))) |> fst
+      |> uncurry (finds (fn ((_, T), {corec, ...}) => T = body_type (fastype_of corec))) |> fst
       |> map (apfst fst #> apsnd the_single); (*###*)
 
     val (eqns_data, _) =
       fold_map (co_dissect_eqn sequential fun_name_corec_spec_list) (map snd specs) []
       |>> flat;
 
+    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
+    val arg_Tss = map (binder_types o snd o fst) fixes;
     val (defs, proof_obligations) =
-      co_build_defs lthy' sequential bs mxs (map (binder_types o snd o fst) fixes)
-        fun_name_corec_spec_list eqns_data;
+      co_build_defs lthy' sequential bs mxs has_call arg_Tss fun_name_corec_spec_list eqns_data;
   in
     lthy'
     |> fold_map Local_Theory.define defs |> snd