return also which code equation was used; tuned
authorkuncar
Fri, 05 Dec 2014 14:14:36 +0100
changeset 60227 eacf75e4da95
parent 60226 ec23f2a97ba4
child 60228 32dd7adba5a4
return also which code equation was used; tuned
src/HOL/Tools/Lifting/lifting_def.ML
--- a/src/HOL/Tools/Lifting/lifting_def.ML	Fri Dec 05 14:14:36 2014 +0100
+++ b/src/HOL/Tools/Lifting/lifting_def.ML	Fri Dec 05 14:14:36 2014 +0100
@@ -6,6 +6,7 @@
 
 signature LIFTING_DEF =
 sig
+  datatype code_eq = NONE_EQ | ABS_EQ | REP_EQ
   type lift_def
   val rty_of_lift_def: lift_def -> typ
   val qty_of_lift_def: lift_def -> typ
@@ -15,6 +16,7 @@
   val rsp_thm_of_lift_def: lift_def -> thm
   val abs_eq_of_lift_def: lift_def -> thm
   val rep_eq_of_lift_def: lift_def -> thm option
+  val code_eq_of_lift_def: lift_def -> code_eq
   val transfer_rules_of_lift_def: lift_def -> thm list
   val morph_lift_def: morphism -> lift_def -> lift_def
   val inst_of_lift_def: Proof.context -> typ -> lift_def -> lift_def
@@ -29,7 +31,19 @@
   val add_lift_def: 
     config -> binding * mixfix -> typ -> term -> thm -> thm list -> local_theory -> 
       lift_def * local_theory
-    
+  
+  val prepare_lift_def:
+    (binding * mixfix -> typ -> term -> thm -> thm list -> Proof.context -> 
+      lift_def * local_theory) -> 
+    binding * mixfix -> typ -> term -> thm list -> local_theory -> 
+    term option * (thm list list -> Proof.context -> lift_def * local_theory)
+
+  val gen_lift_def:
+    (binding * mixfix -> typ -> term -> thm -> thm list -> local_theory -> 
+      lift_def * local_theory) -> 
+    binding * mixfix -> typ -> term -> (Proof.context -> tactic) -> thm list -> 
+    local_theory -> lift_def * local_theory
+
   val lift_def: 
     config -> binding * mixfix -> typ -> term -> (Proof.context -> tactic) -> thm list -> 
     local_theory -> lift_def * local_theory
@@ -48,6 +62,8 @@
 
 infix 0 MRSL
 
+datatype code_eq = NONE_EQ | ABS_EQ | REP_EQ
+
 datatype lift_def = LIFT_DEF of {
   rty: typ,
   qty: typ,
@@ -57,6 +73,7 @@
   rsp_thm: thm,
   abs_eq: thm,
   rep_eq: thm option,
+  code_eq: code_eq,
   transfer_rules: thm list
 };
 
@@ -69,21 +86,22 @@
 val rsp_thm_of_lift_def = #rsp_thm o rep_lift_def;
 val abs_eq_of_lift_def = #abs_eq o rep_lift_def;
 val rep_eq_of_lift_def = #rep_eq o rep_lift_def;
+val code_eq_of_lift_def = #code_eq o rep_lift_def;
 val transfer_rules_of_lift_def = #transfer_rules o rep_lift_def;
 
-fun mk_lift_def rty qty rhs lift_const def_thm rsp_thm abs_eq rep_eq transfer_rules =
+fun mk_lift_def rty qty rhs lift_const def_thm rsp_thm abs_eq rep_eq code_eq transfer_rules =
   LIFT_DEF {rty = rty, qty = qty,
             rhs = rhs, lift_const = lift_const,
-            def_thm = def_thm, rsp_thm = rsp_thm, abs_eq = abs_eq, rep_eq = rep_eq,
-            transfer_rules = transfer_rules };
+            def_thm = def_thm, rsp_thm = rsp_thm, abs_eq = abs_eq, rep_eq = rep_eq, 
+            code_eq = code_eq, transfer_rules = transfer_rules };
 
