adding specialisation of predicates to the predicate compiler
authorbulwahn
Mon, 29 Mar 2010 17:30:52 +0200
changeset 36032 dfd30b5b4e73
parent 36031 199fe16cdaab
child 36033 7106f079bd05
adding specialisation of predicates to the predicate compiler
src/HOL/IsaMakefile
src/HOL/Predicate_Compile.thy
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
--- a/src/HOL/IsaMakefile	Mon Mar 29 17:30:50 2010 +0200
+++ b/src/HOL/IsaMakefile	Mon Mar 29 17:30:52 2010 +0200
@@ -301,6 +301,7 @@
   Tools/Predicate_Compile/predicate_compile_data.ML \
   Tools/Predicate_Compile/predicate_compile_fun.ML \
   Tools/Predicate_Compile/predicate_compile.ML \
+  Tools/Predicate_Compile/predicate_compile_specialisation.ML \
   Tools/Predicate_Compile/predicate_compile_pred.ML \
   Tools/quickcheck_generators.ML \
   Tools/Qelim/cooper_data.ML \
--- a/src/HOL/Predicate_Compile.thy	Mon Mar 29 17:30:50 2010 +0200
+++ b/src/HOL/Predicate_Compile.thy	Mon Mar 29 17:30:52 2010 +0200
@@ -12,6 +12,7 @@
   "Tools/Predicate_Compile/predicate_compile_data.ML"
   "Tools/Predicate_Compile/predicate_compile_fun.ML"
   "Tools/Predicate_Compile/predicate_compile_pred.ML"
+  "Tools/Predicate_Compile/predicate_compile_specialisation.ML"
   "Tools/Predicate_Compile/predicate_compile.ML"
 begin
 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 29 17:30:50 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 29 17:30:52 2010 +0200
@@ -97,7 +97,7 @@
       val _ = print_step options
         ("Compiling functions (" ^ commas (map (Syntax.string_of_term_global thy) funnames) ^
           ") to predicates...")
-      val (fun_pred_specs, thy') =
+      val (fun_pred_specs, thy1) =
         (if function_flattening options andalso (not (null funnames)) then
           if fail_safe_function_flattening options then
             case try (Predicate_Compile_Fun.define_predicates (get_specs funnames)) thy of
@@ -106,24 +106,26 @@
           else Predicate_Compile_Fun.define_predicates (get_specs funnames) thy
         else ([], thy))
         (*||> Theory.checkpoint*)
-      val _ = print_specs options thy' fun_pred_specs
+      val _ = print_specs options thy1 fun_pred_specs
       val specs = (get_specs prednames) @ fun_pred_specs
-      val (intross3, thy''') = process_specification options specs thy'
-      val _ = print_intross options thy''' "Introduction rules with new constants: " intross3
+      val (intross3, thy2) = process_specification options specs thy1
+      val _ = print_intross options thy2 "Introduction rules with new constants: " intross3
       val intross4 = map_specs (maps remove_pointless_clauses) intross3
-      val _ = print_intross options thy''' "After removing pointless clauses: " intross4
-      val intross5 = map_specs (map (remove_equalities thy''')) intross4
-      val _ = print_intross options thy''' "After removing equality premises:" intross5
+      val _ = print_intross options thy2 "After removing pointless clauses: " intross4
+      val intross5 = map_specs (map (remove_equalities thy2)) intross4
+      val _ = print_intross options thy2 "After removing equality premises:" intross5
       val intross6 =
-        map (fn (s, ths) => (overload_const thy''' s, map (AxClass.overload thy''') ths)) intross5
-      val intross7 = map_specs (map (expand_tuples thy''')) intross6
-      val intross8 = map_specs (map (eta_contract_ho_arguments thy''')) intross7
-      val _ = case !intro_hook of NONE => () | SOME f => (map_specs (map (f thy''')) intross8; ())
-      val _ = print_intross options thy''' "introduction rules before registering: " intross8
+        map (fn (s, ths) => (overload_const thy2 s, map (AxClass.overload thy2) ths)) intross5
+      val intross7 = map_specs (map (expand_tuples thy2)) intross6
+      val intross8 = map_specs (map (eta_contract_ho_arguments thy2)) intross7
+      val _ = case !intro_hook of NONE => () | SOME f => (map_specs (map (f thy2)) intross8; ())
+      val _ = print_step options ("Looking for specialisations in " ^ commas (map fst intross8) ^ "...")
+      val (intross9, thy3) = Predicate_Compile_Specialisation.find_specialisations [] intross8 thy2
+      val _ = print_intross options thy3 "introduction rules before registering: " intross9
       val _ = print_step options "Registering introduction rules..."
-      val thy'''' = fold Predicate_Compile_Core.register_intros intross8 thy'''
+      val thy4 = fold Predicate_Compile_Core.register_intros intross9 thy3
     in
