improving code quality thanks to Florian's code review
authorbulwahn
Thu, 12 Nov 2009 09:11:41 +0100
changeset 33629 5f35cf91c6a4
parent 33628 ed2111a5c3ed
child 33630 68e058d061f5
improving code quality thanks to Florian's code review
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Nov 12 09:11:36 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Nov 12 09:11:41 2009 +0100
@@ -141,12 +141,12 @@
   let
     fun split_tuple' _ _ [] = ([], [])
     | split_tuple' is i (t::ts) =
-      (if i mem is then apfst else apsnd) (cons t)
+      (if member (op =) is i then apfst else apsnd) (cons t)
         (split_tuple' is (i+1) ts)
     fun split_tuple is t = split_tuple' is 1 (strip_tuple t)
     fun split_smode' _ _ [] = ([], [])
       | split_smode' smode i (t::ts) =
-        (if i mem (map fst smode) then
+        (if member (op =) (map fst smode) i then
           case (the (AList.lookup (op =) smode i)) of
             NONE => apfst (cons t)
             | SOME is =>
@@ -461,7 +461,7 @@
     val (paramTs, _) = chop nparams (binder_types (fastype_of outp_pred))
     val (param_names, ctxt'') = Variable.variant_fixes (map (fn i => "p" ^ (string_of_int i))
       (1 upto nparams)) ctxt'
-    val params = map Free (param_names ~~ paramTs)
+    val params = map2 (curry Free) param_names paramTs
     in (((outp_pred, params), []), ctxt') end
   | import_intros inp_pred nparams (th :: ths) ctxt =
     let
@@ -508,7 +508,7 @@
       let
         val (_, (_, args)) = strip_intro_concl nparams intro
         val prems = Logic.strip_imp_prems intro
-        val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args)
+        val eqprems = map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) argvs args
         val frees = (fold o fold_aterms)
           (fn t as Free _ =>
               if member (op aconv) params t then I else insert (op aconv) t
@@ -564,25 +564,6 @@
     Thm.equal_elim eq elimrule |> singleton (Variable.export ctxt' ctxt)
   end;
 
-(* special case: predicate with no introduction rule *)
-fun noclause thy predname elim = let
-  val T = (Logic.unvarifyT o Sign.the_const_type thy) predname
-  val Ts = binder_types T
-  val names = Name.variant_list []
-        (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
-  val vs = map2 (curry Free) names Ts
-  val clausehd = HOLogic.mk_Trueprop (list_comb (Const (predname, T), vs))
-  val intro_t = Logic.mk_implies (@{prop False}, clausehd)
-  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
-  val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
-  val intro = Goal.prove (ProofContext.init thy) names [] intro_t
-        (fn _ => etac @{thm FalseE} 1)
-  val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
-        (fn _ => etac elim 1) 
-in
-  ([intro], elim)
-end
-
 fun expand_tuples_elim th = th
 
 (* updaters *)
@@ -614,7 +595,6 @@
         val elim =
           (Drule.standard o Skip_Proof.make_thm thy)
           (mk_casesrule (ProofContext.init thy) pred nparams intros)
-        val (intros, elim) = (*if null intros then noclause thy name elim else*) (intros, elim)
       in
         mk_pred_data ((intros, SOME elim, nparams), no_compilation)
       end
@@ -650,9 +630,10 @@
   end;
 *)
 
-fun add_intro thm thy = let
-   val (name, T) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
-   fun cons_intro gr =
+fun add_intro thm thy =
+  let
+    val (name, T) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
+    fun cons_intro gr =
      case try (Graph.get_node gr) name of
        SOME pred_data => Graph.map_node name (map_pred_data
          (apfst (fn (intros, elim, nparams) => (intros @ [thm], elim, nparams)))) gr
@@ -663,13 +644,15 @@
        in Graph.new_node (name, mk_pred_data (([thm], NONE, nparams), no_compilation)) gr end;
   in PredData.map cons_intro thy end
 
-fun set_elim thm = let
+fun set_elim thm =
+  let
     val (name, _) = dest_Const (fst 
       (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
     fun set (intros, _, nparams) = (intros, SOME thm, nparams)  
   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
 
-fun set_nparams name nparams = let
+fun set_nparams name nparams =
+  let
     fun set (intros, elim, _ ) = (intros, elim, nparams) 
   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
 
@@ -1152,10 +1135,6 @@
   | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
   let
     val Ts = binder_types T
-    (*val argnames = Name.variant_list names
-        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-    val args = map Free (argnames ~~ Ts)
-    val (inargs, outargs) = split_smode mode args*)
     fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
       | mk_split_lambda [x] t = lambda x t
       | mk_split_lambda xs t =
@@ -1182,7 +1161,7 @@
               val vnames = Name.variant_list names
                 (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
                   (1 upto length Ts))
-              val args = map Free (vnames ~~ Ts)
+              val args = map2 (curry Free) vnames Ts
               fun split_args (i, arg) (ins, outs) =
                 if member (op =) pis i then
                   (arg::ins, outs)
@@ -1271,18 +1250,18 @@
   let
     val names = Term.add_free_names t [];
     val Ts = binder_types (fastype_of t);
-    val vs = map Free
-      (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
+    val vs = map2 (curry Free)
+      (Name.variant_list names (replicate (length Ts) "x")) Ts
   in
     fold_rev lambda vs (f (list_comb (t, vs)))
   end;
 
-fun compile_param compilation_modifiers compfuns thy (NONE, t) = t
-  | compile_param compilation_modifiers compfuns thy (m as SOME (Mode (mode, _, ms)), t) =
+fun compile_param compilation_modifiers compfuns thy NONE t = t
+  | compile_param compilation_modifiers compfuns thy (m as SOME (Mode (mode, _, ms))) t =
    let
      val (f, args) = strip_comb (Envir.eta_contract t)
      val (params, args') = chop (length ms) args
-     val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
+     val params' = map2 (compile_param compilation_modifiers compfuns thy) ms params
      val f' =
        case f of
          Const (name, T) => Const (Comp_Mod.function_name_of compilation_modifiers thy name mode,
@@ -1298,7 +1277,7 @@
   case strip_comb t of
     (Const (name, T), params) =>
        let
-         val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
+         val params' = map2 (compile_param compilation_modifiers compfuns thy) ms params
          val name' = Comp_Mod.function_name_of compilation_modifiers thy name mode
          val T' = Comp_Mod.funT_of compilation_modifiers compfuns mode T
        in
@@ -1410,8 +1389,12 @@
                val vnames = Name.variant_list (all_vs @ param_vs)
                 (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
                   pis)
-             in if null pis then []
-               else [HOLogic.mk_tuple (map Free (vnames ~~ map (fn j => nth Ts (j - 1)) pis))] end
+             in
+               if null pis then
+                 []
+               else
+                 [HOLogic.mk_tuple (map2 (curry Free) vnames (map (fn j => nth Ts (j - 1)) pis))]
+             end
     val in_ts = maps mk_input_term (snd mode)
     val params = map2 (fn s => fn T => Free (s, T)) param_vs Ts1'
     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
@@ -1454,7 +1437,7 @@
     map2 (fn NONE => I | SOME is => funT_of (PredicateCompFuns.compfuns) ([], is)) iss Ts1
   val param_names = Name.variant_list []
     (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1)));
-  val params = map Free (param_names ~~ Ts1')
+  val params = map2 (curry Free) param_names Ts1'
   fun mk_args (i, T) argnames =
     let
       val vname = Name.variant (param_names @ argnames) ("x" ^ string_of_int (length Ts1' + i))
@@ -1472,17 +1455,17 @@
               val vnames = Name.variant_list (param_names @ argnames)
                 (map (fn j => "x" ^ string_of_int (length Ts1' + i) ^ "p" ^ string_of_int j)
                   (1 upto (length Ts)))
-             in (HOLogic.mk_tuple (map Free (vnames ~~ Ts)), vnames  @ argnames) end
+             in (HOLogic.mk_tuple (map2 (curry Free) vnames Ts), vnames @ argnames) end
     end
   val (args, argnames) = fold_map mk_args (1 upto (length Ts2) ~~ Ts2) []
   val (inargs, outargs) = split_smode is args
   val param_names' = Name.variant_list (param_names @ argnames)
     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
-  val param_vs = map Free (param_names' ~~ Ts1)
+  val param_vs = map2 (curry Free) param_names' Ts1
   val (params', names) = fold_map (mk_Eval_of []) ((params ~~ Ts1) ~~ iss) []
   val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
   val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
-  val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
+  val param_eqs = map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) param_vs params'
   val funargs = params @ inargs
   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
                   if null outargs then Free("y", HOLogic.unitT) else HOLogic.mk_tuple outargs))
@@ -1515,7 +1498,7 @@
   let
     fun split_tuple' _ _ [] = ([], [])
       | split_tuple' is i (T::Ts) =
-      (if i mem is then apfst else apsnd) (cons T)
+      (if member (op =) is i then apfst else apsnd) (cons T)
         (split_tuple' is (i+1) Ts)
   in
     split_tuple' is 1 (HOLogic.strip_tupleT T)
@@ -1528,7 +1511,8 @@
     fun mk_proj i j t =
       (if i = j then I else HOLogic.mk_fst)
         (funpow (i - 1) HOLogic.mk_snd t)
-    fun mk_arg' i (si, so) = if i mem pis then
+    fun mk_arg' i (si, so) =
+      if member (op =) pis i then
         (mk_proj si ni xin, (si+1, so))
       else
         (mk_proj so (n - ni) xout, (si, so+1))
@@ -1551,16 +1535,9 @@
       val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (HOLogic.mk_tupleT Us2))
       val names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-      (* old *)
-      (*
-      val xs = map Free (names ~~ (Ts1' @ Ts2))
-      val (xparams, xargs) = chop (length iss) xs
-      val (xins, xouts) = split_smode is xargs
-      *)
-      (* new *)
       val param_names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1')))
-      val xparams = map Free (param_names ~~ Ts1')
+      val xparams = map2 (curry Free) param_names Ts1'
       fun mk_vars (i, T) names =
         let
           val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
@@ -1650,8 +1627,8 @@
 
 (* MAJOR FIXME:  prove_params should be simple
  - different form of introrule for parameters ? *)
-fun prove_param thy (NONE, t) = TRY (rtac @{thm refl} 1)
-  | prove_param thy (m as SOME (Mode (mode, is, ms)), t) =
+fun prove_param thy NONE t = TRY (rtac @{thm refl} 1)
+  | prove_param thy (m as SOME (Mode (mode, is, ms))) t =
   let
     val  (f, args) = strip_comb (Envir.eta_contract t)
     val (params, _) = chop (length ms) args
@@ -1668,7 +1645,7 @@
     THEN print_tac "prove_param"
     THEN f_tac
     THEN print_tac "after simplification in prove_args"
-    THEN (EVERY (map (prove_param thy) (ms ~~ params)))
+    THEN (EVERY (map2 (prove_param thy) ms params))
     THEN (REPEAT_DETERM (atac 1))
   end
 
@@ -1689,7 +1666,7 @@
         (* work with parameter arguments *)
         THEN (atac 1)
         THEN (print_tac "parameter goal")
-        THEN (EVERY (map (prove_param thy) (ms ~~ args1)))
+        THEN (EVERY (map2 (prove_param thy) ms args1))
         THEN (REPEAT_DETERM (atac 1))
       end
   | _ => rtac @{thm bindI} 1
@@ -1782,8 +1759,7 @@
                   THEN rtac @{thm not_predI} 1
                   THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
                   THEN (REPEAT_DETERM (atac 1))
-                  (* FIXME: work with parameter arguments *)
-                  THEN (EVERY (map (prove_param thy) (param_modes ~~ params)))
+                  THEN (EVERY (map2 (prove_param thy) param_modes params))
                 else
                   rtac @{thm not_predI'} 1)
                   THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
@@ -1866,8 +1842,9 @@
 *)
 (* TODO: remove function *)
 
-fun prove_param2 thy (NONE, t) = all_tac 
-  | prove_param2 thy (m as SOME (Mode (mode, is, ms)), t) = let
+fun prove_param2 thy NONE t = all_tac 
+  | prove_param2 thy (m as SOME (Mode (mode, is, ms))) t =
+  let
     val  (f, args) = strip_comb (Envir.eta_contract t)
     val (params, _) = chop (length ms) args
     val f_tac = case f of
@@ -1876,11 +1853,11 @@
            :: @{thm "Product_Type.split_conv"}::[])) 1
       | Free _ => all_tac
       | _ => error "prove_param2: illegal parameter term"
-  in  
+  in
     print_tac "before simplification in prove_args:"
     THEN f_tac
     THEN print_tac "after simplification in prove_args"
-    THEN (EVERY (map (prove_param2 thy) (ms ~~ params)))
+    THEN (EVERY (map2 (prove_param2 thy) ms params))
   end
 
 
@@ -1894,7 +1871,7 @@
         (prop_of (predfun_elim_of thy name mode))))
       THEN (etac (predfun_elim_of thy name mode) 1)
       THEN print_tac "prove_expr2"
-      THEN (EVERY (map (prove_param2 thy) (ms ~~ args)))
+      THEN (EVERY (map2 (prove_param2 thy) ms args))
       THEN print_tac "finished prove_expr2"      
     | _ => etac @{thm bindE} 1)
     
@@ -1965,7 +1942,7 @@
                   [predfun_definition_of thy (the name) (iss, is)]) 1
                 THEN etac @{thm not_predE} 1
                 THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
-                THEN (EVERY (map (prove_param2 thy) (param_modes ~~ params)))
+                THEN (EVERY (map2 (prove_param2 thy) param_modes params))
               else
                 etac @{thm not_predE'} 1)
             THEN rec_tac
@@ -2056,7 +2033,7 @@
 
 fun dest_prem thy params t =
   (case strip_comb t of
-    (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
+    (v as Free _, ts) => if member (op =) params v then Prem (ts, v) else Sidecond t
   | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem thy params t of
       Prem (ts, t) => Negprem (ts, t)
     | Negprem _ => error ("Double negation not allowed in premise: " ^
@@ -2086,7 +2063,7 @@
             val (paramTs, _) = chop nparams (binder_types (snd (hd preds)))
             val param_names = Name.variant_list [] (map (fn i => "p" ^ string_of_int i)
               (1 upto length paramTs))
-          in map Free (param_names ~~ paramTs) end
+          in map2 (curry Free) param_names paramTs end
       | intr :: _ => fst (chop nparams
         (snd (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr)))))
     val param_vs = maps term_vs params
@@ -2142,14 +2119,15 @@
   let
     val concl = Logic.strip_imp_concl (prop_of intro)
     val (p, args) = strip_comb (HOLogic.dest_Trueprop concl)
-    val params = List.take (args, nparams_of thy (fst (dest_Const p)))
+    val params = fst (chop (nparams_of thy (fst (dest_Const p))) args)
     fun check_arg arg = case HOLogic.strip_tupleT (fastype_of arg) of
       (Ts as _ :: _ :: _) =>
-        if (length (HOLogic.strip_tuple arg) = length Ts) then true
+        if length (HOLogic.strip_tuple arg) = length Ts then
+          true
         else
-        error ("Format of introduction rule is invalid: tuples must be expanded:"
-        ^ (Syntax.string_of_term_global thy arg) ^ " in " ^
-        (Display.string_of_thm_global thy intro)) 
+          error ("Format of introduction rule is invalid: tuples must be expanded:"
+          ^ (Syntax.string_of_term_global thy arg) ^ " in " ^
+          (Display.string_of_thm_global thy intro)) 
       | _ => true
     val prems = Logic.strip_imp_prems (prop_of intro)
     fun check_prem (Prem (args, _)) = forall check_arg args
@@ -2183,7 +2161,7 @@
 
 fun add_code_equations thy nparams preds result_thmss =
   let
-    fun add_code_equation ((predname, T), (pred, result_thms)) =
+    fun add_code_equation (predname, T) (pred, result_thms) =
       let
         val full_mode = (replicate nparams NONE,
           map (rpair NONE) (1 upto (length (binder_types T) - nparams)))
@@ -2193,7 +2171,7 @@
             val Ts = binder_types T
             val arg_names = Name.variant_list []
               (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
-            val args = map Free (arg_names ~~ Ts)
+            val args = map2 (curry Free) arg_names Ts
             val predfun = Const (predfun_name_of thy predname full_mode,
               Ts ---> PredicateCompFuns.mk_predT @{typ unit})
             val rhs = PredicateCompFuns.mk_Eval (list_comb (predfun, args), @{term "Unity"})
@@ -2210,7 +2188,7 @@
           (pred, result_thms)
       end
   in
-    map add_code_equation (preds ~~ result_thmss)
+    map2 add_code_equation preds result_thmss
   end
 
 (** main function of predicate compiler **)