some optimizations, cleanup
authorkrauss
Tue, 22 May 2007 17:25:26 +0200
changeset 23074 a53cb8ddb052
parent 23073 d810dc04b96d
child 23075 69e30a7e8880
some optimizations, cleanup
src/HOL/Library/sct.ML
src/HOL/Tools/function_package/lexicographic_order.ML
--- a/src/HOL/Library/sct.ML	Tue May 22 16:47:22 2007 +0200
+++ b/src/HOL/Library/sct.ML	Tue May 22 17:25:26 2007 +0200
@@ -149,10 +149,10 @@
 
 
 (* very primitive *)
-fun measures_of RD =
+fun measures_of thy RD =
     let
       val domT = range_type (fastype_of (fst (HOLogic.dest_prod (snd (HOLogic.dest_prod RD)))))
-      val measures = LexicographicOrder.mk_base_funs domT
+      val measures = LexicographicOrder.mk_base_funs thy domT
     in
       measures
     end
@@ -311,7 +311,7 @@
       val RDs = HOLogic.dest_list RDlist
       val n = length RDs
 
-      val Mss = map measures_of RDs
+      val Mss = map (measures_of thy) RDs
 
       val domT = domain_type (fastype_of (hd (hd Mss)))
 
--- a/src/HOL/Tools/function_package/lexicographic_order.ML	Tue May 22 16:47:22 2007 +0200
+++ b/src/HOL/Tools/function_package/lexicographic_order.ML	Tue May 22 17:25:26 2007 +0200
@@ -11,7 +11,7 @@
 
   (* exported for use by size-change termination prototype.
      FIXME: provide a common interface later *)
-  val mk_base_funs : typ -> term list
+  val mk_base_funs : theory -> typ -> term list
   (* exported for debugging *)
   val setup: theory -> theory
 end
@@ -19,108 +19,117 @@
 structure LexicographicOrder : LEXICOGRAPHIC_ORDER =
 struct
 
-(* Theory dependencies *)
-val measures = "List.measures"
-val wf_measures = thm "wf_measures"
-val measures_less = thm "measures_less"
-val measures_lesseq = thm "measures_lesseq"
-                      
-fun del_index (n, []) = []
-  | del_index (n, x :: xs) =
-    if n>0 then x :: del_index (n - 1, xs) else xs 
+(** General stuff **)
+
+fun mk_measures domT mfuns =
+    let val list = HOLogic.mk_list (domT --> HOLogic.natT) mfuns
+    in
+      Const (@{const_name "List.measures"}, fastype_of list --> (HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)))) $ list
+    end
+
+fun del_index n [] = []
+  | del_index n (x :: xs) =
+    if n > 0 then x :: del_index (n - 1) xs else xs 
 
 fun transpose ([]::_) = []
   | transpose xss = map hd xss :: transpose (map tl xss)
 
-fun mk_sum_case (f1, f2) =
-    case (fastype_of f1, fastype_of f2) of
-      (Type("fun", [A, B]), Type("fun", [C, D])) =>
-      if (B = D) then
-        Const("Datatype.sum.sum_case", (A --> B) --> (C --> D) --> Type("+", [A,C]) --> B) $ f1 $ f2
-      else raise TERM ("mk_sum_case: range type mismatch", [f1, f2]) 
-    | _ => raise TERM ("mk_sum_case", [f1, f2])
-                 
-fun dest_wf (Const ("Wellfounded_Recursion.wf", _) $ t) = t
-  | dest_wf t = raise TERM ("dest_wf", [t])
-                      
+(** Matrix cell datatype **)
+
 datatype cell = Less of thm | LessEq of thm | None of thm | False of thm;
          
-fun is_Less cell = case cell of (Less _) => true | _ => false  
+fun is_Less (Less _) = true
+  | is_Less _ = false
                                                         
-fun is_LessEq cell = case cell of (LessEq _) => true | _ => false
+fun is_LessEq (LessEq _) = true
+  | is_LessEq _ = false
                                                             
