thread through bound types
authorblanchet
Wed, 25 Sep 2013 16:43:46 +0200
changeset 53890 5f647a5bd46e
parent 53889 d1bd94eb5d0e
child 53891 27da6373a64f
thread through bound types
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 25 16:43:46 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 25 16:43:46 2013 +0200
@@ -577,9 +577,9 @@
     if is_none maybe_sel_eqn then (I, I, I) else
     let
       val {fun_args, rhs_term, ... } = the maybe_sel_eqn;
-      fun rewrite_q t = if has_call t then @{term False} else @{term True};
-      fun rewrite_g t = if has_call t then undef_const else t;
-      fun rewrite_h t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
+      fun rewrite_q _ t = if has_call t then @{term False} else @{term True};
+      fun rewrite_g _ t = if has_call t then undef_const else t;
+      fun rewrite_h _ t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
       fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
     in
       (massage rewrite_q,
@@ -604,8 +604,7 @@
       | rewrite _ U T t = if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
     fun massage NONE t = t
       | massage (SOME {fun_args, rhs_term, ...}) t =
-        massage_indirect_corec_call lthy has_call (rewrite []) []
-          (range_type (fastype_of t)) rhs_term
+        massage_indirect_corec_call lthy has_call rewrite [] (range_type (fastype_of t)) rhs_term
         |> abs_tuple fun_args;
   in
     massage maybe_sel_eqn
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 16:43:46 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 16:43:46 2013 +0200
@@ -60,13 +60,13 @@
 
   val massage_indirect_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
     typ list -> term -> term -> term -> term
-  val massage_direct_corec_call: Proof.context -> (term -> bool) -> (term -> term) -> typ list ->
-    term -> term
+  val massage_direct_corec_call: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
+    typ list -> term -> term
   val massage_indirect_corec_call: Proof.context -> (term -> bool) ->
-    (typ -> typ -> term -> term) -> typ list -> typ -> term -> term
+    (typ list -> typ -> typ -> term -> term) -> typ list -> typ -> term -> term
   val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
-  val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ list -> term ->
-    term
+  val massage_corec_code_rhs: Proof.context -> (typ list -> term -> term list -> term) ->
+    typ list -> term -> term
   val fold_rev_corec_code_rhs: Proof.context -> (term list -> term -> term list -> 'a -> 'a) ->
     typ list -> term -> 'a -> 'a
 
@@ -249,38 +249,40 @@
 
 fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
 
-fun massage_let_if_case ctxt has_call massage_leaf bound_Ts =
+fun massage_let_if_case ctxt has_call massage_leaf =
   let
     val thy = Proof_Context.theory_of ctxt;
 
-    val typof = curry fastype_of1 bound_Ts; (*###*)
     fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
 
-    fun massage_rec t =
-      (case Term.strip_comb t of
-        (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
-      | (Const (@{const_name If}, _), obj :: (branches as [then_branch, _])) =>
-        let val branches' = map massage_rec branches in
-          Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
-        end
-      | (Const (c, _), args as _ :: _) =>
-        let val n = num_binder_types (Sign.the_const_type thy c) in
-          (case fastype_of1 (bound_Ts, nth args (n - 1)) of
-            Type (s, Ts) =>
-            if case_of ctxt s = SOME c then
-              let
-                val (branches, obj_leftovers) = chop n args;
-                val branches' = map massage_rec branches;
-                val casex' = Const (c, map typof branches' ---> map typof obj_leftovers --->
-                  typof t);
-              in
-                betapplys (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
-              end
-            else
-              massage_leaf t
-          | _ => massage_leaf t)
-        end
-      | _ => massage_leaf t)
+    fun massage_rec bound_Ts t =
+      let val typof = curry fastype_of1 bound_Ts in
+        (case Term.strip_comb t of
+          (Const (@{const_name Let}, _), [arg1, arg2]) =>
+          massage_rec bound_Ts (betapply (arg2, arg1))
+        | (Const (@{const_name If}, _), obj :: (branches as [then_branch, _])) =>
+          let val branches' = map (massage_rec bound_Ts) branches in
+            Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
+          end
+        | (Const (c, _), args as _ :: _) =>
+          let val n = num_binder_types (Sign.the_const_type thy c) in
+            (case fastype_of1 (bound_Ts, nth args (n - 1)) of
+              Type (s, Ts) =>
+              if case_of ctxt s = SOME c then
+                let
+                  val (branches, obj_leftovers) = chop n args;
+                  val branches' = map (massage_rec bound_Ts) branches;
+                  val casex' = Const (c, map typof branches' ---> map typof obj_leftovers --->
+                    typof t);
+                in
+                  betapplys (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
+                end
+              else
+                massage_leaf bound_Ts t
+            | _ => massage_leaf bound_Ts t)
+          end
+        | _ => massage_leaf bound_Ts t)
+      end
   in
     massage_rec
   end;
@@ -289,63 +291,71 @@
 
 fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
   let
-    val typof = curry fastype_of1 bound_Ts;
     val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
 
-    fun massage_direct_call U T t =
-      if has_call t then factor_out_types ctxt raw_massage_call dest_sumT U T t
+    fun massage_direct_call bound_Ts U T t =
+      if has_call t then factor_out_types ctxt (raw_massage_call bound_Ts) dest_sumT U T t
       else build_map_Inl (T, U) $ t;
 
-    fun massage_direct_fun U T t =
-      let val var = Var ((Name.uu, Term.maxidx_of_term t + 1), domain_type (typof t)) in
-        Term.lambda var (massage_direct_call U T (t $ var))
+    fun massage_direct_fun bound_Ts U T t =
+      let
+        val var = Var ((Name.uu, Term.maxidx_of_term t + 1),
+          domain_type (fastype_of1 (bound_Ts, t)));
+      in
+        Term.lambda var (massage_direct_call bound_Ts U T (t $ var))
       end;
 
-    fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
+    fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t =
         (case try (dest_map ctxt s) t of
           SOME (map0, fs) =>
           let
-            val Type (_, dom_Ts) = domain_type (typof t);
+            val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t));
             val map' = mk_map (length fs) dom_Ts Us map0;
-            val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
+            val fs' =
+              map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs;
           in
             Term.list_comb (map', fs')
           end
         | NONE => raise AINT_NO_MAP t)
-      | massage_map _ _ t = raise AINT_NO_MAP t
-    and massage_map_or_map_arg U T t =
+      | massage_map _ _ _ t = raise AINT_NO_MAP t
+    and massage_map_or_map_arg bound_Ts U T t =
       if T = U then
         if has_call t then unexpected_corec_call ctxt t else t
       else
-        massage_map U T t
-        handle AINT_NO_MAP _ => massage_direct_fun U T t;
+        massage_map bound_Ts U T t
+        handle AINT_NO_MAP _ => massage_direct_fun bound_Ts U T t;
 
-    fun massage_call U T =
-      massage_let_if_case ctxt has_call (fn t =>
+    fun massage_call bound_Ts U T =
+      massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
         if has_call t then
           (case U of
             Type (s, Us) =>
             (case try (dest_ctr ctxt s) t of
               SOME (f, args) =>
-              let val f' = mk_ctr Us f in
-                Term.list_comb (f',
-                  map3 massage_call (binder_types (typof f')) (map typof args) args)
+              let
+                val typof = curry fastype_of1 bound_Ts;
+                val f' = mk_ctr Us f
+                val f'_T = typof f';
+                val arg_Ts = map typof args;
+              in
+                Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args)
               end
             | NONE =>
               (case t of
                 t1 $ t2 =>
                 (if has_call t2 then
-                  massage_direct_call U T t
+                  massage_direct_call bound_Ts U T t
                 else
-                  massage_map U T t1 $ t2
-                  handle AINT_NO_MAP _ => massage_direct_call U T t)
-              | Abs (s, T', t') => Abs (s, T', massage_call (range_type U) (range_type T) t')
-              | _ => massage_direct_call U T t))
+                  massage_map bound_Ts U T t1 $ t2
+                  handle AINT_NO_MAP _ => massage_direct_call bound_Ts U T t)
+              | Abs (s, T', t') =>
+                Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t')
+              | _ => massage_direct_call bound_Ts U T t))
           | _ => ill_formed_corec_call ctxt t)
         else
           build_map_Inl (T, U) $ t) bound_Ts;
   in
-    massage_call U (typof t) t
+    massage_call bound_Ts U (fastype_of1 (bound_Ts, t)) t
   end;
 
 fun expand_ctr_term ctxt s Ts t =
@@ -357,15 +367,16 @@
 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   (case fastype_of1 (bound_Ts, t) of
     Type (s, Ts) =>
-    massage_let_if_case ctxt has_call (fn t =>
+    massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
       if can (dest_ctr ctxt s) t then
         t
       else
-        massage_let_if_case ctxt has_call I bound_Ts (expand_ctr_term ctxt s Ts t)) bound_Ts t
+        massage_let_if_case ctxt has_call (K I) bound_Ts (expand_ctr_term ctxt s Ts t)) bound_Ts t
   | _ => raise Fail "expand_corec_code_rhs");
 
 fun massage_corec_code_rhs ctxt massage_ctr =
-  massage_let_if_case ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
+  massage_let_if_case ctxt (K false)
+    (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
 
 fun fold_rev_corec_code_rhs ctxt f =
   fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);