Function package can now do automatic splits of overlapping datatype patterns
authorkrauss
Mon, 31 Jul 2006 18:07:42 +0200
changeset 20270 3abe7dae681e
parent 20269 c40070317ab8
child 20271 e76e77e0d615
Function package can now do automatic splits of overlapping datatype patterns
src/HOL/FunDef.thy
src/HOL/Tools/function_package/fundef_common.ML
src/HOL/Tools/function_package/fundef_package.ML
src/HOL/Tools/function_package/pattern_split.ML
src/HOL/Tools/function_package/termination.ML
src/HOL/ex/Fundefs.thy
--- a/src/HOL/FunDef.thy	Mon Jul 31 18:05:40 2006 +0200
+++ b/src/HOL/FunDef.thy	Mon Jul 31 18:07:42 2006 +0200
@@ -9,6 +9,7 @@
 ("Tools/function_package/fundef_proof.ML")
 ("Tools/function_package/termination.ML")
 ("Tools/function_package/mutual.ML")
+("Tools/function_package/pattern_split.ML")
 ("Tools/function_package/fundef_package.ML")
 ("Tools/function_package/fundef_datatype.ML")
 ("Tools/function_package/auto_term.ML")
@@ -71,6 +72,7 @@
 use "Tools/function_package/fundef_proof.ML"
 use "Tools/function_package/termination.ML"
 use "Tools/function_package/mutual.ML"
+use "Tools/function_package/pattern_split.ML"
 use "Tools/function_package/fundef_package.ML"
 
 setup FundefPackage.setup
--- a/src/HOL/Tools/function_package/fundef_common.ML	Mon Jul 31 18:05:40 2006 +0200
+++ b/src/HOL/Tools/function_package/fundef_common.ML	Mon Jul 31 18:07:42 2006 +0200
@@ -150,7 +150,8 @@
      }
 
 
