simplified and improved compilation of depth-limited search in the predicate compiler
authorbulwahn
Sat, 24 Oct 2009 16:55:42 +0200
changeset 33133 2eb7dfcf3bc3
parent 33132 07efd452a698
child 33134 88c9c3460fe7
simplified and improved compilation of depth-limited search in the predicate compiler
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -101,8 +101,6 @@
   val string_of_moded_prem : theory -> (indprem * tmode) -> string
   val all_modes_of : theory -> (string * mode list) list
   val all_generator_modes_of : theory -> (string * mode list) list
-  val compile_clause : compilation_funs -> term option -> (term list -> term) ->
-    theory -> string list -> string list -> mode -> term -> moded_clause -> term
   val preprocess_intro : theory -> thm -> thm
   val is_constrt : theory -> term -> bool
   val is_predT : typ -> bool
@@ -936,7 +934,7 @@
     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
     val paramTs' = map2 (fn SOME is => sizelim_funT_of PredicateCompFuns.compfuns ([], is) | NONE => I) iss paramTs 
   in
-    (paramTs' @ inargTs @ [@{typ "code_numeral"}]) ---> (mk_predT compfuns (mk_tupleT outargTs))
+    (paramTs' @ inargTs @ [@{typ bool}, @{typ "code_numeral"}]) ---> (mk_predT compfuns (mk_tupleT outargTs))
   end;  
 
 fun mk_sizelim_fun_of compfuns thy (name, T) mode =
@@ -1279,76 +1277,29 @@
   in
     fold_rev lambda vs (f (list_comb (t, vs)))
   end;
-(*
-fun compile_param_ext thy compfuns modes (NONE, t) = t
-  | compile_param_ext thy compfuns modes (m as SOME (Mode ((iss, is'), is, ms)), t) =
-      let
-        val (vs, u) = strip_abs t
-        val (ivs, ovs) = split_mode is vs    
-        val (f, args) = strip_comb u
-        val (params, args') = chop (length ms) args
-        val (inargs, outargs) = split_mode is' args'
-        val b = length vs
-        val perm = map (fn i => (find_index_eq (Bound (b - i)) args') + 1) (1 upto b)
-        val outp_perm =
-          snd (split_mode is perm)
-          |> map (fn i => i - length (filter (fn x => x < i) is'))
-        val names = [] -- TODO
-        val out_names = Name.variant_list names (replicate (length outargs) "x")
-        val f' = case f of
-            Const (name, T) =>
-              if AList.defined op = modes name then
-                mk_predfun_of thy compfuns (name, T) (iss, is')
-              else error "compile param: Not an inductive predicate with correct mode"
-          | Free (name, T) => Free (name, param_funT_of compfuns T (SOME is'))
-        val outTs = dest_tupleT (dest_predT compfuns (body_type (fastype_of f')))
-        val out_vs = map Free (out_names ~~ outTs)
-        val params' = map (compile_param thy modes) (ms ~~ params)
-        val f_app = list_comb (f', params' @ inargs)
-        val single_t = (mk_single compfuns (mk_tuple (map (fn i => nth out_vs (i - 1)) outp_perm)))
-        val match_t = compile_match thy compfuns [] [] out_vs single_t
-      in list_abs (ivs,
-        mk_bind compfuns (f_app, match_t))
-      end
-  | compile_param_ext _ _ _ _ = error "compile params"
-*)
 
-fun compile_param neg_in_sizelim size thy compfuns (NONE, t) = t
-  | compile_param neg_in_sizelim size thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
+fun compile_param size thy compfuns (NONE, t) = t
+  | compile_param size thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
    let
      val (f, args) = strip_comb (Envir.eta_contract t)
      val (params, args') = chop (length ms) args
-     val params' = map (compile_param neg_in_sizelim size thy compfuns) (ms ~~ params)
+     val params' = map (compile_param size thy compfuns) (ms ~~ params)
      val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
      val funT_of = case size of NONE => funT_of | SOME _ => sizelim_funT_of
      val f' =
        case f of
-         Const (name, T) =>
-           mk_fun_of compfuns thy (name, T) (iss, is')
-       | Free (name, T) =>
-         case neg_in_sizelim of
-           SOME _ =>  Free (name, sizelim_funT_of compfuns (iss, is') T)
-         | NONE => Free (name, funT_of compfuns (iss, is') T)
-           
+         Const (name, T) => mk_fun_of compfuns thy (name, T) (iss, is')
+       | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
        | _ => error ("PredicateCompiler: illegal parameter term")
    in
-     (case neg_in_sizelim of SOME size_t =>
-       (fn t =>
-       let
-         val Ts = fst (split_last (binder_types (fastype_of t)))
-         val names = map (fn i => "x" ^ string_of_int i) (1 upto length Ts)
-       in
-         list_abs (names ~~ Ts, list_comb (t, (map Bound ((length Ts) - 1 downto 0)) @ [size_t]))
-       end)
-     | NONE => I)
-     (list_comb (f', params' @ args'))
+     list_comb (f', params' @ args')
    end
 
-fun compile_expr neg_in_sizelim size thy ((Mode (mode, is, ms)), t) =
+fun compile_expr size thy ((Mode (mode, is, ms)), t) =
   case strip_comb t of
     (Const (name, T), params) =>
        let
-         val params' = map (compile_param neg_in_sizelim size thy PredicateCompFuns.compfuns) (ms ~~ params)
+         val params' = map (compile_param size thy PredicateCompFuns.compfuns) (ms ~~ params)
          val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
        in
          list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
@@ -1364,7 +1315,7 @@
   case strip_comb t of
     (Const (name, T), params) =>
       let
-        val params' = map (compile_param NONE size thy PredicateCompFuns.compfuns) (ms ~~ params)
+        val params' = map (compile_param size thy PredicateCompFuns.compfuns) (ms ~~ params)
       in
         list_comb (mk_generator_of compfuns thy (name, T) mode, params' @ inargs)
       end
@@ -1421,7 +1372,7 @@
 			end
 		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
     val (inargs, outargs) = pairself flat (split_list inoutargs)
-    val size_t = case size of NONE => [] | SOME size_t => [size_t]
+    val size_t = case size of NONE => [] | SOME (polarity, size_t) => [polarity, size_t]
 		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ size_t), mk_tuple outargs)
     val t = fold_rev mk_split_lambda args r
   in
@@ -1441,7 +1392,7 @@
       | map_params t = t
     in map_aterms map_params arg end
   
-fun compile_clause compfuns size final_term thy all_vs param_vs (iss, is) inp (ts, moded_ps) =
+fun compile_clause compfuns size thy all_vs param_vs (iss, is) inp (ts, moded_ps) =
   let
     fun check_constrt t (names, eqs) =
       if is_constrt thy t then (t, (names, eqs)) else
@@ -1461,12 +1412,8 @@
             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
               out_ts'' (names', map (rpair []) vs);
           in
-          (* termify code:
             compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
-              (mk_single compfuns (mk_tuple (map mk_valtermify_term out_ts)))
-           *)
-            compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
-              (final_term out_ts)
+              (mk_single compfuns (mk_tuple out_ts))
           end
       | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
           let
@@ -1482,9 +1429,9 @@
                    val in_ts = map (compile_arg size thy param_vs iss) in_ts
                    val args = case size of
                      NONE => in_ts
-                   | SOME size_t => in_ts @ [size_t]
+                   | SOME (polarity, size_t) => in_ts @ [polarity, size_t]
                    val u = lift_pred compfuns
-                     (list_comb (compile_expr NONE size thy (mode, t), args))
+                     (list_comb (compile_expr size thy (mode, t), args))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1492,8 +1439,12 @@
              | Negprem (us, t) =>
                  let
                    val (in_ts, out_ts''') = split_smode is us
-                   val u = lift_pred compfuns
-                     (mk_not PredicateCompFuns.compfuns (list_comb (compile_expr size NONE thy (mode, t), in_ts)))
+                   val size' = Option.map (apfst HOLogic.mk_not) size
+                   val args = case size' of
+                     NONE => in_ts
+                   | SOME (polarity, size_t) => in_ts @ [polarity, size_t]
+                   val u = lift_pred compfuns (mk_not PredicateCompFuns.compfuns
+                     (list_comb (compile_expr size' thy (mode, t), args)))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1509,7 +1460,7 @@
                    val (in_ts, out_ts''') = split_smode is us;
                    val args = case size of
                      NONE => in_ts
-                   | SOME size_t => in_ts @ [size_t]
+                     | SOME (polarity, size_t) => in_ts @ [polarity, size_t]
                    val u = compile_gen_expr size thy compfuns (mode, t) args
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
@@ -1517,7 +1468,7 @@
                  end
              | Generator (v, T) =>
                  let
-                   val u = lift_random (HOLogic.mk_random T (the size))
+                 val u = lift_random (HOLogic.mk_random T (snd (the size)))
                    val rest = compile_prems [Free (v, T)]  vs' names'' ps;
                  in
                    (u, rest)
@@ -1537,7 +1488,6 @@
     val (Us1, Us2) = split_smodeT (snd mode) Ts2
     val funT_of = if use_size then sizelim_funT_of else funT_of
     val Ts1' = map2 (fn NONE => I | SOME is => funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
-    val size_name = Name.variant (all_vs @ param_vs) "size"
   	fun mk_input_term (i, NONE) =
 		    [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
 		  | mk_input_term (i, SOME pis) = case HOLogic.strip_tupleT (nth Ts2 (i - 1)) of
@@ -1551,24 +1501,32 @@
 						   else [HOLogic.mk_tuple (map Free (vnames ~~ map (fn j => nth Ts (j - 1)) pis))] end
 		val in_ts = maps mk_input_term (snd mode)
     val params = map2 (fn s => fn T => Free (s, T)) param_vs Ts1'
-    val size = Free (size_name, @{typ "code_numeral"})
+    val [depth_name, polarity_name] = Name.variant_list (all_vs @ param_vs) ["depth", "polarity"]
+    val size = Free (depth_name, @{typ "code_numeral"})
+    val polarity = Free (polarity_name, @{typ "bool"})
     val decr_size =
       if use_size then
-        SOME (Const ("HOL.minus_class.minus", @{typ "code_numeral => code_numeral => code_numeral"})
+        SOME (polarity, Const ("HOL.minus_class.minus", @{typ "code_numeral => code_numeral => code_numeral"})
           $ size $ Const ("HOL.one_class.one", @{typ "Code_Numeral.code_numeral"}))
       else
         NONE
     val cl_ts =
-      map (compile_clause compfuns decr_size (fn out_ts => mk_single compfuns (mk_tuple out_ts))
+      map (compile_clause compfuns decr_size
         thy all_vs param_vs mode (mk_tuple in_ts)) moded_cls;
     val t = foldr1 (mk_sup compfuns) cl_ts
     val T' = mk_predT compfuns (mk_tupleT Us2)
-    val size_t = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
-      $ HOLogic.mk_eq (size, @{term "0 :: code_numeral"})
-      $ mk_bot compfuns (dest_predT compfuns T') $ t
+    val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
+    val full_mode = null Us2
+    val size_t = 
+      if_const $ HOLogic.mk_eq (size, @{term "0 :: code_numeral"})
+      $ (if full_mode then 
+          if_const $ polarity $ mk_bot compfuns (dest_predT compfuns T') $ mk_single compfuns HOLogic.unit
+        else
+          mk_bot compfuns (dest_predT compfuns T'))
+      $ t
     val fun_const = mk_fun_of compfuns thy (s, T) mode
     val eq = if use_size then
-      (list_comb (fun_const, params @ in_ts @ [size]), size_t)
+      (list_comb (fun_const, params @ in_ts @ [polarity, size]), size_t)
     else
       (list_comb (fun_const, params @ in_ts), t)
   in
@@ -2500,7 +2458,7 @@
       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
                 ^ Syntax.string_of_term_global thy t_compr); m);
     val (inargs, outargs) = split_smode user_mode' args;
-    val t_pred = list_comb (compile_expr NONE NONE thy (m, list_comb (pred, params)), inargs);
+    val t_pred = list_comb (compile_expr NONE thy (m, list_comb (pred, params)), inargs);
     val t_eval = if null outargs then t_pred else
       let
         val outargs_bounds = map (fn Bound i => i) outargs;