simplified interfaces, some restructuring
authorkrauss
Fri, 01 Jun 2007 15:57:45 +0200
changeset 23189 4574ab8f3b21
parent 23188 595a0e24bd8e
child 23190 d45c4d6c5f15
simplified interfaces, some restructuring
src/HOL/Tools/function_package/fundef_common.ML
src/HOL/Tools/function_package/fundef_core.ML
src/HOL/Tools/function_package/fundef_datatype.ML
src/HOL/Tools/function_package/fundef_package.ML
src/HOL/Tools/function_package/mutual.ML
--- a/src/HOL/Tools/function_package/fundef_common.ML	Fri Jun 01 15:20:53 2007 +0200
+++ b/src/HOL/Tools/function_package/fundef_common.ML	Fri Jun 01 15:57:45 2007 +0200
@@ -224,6 +224,51 @@
   fun fundef_parser default_cfg = config_parser default_cfg -- P.fixes --| P.$$$ "where" -- statements_ow
 end
 
+
+
+(* Common operations on equations *)
+
+fun open_all_all (Const ("all", _) $ Abs (n, T, b)) = apfst (cons (n, T)) (open_all_all b)
+  | open_all_all t = ([], t)
+
+exception MalformedEquation of term
+
+fun split_def geq =
+    let
+      val (qs, imp) = open_all_all geq
+
+      val gs = Logic.strip_imp_prems imp
+      val eq = Logic.strip_imp_concl imp
+
+      val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
+          handle TERM _ => raise MalformedEquation geq
+
+      val (head, args) = strip_comb f_args
+
+      val fname = fst (dest_Free head)
+          handle TERM _ => raise MalformedEquation geq
+    in
+      (fname, qs, gs, args, rhs)
+    end
+
+exception ArgumentCount of string
+
+fun mk_arities fqgars =
+    let fun f (fname, _, _, args, _) arities =
+            let val k = length args
+            in
+              case Symtab.lookup arities fname of
+                NONE => Symtab.update (fname, k) arities
+              | SOME i => (if i = k then arities else raise ArgumentCount fname)
+            end
+    in
+      fold f fqgars Symtab.empty
+    end
+
+
+
+
+
 end
 
 (* Common Abbreviations *)
--- a/src/HOL/Tools/function_package/fundef_core.ML	Fri Jun 01 15:20:53 2007 +0200
+++ b/src/HOL/Tools/function_package/fundef_core.ML	Fri Jun 01 15:57:45 2007 +0200
@@ -10,9 +10,8 @@
 sig
     val prepare_fundef : FundefCommon.fundef_config
                          -> string (* defname *)
-                         -> (string * typ * mixfix) (* defined symbol *)
+                         -> ((string * typ) * mixfix) list (* defined symbol *)
                          -> ((string * typ) list * term list * term * term) list (* specification *)
-                         -> string (* default_value, not parsed yet *)
                          -> local_theory
 
                          -> (term   (* f *)
@@ -858,9 +857,9 @@
     end
 
 
-fun prepare_fundef config defname (fname, fT, mixfix) abstract_qglrs default_str lthy =
+fun prepare_fundef config defname [((fname, fT), mixfix)] abstract_qglrs lthy =
     let
-      val FundefConfig {domintros, tailrec, ...} = config 
+      val FundefConfig {domintros, tailrec, default=default_str, ...} = config 
                                                          
       val fvar = Free (fname, fT)
       val domT = domain_type fT
--- a/src/HOL/Tools/function_package/fundef_datatype.ML	Fri Jun 01 15:20:53 2007 +0200
+++ b/src/HOL/Tools/function_package/fundef_datatype.ML	Fri Jun 01 15:57:45 2007 +0200
@@ -20,6 +20,40 @@
 open FundefLib
 open FundefCommon
 
+fun check_constr_pattern thy err (Bound _) = ()
+  | check_constr_pattern thy err t =
+    let
+      val (hd, args) = strip_comb t
+    in
+      (((case DatatypePackage.datatype_of_constr thy (fst (dest_Const hd)) of
+           SOME _ => ()
+         | NONE => err t)
+        handle TERM ("dest_Const", _) => err t);
+       map (check_constr_pattern thy err) args; 
+       ())
+    end
+
+
+fun check_pats ctxt geq =
+    let 
+      fun err str = error (cat_lines ["Malformed \"fun\" definition:",
+                                      str,
+                                      ProofContext.string_of_term ctxt geq])
+      val thy = ProofContext.theory_of ctxt
+
+      val (fname, qs, gs, args, rhs) = split_def geq 
+
+      val _ = if not (null gs) then err "Conditional equations not allowed with \"fun\"" else ()
+      val _ = map (check_constr_pattern thy (fn t => err "Not a constructor pattern")) args
+
+                  (* just count occurrences to check linearity *)
+      val _ = if fold (fold_aterms (fn Bound _ => curry (op +) 1 | _ => I)) args 0 < length qs
+              then err "Nonlinear pattern" else ()
+    in
+      ()
+    end
+
+
 fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T)
 fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T)
 