-      thy''''
+      thy4
     end;
 
 fun preprocess options t thy =
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:50 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:52 2010 +0200
@@ -295,12 +295,13 @@
       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
       (Symtab.dest (Datatype.get_all thy)));
     fun check t = (case strip_comb t of
-        (Free _, []) => true
+        (Var _, []) => true
+      | (Free _, []) => true
       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
           | _ => false)
       | _ => false)
-  in check end;  
+  in check end;
 
 fun is_funtype (Type ("fun", [_, _])) = true
   | is_funtype _ = false;
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML	Mon Mar 29 17:30:52 2010 +0200
@@ -0,0 +1,200 @@
+(*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
+    Author:     Lukas Bulwahn, TU Muenchen
+
+Deriving specialised predicates and their intro rules
+*)
+
+signature PREDICATE_COMPILE_SPECIALISATION =
+sig
+  val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory
+end;
+
+structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
+struct
+
+open Predicate_Compile_Aux;
+
+(* table of specialisations *)
+structure Specialisations = Theory_Data
+(
+  type T = (term * term) Item_Net.T;
+  val empty = Item_Net.init ((op aconv o pairself fst) : (term * term) * (term * term) -> bool)
+    (single o fst);
+  val extend = I;
+  val merge = Item_Net.merge;
+)
+
+fun specialisation_of thy atom =
+  Item_Net.retrieve (Specialisations.get thy) atom
+
+fun print_specialisations thy =
+  tracing (cat_lines (map (fn (t, spec_t) =>
+      Syntax.string_of_term_global thy t ^ " ~~~> " ^ Syntax.string_of_term_global thy spec_t)
+    (Item_Net.content (Specialisations.get thy))))
+
+fun import (pred, intros) args ctxt =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val ((Tinst, intros'), ctxt') = Variable.importT intros ctxt
+    val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
+    val Ts = binder_types (fastype_of pred')
+    val argTs = map fastype_of args
+    val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
+    val args' = map (Envir.subst_term_types Tsubst) args
+  in
+    (((pred', intros'), args'), ctxt')
+  end
+
+
+
+
+fun specialise_intros black_list (pred, intros) pats thy =
+  let
+    val ctxt = ProofContext.init thy
+    val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
+    val pats = map (Logic.incr_indexes ([],  maxidx + 1)) pats
+    val (((pred, intros), pats), ctxt') = import (pred, intros) pats ctxt
+    val intros_t = map prop_of intros
+    val result_pats = map Var (fold_rev Term.add_vars pats [])
+    fun mk_fresh_name names =
+      let
+        val name =
+          Name.variant names ("specialised_" ^ Long_Name.base_name (fst (dest_Const pred)))
+        val bname = Sign.full_bname thy name
+      in
+        if Sign.declared_const thy bname then
+          mk_fresh_name (name :: names)
+        else
+          bname
+      end
+    val constname = mk_fresh_name []
+    val constT = map fastype_of result_pats ---> @{typ bool}
+    val specialised_const = Const (constname, constT)
+    val specialisation =
+      [(HOLogic.mk_Trueprop (list_comb (pred, pats)),
+        HOLogic.mk_Trueprop (list_comb (specialised_const, result_pats)))]
+    fun specialise_intro intro =
+      (let
+        val (prems, concl) = Logic.strip_horn (prop_of intro)
+        val env = Pattern.unify thy
+          (HOLogic.mk_Trueprop (list_comb (pred, pats)), concl) (Envir.empty 0)
+        val prems = map (Envir.norm_term env) prems
+        val args = map (Envir.norm_term env) result_pats
+        val concl = HOLogic.mk_Trueprop (list_comb (specialised_const, args))
+        val intro = Logic.list_implies (prems, concl)
+      in
+        SOME intro
+      end handle Pattern.Unif => NONE)
+    val specialised_intros_t = map_filter I (map specialise_intro intros)
+    val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
+    val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
+    val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
+    val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
+      [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
+    val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
+    val ([spec], thy''') = find_specialisations black_list [(constname, exported_intros)] thy''
+    val thy'''' = Predicate_Compile_Core.register_intros spec thy'''
+  in
+    thy''''
+  end
+
+and find_specialisations black_list specs thy =
+  let
+    val add_vars = fold_aterms (fn Var v => cons v | _ => I);
+    fun is_nontrivial_constrt thy t = not (is_Var t) andalso (is_constrt thy t)
+    fun fresh_free T free_names =
+      let
+        val free_name = Name.variant free_names "x"
+      in
+        (Free (free_name, T), free_name :: free_names)
+      end
+    fun replace_term_and_restrict thy T t Tts free_names =
+      let
+        val (free, free_names') = fresh_free T free_names
+        val Tts' = map (apsnd (Pattern.rewrite_term thy [(t, free)] [])) Tts
+        val (ts', free_names'') = restrict_pattern' thy Tts' free_names'
+      in
+        (free :: ts', free_names'')
+      end
+    and restrict_pattern' thy [] free_names = ([], free_names)
+      | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
+      let
+        val (ts', free_names') = restrict_pattern' thy Tts free_names
+      in
+        (Free (x, T) :: ts', free_names')
+      end
+      | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
+        replace_term_and_restrict thy T t Tts free_names
+      | restrict_pattern' thy ((T as Type (Tcon, Ts), t) :: Tts) free_names =
+        case Datatype_Data.get_constrs thy Tcon of
+          NONE => replace_term_and_restrict thy T t Tts free_names
+        | SOME constrs => (case strip_comb t of
+          (Const (s, _), ats) => (case AList.lookup (op =) constrs s of
+            SOME constr_T =>
+              let
+                val (Ts', T') = strip_type constr_T
+                val Tsubst = Type.raw_match (T', T) Vartab.empty
+                val Ts = map (Envir.subst_type Tsubst) Ts'
+                val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
+                val (ats', ts') = chop (length ats) bts'
+              in
+                (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
+              end
+            | NONE => replace_term_and_restrict thy T t Tts free_names))
+    fun restrict_pattern thy Ts args =
+      let
+        val args = map Logic.unvarify_global args
+        val Ts = map Logic.unvarifyT_global Ts
+        val free_names = fold Term.add_free_names args []
+        val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
+      in map Logic.varify_global pat end
+    fun detect' atom thy =
+      case strip_comb atom of
+        (pred as Const (pred_name, _), args) =>
+          let
+          val Ts = binder_types (Sign.the_const_type thy pred_name)
+          val vnames = map fst (fold Term.add_var_names args [])
+          val pats = restrict_pattern thy Ts args
+        in
+          if (exists (is_nontrivial_constrt thy) pats)
+            orelse (has_duplicates (op =) (fold add_vars pats [])) then
+            let
+              val thy' =
+                case specialisation_of thy atom of
+                  [] =>
+                    if member (op =) ((map fst specs) @ black_list) pred_name then
+                      thy
+                    else
+                      (case try (Predicate_Compile_Core.intros_of thy) pred_name of
+                        NONE => thy
+                      | SOME intros =>
+                          specialise_intros ((map fst specs) @ (pred_name :: black_list))
+                            (pred, intros) pats thy)
+                  | (t, specialised_t) :: _ => thy
+                val atom' =
+                  case specialisation_of thy' atom of
+                    [] => atom
+                  | (t, specialised_t) :: _ =>
+                    let
+                      val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
+                    in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom
+                    (*FIXME: this exception could be caught earlier in specialisation_of *)
+            in
+              (atom', thy')
+            end
+          else (atom, thy)
+        end
+      | _ => (atom, thy)
+    fun specialise' (constname, intros) thy =
+      let
+        (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
+        val intros = Drule.zero_var_indexes_list intros
+        val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
+      in
+        ((constname, map (Skip_Proof.make_thm thy') intros_t'), thy')
+      end
+  in
+    fold_map specialise' specs thy
+  end
+
+end;
\ No newline at end of file