properly fold over branches
authorblanchet
Wed, 25 Sep 2013 12:00:22 +0200
changeset 53870 5d45882b4f36
parent 53869 a6f6df7f01cf
child 53871 a1a52423601f
properly fold over branches
src/HOL/BNF/Tools/bnf_ctr_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
--- a/src/HOL/BNF/Tools/bnf_ctr_sugar.ML	Wed Sep 25 11:56:33 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_ctr_sugar.ML	Wed Sep 25 12:00:22 2013 +0200
@@ -40,9 +40,9 @@
   val mk_ctr: typ list -> term -> term
   val mk_case: typ list -> typ -> term -> term
   val mk_disc_or_sel: typ list -> term -> term
-
   val name_of_ctr: term -> string
   val name_of_disc: term -> string
+  val dest_case: Proof.context -> string -> typ list -> term -> (term list * term list) option
 
   val wrap_free_constructors: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     (((bool * bool) * term list) * binding) *
@@ -215,6 +215,28 @@
 
 val base_name_of_ctr = Long_Name.base_name o name_of_ctr;
 
+fun dest_case ctxt s Ts t =
+  (case Term.strip_comb t of
+    (Const (c, _), args as _ :: _) =>
+    (case ctr_sugar_of ctxt s of
+      SOME {casex = Const (case_name, _), discs = discs0, selss = selss0, ...} =>
+      if case_name = c then
+        let
+          val n = length discs0;
+          val (branches, obj :: leftovers) = chop n args;
+          val discs = map (mk_disc_or_sel Ts) discs0;
+          val selss = map (map (mk_disc_or_sel Ts)) selss0;
+          val conds = map (rapp obj) discs;
+          val branch_argss = map (fn sels => map (rapp obj) sels @ leftovers) selss;
+          val branches' = map2 (curry Term.betapplys) branches branch_argss;
+        in
+          SOME (conds, branches')
+        end
+      else
+        NONE
+    | _ => NONE)
+  | _ => NONE);
+
 fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
 
 fun prepare_wrap_free_constructors prep_term ((((no_discs_sels, rep_compat), raw_ctrs),
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 11:56:33 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 12:00:22 2013 +0200
@@ -174,7 +174,7 @@
             val map' = mk_map (length fs) Us ran_Ts map0;
             val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
           in
-            list_comb (map', fs')
+            Term.list_comb (map', fs')
           end
         | NONE => raise AINT_NO_MAP t)
       | massage_map _ _ t = raise AINT_NO_MAP t
@@ -196,18 +196,21 @@
     massage_call
   end;
 
-fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
-
 fun fold_rev_let_if_case ctxt f bound_Ts =
   let
+    val thy = Proof_Context.theory_of ctxt;
+
     fun fld t =
       (case Term.strip_comb t of
         (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
       | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
       | (Const (c, _), args as _ :: _) =>
-        let val (branches, obj) = split_last args in
-          (case fastype_of1 (bound_Ts, obj) of
-            Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t
+        let val n = num_binder_types (Sign.the_const_type thy c) in
+          (case fastype_of1 (bound_Ts, nth args (n - 1)) of
+            Type (s, Ts) =>
+            (case dest_case ctxt s Ts t of
+              NONE => f t
+            | SOME (conds, branches) => fold_rev fld branches)
           | _ => f t)
         end
       | _ => f t)
@@ -215,6 +218,8 @@
     fld
   end;
 
+fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
+
 fun massage_let_if_case ctxt has_call massage_leaf bound_Ts U =
   let
     val typof = curry fastype_of1 bound_Ts;
@@ -224,7 +229,7 @@
       (case Term.strip_comb t of
         (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
       | (Const (@{const_name If}, _), obj :: branches) =>
-        list_comb (If_const U $ tap check_obj obj, map massage_rec branches)
+        Term.list_comb (If_const U $ tap check_obj obj, map massage_rec branches)
       | (Const (c, _), args as _ :: _) =>
         let val (branches, obj) = split_last args in
           (case fastype_of1 (bound_Ts, obj) of
@@ -234,7 +239,7 @@
                 val branches' = map massage_rec branches;
                 val casex' = Const (c, map typof branches' ---> typof obj);
               in
-                list_comb (casex', branches') $ tap check_obj obj
+                Term.list_comb (casex', branches') $ tap check_obj obj
               end
             else
               massage_leaf t
@@ -269,7 +274,7 @@
             val map' = mk_map (length fs) dom_Ts Us map0;
             val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
           in
-            list_comb (map', fs')
+            Term.list_comb (map', fs')
           end
         | NONE => raise AINT_NO_MAP t)
       | massage_map _ _ t = raise AINT_NO_MAP t
@@ -288,7 +293,8 @@
             (case try (dest_ctr ctxt s) t of
               SOME (f, args) =>
               let val f' = mk_ctr Us f in
-                list_comb (f', map3 massage_call (binder_types (typof f')) (map typof args) args)
+                Term.list_comb (f',
+                  map3 massage_call (binder_types (typof f')) (map typof args) args)
               end
             | NONE =>
               (case t of
@@ -309,7 +315,8 @@
 
 fun expand_ctr_term ctxt s Ts t =
   (case ctr_sugar_of ctxt s of
-    SOME {ctrs, casex, ...} => list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t
+    SOME {ctrs, casex, ...} =>
+    Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t
   | NONE => raise Fail "expand_ctr_term");
 
 fun expand_corec_code_rhs ctxt has_call bound_Ts t =