--- a/src/HOL/Tools/function_package/fundef_package.ML	Fri Jun 01 15:20:53 2007 +0200
+++ b/src/HOL/Tools/function_package/fundef_package.ML	Fri Jun 01 15:57:45 2007 +0200
@@ -32,7 +32,7 @@
 end
 
 
-structure FundefPackage : FUNDEF_PACKAGE =
+structure FundefPackage (*: FUNDEF_PACKAGE*) =
 struct
 
 open FundefLib
@@ -42,6 +42,67 @@
 
 fun mk_defname fixes = fixes |> map (fst o fst) |> space_implode "_" 
 
+
+(* Check for all sorts of errors in the input *)
+fun check_def ctxt fixes eqs =
+    let
+      val fnames = map (fst o fst) fixes
+                                
+      fun check geq = 
+          let
+            fun input_error msg = cat_lines [msg, ProofContext.string_of_term ctxt geq]
+                                  
+            val fqgar as (fname, qs, gs, args, rhs) = split_def geq
+                                 
+            val _ = fname mem fnames 
+                    orelse error (input_error ("Head symbol of left hand side must be " ^ plural "" "one out of " fnames 
+                                               ^ commas_quote fnames))
+                                            
+            fun add_bvs t is = add_loose_bnos (t, 0, is)
+            val rvs = (add_bvs rhs [] \\ fold add_bvs args [])
+                        |> map (fst o nth (rev qs))
+                      
+            val _ = null rvs orelse error (input_error ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs
+                                                        ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:"))
+                                    
+            val _ = forall (forall_aterms (fn Free (n, _) => not (n mem fnames) | _ => true)) gs orelse
+                    error (input_error "Recursive Calls not allowed in premises")
+          in
+            fqgar
+          end
+    in
+      (mk_arities (map check eqs); ())
+      handle ArgumentCount fname => 
+             error ("Function " ^ quote fname ^ " has different numbers of arguments in different equations")
+    end
+
+
+fun mk_catchall fixes arities =
+    let
+      fun mk_eqn ((fname, fT), _) =
+          let 
+            val n = the (Symtab.lookup arities fname)
+            val (argTs, rT) = chop n (binder_types fT)
+                                   |> apsnd (fn Ts => Ts ---> body_type fT) 
+                              
+            val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
+          in
+            HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
+                          Const ("HOL.undefined", rT))
+              |> HOLogic.mk_Trueprop
+              |> fold_rev mk_forall qs
+          end
+    in
+      map mk_eqn fixes
+    end
+
+fun add_catchall fixes spec =
+    let 
+      val catchalls = mk_catchall fixes (mk_arities (map split_def (map (snd o snd) spec)))
+    in
+      spec @ map (pair ("",[]) o pair true) catchalls
+    end
+
 fun burrow_snd f ps = (* ('a list -> 'b list) -> ('c * 'a) list -> ('c * 'b) list *)
     let val (xs, ys) = split_list ps
     in xs ~~ f ys end