-fun thm_of_cell cell =
-    case cell of 
-      Less thm => thm
-    | LessEq thm => thm
-    | False thm => thm
-    | None thm => thm
+fun thm_of_cell (Less thm) = thm 
+  | thm_of_cell (LessEq thm) = thm 
+  | thm_of_cell (False thm) = thm 
+  | thm_of_cell (None thm) = thm 
                   
-fun mk_base_fun_bodys (t : term) (tt : typ) =
-    case tt of
-      Type("*", [ft, st]) => (mk_base_fun_bodys (Const("fst", tt --> ft) $ t) ft) @ (mk_base_fun_bodys (Const("snd", tt --> st) $ t) st)      
-    | _ => [(t, tt)]
-           
-fun mk_base_fun_header fulltyp (t, typ) =
-    Abs ("x", fulltyp, HOLogic.size_const typ $ t)
-         
-fun mk_base_funs (tt: typ) = 
-    mk_base_fun_bodys (Bound 0) tt |>
-                      map (mk_base_fun_header tt)
+fun pr_cell (Less _ ) = " <  "
+  | pr_cell (LessEq _) = " <= " 
+  | pr_cell (None _) = " N  "
+  | pr_cell (False _) = " F  "
+
+
+(** Generating Measure Functions **)
+
+fun mk_comp g f = 
+    let 
+      val fT = fastype_of f 
+      val gT as (Type ("fun", [xT, _])) = fastype_of g
+      val comp = Abs ("f", fT, Abs ("g", gT, Abs ("x", xT, Bound 2 $ (Bound 1 $ Bound 0))))
+    in
+      Envir.beta_norm (comp $ f $ g)
+    end
+
+fun mk_base_funs thy (T as Type("*", [fT, sT])) = (* products *)
+      map (mk_comp (Const ("fst", T --> fT))) (mk_base_funs thy fT)
+    @ map (mk_comp (Const ("snd", T --> sT))) (mk_base_funs thy sT)
 
-fun mk_funorder_funs (tt : typ) (one : bool) : Term.term list =
-    case tt of
-      Type("+", [ft, st]) => let
-                               val ft_funs = mk_funorder_funs ft
-                               val st_funs = mk_funorder_funs st 
-                             in
-                               (if one then 
-                                 (product (ft_funs true) (st_funs false)) @ (product (ft_funs false) (st_funs true))
-                               else
-                                 product (ft_funs false) (st_funs false)) 
-                                    |> map mk_sum_case
-                             end
-    | _ => if one then [Abs ("x", tt, HOLogic.Suc_zero)] else [Abs ("x", tt, HOLogic.zero)]
+  | mk_base_funs thy T = (* default: size function, if available *)
+    if Sorts.of_sort (Sign.classes_of thy) (T, [HOLogic.class_size])
+    then [HOLogic.size_const T]
+    else []
+
+fun mk_sum_case f1 f2 =
+    let
+      val Type ("fun", [fT, Q]) = fastype_of f1 
+      val Type ("fun", [sT, _]) = fastype_of f2
+    in
+      Const (@{const_name "Sum_Type.sum_case"}, (fT --> Q) --> (sT --> Q) --> Type("+", [fT, sT]) --> Q) $ f1 $ f2
+    end
+                 
+fun constant_0 T = Abs ("x", T, HOLogic.zero)
+fun constant_1 T = Abs ("x", T, HOLogic.Suc_zero)
 
-fun mk_ext_base_funs (tt : typ) =
-    case tt of
-      Type("+", [ft, st]) =>
-      product (mk_ext_base_funs ft) (mk_ext_base_funs st)
-              |> map mk_sum_case
-    | _ => mk_base_funs tt
+fun mk_funorder_funs (Type ("+", [fT, sT])) =
+      map (fn m => mk_sum_case m (constant_0 sT)) (mk_funorder_funs fT)
+    @ map (fn m => mk_sum_case (constant_0 fT) m) (mk_funorder_funs sT)
+  | mk_funorder_funs T = [ constant_1 T ] 
 