-fun map_lift_def f1 f2 f3 f4 f5 f6 f7 f8 f9
+fun map_lift_def f1 f2 f3 f4 f5 f6 f7 f8 f9 f10
   (LIFT_DEF {rty = rty, qty = qty, rhs = rhs, lift_const = lift_const,
-  def_thm = def_thm, rsp_thm = rsp_thm, abs_eq = abs_eq, rep_eq = rep_eq, 
+  def_thm = def_thm, rsp_thm = rsp_thm, abs_eq = abs_eq, rep_eq = rep_eq, code_eq = code_eq,
   transfer_rules = transfer_rules }) =
   LIFT_DEF {rty = f1 rty, qty = f2 qty, rhs = f3 rhs, lift_const = f4 lift_const,
             def_thm = f5 def_thm, rsp_thm = f6 rsp_thm, abs_eq = f7 abs_eq, rep_eq = f8 rep_eq,
-            transfer_rules = f9 transfer_rules }
+            code_eq = f9 code_eq, transfer_rules = f10 transfer_rules }
 
 fun morph_lift_def phi =
   let
@@ -91,7 +109,7 @@
     val mterm = Morphism.term phi
     val mthm = Morphism.thm phi
   in
-    map_lift_def mtyp mtyp mterm mterm mthm mthm mthm (Option.map mthm) (map mthm)
+    map_lift_def mtyp mtyp mterm mterm mthm mthm mthm (Option.map mthm) I (map mthm)
   end
 
 fun mk_inst_of_lift_def qty lift_def = Vartab.empty |> Type.raw_match (qty_of_lift_def lift_def, qty)
@@ -438,55 +456,23 @@
 
   in
     if is_valid_eq abs_eq_thm then
-      Code.add_default_eqn abs_eq_thm thy
+      (ABS_EQ, Code.add_default_eqn abs_eq_thm thy)
     else
       let
         val (rty_body, qty_body) = get_body_types (rty, qty)
       in
         if rty_body = qty_body then
-         Code.add_default_eqn (the opt_rep_eq_thm) thy
+          (REP_EQ, Code.add_default_eqn (the opt_rep_eq_thm) thy)
         else
           if is_some opt_rep_eq_thm andalso is_valid_abs_eq (the opt_rep_eq_thm)
           then
-            Code.add_abs_default_eqn (the opt_rep_eq_thm) thy
+            (REP_EQ, Code.add_abs_default_eqn (the opt_rep_eq_thm) thy)
           else
-            thy
+            (NONE_EQ, thy)
       end
   end
 
 local
-  fun encode_code_eq ctxt abs_eq opt_rep_eq (rty, qty) = 
-    let
-      fun mk_type typ = typ |> Logic.mk_type |> Thm.cterm_of ctxt |> Drule.mk_term
-    in
-      Conjunction.intr_balanced [abs_eq, (the_default TrueI opt_rep_eq), mk_type rty, mk_type qty]
-    end
-  
-  exception DECODE
-    
-  fun decode_code_eq thm =
-    if Thm.nprems_of thm > 0 then raise DECODE 
-    else
-      let
-        val [abs_eq, rep_eq, rty, qty] = Conjunction.elim_balanced 4 thm
-        val opt_rep_eq = if Thm.eq_thm_prop (rep_eq, TrueI) then NONE else SOME rep_eq
-        fun dest_type typ = typ |> Drule.dest_term |> Thm.term_of |> Logic.dest_type
-      in
-        (abs_eq, opt_rep_eq, (dest_type rty, dest_type qty)) 
-      end
-  
-  fun register_encoded_code_eq thm thy =
-    let
-      val (abs_eq_thm, opt_rep_eq_thm, (rty, qty)) = decode_code_eq thm
-    in
-      register_code_eq_thy abs_eq_thm opt_rep_eq_thm (rty, qty) thy
-    end
-    handle DECODE => thy
-  
-  val register_code_eq_attribute = Thm.declaration_attribute
-    (fn thm => Context.mapping (register_encoded_code_eq thm) I)
-  val register_code_eq_attrib = Attrib.internal (K register_code_eq_attribute)
-
   fun no_no_code ctxt (rty, qty) =
     if same_type_constrs (rty, qty) then
       forall (no_no_code ctxt) (Targs rty ~~ Targs qty)
