removed "sum_tools.ML" in favour of BalancedTree
authorkrauss
Mon, 25 Jun 2007 12:16:27 +0200
changeset 23494 f985f9239e0d
parent 23493 a056eefb76e5
child 23495 e4dd6beeafab
removed "sum_tools.ML" in favour of BalancedTree
src/HOL/FunDef.thy
src/HOL/IsaMakefile
src/HOL/Tools/function_package/mutual.ML
src/HOL/Tools/function_package/sum_tools.ML
--- a/src/HOL/FunDef.thy	Mon Jun 25 00:36:42 2007 +0200
+++ b/src/HOL/FunDef.thy	Mon Jun 25 12:16:27 2007 +0200
@@ -8,7 +8,6 @@
 theory FunDef
 imports Datatype Accessible_Part
 uses
-  ("Tools/function_package/sum_tools.ML")
   ("Tools/function_package/fundef_lib.ML")
   ("Tools/function_package/fundef_common.ML")
   ("Tools/function_package/inductive_wrap.ML")
@@ -86,8 +85,6 @@
     by (rule THE_default_none)
 qed
 
-
-use "Tools/function_package/sum_tools.ML"
 use "Tools/function_package/fundef_lib.ML"
 use "Tools/function_package/fundef_common.ML"
 use "Tools/function_package/inductive_wrap.ML"
--- a/src/HOL/IsaMakefile	Mon Jun 25 00:36:42 2007 +0200
+++ b/src/HOL/IsaMakefile	Mon Jun 25 12:16:27 2007 +0200
@@ -117,7 +117,7 @@
   Tools/function_package/lexicographic_order.ML				\
   Tools/function_package/mutual.ML					\
   Tools/function_package/pattern_split.ML				\
-  Tools/function_package/sum_tools.ML Tools/inductive_codegen.ML	\
+  Tools/inductive_codegen.ML	\
   Tools/inductive_package.ML Tools/inductive_realizer.ML Tools/meson.ML	\
   Tools/metis_tools.ML Tools/numeral_syntax.ML 				\
   Tools/old_inductive_package.ML Tools/polyhash.ML 			\
--- a/src/HOL/Tools/function_package/mutual.ML	Mon Jun 25 00:36:42 2007 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML	Mon Jun 25 12:16:27 2007 +0200
@@ -31,6 +31,36 @@
 (* Theory dependencies *)
 val sum_case_rules = thms "Sum_Type.sum_cases"
 val split_apply = thm "Product_Type.split"
+val projl_inl = thm "Sum_Type.Projl_Inl"
+val projr_inr = thm "Sum_Type.Projr_Inr"
+
+
+(* Sum types *)
+fun mk_sumT LT RT = Type ("+", [LT, RT])
+fun mk_sumcase TL TR T l r = Const (@{const_name "Sum_Type.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
+
+val App = curry op $
+
+fun mk_inj ST n i = 
+    BalancedTree.access 
+    { init = (ST, I : term -> term),
+      left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, App (Const (@{const_name "Inl"}, LT --> T)) o inj)),
+      right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, App (Const (@{const_name "Inr"}, RT --> T)) o inj))} n i 
+    |> snd
+
+fun mk_proj ST n i = 
+    BalancedTree.access 
+    { init = (ST, I : term -> term),
+      left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, proj o App (Const (@{const_name "Projl"}, T --> LT)))),
+      right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, proj o App (Const (@{const_name "Projr"}, T --> RT))))} n i
+    |> snd
+
+fun mk_sumcases T fs =
+      BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT)) 
+                        (map (fn f => (f, domain_type (fastype_of f))) fs)
+      |> fst
+
+
 
 type qgar = string * (string * typ) list * term list * term list * term
 
