more robust registration of code equations
authorkuncar
Mon, 24 Feb 2014 18:12:41 +0100
changeset 55724 7572fc374f80
parent 55723 f66371633e13
child 55729 3244957ca236
more robust registration of code equations
src/HOL/Tools/Lifting/lifting_def.ML
--- a/src/HOL/Tools/Lifting/lifting_def.ML	Mon Feb 24 18:12:40 2014 +0100
+++ b/src/HOL/Tools/Lifting/lifting_def.ML	Mon Feb 24 18:12:41 2014 +0100
@@ -324,68 +324,74 @@
     simplify_code_eq ctxt abs_eq
   end
 
-fun define_code_using_abs_eq abs_eq_thm lthy =
-  if null (Logic.strip_imp_prems(prop_of abs_eq_thm)) then
-    (snd oo Local_Theory.note) ((Binding.empty, [Code.add_default_eqn_attrib]), [abs_eq_thm]) lthy
-  else
-    lthy
-  
-fun define_code_using_rep_eq opt_rep_eq_thm lthy = 
-  case opt_rep_eq_thm of
-    SOME rep_eq_thm =>   
-      let
-        val add_abs_eqn_attribute = 
-          Thm.declaration_attribute (fn thm => Context.mapping (Code.add_abs_eqn thm) I)
-        val add_abs_eqn_attrib = Attrib.internal (K add_abs_eqn_attribute);
-      in
-        (snd oo Local_Theory.note) ((Binding.empty, [add_abs_eqn_attrib]), [rep_eq_thm]) lthy
-      end
-    | NONE => lthy
 
-fun has_constr ctxt quot_thm =
+fun register_code_eq_thy abs_eq_thm opt_rep_eq_thm (rty, qty) thy =
   let
-    val thy = Proof_Context.theory_of ctxt
-    val abs_fun = quot_thm_abs quot_thm
-  in
-    if is_Const abs_fun then
-      Code.is_constr thy ((fst o dest_Const) abs_fun)
-    else
-      false
-  end
+    fun no_abstr (t $ u) = no_abstr t andalso no_abstr u
+      | no_abstr (Abs (_, _, t)) = no_abstr t
+      | no_abstr (Const (name, _)) = not (Code.is_abstr thy name)
+      | no_abstr _ = true
+    fun is_valid_eq eqn = can (Code.assert_eqn thy) (mk_meta_eq eqn, true) 
+      andalso no_abstr (prop_of eqn)
+    fun is_valid_abs_eq abs_eq = can (Code.assert_abs_eqn thy NONE) (mk_meta_eq abs_eq)
 
-fun has_abstr ctxt quot_thm =
-  let
-    val thy = Proof_Context.theory_of ctxt
-    val abs_fun = quot_thm_abs quot_thm
   in
-    if is_Const abs_fun then
-      Code.is_abstr thy ((fst o dest_Const) abs_fun)
+    if is_valid_eq abs_eq_thm then
+      Code.add_default_eqn abs_eq_thm thy
     else
-      false
-  end
-
-fun define_code abs_eq_thm opt_rep_eq_thm (rty, qty) lthy =
-  let
-    val (rty_body, qty_body) = get_body_types (rty, qty)
-  in
-    if rty_body = qty_body then
-      if null (Logic.strip_imp_prems(prop_of abs_eq_thm)) then
-        (snd oo Local_Theory.note) ((Binding.empty, [Code.add_default_eqn_attrib]), [abs_eq_thm]) lthy
-      else
-        (snd oo Local_Theory.note) ((Binding.empty, [Code.add_default_eqn_attrib]), [the opt_rep_eq_thm]) lthy
-    else
-      let 
-        val body_quot_thm = Lifting_Term.prove_quot_thm lthy (rty_body, qty_body)
+      let
+        val (rty_body, qty_body) = get_body_types (rty, qty)
       in
-        if has_constr lthy body_quot_thm then
-          define_code_using_abs_eq abs_eq_thm lthy
-        else if has_abstr lthy body_quot_thm then
-          define_code_using_rep_eq opt_rep_eq_thm lthy
+        if rty_body = qty_body then
+         Code.add_default_eqn (the opt_rep_eq_thm) thy
         else
-          lthy
+          if is_some opt_rep_eq_thm andalso is_valid_abs_eq (the opt_rep_eq_thm)
+          then
+            Code.add_abs_eqn (the opt_rep_eq_thm) thy
+          else
+            thy
       end
   end
 
+local
+  fun encode_code_eq thy abs_eq opt_rep_eq (rty, qty) = 
+    let
+      fun mk_type typ = typ |> Logic.mk_type |> cterm_of thy |> Drule.mk_term
+    in
+      Conjunction.intr_balanced [abs_eq, (the_default TrueI opt_rep_eq), mk_type rty, mk_type qty]
+    end
+  
+  fun decode_code_eq thm =
+    let
+      val [abs_eq, rep_eq, rty, qty] = Conjunction.elim_balanced 4 thm
+      val opt_rep_eq = if Thm.eq_thm_prop (rep_eq, TrueI) then NONE else SOME rep_eq
+      fun dest_type typ = typ |> Drule.dest_term |> term_of |> Logic.dest_type
+    in
+      (abs_eq, opt_rep_eq, (dest_type rty, dest_type qty)) 
+    end
+  
+  fun register_encoded_code_eq thm thy =
+    let
+      val (abs_eq_thm, opt_rep_eq_thm, (rty, qty)) = decode_code_eq thm
+    in
+      register_code_eq_thy abs_eq_thm opt_rep_eq_thm (rty, qty) thy
+    end
+  
+  val register_code_eq_attribute = Thm.declaration_attribute
+    (fn thm => Context.mapping (register_encoded_code_eq thm) I)
+  val register_code_eq_attrib = Attrib.internal (K register_code_eq_attribute)
+in
+
+fun register_code_eq abs_eq_thm opt_rep_eq_thm (rty, qty) lthy =
+  let
+    val thy = Proof_Context.theory_of lthy
+    val encoded_code_eq = encode_code_eq thy abs_eq_thm opt_rep_eq_thm (rty, qty)
+  in
+    (snd oo Local_Theory.note) ((Binding.empty, [register_code_eq_attrib]), 
+      [encoded_code_eq]) lthy
+  end
+end
+            
 (*
   Defines an operation on an abstract type in terms of a corresponding operation 
     on a representation type.
@@ -434,7 +440,7 @@
       |> (case opt_rep_eq_thm of 
             SOME rep_eq_thm => (snd oo Local_Theory.note) ((rep_eq_thm_name, []), [rep_eq_thm])
             | NONE => I)
-      |> define_code abs_eq_thm opt_rep_eq_thm (rty_forced, qty)
+      |> register_code_eq abs_eq_thm opt_rep_eq_thm (rty_forced, qty)
   end
 
 fun mk_readable_rsp_thm_eq tm lthy =