-type result_with_names = fundef_mresult * mutual_info * string list list * attribute list list list
+type fundef_spec = ((string * attribute list) * term list) list list
+type result_with_names = fundef_mresult * mutual_info * fundef_spec
 
 structure FundefData = TheoryDataFun
 (struct
--- a/src/HOL/Tools/function_package/fundef_package.ML	Mon Jul 31 18:05:40 2006 +0200
+++ b/src/HOL/Tools/function_package/fundef_package.ML	Mon Jul 31 18:07:42 2006 +0200
@@ -1,3 +1,4 @@
+
 (*  Title:      HOL/Tools/function_package/fundef_package.ML
     ID:         $Id$
     Author:     Alexander Krauss, TU Muenchen
@@ -9,7 +10,7 @@
 
 signature FUNDEF_PACKAGE = 
 sig
-    val add_fundef : ((bstring * Attrib.src list) * string) list list -> theory -> Proof.state (* Need an _i variant *)
+    val add_fundef : ((bstring * (Attrib.src list * bool)) * string) list list -> bool -> theory -> Proof.state (* Need an _i variant *)
 
     val cong_add: attribute
     val cong_del: attribute
@@ -27,17 +28,20 @@
 val True_implies = thm "True_implies"
 
 
-fun add_simps label moreatts (MutualPart {f_name, ...}, psimps) (names, attss) thy =
+fun add_simps label moreatts (MutualPart {f_name, ...}, psimps) spec_part thy =
     let 
+      val psimpss = Library.unflat (map snd spec_part) psimps
+      val (names, attss) = split_list (map fst spec_part) 
+
       val thy = thy |> Theory.add_path f_name 
                 
       val thy = thy |> Theory.add_path label
-      val spsimps = map standard psimps
-      val add_list = (names ~~ spsimps) ~~ attss
-      val (_, thy) = PureThy.add_thms add_list thy
+      val spsimpss = map (map standard) psimpss (* FIXME *)
+      val add_list = (names ~~ spsimpss) ~~ attss
+      val (_, thy) = PureThy.add_thmss add_list thy
       val thy = thy |> Theory.parent_path
                 
-      val (_, thy) = PureThy.add_thmss [((label, spsimps), Simplifier.simp_add :: moreatts)] thy
+      val (_, thy) = PureThy.add_thmss [((label, flat spsimpss), Simplifier.simp_add :: moreatts)] thy
       val thy = thy |> Theory.parent_path
     in
       thy
@@ -48,7 +52,7 @@
 
 
 
-fun fundef_afterqed congs mutual_info name data names atts [[result]] thy =
+fun fundef_afterqed congs mutual_info name data spec [[result]] thy =
     let
 	val fundef_data = FundefMutual.mk_partial_rules_mutual thy mutual_info data result
 	val FundefMResult {psimps, subset_pinducts, simple_pinducts, termination, domintros, cases, ...} = fundef_data
@@ -58,44 +62,43 @@
 	val dom_abbrev = Logic.mk_equals (Free (name ^ "_dom", fastype_of accR), accR)
 	val (_, thy) = LocalTheory.mapping NONE (Specification.abbreviation_i ("", false) [(NONE, dom_abbrev)]) thy
 
-        val thy = fold2 (add_simps "psimps" []) (parts ~~ psimps) (names ~~ atts) thy
+        val thy = fold2 (add_simps "psimps" []) (parts ~~ psimps) spec thy
+
+        val casenames = flat (map (map (fst o fst)) spec)
 
 	val thy = thy |> Theory.add_path name
-	val (_, thy) = PureThy.add_thms [(("cases", cases), [RuleCases.case_names (flat names)])] thy
+	val (_, thy) = PureThy.add_thms [(("cases", cases), [RuleCases.case_names casenames])] thy
 	val (_, thy) = PureThy.add_thmss [(("domintros", domintros), [])] thy
 	val (_, thy) = PureThy.add_thms [(("termination", standard termination), [])] thy
-	val (_,thy) = PureThy.add_thmss [(("pinduct", map standard simple_pinducts), [RuleCases.case_names (flat names), InductAttrib.induct_set ""])] thy
+	val (_,thy) = PureThy.add_thmss [(("pinduct", map standard simple_pinducts), [RuleCases.case_names casenames, InductAttrib.induct_set ""])] thy
 	val thy = thy |> Theory.parent_path
     in
-	add_fundef_data name (fundef_data, mutual_info, names, atts) thy
+      add_fundef_data name (fundef_data, mutual_info, spec) thy
     end
 
-fun gen_add_fundef prep_att eqns_attss thy =
+fun gen_add_fundef prep_att eqns_attss preprocess thy =
     let
-	fun split eqns_atts =
-	    let 
-		val (natts, eqns) = split_list eqns_atts
-		val (names, raw_atts) = split_list natts
-		val atts = map (map (prep_att thy)) raw_atts
-	    in
-		((names, atts), eqns)
-	    end
-
+      fun prep_eqns neqs =
+          neqs
+            |> map (apsnd (Sign.read_prop thy))    
+            |> map (apfst (apsnd (apfst (map (prep_att thy)))))
+            |> FundefSplit.split_some_equations (ProofContext.init thy)
+      
+      val spec = map prep_eqns eqns_attss
+      val t_eqnss = map (flat o map snd) spec
 
-	val (natts, eqns) = split_list (map split_list eqns_attss)
-	val (names, raw_atts) = split_list (map split_list natts)
-
-	val atts = map (map (map (prep_att thy))) raw_atts
+(*
+ val t_eqns = if preprocess then map (FundefSplit.split_all_equations (ProofContext.init thy)) t_eqns
+              else t_eqns
+*)
 
-	val congs = get_fundef_congs (Context.Theory thy)
+      val congs = get_fundef_congs (Context.Theory thy)
 
-	val t_eqns = map (map (Sign.read_prop thy)) eqns
-
-	val (mutual_info, name, (data, thy)) = FundefMutual.prepare_fundef_mutual congs t_eqns thy
-	val Prep {goal, goalI, ...} = data
+      val (mutual_info, name, (data, thy)) = FundefMutual.prepare_fundef_mutual congs t_eqnss thy
+      val Prep {goal, goalI, ...} = data
     in
 	thy |> ProofContext.init
-	    |> Proof.theorem_i PureThy.internalK NONE (fundef_afterqed congs mutual_info name data names atts) NONE ("", [])
+	    |> Proof.theorem_i PureThy.internalK NONE (fundef_afterqed congs mutual_info name data spec) NONE ("", [])
 	    [(("", []), [(goal, [])])]
             |> Proof.refine (Method.primitive_text (fn _ => goalI))
             |> Seq.hd
@@ -106,7 +109,7 @@
     let
 	val totality = hd (hd thmss)
 
-	val (FundefMResult {psimps, simple_pinducts, ... }, Mutual {parts, ...}, names, atts)
+	val (FundefMResult {psimps, simple_pinducts, ... }, Mutual {parts, ...}, spec)
 	  = the (get_fundef_data name thy)
 
 	val remove_domain_condition = full_simplify (HOL_basic_ss addsimps [totality, True_implies])
@@ -117,7 +120,7 @@
         val has_guards = exists ((fn (Const ("Trueprop", _) $ _) => false | _ => true) o prop_of) (flat tsimps)
         val allatts = if has_guards then [] else [RecfunCodegen.add NONE]
 
-        val thy = fold2 (add_simps "simps" allatts) (parts ~~ tsimps) (names ~~ atts) thy
+        val thy = fold2 (add_simps "simps" allatts) (parts ~~ tsimps) spec thy
 
 	val thy = Theory.add_path name thy
 		  
@@ -161,7 +164,7 @@
 	val data = the (get_fundef_data name thy)
                    handle Option.Option => raise ERROR ("No such function definition: " ^ name)
 
-	val (res as FundefMResult {termination, ...}, mutual, _, _) = data
+	val (res as FundefMResult {termination, ...}, mutual, _) = data
 	val goal = FundefTermination.mk_total_termination_goal data
     in
 	thy |> ProofContext.init
@@ -206,13 +209,29 @@
 
 local structure P = OuterParse and K = OuterKeyword in
 
+
+
+val star = Scan.one (fn t => (OuterLex.val_of t = "*"));
+
+
+val attribs_with_star = P.$$$ "[" |-- P.!!! ((P.list (star >> K NONE || P.attrib >> SOME)) 
+                                               >> (fn x => (map_filter I x, exists is_none x)))
+                              --| P.$$$ "]";
+
+val opt_attribs_with_star = Scan.optional attribs_with_star ([], false);
+
+fun opt_thm_name_star s =
+  Scan.optional ((P.name -- opt_attribs_with_star || (attribs_with_star >> pair "")) --| P.$$$ s) ("", ([], false));
+
+
 val function_decl =
-    Scan.repeat1 (P.opt_thm_name ":" -- P.prop);
+    Scan.repeat1 (opt_thm_name_star ":" -- P.prop);
 
 val functionP =
   OuterSyntax.command "function" "define general recursive functions" K.thy_goal
-    (P.and_list1 function_decl >> (fn eqnss =>
-      Toplevel.print o Toplevel.theory_to_proof (add_fundef eqnss)));
+  (((Scan.optional (P.$$$ "(" -- P.!!! (P.$$$ "pre" -- P.$$$ ")") >> K true) false) --    
+  P.and_list1 function_decl) >> (fn (prepr, eqnss) =>
+                                    Toplevel.print o Toplevel.theory_to_proof (add_fundef eqnss prepr)));
 
 val terminationP =
   OuterSyntax.command "termination" "prove termination of a recursive function" K.thy_goal
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/function_package/pattern_split.ML	Mon Jul 31 18:07:42 2006 +0200
@@ -0,0 +1,133 @@
+(*  Title:      HOL/Tools/function_package/fundef_package.ML
+    ID:         $Id$
+    Author:     Alexander Krauss, TU Muenchen
+
+A package for general recursive function definitions. 
+
+Automatic splitting of overlapping constructor patterns. This is a preprocessing step which 
+turns a specification with overlaps into an overlap-free specification.
+
+*)
+
+signature FUNDEF_SPLIT = 
+sig
+  val split_some_equations : ProofContext.context -> (('a * ('b * bool)) * Term.term) list 
+                             -> (('a * 'b) * Term.term list) list
+
+end
+
+structure FundefSplit : FUNDEF_SPLIT = 
+struct
+
+
+(* We use proof context for the variable management *)
+(* FIXME: no __ *)
+
+fun new_var ctx vs T = 
+    let 
+      val [v] = Variable.variant_frees ctx vs [("v", T)]
+    in
+      (Free v :: vs, Free v)
+    end
+
+fun saturate ctx vs t =
+    fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t))
+         (binder_types (fastype_of t)) (vs, t)
+
+
+(* This is copied from "fundef_datatype.ML" *)
+fun inst_constrs_of thy (T as Type (name, _)) =
+	map (fn (Cn,CT) => Envir.subst_TVars (Type.typ_match (Sign.tsig_of thy) (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
+	    (the (DatatypePackage.get_datatype_constrs thy name))
+  | inst_constrs_of thy t = (print t; sys_error "inst_constrs_of")
+
+
+
+fun pattern_subtract_subst ctx vs _ (Free v2) = []
+  | pattern_subtract_subst ctx vs (v as (Free (_, T))) t' =
+    let 
+      fun foo constr = 
+          let 
+            val (vs', t) = saturate ctx vs constr
+            val substs = pattern_subtract_subst ctx vs' t t'
+          in
+            map (cons (v, t)) substs
+          end
+    in
+      flat (map foo (inst_constrs_of (ProofContext.theory_of ctx) T))
+    end
+  | pattern_subtract_subst ctx vs t t' =
+    let
+      val (C, ps) = strip_comb t
+      val (C', qs) = strip_comb t'
+    in
+      if C = C'
+      then flat (map2 (pattern_subtract_subst ctx vs) ps qs)
+      else [[]]
+    end
+
+fun pattern_subtract_parallel ctx vs ps qs =
+    flat (map2 (pattern_subtract_subst ctx vs) ps qs)
+
+
+
+(* ps - qs *)
+fun pattern_subtract ctx eq2 eq1 =
+    let
+      val _ $ (_ $ lhs1 $ _) = eq1
+      val _ $ (_ $ lhs2 $ _) = eq2
+
+      val thy = ProofContext.theory_of ctx
+      val vs = term_frees eq1
+    in
+      map (fn sigma => Pattern.rewrite_term thy sigma [] eq1) (pattern_subtract_subst ctx vs lhs1 lhs2)
+    end
+
+
+(* ps - p' *)
+fun pattern_subtract_from_many ctx p'=
+    flat o map (pattern_subtract ctx p')
+
+(* in reverse order *)
+fun pattern_subtract_many ctx ps' =
+    fold_rev (pattern_subtract_from_many ctx) ps'
+
+
+
+fun split_all_equations ctx eqns =
+    let 
+      fun split_aux prev [] = []
+        | split_aux prev (e::es) = pattern_subtract_many ctx prev [e] @ split_aux (e::prev) es
+    in
+      split_aux [] eqns
+end
+
+
+
+fun split_some_equations ctx eqns =
+    let
+      fun split_aux prevs [] = []
+        | split_aux prev (((n, (att, true)), eq) :: es) = ((n, att), pattern_subtract_many ctx prev [eq])
+                                                          :: split_aux (eq :: prev) es
+        | split_aux prev (((n, (att, false)), eq) :: es) = ((n, att), [eq]) 
+                                                                :: split_aux (eq :: prev) es
+    in
+      split_aux [] eqns
+    end
+
+
+
+
+
+
+end
+
+
+
+
+
+
+
+
+
+
--- a/src/HOL/Tools/function_package/termination.ML	Mon Jul 31 18:05:40 2006 +0200
+++ b/src/HOL/Tools/function_package/termination.ML	Mon Jul 31 18:07:42 2006 +0200
@@ -20,7 +20,7 @@
 open FundefCommon
 open FundefAbbrev
 
-fun mk_total_termination_goal (FundefMResult {R, f, ... }, _, _, _) =
+fun mk_total_termination_goal (FundefMResult {R, f, ... }, _, _) =
     let
 	val domT = domain_type (fastype_of f)
 	val x = Free ("x", domT)
@@ -28,7 +28,7 @@
 	Trueprop (mk_mem (x, Const (acc_const_name, fastype_of R --> HOLogic.mk_setT domT) $ R))
     end
 
-fun mk_partial_termination_goal thy (FundefMResult {R, f, ... }, _, _, _) dom =
+fun mk_partial_termination_goal thy (FundefMResult {R, f, ... }, _, _) dom =
     let
 	val domT = domain_type (fastype_of f)
 	val D = Sign.simple_read_term thy (Logic.varifyT (HOLogic.mk_setT domT)) dom
--- a/src/HOL/ex/Fundefs.thy	Mon Jul 31 18:05:40 2006 +0200
+++ b/src/HOL/ex/Fundefs.thy	Mon Jul 31 18:07:42 2006 +0200
@@ -92,13 +92,11 @@
   show "wf ?R" ..
 
   fix n::nat assume "~ 100 < n" (* Inner call *)
-  thus "(n + 11, n) : ?R"
-    by simp arith
+  thus "(n + 11, n) : ?R" by simp 
 
   assume inner_trm: "n + 11 : f91_dom" (* Outer call *)
   with f91_estimate have "n + 11 < f91 (n + 11) + 11" .
-  with `~ 100 < n` show "(f91 (n + 11), n) : ?R"
-    by simp arith
+  with `~ 100 < n` show "(f91 (n + 11), n) : ?R" by simp 
 qed
 
 
@@ -108,7 +106,7 @@
 subsection {* Overlapping patterns *}
 
 text {* Currently, patterns must always be compatible with each other, since
-no automatich splitting takes place. But the following definition of
+no automatic splitting takes place. But the following definition of
 gcd is ok, although patterns overlap: *}
 
 consts gcd2 :: "nat \<Rightarrow> nat \<Rightarrow> nat"