@@ -506,12 +492,15 @@
 
 fun register_code_eq abs_eq_thm opt_rep_eq_thm (rty, qty) lthy =
   let
-    val encoded_code_eq = encode_code_eq lthy abs_eq_thm opt_rep_eq_thm (rty, qty)
+    val mthm = Morphism.thm (Local_Theory.target_morphism lthy)
+    val abs_eq_thm =  mthm abs_eq_thm
+    val opt_rep_eq_thm = Option.map mthm opt_rep_eq_thm
   in
     if no_no_code lthy (rty, qty) then 
-      (snd oo Local_Theory.note) ((Binding.empty, [register_code_eq_attrib]), [encoded_code_eq]) lthy
+      Local_Theory.background_theory_result 
+        (register_code_eq_thy abs_eq_thm opt_rep_eq_thm (rty, qty)) lthy
     else
-      lthy
+      (NONE_EQ, lthy)
   end
 end
             
@@ -568,12 +557,12 @@
         else map_filter (fn (_, attrs, thms) => if null attrs then NONE 
           else SOME ((Binding.empty, []), [(thms, attrs)])) notes
       end
+    val (code_eq, lthy) = register_code_eq abs_eq_thm opt_rep_eq_thm (rty_forced, qty) lthy
     val lift_def = mk_lift_def rty_forced qty newrhs lift_const def_thm rsp_thm abs_eq_thm 
-          opt_rep_eq_thm transfer_rules
+          opt_rep_eq_thm code_eq transfer_rules
   in
     lthy
       |> Local_Theory.notes (notes (#notes config)) |> snd
-      |> register_code_eq abs_eq_thm opt_rep_eq_thm (rty_forced, qty)
       |> ` (fn lthy => morph_lift_def (Local_Theory.target_morphism lthy) lift_def)
       ||> Local_Theory.restore
   end
@@ -697,7 +686,7 @@
     Symtab.fold (fn (_, data) => fn l => collect data l) table []
   end
 
-fun prepare_lift_def config var qty rhs par_thms lthy =
+fun prepare_lift_def add_lift_def var qty rhs par_thms lthy =
   let
     val rsp_rel = Lifting_Term.equiv_relation lthy (fastype_of rhs, qty)
     val rty_forced = (domain_type o fastype_of) rsp_rel;
@@ -714,7 +703,7 @@
     val opt_proven_rsp_thm = try_prove_reflexivity lthy prsp_tm
     
     fun after_qed internal_rsp_thm lthy = 
-      add_lift_def config var qty rhs (internal_rsp_thm RS to_rsp) par_thms lthy
+      add_lift_def var qty rhs (internal_rsp_thm RS to_rsp) par_thms lthy
   in
     case opt_proven_rsp_thm of
       SOME thm => (NONE, K (after_qed thm))
@@ -737,9 +726,9 @@
         end 
   end
 
-fun lift_def config var qty rhs tac par_thms lthy =
+fun gen_lift_def add_lift_def var qty rhs tac par_thms lthy =
   let
-    val (goal, after_qed) = prepare_lift_def config var qty rhs par_thms lthy
+    val (goal, after_qed) = prepare_lift_def add_lift_def var qty rhs par_thms lthy
   in
     case goal of
       SOME goal => 
@@ -752,6 +741,9 @@
       | NONE => after_qed [[Drule.dummy_thm]] lthy
   end
 
+fun lift_def config var qty rhs tac par_thms lthy = gen_lift_def (add_lift_def config)
+  var qty rhs tac par_thms lthy
+
 (*
 
   lifting_definition command. It opens a proof of a corresponding respectfulness 
@@ -765,7 +757,7 @@
     val var = (binding, mx)
     val rhs = (Syntax.check_term lthy o Syntax.parse_term lthy) rhs_raw
     val par_thms = Attrib.eval_thms lthy par_xthms
-    val (goal, after_qed) = prepare_lift_def default_config var qty rhs par_thms lthy
+    val (goal, after_qed) = prepare_lift_def (add_lift_def default_config) var qty rhs par_thms lthy
   in
     Proof.theorem NONE (snd oo after_qed) [map (rpair []) (the_list goal)] lthy
   end