@@ -39,10 +69,10 @@
 datatype mutual_part =
   MutualPart of 
    {
+    i : int,
+    i' : int,
     fvar : string * typ,
     cargTs: typ list,
-    pthA: SumTools.sum_path,
-    pthR: SumTools.sum_path,
     f_def: term,
 
     f: term option,
@@ -53,12 +83,12 @@
 datatype mutual_info =
   Mutual of 
    { 
+    n : int,
+    n' : int,
     fsum_var : string * typ,
 
     ST: typ,
     RST: typ,
-    streeA: SumTools.sum_tree,
-    streeR: SumTools.sum_tree,
 
     parts: mutual_part list,
     fqgars: qgar list,
@@ -87,6 +117,7 @@
 
 fun analyze_eqs ctxt defname fs eqs =
     let
+      val num = length fs
         val fnames = map fst fs
         val fqgars = map split_def eqs
         val arities = mk_arities fqgars
@@ -102,40 +133,44 @@
         val (caTss, resultTs) = split_list (map curried_types fs)
         val argTs = map (foldr1 HOLogic.mk_prodT) caTss
 
-        val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs
-        val (ST, streeA, pthsA) = SumTools.mk_tree argTs
+        val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs
+        val n' = length dresultTs
+
+        val RST = BalancedTree.make (uncurry mk_sumT) dresultTs
+        val ST = BalancedTree.make (uncurry mk_sumT) argTs
 
         val fsum_type = ST --> RST
 
         val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
         val fsum_var = (fsum_var_name, fsum_type)
 
-        fun define (fvar as (n, T)) caTs pthA pthR =
+        fun define (fvar as (n, T)) caTs resultT i =
             let
-                val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *)
+                val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
+                val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1 
 
-                val f_exp = SumTools.mk_proj streeR pthR (Free fsum_var $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars))
+                val f_exp = mk_proj RST n' i' (Free fsum_var $ mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
                 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
 
                 val rew = (n, fold_rev lambda vars f_exp)
             in
-                (MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew)
+                (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
             end
             
-        val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR)
+        val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
 
         fun convert_eqs (f, qs, gs, args, rhs) =
             let
-              val MutualPart {pthA, pthR, ...} = get_part f parts
+              val MutualPart {i, i', ...} = get_part f parts
             in
-              (qs, gs, SumTools.mk_inj streeA pthA (foldr1 (mk_prod_abs qs) args),
-               SumTools.mk_inj streeR pthR (replace_frees rews rhs)
-                               |> Envir.norm_term (Envir.empty 0))
+              (qs, gs, mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
+               mk_inj RST n' i' (replace_frees rews rhs)
+                               |> Envir.beta_norm)
             end
 
         val qglrs = map convert_eqs fqgars
     in
-        Mutual {fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR,
+        Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, 
                 parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
     end
 
@@ -144,22 +179,22 @@
 
 fun define_projections fixes mutual fsum lthy =
     let
-      fun def ((MutualPart {fvar=(fname, fT), cargTs, pthA, pthR, f_def, ...}), (_, mixfix)) lthy =
+      fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
           let
             val ((f, (_, f_defthm)), lthy') =
               LocalTheory.def Thm.internalK ((fname, mixfix),
                                             ((fname ^ "_def", []), Term.subst_bound (fsum, f_def)))
                               lthy
           in
-            (MutualPart {fvar=(fname, fT), cargTs=cargTs, pthA=pthA, pthR=pthR, f_def=f_def,
+            (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
                          f=SOME f, f_defthm=SOME f_defthm },
              lthy')
           end
           
-      val Mutual { fsum_var, ST, RST, streeA, streeR, parts, fqgars, qglrs, ... } = mutual
+      val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
       val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
     in
-      (Mutual { fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts',
+      (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
                 fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
        lthy')
     end
@@ -184,7 +219,7 @@
       val ags = map (assume o cterm_of thy) gs
 
       val import = fold forall_elim cqs
-                   #> fold implies_elim_swp ags
+                   #> fold (flip implies_elim) ags
 
       val export = fold_rev (implies_intr o cprop_of) ags
                    #> fold_rev forall_intr_rename (oqnames ~~ cqs)
@@ -194,7 +229,7 @@
 
 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq =
     let
-      val (MutualPart {f=SOME f, f_defthm=SOME f_def, pthR, ...}) = get_part fname parts
+      val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts
 
       val psimp = import sum_psimp_eq
       val (simp, restore_cond) = case cprems_of psimp of
@@ -206,7 +241,7 @@
                  (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
                  (fn _ => SIMPSET (unfold_tac all_orig_fdefs)
                           THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
-                          THEN SIMPSET' (fn ss => simp_tac (ss addsimps [SumTools.projl_inl, SumTools.projr_inr])) 1)
+                          THEN SIMPSET' (fn ss => simp_tac (ss addsimps [projl_inl, projr_inr])) 1)
         |> restore_cond 
         |> export
     end
@@ -223,9 +258,9 @@
            |> fold_rev forall_intr xs
            |> forall_elim_vars 0
     end
-    
+
 
-fun mutual_induct_rules lthy induct all_f_defs (Mutual {RST, parts, streeA, ...}) =
+fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) =
     let
       val cert = cterm_of (ProofContext.theory_of lthy)
       val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => 
@@ -242,26 +277,25 @@
           end
           
       val Ps = map2 mk_P parts newPs
-      val case_exp = SumTools.mk_sumcases streeA HOLogic.boolT Ps
+      val case_exp = mk_sumcases HOLogic.boolT Ps
                      
       val induct_inst =
           forall_elim (cert case_exp) induct
                       |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
                       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
           
-      fun mk_proj rule (MutualPart {cargTs, pthA, ...}) =
+      fun project rule (MutualPart {cargTs, i, ...}) =
           let
-            val afs = map_index (fn (i,T) => Free ("a" ^ string_of_int i, T)) cargTs
-            val inj = SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod afs)
+            val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int j, T)) cargTs
+            val inj = mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
           in
             rule
               |> forall_elim (cert inj)
               |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
               |> fold_rev (forall_intr o cert) (afs @ newPs)
           end
-          
     in
-      map (mk_proj induct_inst) parts
+      map (project induct_inst) parts
     end
     
 
@@ -295,7 +329,6 @@
                      trsimps=mtrsimps}
     end
       
-
 (* puts an object in the "right bucket" *)
 fun store_grouped P x [] = []
   | store_grouped P x ((l, xs)::bs) =
--- a/src/HOL/Tools/function_package/sum_tools.ML	Mon Jun 25 00:36:42 2007 +0200
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,124 +0,0 @@
-(*  Title:      HOL/Tools/function_package/sum_tools.ML
-    ID:         $Id$
-    Author:     Alexander Krauss, TU Muenchen
-
-A package for general recursive function definitions. 
-Tools for mutual recursive definitions. This could actually be useful for other packages, too, but needs
-some cleanup first...
-
-*)
-
-signature SUM_TOOLS =
-sig
-  type sum_tree
-  type sum_path
-
-  val projl_inl: thm
-  val projr_inr: thm
-
-  val mk_tree : typ list -> typ * sum_tree * sum_path list
-  val mk_tree_distinct : typ list -> typ * sum_tree * sum_path list
-
-  val mk_proj: sum_tree -> sum_path -> term -> term
-  val mk_inj: sum_tree -> sum_path -> term -> term
-
-  val mk_sumcases: sum_tree -> typ -> term list -> term
-end
-
-
-structure SumTools: SUM_TOOLS =
-struct
-
-val inlN = "Sum_Type.Inl"
-val inrN = "Sum_Type.Inr"
-val sumcaseN = "Sum_Type.sum_case"
-
-val projlN = "Sum_Type.Projl"
-val projrN = "Sum_Type.Projr"
-val projl_inl = thm "Sum_Type.Projl_Inl"
-val projr_inr = thm "Sum_Type.Projr_Inr"
-
-fun mk_sumT LT RT = Type ("+", [LT, RT])
-fun mk_sumcase TL TR T l r = Const (sumcaseN, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
-
-datatype sum_tree 
-  = Leaf of typ
-  | Branch of (typ * (typ * sum_tree) * (typ * sum_tree))
-
-type sum_path = bool list (* true: left, false: right *)
-                
-fun sum_type_of (Leaf T) = T
-  | sum_type_of (Branch (ST,(LT,_),(RT,_))) = ST
-                                              
-                                              
-fun mk_tree Ts =
-    let 
-      fun mk_tree' 1 [T] = (T, Leaf T, [[]])
-        | mk_tree' n Ts =
-          let 
-            val n2 = n div 2
-            val (lTs, rTs) = chop n2 Ts
-            val (TL, ltree, lpaths) = mk_tree' n2 lTs
-            val (TR, rtree, rpaths) = mk_tree' (n - n2) rTs
-            val T = mk_sumT TL TR
-            val pths = map (cons true) lpaths @ map (cons false) rpaths 
-          in
-            (T, Branch (T, (TL, ltree), (TR, rtree)), pths)
-          end
-    in
-      mk_tree' (length Ts) Ts
-    end
-    
-    
-fun mk_tree_distinct Ts =
-    let
-      fun insert_once T Ts =
-          let
-            val i = find_index_eq T Ts
-          in
-            if i = ~1 then (length Ts, Ts @ [T]) else (i, Ts)
-          end
-          
-      val (idxs, dist_Ts) = fold_map insert_once Ts []
-                            
-      val (ST, tree, pths) = mk_tree dist_Ts
-    in
-      (ST, tree, map (nth pths) idxs)
-    end
-
-
-fun mk_inj (Leaf _) [] t = t
-  | mk_inj (Branch (ST, (LT, tr), _)) (true::pth) t = 
-    Const (inlN, LT --> ST) $ mk_inj tr pth t
-  | mk_inj (Branch (ST, _, (RT, tr))) (false::pth) t = 
-    Const (inrN, RT --> ST) $ mk_inj tr pth t
-  | mk_inj _ _ _ = sys_error "mk_inj"
-
-fun mk_proj (Leaf _) [] t = t
-  | mk_proj (Branch (ST, (LT, tr), _)) (true::pth) t = 
-    mk_proj tr pth (Const (projlN, ST --> LT) $ t)
-  | mk_proj (Branch (ST, _, (RT, tr))) (false::pth) t = 
-    mk_proj tr pth (Const (projrN, ST --> RT) $ t)
-  | mk_proj _ _ _ = sys_error "mk_proj"
-
-
-fun mk_sumcases tree T ts =
-    let
-      fun mk_sumcases' (Leaf _) (t::ts) = (t,ts)
-        | mk_sumcases' (Branch (ST, (LT, ltr), (RT, rtr))) ts =
-          let
-            val (lcase, ts') = mk_sumcases' ltr ts
-            val (rcase, ts'') = mk_sumcases' rtr ts'
-          in
-            (mk_sumcase LT RT T lcase rcase, ts'')
-          end
-        | mk_sumcases' _ [] = sys_error "mk_sumcases"
-    in
-      fst (mk_sumcases' tree ts)
-    end
-    
-end
-
-
-
-