-fun mk_all_measure_funs (tt : typ) =
-    case tt of
-      Type("+", _) => (mk_ext_base_funs tt) @ (mk_funorder_funs tt true)
-    | _ => mk_base_funs tt
+fun mk_ext_base_funs thy (Type("+", [fT, sT])) =
+    product (mk_ext_base_funs thy fT) (mk_ext_base_funs thy sT)
+       |> map (uncurry mk_sum_case)
+  | mk_ext_base_funs thy T = mk_base_funs thy T
+
+fun mk_all_measure_funs thy (T as Type ("+", _)) =
+    mk_ext_base_funs thy T @ mk_funorder_funs T
+  | mk_all_measure_funs thy T = mk_base_funs thy T
+
+
+(** Proof attempts to build the matrix **)
            
 fun dest_term (t : term) =
     let
-      val (vars, prop) = (FundefLib.dest_all_all t)
+      val (vars, prop) = FundefLib.dest_all_all t
       val prems = Logic.strip_imp_prems prop
-      val (tuple, rel) = Logic.strip_imp_concl prop
+      val (lhs, rhs) = Logic.strip_imp_concl prop
                          |> HOLogic.dest_Trueprop 
-                         |> HOLogic.dest_mem
-      val (lhs, rhs) = HOLogic.dest_prod tuple
+                         |> HOLogic.dest_mem |> fst
+                         |> HOLogic.dest_prod 
     in
-      (vars, prems, lhs, rhs, rel)
+      (vars, prems, lhs, rhs)
     end
     
 fun mk_goal (vars, prems, lhs, rhs) rel =
     let 
       val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop
     in  
-      Logic.list_implies (prems, concl) |>
-                         fold_rev FundefLib.mk_forall vars
+      Logic.list_implies (prems, concl) 
+        |> fold_rev FundefLib.mk_forall vars
     end
     
 fun prove (thy: theory) solve_tac (t: term) =
     cterm_of thy t |> Goal.init 
     |> SINGLE solve_tac |> the
     
-fun mk_cell (thy : theory) solve_tac (vars, prems) (lhs, rhs) = 
+fun mk_cell (thy : theory) solve_tac (vars, prems, lhs, rhs) mfun = 
     let 
-      val goals = mk_goal (vars, prems, lhs, rhs) 
+      val goals = mk_goal (vars, prems, mfun $ lhs, mfun $ rhs) 
       val less_thm = goals "Orderings.ord_class.less" |> prove thy solve_tac
     in
       if Thm.no_prems less_thm then
@@ -136,28 +145,15 @@
             else None lesseq_thm
         end
     end
-    
-fun mk_row (thy: theory) solve_tac measure_funs (t : term) =
-    let
-      val (vars, prems, lhs, rhs, _) = dest_term t
-      val lhs_list = map (fn x => x $ lhs) measure_funs
-      val rhs_list = map (fn x => x $ rhs) measure_funs
-    in
-      map (mk_cell thy solve_tac (vars, prems)) (lhs_list ~~ rhs_list)
-    end
-    
-fun pr_cell cell = case cell of Less _ => " <  " 
-                              | LessEq _ => " <= " 
-                              | None _ => " N  "
-                              | False _ => " F  "
+
+
+(** Search algorithms **)
 
-fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table))
-
-fun check_col ls = (forall (fn c => is_Less c orelse is_LessEq c) ls) andalso not (forall (fn c => is_LessEq c) ls)
+fun check_col ls = forall (fn c => is_Less c orelse is_LessEq c) ls andalso not (forall (is_LessEq) ls)
 
-fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (curry del_index col)
+fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (del_index col)
 
-fun transform_order col order = map (fn x => if x>=col then x+1 else x) order
+fun transform_order col order = map (fn x => if x >= col then x + 1 else x) order
       
 (* simple depth-first search algorithm for the table *)
 fun search_table table =
@@ -173,15 +169,15 @@
              val order_opt = (table, col) |-> transform_table |> search_table
            in case order_opt of
                 NONE => NONE
-              | SOME order =>SOME (col::(transform_order col order))
+              | SOME order =>SOME (col :: transform_order col order)
            end
       end
 
 (* find all positions of elements in a list *) 
