more robust n2m w.r.t. 'let's
authorblanchet
Mon, 04 Nov 2013 14:46:38 +0100
changeset 54243 a596292be9a8
parent 54242 99ef8036fb3d
child 54244 0753e8866ac8
more robust n2m w.r.t. 'let's
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon Nov 04 12:40:28 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon Nov 04 14:46:38 2013 +0100
@@ -39,7 +39,6 @@
     'a list
   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
-  val dest_map: Proof.context -> string -> term -> term * term list
 
   type lfp_sugar_thms =
     (thm list * thm * Args.src list)
@@ -285,27 +284,6 @@
   | unzip_corecT _ (Type (@{type_name sum}, Ts)) = Ts
   | unzip_corecT _ T = [T];
 
-val dummy_var_name = "?f"
-
-fun mk_map_pattern ctxt s =
-  let
-    val bnf = the (bnf_of ctxt s);
-    val mapx = map_of_bnf bnf;
-    val live = live_of_bnf bnf;
-    val (f_Ts, _) = strip_typeN live (fastype_of mapx);
-    val fs = map_index (fn (i, T) => Var ((dummy_var_name, i), T)) f_Ts;
-  in
-    (mapx, betapplys (mapx, fs))
-  end;
-
-fun dest_map ctxt s call =
-  let
-    val (map0, pat) = mk_map_pattern ctxt s;
-    val (_, tenv) = fo_match ctxt call pat;
-  in
-    (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
-  end;
-
 fun liveness_of_fp_bnf n bnf =
   (case T_of_bnf bnf of
     Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Mon Nov 04 12:40:28 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Mon Nov 04 14:46:38 2013 +0100
@@ -7,6 +7,9 @@
 
 signature BNF_FP_N2M_SUGAR =
 sig
+  val unfold_let: term -> term
+  val dest_map: Proof.context -> string -> term -> term * term list
+
   val mutualize_fp_sugars: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
     (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
     local_theory ->
@@ -34,24 +37,61 @@
 
 val n2mN = "n2m_"
 
-fun dest_applied_map_or_ctr ctxt s (t as t1 $ _) =
-  (case try (dest_map ctxt s) t1 of
-    SOME res => res
-  | NONE =>
-    let
-      val thy = Proof_Context.theory_of ctxt;
-      val map_thms = of_fp_sugar #mapss (the (fp_sugar_of ctxt s))
-      val map_thms' = map (fn thm => thm RS sym RS eq_reflection) map_thms;
-      val t' = Raw_Simplifier.rewrite_term thy map_thms' [] t;
-    in
-      if t aconv t' then raise Fail "dest_applied_map_or_ctr"
-      else dest_map ctxt s (fst (dest_comb t'))
-    end);
+fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
+  | unfold_let (Const (@{const_name prod_case}, _) $ t) =
+    (case unfold_let t of
+      t' as Abs (s1, T1, Abs (s2, T2, _)) =>
+      let
+        val x = (s1 ^ s2, Term.maxidx_of_term t + 1);
+        val v = Var (x, HOLogic.mk_prodT (T1, T2));
+      in
+        lambda v (unfold_let (betapplys (t', [HOLogic.mk_fst v, HOLogic.mk_snd v])))
+      end
+    | _ => t)
+  | unfold_let (t $ u) = betapply (unfold_let t, unfold_let u)
+  | unfold_let (Abs (s, T, t)) = Abs (s, T, unfold_let t)
+  | unfold_let t = t;
+
+val dummy_var_name = "?f"
+
+fun mk_map_pattern ctxt s =
+  let
+    val bnf = the (bnf_of ctxt s);
+    val mapx = map_of_bnf bnf;
+    val live = live_of_bnf bnf;
+    val (f_Ts, _) = strip_typeN live (fastype_of mapx);
+    val fs = map_index (fn (i, T) => Var ((dummy_var_name, i), T)) f_Ts;
+  in
+    (mapx, betapplys (mapx, fs))
+  end;
+
+fun dest_map ctxt s call =
+  let
+    val (map0, pat) = mk_map_pattern ctxt s;
+    val (_, tenv) = fo_match ctxt call pat;
+  in
+    (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
+  end;
+
+fun dest_abs_or_applied_map_or_ctr _ _ (Abs (_, _, t)) = (Term.dummy, [t])
+  | dest_abs_or_applied_map_or_ctr ctxt s (t as t1 $ _) =
+    (case try (dest_map ctxt s) t1 of
+      SOME res => res
+    | NONE =>
+      let
+        val thy = Proof_Context.theory_of ctxt;
+        val map_thms = of_fp_sugar #mapss (the (fp_sugar_of ctxt s))
+        val map_thms' = map (fn thm => thm RS sym RS eq_reflection) map_thms;
+        val t' = Raw_Simplifier.rewrite_term thy map_thms' [] t;
+      in
+        if t aconv t' then raise Fail "dest_applied_map_or_ctr"
+        else dest_map ctxt s (fst (dest_comb t'))
+      end);
 
 (* TODO: test with sort constraints on As *)
 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
    as deads? *)
-fun mutualize_fp_sugars has_nested fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
+fun mutualize_fp_sugars has_nested fp bs fpTs _ callssss fp_sugars0 no_defs_lthy0 =
   if has_nested orelse has_duplicates (op =) fpTs then
     let
       val thy = Proof_Context.theory_of no_defs_lthy0;
@@ -99,7 +139,7 @@
       and freeze_fp calls (T as Type (s, Ts)) =
           (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of
             [] =>
-            (case map_filter (try (snd o dest_applied_map_or_ctr no_defs_lthy s)) calls of
+            (case map_filter (try (snd o dest_abs_or_applied_map_or_ctr no_defs_lthy s)) calls of
               [] => freeze_fp_default T
             | callss => freeze_fp_map callss s Ts)
           | callss => freeze_fp_map callss s Ts)
@@ -192,7 +232,7 @@
     fun do_ctr ctr =
       (case AList.lookup Term.aconv_untyped callsss ctr of
         NONE => replicate (num_binder_types (fastype_of ctr)) []
-      | SOME callss => map (map Envir.beta_eta_contract) callss);
+      | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   in
     map do_ctr ctrs
   end;
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Mon Nov 04 12:40:28 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Mon Nov 04 14:46:38 2013 +0100
@@ -31,6 +31,7 @@
 
 open BNF_Util
 open BNF_FP_Util
+open BNF_FP_N2M_Sugar
 open BNF_FP_Rec_Sugar_Util
 open BNF_FP_Rec_Sugar_Tactics
 
@@ -281,9 +282,6 @@
       bs mxs
   end;
 
-fun massage_comp ctxt has_call bound_Ts t = (* FIXME unused *)
-  massage_nested_corec_call ctxt has_call (K (K (K I))) bound_Ts (fastype_of1 (bound_Ts, t)) t;
-
 fun find_rec_calls ctxt has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
   let
     fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
@@ -802,11 +800,11 @@
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
     end;
 
-fun find_corec_calls has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
+fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
   let
     val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs
       |> find_index (equal sel) o #sels o the;
-    fun find t = if has_call t then [t] else [];
+    fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else [];
   in
     find rhs_term
     |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
@@ -830,7 +828,7 @@
       |> partition_eq ((op =) o pairself #fun_name)
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       |> map (flat o snd)
-      |> map2 (fold o find_corec_calls has_call) basic_ctr_specss
+      |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss
       |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} =>
         (ctr, map (K []) sels))) basic_ctr_specss);
 
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Mon Nov 04 12:40:28 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Mon Nov 04 14:46:38 2013 +0100
@@ -66,12 +66,11 @@
 
   val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
     typ list -> term -> term -> term -> term
-  val unfold_let: term -> term
   val massage_mutual_corec_call: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
     typ list -> term -> term
   val massage_nested_corec_call: Proof.context -> (term -> bool) ->
     (typ list -> typ -> typ -> term -> term) -> typ list -> typ -> term -> term
-  val fold_rev_corec_call:  Proof.context -> (term list -> term -> 'a -> 'a) -> typ list -> term ->
+  val fold_rev_corec_call: Proof.context -> (term list -> term -> 'a -> 'a) -> typ list -> term ->
     'a -> string list * 'a
   val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
   val massage_corec_code_rhs: Proof.context -> (typ list -> term -> term list -> term) ->
@@ -305,20 +304,6 @@
     massage_call
   end;
 
-fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
-  | unfold_let (Const (@{const_name prod_case}, _) $ t) =
-    (case unfold_let t of
-      t' as Abs (s1, T1, Abs (s2, T2, _)) =>
-      let
-        val x = (s1 ^ s2, Term.maxidx_of_term t + 1);
-        val v = Var (x, HOLogic.mk_prodT (T1, T2));
-      in
-        lambda v (unfold_let (betapplys (t', [HOLogic.mk_fst v, HOLogic.mk_snd v])))
-      end
-    | _ => t)
-  | unfold_let (t $ u) = betapply (unfold_let t, u)
-  | unfold_let t = t;
-
 fun fold_rev_let_if_case ctxt f bound_Ts t =
   let
     val thy = Proof_Context.theory_of ctxt;