encapsulating records with datatype constructors and adding type annotations to make SML/NJ happy
authorbulwahn
Thu, 29 Oct 2009 13:59:37 +0100
changeset 33330 d6eb7f19bfc6
parent 33329 b129e4c476d6
child 33331 d8bfa9564a52
encapsulating records with datatype constructors and adding type annotations to make SML/NJ happy
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Wed Oct 28 12:29:03 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Thu Oct 29 13:59:37 2009 +0100
@@ -170,7 +170,8 @@
     (fn (i, Argmode s) => if s = "i" then [(i + 1, NONE)] else []
       | (i, Argmode_Tuple ss) => [(i + 1, SOME (mk_numeral_mode ss))]) m))
 
-val parse_smode = (P.$$$ "[" |-- P.enum "," P.nat --| P.$$$ "]") >> map (rpair NONE)
+val parse_smode = (P.$$$ "[" |-- P.enum "," P.nat --| P.$$$ "]")
+  >> map (rpair (NONE : int list option))
 
 fun gen_parse_mode smode_parser =
   (Scan.optional
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Oct 28 12:29:03 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Oct 29 13:59:37 2009 +0100
@@ -156,16 +156,16 @@
         (split_smode' smode (i+1) ts)
   in split_smode' smode 1 ts end
 
-val split_smode = gen_split_smode (HOLogic.mk_tuple, HOLogic.strip_tuple)   
-val split_smodeT = gen_split_smode (HOLogic.mk_tupleT, HOLogic.strip_tupleT)
+fun split_smode smode ts = gen_split_smode (HOLogic.mk_tuple, HOLogic.strip_tuple) smode ts
+fun split_smodeT smode ts = gen_split_smode (HOLogic.mk_tupleT, HOLogic.strip_tupleT) smode ts
 
 fun gen_split_mode split_smode (iss, is) ts =
   let
     val (t1, t2) = chop (length iss) ts 
   in (t1, split_smode is t2) end
 
-val split_mode = gen_split_mode split_smode
-val split_modeT = gen_split_mode split_smodeT
+fun split_mode (iss, is) ts = gen_split_mode split_smode (iss, is) ts
+fun split_modeT (iss, is) ts = gen_split_mode split_smodeT (iss, is) ts
 
 datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term
   | Generator of (string * typ);
@@ -1166,6 +1166,28 @@
     (t, names)
   end;
 
+structure Comp_Mod =
+struct
+
+datatype comp_modifiers = Comp_Modifiers of
+{
+  const_name_of : theory -> string -> Predicate_Compile_Aux.mode -> string,
+  funT_of : compilation_funs -> mode -> typ -> typ,
+  additional_arguments : string list -> term list,
+  wrap_compilation : compilation_funs -> string -> typ -> mode -> term list -> term -> term,
+  transform_additional_arguments : indprem -> term list -> term list
+}
+
+fun dest_comp_modifiers (Comp_Modifiers c) = c
+
+val const_name_of = #const_name_of o dest_comp_modifiers
+val funT_of = #funT_of o dest_comp_modifiers
+val additional_arguments = #additional_arguments o dest_comp_modifiers
+val wrap_compilation = #wrap_compilation o dest_comp_modifiers
+val transform_additional_arguments = #transform_additional_arguments o dest_comp_modifiers
+
+end;
+
 fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg = 
   let
     fun map_params (t as Free (f, T)) =
@@ -1173,7 +1195,7 @@
         case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
           SOME is =>
             let
-              val T' = #funT_of compilation_modifiers compfuns ([], is) T
+              val T' = Comp_Mod.funT_of compilation_modifiers compfuns ([], is) T
             in fst (mk_Eval_of additional_arguments ((Free (f, T'), T), SOME is) []) end
         | NONE => t
       else t
@@ -1223,9 +1245,9 @@
      val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
      val f' =
        case f of
-         Const (name, T) => Const (#const_name_of compilation_modifiers thy name mode,
-           #funT_of compilation_modifiers compfuns mode T)
-       | Free (name, T) => Free (name, #funT_of compilation_modifiers compfuns mode T)
+         Const (name, T) => Const (Comp_Mod.const_name_of compilation_modifiers thy name mode,
+           Comp_Mod.funT_of compilation_modifiers compfuns mode T)
+       | Free (name, T) => Free (name, Comp_Mod.funT_of compilation_modifiers compfuns mode T)
        | _ => error ("PredicateCompiler: illegal parameter term")
    in
      list_comb (f', params' @ args')
@@ -1237,13 +1259,13 @@
        let
          val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
            (*val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of*)
-         val name' = #const_name_of compilation_modifiers thy name mode
-         val T' = #funT_of compilation_modifiers compfuns mode T
+         val name' = Comp_Mod.const_name_of compilation_modifiers thy name mode
+         val T' = Comp_Mod.funT_of compilation_modifiers compfuns mode T
        in
          (list_comb (Const (name', T'), params' @ inargs @ additional_arguments))
        end
   | (Free (name, T), params) =>
-    list_comb (Free (name, #funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
+    list_comb (Free (name, Comp_Mod.funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
 
 fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments (iss, is) inp (ts, moded_ps) =
   let
@@ -1277,7 +1299,7 @@
             val (out_ts'', (names'', constr_vs')) = fold_map distinct_v
               out_ts' ((names', map (rpair []) vs))
             val additional_arguments' =
-              #transform_additional_arguments compilation_modifiers p additional_arguments
+              Comp_Mod.transform_additional_arguments compilation_modifiers p additional_arguments
             val (compiled_clause, rest) = case p of
                Prem (us, t) =>
                  let
@@ -1331,7 +1353,7 @@
     val (Ts1, Ts2) = chop (length (fst mode)) (binder_types T)
     val (Us1, Us2) = split_smodeT (snd mode) Ts2
     val Ts1' =
-      map2 (fn NONE => I | SOME is => #funT_of compilation_modifiers compfuns ([], is)) (fst mode) Ts1
+      map2 (fn NONE => I | SOME is => Comp_Mod.funT_of compilation_modifiers compfuns ([], is)) (fst mode) Ts1
     fun mk_input_term (i, NONE) =
         [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
       | mk_input_term (i, SOME pis) = case HOLogic.strip_tupleT (nth Ts2 (i - 1)) of
@@ -1345,17 +1367,17 @@
                else [HOLogic.mk_tuple (map Free (vnames ~~ map (fn j => nth Ts (j - 1)) pis))] end
     val in_ts = maps mk_input_term (snd mode)
     val params = map2 (fn s => fn T => Free (s, T)) param_vs Ts1'
-    val additional_arguments = #additional_arguments compilation_modifiers (all_vs @ param_vs)
+    val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers (all_vs @ param_vs)
     val cl_ts =
       map (compile_clause compilation_modifiers compfuns
         thy all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts)) moded_cls;
-    val compilation = #wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
+    val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
       (if null cl_ts then
         mk_bot compfuns (HOLogic.mk_tupleT Us2)
       else foldr1 (mk_sup compfuns) cl_ts)
     val fun_const =
-      Const (#const_name_of compilation_modifiers thy s mode,
-        #funT_of compilation_modifiers compfuns mode T)
+      Const (Comp_Mod.const_name_of compilation_modifiers thy s mode,
+        Comp_Mod.funT_of compilation_modifiers compfuns mode T)
   in
     HOLogic.mk_Trueprop
       (HOLogic.mk_eq (list_comb (fun_const, params @ in_ts @ additional_arguments), compilation))
@@ -2114,30 +2136,47 @@
 
 (** main function of predicate compiler **)
 
+datatype steps = Steps of
+  {
+  compile_preds : theory -> string list -> string list -> (string * typ) list
+    -> (moded_clause list) pred_mode_table -> term pred_mode_table,
+  create_definitions: (string * typ) list -> string * mode list -> theory -> theory,
+  infer_modes : options -> theory -> (string * mode list) list -> (string * mode list) list
+    -> string list -> (string * (term list * indprem list) list) list
+    -> moded_clause list pred_mode_table,
+  prove : options -> theory -> (string * (term list * indprem list) list) list
+    -> (string * typ) list -> (string * mode list) list
+    -> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
+  are_not_defined : theory -> string list -> bool,
+  qname : bstring
+  }
+
+
 fun add_equations_of steps options prednames thy =
   let
+    fun dest_steps (Steps s) = s
     val _ = print_step options ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
       (*val _ = check_intros_elim_match thy prednames*)
       (*val _ = map (check_format_of_intro_rule thy) (maps (intros_of thy) prednames)*)
     val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
       prepare_intrs thy prednames (maps (intros_of thy) prednames)
     val _ = print_step options "Infering modes..."
-    val moded_clauses = #infer_modes steps options thy extra_modes all_modes param_vs clauses 
+    val moded_clauses = #infer_modes (dest_steps steps) options thy extra_modes all_modes param_vs clauses 
     val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
     val _ = check_expected_modes options modes
     val _ = print_modes options modes
       (*val _ = print_moded_clauses thy moded_clauses*)
     val _ = print_step options "Defining executable functions..."
-    val thy' = fold (#create_definitions steps preds) modes thy
+    val thy' = fold (#create_definitions (dest_steps steps) preds) modes thy
       |> Theory.checkpoint
     val _ = print_step options "Compiling equations..."
     val compiled_terms =
-      (#compile_preds steps) thy' all_vs param_vs preds moded_clauses
+      #compile_preds (dest_steps steps) thy' all_vs param_vs preds moded_clauses
     val _ = print_compiled_terms options thy' compiled_terms
     val _ = print_step options "Proving equations..."
-    val result_thms = #prove steps options thy' clauses preds (extra_modes @ modes)
+    val result_thms = #prove (dest_steps steps) options thy' clauses preds (extra_modes @ modes)
       moded_clauses compiled_terms
-    val qname = #qname steps
+    val qname = #qname (dest_steps steps)
     val attrib = fn thy => Attrib.attribute_i thy (Attrib.internal (K (Thm.declaration_attribute
       (fn thm => Context.mapping (Code.add_eqn thm) I))))
     val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
@@ -2164,6 +2203,7 @@
   
 fun gen_add_equations steps options names thy =
   let
+    fun dest_steps (Steps s) = s
     val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
       |> Theory.checkpoint;
     fun strong_conn_of gr keys =
@@ -2171,24 +2211,25 @@
     val scc = strong_conn_of (PredData.get thy') names
     val thy'' = fold_rev
       (fn preds => fn thy =>
-        if #are_not_defined steps thy preds then
+        if #are_not_defined (dest_steps steps) thy preds then
           add_equations_of steps options preds thy else thy)
       scc thy' |> Theory.checkpoint
   in thy'' end
 
 (* different instantiantions of the predicate compiler *)
 
-val predicate_comp_modifiers =
-  {const_name_of = predfun_name_of,
-  funT_of = funT_of,
+val predicate_comp_modifiers = Comp_Mod.Comp_Modifiers
+  {const_name_of = predfun_name_of : (theory -> string -> mode -> string),
+  funT_of = funT_of : (compilation_funs -> mode -> typ -> typ),
   additional_arguments = K [],
-  wrap_compilation = K (K (K (K (K I)))),
-  transform_additional_arguments = K I
+  wrap_compilation = K (K (K (K (K I))))
+   : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
+  transform_additional_arguments = K I : (indprem -> term list -> term list)
   }
 
-val depth_limited_comp_modifiers =
+val depth_limited_comp_modifiers = Comp_Mod.Comp_Modifiers
   {const_name_of = depth_limited_function_name_of,
-  funT_of = depth_limited_funT_of,
+  funT_of = depth_limited_funT_of : (compilation_funs -> mode -> typ -> typ),
   additional_arguments = fn names =>
     let
       val [depth_name, polarity_name] = Name.variant_list names ["depth", "polarity"]
@@ -2219,38 +2260,38 @@
     in [polarity', depth'] end
   }
 
-val rpred_comp_modifiers =
+val rpred_comp_modifiers = Comp_Mod.Comp_Modifiers
   {const_name_of = generator_name_of,
-  funT_of = K generator_funT_of,
+  funT_of = K generator_funT_of : (compilation_funs -> mode -> typ -> typ),
   additional_arguments = fn names => [Free (Name.variant names "size", @{typ code_numeral})],
-  wrap_compilation = K (K (K (K (K I)))),
-  transform_additional_arguments = K I
+  wrap_compilation = K (K (K (K (K I))))
+    : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
+  transform_additional_arguments = K I : (indprem -> term list -> term list)
   }
 
-
 val add_equations = gen_add_equations
-  {infer_modes = infer_modes,
+  (Steps {infer_modes = infer_modes,
   create_definitions = create_definitions,
   compile_preds = compile_preds predicate_comp_modifiers PredicateCompFuns.compfuns,
   prove = prove,
   are_not_defined = fn thy => forall (null o modes_of thy),
-  qname = "equation"}
+  qname = "equation"})
 
 val add_depth_limited_equations = gen_add_equations
-  {infer_modes = infer_modes,
+  (Steps {infer_modes = infer_modes,
   create_definitions = create_definitions_of_depth_limited_functions,
   compile_preds = compile_preds depth_limited_comp_modifiers PredicateCompFuns.compfuns,
   prove = prove_by_skip,
   are_not_defined = fn thy => forall (null o depth_limited_modes_of thy),
-  qname = "depth_limited_equation"}
+  qname = "depth_limited_equation"})
 
 val add_quickcheck_equations = gen_add_equations
-  {infer_modes = infer_modes_with_generator,
+  (Steps {infer_modes = infer_modes_with_generator,
   create_definitions = rpred_create_definitions,
   compile_preds = compile_preds rpred_comp_modifiers RandomPredCompFuns.compfuns,
   prove = prove_by_skip,
   are_not_defined = fn thy => forall (null o rpred_modes_of thy),
-  qname = "rpred_equation"}
+  qname = "rpred_equation"})
 
 (** user interface **)