-fun find_index_list pred =
-  let fun find _ [] = []
-        | find n (x :: xs) = if pred x then n::(find (n + 1) xs) else find (n + 1) xs;
-  in find 0 end;
+fun find_index_list P =
+    let fun find _ [] = []
+          | find n (x :: xs) = if P x then n :: find (n + 1) xs else find (n + 1) xs
+    in find 0 end
 
 (* simple breadth-first search algorithm for the table *) 
 fun bfs_search_table nodes =
@@ -191,7 +187,7 @@
 	val (order, table) = node
       in
         case table of
-          [] => SOME (foldr (fn (c, order) => c::transform_order c order) [] (rev order))
+          [] => SOME (foldr (fn (c, order) => c :: transform_order c order) [] (rev order))
         | _ => let
 	    val cols = find_index_list (check_col) (transpose table)
           in
@@ -199,7 +195,7 @@
 	      [] => NONE
             | _ => let 
               val newtables = map (transform_table table) cols 
-              val neworders = map (fn c => c::order) cols
+              val neworders = map (fn c => c :: order) cols
               val newnodes = neworders ~~ newtables
             in
               bfs_search_table (rnodes @ newnodes)
@@ -209,96 +205,77 @@
 
 fun nsearch_table table = bfs_search_table [([], table)] 	       
 
-fun prove_row row (st : thm) =
-    case row of
-      [] => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun prove_row (row is empty)" 
-    | cell::tail =>
-      case cell of
-        Less less_thm =>
-        let
-          val next_thm = st |> SINGLE (rtac measures_less 1) |> the
-        in
-          implies_elim next_thm less_thm
-        end
-      | LessEq lesseq_thm =>
-        let
-          val next_thm = st |> SINGLE (rtac measures_lesseq 1) |> the
-        in
-          implies_elim next_thm lesseq_thm 
-          |> prove_row tail
-        end
-      | _ => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun prove_row (Only expecting Less or LessEq)"
-             
+(** Proof Reconstruction **)
+
+(* prove row :: cell list -> tactic *)
+fun prove_row (Less less_thm :: _) =
+    (rtac @{thm "measures_less"} 1)
+    THEN PRIMITIVE (flip implies_elim less_thm)
+  | prove_row (LessEq lesseq_thm :: tail) =
+    (rtac @{thm "measures_lesseq"} 1)
+    THEN PRIMITIVE (flip implies_elim lesseq_thm)
+    THEN prove_row tail
+  | prove_row _ = sys_error "lexicographic_order"
+
+
+(** Error reporting **)
+
+fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table))
+       
 fun pr_unprovable_subgoals table =
-    filter (fn x => not (is_Less x) andalso not (is_LessEq x)) (flat table)
+    filter_out (fn x => is_Less x orelse is_LessEq x) (flat table)
     |> map ((fn th => Pretty.string_of (Pretty.chunks (Display.pretty_goals (Thm.nprems_of th) th))) o thm_of_cell)
     
-fun pr_goal thy t i = 
-    let
-      val (_, prems, lhs, rhs, _) = dest_term t 
-      val prterm = string_of_cterm o (cterm_of thy)
+fun no_order_msg table thy tl measure_funs =  
+    let 
+      fun pr_fun t i = string_of_int i ^ ") " ^ string_of_cterm (cterm_of thy t)
+
+      fun pr_goal t i = 
+          let
+            val (_, _, lhs, rhs) = dest_term t 
+            val prterm = string_of_cterm o (cterm_of thy)
+          in (* also show prems? *)
+               i ^ ") " ^ prterm lhs ^ " '<' " ^ prterm rhs
+          end
+
+      val gc = map (fn i => chr (i + 96)) (1 upto length table)
+      val mc = 1 upto length measure_funs
+      val tstr = "   " ^ concat (map (enclose " " " " o string_of_int) mc)
+                 :: map2 (fn r => fn i => i ^ ": " ^ concat (map pr_cell r)) table gc
+      val gstr = "Goals:" :: map2 pr_goal tl gc
+      val mstr = "Measures:" :: map2 pr_fun measure_funs mc
+      val ustr = "Unfinished subgoals:" :: pr_unprovable_subgoals table
     in