@@ -107,16 +168,23 @@
       fun prep_eqn e = the_single (snd (fst (prep [] [e] ctxt')))
                          |> apsnd the_single
 
-      val spec = map prep_eqn eqns
+      val raw_spec = map prep_eqn eqns
                      |> map (apsnd (fn t => fold_rev (mk_forall o Free) (frees_in_term ctxt' t) t)) (* Add quantifiers *)
-                     |> burrow_snd (fn ts => FundefSplit.split_some_equations ctxt' (flags ~~ ts))
+
+      val _ = check_def ctxt' fixes (map snd raw_spec)
+
+      val spec = raw_spec
+                     |> burrow_snd (fn ts => flags ~~ ts)
+                     (*|> (if global_flag then add_catchall fixes else I) *) (* Completion: still disabled *)
+                     |> burrow_snd (FundefSplit.split_some_equations ctxt')
+
     in
       ((fixes, spec), ctxt')
     end
 
 fun gen_add_fundef prep_spec fixspec eqnss_flags config lthy =
     let
-      val FundefConfig {sequential, default, tailrec, ...} = config
+      val FundefConfig {sequential, ...} = config
 
       val ((fixes, spec), ctxt') = prep_with_flags prep_spec fixspec eqnss_flags sequential lthy
 
@@ -125,7 +193,7 @@
       val t_eqns = spec |> map snd |> flat (* flatten external structure *)
 
       val ((goalstate, cont, sort_cont), lthy) =
-          FundefMutual.prepare_fundef_mutual config defname fixes t_eqns default lthy
+          FundefMutual.prepare_fundef_mutual config defname fixes t_eqns lthy
 
       val afterqed = fundef_afterqed config fixes spec defname cont sort_cont
     in
--- a/src/HOL/Tools/function_package/mutual.ML	Fri Jun 01 15:20:53 2007 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML	Fri Jun 01 15:57:45 2007 +0200
@@ -13,7 +13,6 @@
                               -> string (* defname *)
                               -> ((string * typ) * mixfix) list
                               -> term list
-                              -> string (* default, unparsed term *)
                               -> local_theory
                               -> ((thm (* goalstate *)
                                    * (thm -> FundefCommon.fundef_result) (* proof continuation *)
@@ -72,50 +71,6 @@
     if n < 5 then fst (chop n ["P","Q","R","S"])
     else map (fn i => "P" ^ string_of_int i) (1 upto n)
 
-
-fun open_all_all (Const ("all", _) $ Abs (n, T, b)) = apfst (cons (n, T)) (open_all_all b)
-  | open_all_all t = ([], t)
-
-(* Builds a curried clause description in abstracted form *)
-fun split_def ctxt fnames geq arities =
-    let
-      fun input_error msg = cat_lines [msg, ProofContext.string_of_term ctxt geq]
-                            
-      val (qs, imp) = open_all_all geq
-
-      val gs = Logic.strip_imp_prems imp
-      val eq = Logic.strip_imp_concl imp
-
-      val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
-      val (head, args) = strip_comb f_args
-
-      val invalid_head_msg = "Head symbol of left hand side must be " ^ plural "" "one out of " fnames ^ commas_quote fnames
-      val fname = fst (dest_Free head)
-          handle TERM _ => error (input_error invalid_head_msg)
-
-      val _ = fname mem fnames orelse error (input_error invalid_head_msg)
-
-      fun add_bvs t is = add_loose_bnos (t, 0, is)
-      val rvs = (add_bvs rhs [] \\ fold add_bvs args [])
-                  |> map (fst o nth (rev qs))
-                
-      val _ = null rvs orelse error (input_error ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs
-                                              ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:"))
-
-      val _ = forall (forall_aterms (fn Free (n, _) => not (n mem fnames) | _ => true)) gs orelse
-                     error (input_error "Recursive Calls not allowed in premises")
-
-      val k = length args
-
-      val arities' = case Symtab.lookup arities fname of
-                       NONE => Symtab.update (fname, k) arities
-                     | SOME i => (i = k orelse
-                                  error (input_error ("Function " ^ quote fname ^ " has different numbers of arguments in different equations"));
-                                  arities)
-    in
-      ((fname, qs, gs, args, rhs), arities')
-    end
-    
 fun get_part fname =
     the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
                      
@@ -133,7 +88,8 @@
 fun analyze_eqs ctxt defname fs eqs =
     let
         val fnames = map fst fs
-        val (fqgars, arities) = fold_map (split_def ctxt fnames) eqs Symtab.empty
+        val fqgars = map split_def eqs
+        val arities = mk_arities fqgars
 
         fun curried_types (fname, fT) =
             let
@@ -325,15 +281,17 @@
       fun mk_mpsimp fqgar sum_psimp =
           in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
           
+      val rew_ss = HOL_basic_ss addsimps all_f_defs
       val mpsimps = map2 mk_mpsimp fqgars psimps
       val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps
       val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
-      val mtermination = full_simplify (HOL_basic_ss addsimps all_f_defs) termination
+      val mtermination = full_simplify rew_ss termination
+      val mdomintros = map_option (map (full_simplify rew_ss)) domintros
     in
       FundefResult { fs=fs, G=G, R=R,
                      psimps=mpsimps, subset_pinducts=[subset_pinduct], simple_pinducts=minducts,
                      cases=cases, termination=mtermination,
-                     domintros=domintros,
+                     domintros=mdomintros,
                      trsimps=mtrsimps}
     end
       
@@ -351,13 +309,13 @@
              |> map (snd #> map snd)                     (* and remove the labels afterwards *)
 
 
-fun prepare_fundef_mutual config defname fixes eqss default lthy =
+fun prepare_fundef_mutual config defname fixes eqss lthy =
     let
       val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
       val Mutual {fsum_var=(n, T), qglrs, ...} = mutual
           
       val ((fsum, goalstate, cont), lthy') =
-          FundefCore.prepare_fundef config defname (n, T, NoSyn) qglrs default lthy
+          FundefCore.prepare_fundef config defname [((n, T), NoSyn)] qglrs lthy
           
       val (mutual', lthy'') = define_projections fixes mutual fsum lthy'