-      (* also show prems? *)
-        i ^ ") " ^ (prterm lhs) ^ " '<' " ^ (prterm rhs) 
-    end
-    
-fun pr_fun thy t i =
-    (string_of_int i) ^ ") " ^ (string_of_cterm (cterm_of thy t))
-                                             
-(* fun pr_err: prints the table if tactic failed *)
-fun pr_err table thy tl measure_funs =  
-    let 
-      val gc = map (fn i => chr (i + 96)) (1 upto (length table))
-      val mc = 1 upto (length measure_funs)
-      val tstr = ("   " ^ (concat (map (fn i => " " ^ (string_of_int i) ^ "  ") mc))) ::
-                 (map2 (fn r => fn i => i ^ ": " ^ (concat (map pr_cell r))) table gc)
-      val gstr = ("Goals:"::(map2 (pr_goal thy) tl gc))
-      val mstr = ("Measures:"::(map2 (pr_fun thy) measure_funs mc))   
-      val ustr = ("Unfinished subgoals:"::(pr_unprovable_subgoals table))
-    in
-      tstr @ gstr @ mstr @ ustr
+      cat_lines ("Could not find lexicographic termination order:" :: tstr @ gstr @ mstr @ ustr)
     end
       
-(* the main function: create table, search table, create relation,
-   and prove the subgoals *)  
+(** The Main Function **)
 fun lexicographic_order_tac ctxt solve_tac (st: thm) = 
     let
       val thy = theory_of_thm st
-      val premlist = prems_of st
-    in
-      case premlist of 
-            [] => error "invalid number of subgoals for this tactic - expecting at least 1 subgoal" 
-          | (wfR::tl) => let
-    val trueprop $ (wf $ rel) = wfR
-    val crel = cterm_of thy rel
-    val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel))
-    val measure_funs = mk_all_measure_funs domT
-    val _ = writeln "Creating table"
-    val table = map (mk_row thy solve_tac measure_funs) tl
-    val _ = writeln "Searching for lexicographic order"
-    (* val _ = pr_table table *)
-    val possible_order = search_table table
-      in
-    case possible_order of 
-        NONE => error (cat_lines ("Could not find lexicographic termination order:"::(pr_err table thy tl measure_funs)))
-      | SOME order  => let
+      val ((trueprop $ (wf $ rel)) :: tl) = prems_of st
+
+      val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel))
+
+      val measure_funs = mk_all_measure_funs thy domT (* 1: generate measures *)
+                         
+      (* 2: create table *)
+      val table = map (fn t => map (mk_cell thy solve_tac (dest_term t)) measure_funs) tl
+
+      val order = the (search_table table) (* 3: search table *)
+          handle Option => error (no_order_msg table thy tl measure_funs)
+
       val clean_table = map (fn x => map (nth x) order) table
-      val funs = map (nth measure_funs) order
-      val list = HOLogic.mk_list (domT --> HOLogic.natT) funs
-      val relterm = Const(measures, (fastype_of list) --> (fastype_of rel)) $ list
-      val crelterm = cterm_of thy relterm
-      val _ = writeln ("Instantiating R with " ^ (string_of_cterm crelterm))
-      val _ = writeln "Proving subgoals"
-        in
-      st |> cterm_instantiate [(crel, crelterm)]
-        |> SINGLE (rtac wf_measures 1) |> the
-        |> fold prove_row clean_table
-        |> Seq.single
-                    end
-            end
+
+      val relation = mk_measures domT (map (nth measure_funs) order)
+      val _ = writeln ("Found termination order: " ^ quote (ProofContext.string_of_term ctxt relation))
+
+    in (* 4: proof reconstruction *)
+      st |> (PRIMITIVE (cterm_instantiate [(cterm_of thy rel, cterm_of thy relation)])
+              THEN rtac @{thm "wf_measures"} 1
+              THEN EVERY (map prove_row clean_table))
     end
 
-(* FIXME goal addressing ?? *)
 fun lexicographic_order thms ctxt = Method.SIMPLE_METHOD (FundefCommon.apply_termination_rule ctxt 1 
                                                          THEN lexicographic_order_tac ctxt (auto_tac (local_clasimpset_of ctxt)))