--- a/src/Pure/Isar/code.ML Fri Feb 19 11:06:21 2010 +0100
+++ b/src/Pure/Isar/code.ML Fri Feb 19 11:06:21 2010 +0100
@@ -22,6 +22,8 @@
(*constructor sets*)
val constrset_of_consts: theory -> (string * typ) list
-> string * ((string * sort) list * (string * typ list) list)
+ val abstype_cert: theory -> string * typ -> string
+ -> string * ((string * sort) list * ((string * typ) * (string * term)))
(*code equations and certificates*)
val mk_eqn: theory -> thm * bool -> thm * bool
@@ -34,9 +36,10 @@
val empty_cert: theory -> string -> cert
val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
val constrain_cert: theory -> sort list -> cert -> cert
- val typscheme_cert: theory -> cert -> (string * sort) list * typ
- val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
- val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+ val typargs_deps_of_cert: theory -> cert -> (string * sort) list * (string * typ list) list
+ val equations_of_cert: theory -> cert ->
+ ((string * sort) list * typ) * ((string option * (term list * term)) * (thm option * bool)) list
+ val bare_thms_of_cert: theory -> cert -> thm list
val pretty_cert: theory -> cert -> Pretty.T list
(*executable code*)
@@ -46,6 +49,8 @@
val add_signature_cmd: string * string -> theory -> theory
val add_datatype: (string * typ) list -> theory -> theory
val add_datatype_cmd: string list -> theory -> theory
+ val add_abstype: string * typ -> string * typ -> theory -> Proof.state
+ val add_abstype_cmd: string -> string -> theory -> Proof.state
val type_interpretation:
(string * ((string * sort) list * (string * typ list) list)
-> theory -> theory) -> theory -> theory
@@ -59,7 +64,9 @@
val add_case: thm -> theory -> theory
val add_undefined: string -> theory -> theory
val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
- val get_datatype_of_constr: theory -> string -> string option
+ val get_datatype_of_constr_or_abstr: theory -> string -> (string * bool) option
+ val is_constr: theory -> string -> bool
+ val is_abstr: theory -> string -> bool
val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
val get_case_scheme: theory -> string -> (int * (int * string list)) option
val undefineds: theory -> string list
@@ -122,34 +129,31 @@
fun read_const thy = AxClass.unoverload_const thy o read_bare_const thy;
-
(** data store **)
-(* code equations *)
+(* datatypes *)
+
+datatype typ_spec = Constructors of (string * typ list) list
+ | Abstractor of (string * typ) * (string * thm);
-type eqns = bool * (thm * bool) list;
- (*default flag, theorems with proper flag *)
+fun constructors_of (Constructors cos) = (cos, false)
+ | constructors_of (Abstractor ((co, ty), _)) = ([(co, [ty])], true);
+
+
+(* functions *)
-fun add_drop_redundant thy (thm, proper) thms =
- let
- val args_of = snd o strip_comb o map_types Type.strip_sorts
- o fst o Logic.dest_equals o Thm.plain_prop_of;
- val args = args_of thm;
- val incr_idx = Logic.incr_indexes ([], Thm.maxidx_of thm + 1);
- fun matches_args args' = length args <= length args' andalso
- Pattern.matchess thy (args, (map incr_idx o take (length args)) args');
- fun drop (thm', proper') = if (proper orelse not proper')
- andalso matches_args (args_of thm') then
- (warning ("Code generator: dropping redundant code equation\n" ^
- Display.string_of_thm_global thy thm'); true)
- else false;
- in (thm, proper) :: filter_out drop thms end;
+datatype fun_spec = Default of (thm * bool) list
+ | Eqns of (thm * bool) list
+ | Proj of term * string
+ | Abstr of thm * string;
-fun add_thm thy _ thm (false, thms) = (false, add_drop_redundant thy thm thms)
- | add_thm thy true thm (true, thms) = (true, thms @ [thm])
- | add_thm thy false thm (true, thms) = (false, [thm]);
+val empty_fun_spec = Default [];
-fun del_thm thm = apsnd (remove (eq_fst Thm.eq_thm_prop) (thm, true));
+fun is_default (Default _) = true
+ | is_default _ = false;
+
+fun associated_abstype (Abstr (_, tyco)) = SOME tyco
+ | associated_abstype _ = NONE;
(* executable code data *)
@@ -157,49 +161,49 @@
datatype spec = Spec of {
history_concluded: bool,
signatures: int Symtab.table * typ Symtab.table,
- eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
+ functions: ((bool * fun_spec) * (serial * fun_spec) list) Symtab.table
(*with explicit history*),
- dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
+ datatypes: ((serial * ((string * sort) list * typ_spec)) list) Symtab.table
(*with explicit history*),
cases: (int * (int * string list)) Symtab.table * unit Symtab.table
};
-fun make_spec (history_concluded, ((signatures, eqns), (dtyps, cases))) =
+fun make_spec (history_concluded, ((signatures, functions), (datatypes, cases))) =
Spec { history_concluded = history_concluded,
- signatures = signatures, eqns = eqns, dtyps = dtyps, cases = cases };
+ signatures = signatures, functions = functions, datatypes = datatypes, cases = cases };
fun map_spec f (Spec { history_concluded = history_concluded, signatures = signatures,
- eqns = eqns, dtyps = dtyps, cases = cases }) =
- make_spec (f (history_concluded, ((signatures, eqns), (dtyps, cases))));
-fun merge_spec (Spec { history_concluded = _, signatures = (tycos1, sigs1), eqns = eqns1,
- dtyps = dtyps1, cases = (cases1, undefs1) },
- Spec { history_concluded = _, signatures = (tycos2, sigs2), eqns = eqns2,
- dtyps = dtyps2, cases = (cases2, undefs2) }) =
+ functions = functions, datatypes = datatypes, cases = cases }) =
+ make_spec (f (history_concluded, ((signatures, functions), (datatypes, cases))));
+fun merge_spec (Spec { history_concluded = _, signatures = (tycos1, sigs1), functions = functions1,
+ datatypes = datatypes1, cases = (cases1, undefs1) },
+ Spec { history_concluded = _, signatures = (tycos2, sigs2), functions = functions2,
+ datatypes = datatypes2, cases = (cases2, undefs2) }) =
let
val signatures = (Symtab.merge (op =) (tycos1, tycos2),
Symtab.merge typ_equiv (sigs1, sigs2));
- fun merge_eqns ((_, history1), (_, history2)) =
+ fun merge_functions ((_, history1), (_, history2)) =
let
val raw_history = AList.merge (op = : serial * serial -> bool)
- (K true) (history1, history2)
- val filtered_history = filter_out (fst o snd) raw_history
+ (K true) (history1, history2);
+ val filtered_history = filter_out (is_default o snd) raw_history;
val history = if null filtered_history
then raw_history else filtered_history;
in ((false, (snd o hd) history), history) end;
- val eqns = Symtab.join (K merge_eqns) (eqns1, eqns2);
- val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
+ val functions = Symtab.join (K merge_functions) (functions1, functions2);
+ val datatypes = Symtab.join (K (AList.merge (op =) (K true))) (datatypes1, datatypes2);
val cases = (Symtab.merge (K true) (cases1, cases2),
Symtab.merge (K true) (undefs1, undefs2));
- in make_spec (false, ((signatures, eqns), (dtyps, cases))) end;
+ in make_spec (false, ((signatures, functions), (datatypes, cases))) end;
fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
fun the_signatures (Spec { signatures, ... }) = signatures;
-fun the_eqns (Spec { eqns, ... }) = eqns;
-fun the_dtyps (Spec { dtyps, ... }) = dtyps;
+fun the_functions (Spec { functions, ... }) = functions;
+fun the_datatypes (Spec { datatypes, ... }) = datatypes;
fun the_cases (Spec { cases, ... }) = cases;
val map_history_concluded = map_spec o apfst;
val map_signatures = map_spec o apsnd o apfst o apfst;
-val map_eqns = map_spec o apsnd o apfst o apsnd;
-val map_dtyps = map_spec o apsnd o apsnd o apfst;
+val map_functions = map_spec o apsnd o apfst o apsnd;
+val map_typs = map_spec o apsnd o apsnd o apfst;
val map_cases = map_spec o apsnd o apsnd o apsnd;
@@ -251,6 +255,7 @@
in
+
(* access to executable code *)
val the_exec = fst o Code_Data.get;
@@ -259,9 +264,9 @@
val purge_data = (Code_Data.map o apsnd) (fn _ => empty_dataref ());
-fun change_eqns delete c f = (map_exec_purge o map_eqns
- o (if delete then Symtab.map_entry c else Symtab.map_default (c, ((false, (true, [])), [])))
- o apfst) (fn (_, eqns) => (true, f eqns));
+fun change_fun_spec delete c f = (map_exec_purge o map_functions
+ o (if delete then Symtab.map_entry c else Symtab.map_default (c, ((false, empty_fun_spec), [])))
+ o apfst) (fn (_, spec) => (true, f spec));
(* tackling equation history *)
@@ -276,7 +281,7 @@
then NONE
else thy
|> (Code_Data.map o apfst)
- ((map_eqns o Symtab.map) (fn ((changed, current), history) =>
+ ((map_functions o Symtab.map) (fn ((changed, current), history) =>
((false, current),
if changed then (serial (), current) :: history else history))
#> map_history_concluded (K true))
@@ -359,29 +364,32 @@
(* datatypes *)
-fun constrset_of_consts thy cs =
+fun no_constr thy s (c, ty) = error ("Not a datatype constructor:\n" ^ string_of_const thy c
+ ^ " :: " ^ string_of_typ thy ty ^ "\n" ^ enclose "(" ")" s);
+
+fun ty_sorts thy (c, raw_ty) =
let
- val _ = map (fn (c, _) => if (is_some o AxClass.class_of_param thy) c
- then error ("Is a class parameter: " ^ string_of_const thy c) else ()) cs;
- fun no_constr s (c, ty) = error ("Not a datatype constructor:\n" ^ string_of_const thy c
- ^ " :: " ^ string_of_typ thy ty ^ "\n" ^ enclose "(" ")" s);
+ val _ = Thm.cterm_of thy (Const (c, raw_ty));
+ val ty = subst_signature thy c raw_ty;
+ val ty_decl = (Logic.unvarifyT o const_typ thy) c;
fun last_typ c_ty ty =
let
val tfrees = Term.add_tfreesT ty [];
val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) ty
- handle TYPE _ => no_constr "bad type" c_ty
+ handle TYPE _ => no_constr thy "bad type" c_ty
val _ = if has_duplicates (eq_fst (op =)) vs
- then no_constr "duplicate type variables in datatype" c_ty else ();
+ then no_constr thy "duplicate type variables in datatype" c_ty else ();
val _ = if length tfrees <> length vs
- then no_constr "type variables missing in datatype" c_ty else ();
+ then no_constr thy "type variables missing in datatype" c_ty else ();
in (tyco, vs) end;
- fun ty_sorts (c, raw_ty) =
- let
- val ty = subst_signature thy c raw_ty;
- val ty_decl = (Logic.unvarifyT o const_typ thy) c;
- val (tyco, _) = last_typ (c, ty) ty_decl;
- val (_, vs) = last_typ (c, ty) ty;
- in ((tyco, map snd vs), (c, (map fst vs, ty))) end;
+ val (tyco, _) = last_typ (c, ty) ty_decl;
+ val (_, vs) = last_typ (c, ty) ty;
+ in ((tyco, map snd vs), (c, (map fst vs, ty))) end;
+
+fun constrset_of_consts thy cs =
+ let
+ val _ = map (fn (c, _) => if (is_some o AxClass.class_of_param thy) c
+ then error ("Is a class parameter: " ^ string_of_const thy c) else ()) cs;
fun add ((tyco', sorts'), c) ((tyco, sorts), cs) =
let
val _ = if (tyco' : string) <> tyco
@@ -394,31 +402,68 @@
val the_v = the o AList.lookup (op =) (vs ~~ vs');
val ty' = map_atyps (fn TFree (v, _) => TFree (the_v v)) ty;
in (c, (fst o strip_type) ty') end;
- val c' :: cs' = map ty_sorts cs;
+ val c' :: cs' = map (ty_sorts thy) cs;
val ((tyco, sorts), cs'') = fold add cs' (apsnd single c');
val vs = Name.names Name.context Name.aT sorts;
val cs''' = map (inst vs) cs'';
in (tyco, (vs, rev cs''')) end;
-fun get_datatype thy tyco =
- case these (Symtab.lookup ((the_dtyps o the_exec) thy) tyco)
- of (_, spec) :: _ => spec
- | [] => arity_number thy tyco
- |> Name.invents Name.context Name.aT
- |> map (rpair [])
- |> rpair [];
+fun abstype_cert thy abs_ty rep =
+ let
+ val _ = pairself (fn c => if (is_some o AxClass.class_of_param thy) c
+ then error ("Is a class parameter: " ^ string_of_const thy c) else ()) (fst abs_ty, rep);
+ val ((tyco, sorts), (abs, (vs, ty'))) = ty_sorts thy abs_ty;
+ val (ty, ty_abs) = case ty'
+ of Type ("fun", [ty, ty_abs]) => (ty, ty_abs)
+ | _ => error ("Not a datatype abstractor:\n" ^ string_of_const thy abs
+ ^ " :: " ^ string_of_typ thy ty');
+ val _ = Thm.cterm_of thy (Const (rep, ty_abs --> ty)) handle CTERM _ =>
+ error ("Not a projection:\n" ^ string_of_const thy rep);
+ val cert = Logic.mk_equals (Const (abs, ty --> ty_abs) $ (Const (rep, ty_abs --> ty)
+ $ Free ("x", ty_abs)), Free ("x", ty_abs));
+ in (tyco, (vs ~~ sorts, ((fst abs_ty, ty), (rep, cert)))) end;
+
+fun get_datatype_entry thy tyco = case these (Symtab.lookup ((the_datatypes o the_exec) thy) tyco)
+ of (_, entry) :: _ => SOME entry
+ | _ => NONE;
-fun get_datatype_of_constr thy c =
+fun get_datatype_spec thy tyco = case get_datatype_entry thy tyco
+ of SOME (vs, spec) => apfst (pair vs) (constructors_of spec)
+ | NONE => arity_number thy tyco
+ |> Name.invents Name.context Name.aT
+ |> map (rpair [])
+ |> rpair []
+ |> rpair false;
+
+fun get_abstype_spec thy tyco = case get_datatype_entry thy tyco
+ of SOME (vs, Abstractor spec) => (vs, spec)
+ | NONE => error ("Not an abstract type: " ^ tyco);
+
+fun get_datatype thy = fst o get_datatype_spec thy;
+
+fun get_datatype_of_constr_or_abstr thy c =
case (snd o strip_type o const_typ thy) c
- of Type (tyco, _) => if member (op =) ((map fst o snd o get_datatype thy) tyco) c
- then SOME tyco else NONE
+ of Type (tyco, _) => let val ((vs, cos), abstract) = get_datatype_spec thy tyco
+ in if member (op =) (map fst cos) c then SOME (tyco, abstract) else NONE end
| _ => NONE;
-fun is_constr thy = is_some o get_datatype_of_constr thy;
+fun is_constr thy c = case get_datatype_of_constr_or_abstr thy c
+ of SOME (_, false) => true
+ | _ => false;
+
+fun is_abstr thy c = case get_datatype_of_constr_or_abstr thy c
+ of SOME (_, true) => true
+ | _ => false;
(* bare code equations *)
+(* convention for variables:
+ ?x ?'a for free-floating theorems (e.g. in the data store)
+ ?x 'a for certificates
+ x 'a for final representation of equations
+*)
+
exception BAD_THM of string;
fun bad_thm msg = raise BAD_THM msg;
fun error_thm f thm = f thm handle BAD_THM msg => error msg;
@@ -430,12 +475,9 @@
in not (has_duplicates (op =) ((fold o fold_aterms)
(fn Var (v, _) => cons v | _ => I) args [])) end;
-fun gen_assert_eqn thy check_patterns (thm, proper) =
+fun check_eqn thy { allow_nonlinear, allow_consts, allow_pats } thm (lhs, rhs) =
let
fun bad s = bad_thm (s ^ ":\n" ^ Display.string_of_thm_global thy thm);
- val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
- handle TERM _ => bad "Not an equation"
- | THM _ => bad "Not an equation";
fun vars_of t = fold_aterms (fn Var (v, _) => insert (op =) v
| Free _ => bad "Illegal free variable in equation"
| _ => I) t [];
@@ -461,21 +503,23 @@
| check _ (Var _) = bad "Variable with application on left hand side of equation"
| check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
| check n (Const (c_ty as (c, ty))) =
- let
+ if allow_pats then let
val c' = AxClass.unoverload_const thy c_ty
in if n = (length o fst o strip_type o subst_signature thy c') ty
- then if not proper orelse not check_patterns orelse is_constr thy c'
+ then if allow_consts orelse is_constr thy c'
then ()
else bad (quote c ^ " is not a constructor, on left hand side of equation")
else bad ("Partially applied constant " ^ quote c ^ " on left hand side of equation")
- end;
+ end else bad ("Pattern not allowed, but constant " ^ quote c ^ " encountered on left hand side")
val _ = map (check 0) args;
- val _ = if not proper orelse is_linear thm then ()
+ val _ = if allow_nonlinear orelse is_linear thm then ()
else bad "Duplicate variables on left hand side of equation";
val _ = if (is_none o AxClass.class_of_param thy) c then ()
else bad "Overloaded constant as head in equation";
val _ = if not (is_constr thy c) then ()
else bad "Constructor as head in equation";
+ val _ = if not (is_abstr thy c) then ()
+ else bad "Abstractor as head in equation";
val ty_decl = Sign.the_const_type thy c;
val _ = if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
then () else bad_thm ("Type\n" ^ string_of_typ thy ty
@@ -483,8 +527,39 @@
^ Display.string_of_thm_global thy thm
^ "\nis incompatible with declared function type\n"
^ string_of_typ thy ty_decl)
+ in () end;
+
+fun gen_assert_eqn thy check_patterns (thm, proper) =
+ let
+ fun bad s = bad_thm (s ^ ":\n" ^ Display.string_of_thm_global thy thm);
+ val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
+ handle TERM _ => bad "Not an equation"
+ | THM _ => bad "Not a proper equation";
+ val _ = check_eqn thy { allow_nonlinear = not proper,
+ allow_consts = not (proper andalso check_patterns), allow_pats = true } thm (lhs, rhs);
in (thm, proper) end;
+fun assert_abs_eqn thy some_tyco thm =
+ let
+ fun bad s = bad_thm (s ^ ":\n" ^ Display.string_of_thm_global thy thm);
+ val (full_lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
+ handle TERM _ => bad "Not an equation"
+ | THM _ => bad "Not a proper equation";
+ val (rep, lhs) = dest_comb full_lhs
+ handle TERM _ => bad "Not an abstract equation";
+ val tyco = (fst o dest_Type o domain_type o snd o dest_Const) rep
+ handle TERM _ => bad "Not an abstract equation";
+ val _ = case some_tyco of SOME tyco' => if tyco = tyco' then ()
+ else bad ("Abstract type mismatch:" ^ quote tyco ^ " vs. " ^ quote tyco')
+ | NONE => ();
+ val (_, (_, (rep', _))) = get_abstype_spec thy tyco;
+ val rep_const = (fst o dest_Const) rep;
+ val _ = if rep_const = rep' then ()
+ else bad ("Projection mismatch: " ^ quote rep_const ^ " vs. " ^ quote rep');
+ val _ = check_eqn thy { allow_nonlinear = false,
+ allow_consts = false, allow_pats = false } thm (lhs, rhs);
+ in (thm, tyco) end;
+
fun assert_eqn thy = error_thm (gen_assert_eqn thy true);
fun meta_rewrite thy = LocalDefs.meta_rewrite_rule (ProofContext.init thy);
@@ -498,6 +573,8 @@
fun mk_eqn_liberal thy = Option.map (fn (thm, _) => (thm, is_linear thm))
o try_thm (gen_assert_eqn thy false) o rpair false o meta_rewrite thy;
+fun mk_abs_eqn thy = error_thm (assert_abs_eqn thy NONE) o meta_rewrite thy;
+
val head_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
fun const_typ_eqn thy thm =
@@ -509,11 +586,20 @@
fun const_eqn thy = fst o const_typ_eqn thy;
+fun const_abs_eqn thy = AxClass.unoverload_const thy o dest_Const o fst o strip_comb o snd
+ o dest_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
+
fun logical_typscheme thy (c, ty) =
(map dest_TFree (Sign.const_typargs thy (c, ty)), Type.strip_sorts ty);
fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
+fun mk_proj tyco vs ty abs rep =
+ let
+ val ty_abs = Type (tyco, map TFree vs);
+ val xarg = Var (("x", 0), ty);
+ in Logic.mk_equals (Const (rep, ty_abs --> ty) $ (Const (abs, ty --> ty_abs) $ xarg), xarg) end;
+
(* technical transformations of code equations *)
@@ -578,7 +664,50 @@
val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
in (typscheme thy (c, ty), head) end;
-abstype cert = Cert of thm * bool list with
+fun typscheme_projection thy =
+ typscheme thy o dest_Const o fst o dest_comb o fst o Logic.dest_equals;
+
+fun typscheme_abs thy =
+ typscheme thy o dest_Const o fst o strip_comb o snd o dest_comb o fst o Logic.dest_equals o Thm.prop_of;
+
+fun constrain_thm thy vs sorts thm =
+ let
+ val mapping = map2 (fn (v, sort) => fn sort' =>
+ (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
+ val inst = map2 (fn (v, sort) => fn (_, sort') =>
+ (((v, 0), sort), TFree (v, sort'))) vs mapping;
+ val subst = (map_types o map_atyps)
+ (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) mapping v)));
+ in
+ thm
+ |> Thm.varifyT
+ |> Thm.certify_instantiate (inst, [])
+ |> pair subst
+ end;
+
+fun concretify_abs thy tyco abs_thm =
+ let
+ val (vs, ((c, _), (_, cert))) = get_abstype_spec thy tyco;
+ val lhs = (fst o Logic.dest_equals o Thm.prop_of) abs_thm
+ val ty = fastype_of lhs;
+ val ty_abs = (fastype_of o snd o dest_comb) lhs;
+ val abs = Thm.cterm_of thy (Const (c, ty --> ty_abs));
+ val raw_concrete_thm = Drule.transitive_thm OF [Thm.symmetric cert, Thm.combination (Thm.reflexive abs) abs_thm];
+ in (c, (Thm.varifyT o zero_var_indexes) raw_concrete_thm) end;
+
+fun add_rhss_of_eqn thy t =
+ let
+ val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals o subst_signatures thy) t;
+ fun add_const (Const (c, ty)) = insert (op =) (c, Sign.const_typargs thy (c, ty))
+ | add_const _ = I
+ in fold_aterms add_const t end;
+
+fun dest_eqn thy = apfst (snd o strip_comb) o Logic.dest_equals o subst_signatures thy o Logic.unvarify;
+
+abstype cert = Equations of thm * bool list
+ | Projection of term * string
+ | Abstract of thm * string
+with
fun empty_cert thy c =
let
@@ -590,7 +719,7 @@
|> map (fn v => TFree (v, []));
val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
val chead = build_head thy (c, ty);
- in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
+ in Equations (Thm.weaken chead Drule.dummy_thm, []) end;
fun cert_of_eqns thy c [] = empty_cert thy c
| cert_of_eqns thy c raw_eqns =
@@ -615,65 +744,127 @@
else Conv.rewr_conv head_thm ct;
val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
- in Cert (cert_thm, propers) end;
+ in Equations (cert_thm, propers) end;
-fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
+fun cert_of_proj thy c tyco =
+ let
+ val (vs, ((abs, ty), (rep, cert))) = get_abstype_spec thy tyco;
+ val _ = if c = rep then () else
+ error ("Wrong head of projection,\nexpected constant " ^ string_of_const thy rep);
+ in Projection (mk_proj tyco vs ty abs rep, tyco) end;
+
+fun cert_of_abs thy tyco c raw_abs_thm =
let
- val ((vs, _), head) = get_head thy cert_thm;
- val subst = map2 (fn (v, sort) => fn sort' =>
- (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
- val head' = Thm.term_of head
- |> (map_types o map_atyps)
- (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
- |> Thm.cterm_of thy;
- val inst = map2 (fn (v, sort) => fn (_, sort') =>
- (((v, 0), sort), TFree (v, sort'))) vs subst;
- val cert_thm' = cert_thm
- |> Thm.implies_intr head
- |> Thm.varifyT
- |> Thm.certify_instantiate (inst, [])
- |> Thm.elim_implies (Thm.assume head');
- in (Cert (cert_thm', propers)) end;
+ val abs_thm = singleton (canonize_thms thy) raw_abs_thm;
+ val _ = assert_abs_eqn thy (SOME tyco) abs_thm;
+ val _ = if c = const_abs_eqn thy abs_thm then ()
+ else error ("Wrong head of abstract code equation,\nexpected constant "
+ ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy abs_thm);
+ in Abstract (Thm.freezeT abs_thm, tyco) end;
-fun typscheme_cert thy (Cert (cert_thm, _)) =
- fst (get_head thy cert_thm);
+fun constrain_cert thy sorts (Equations (cert_thm, propers)) =
+ let
+ val ((vs, _), head) = get_head thy cert_thm;
+ val (subst, cert_thm') = cert_thm
+ |> Thm.implies_intr head
+ |> constrain_thm thy vs sorts;
+ val head' = Thm.term_of head
+ |> subst
+ |> Thm.cterm_of thy;
+ val cert_thm'' = cert_thm'
+ |> Thm.elim_implies (Thm.assume head');
+ in Equations (cert_thm'', propers) end
+ | constrain_cert thy _ (cert as Projection _) =
+ cert
+ | constrain_cert thy sorts (Abstract (abs_thm, tyco)) =
+ Abstract (snd (constrain_thm thy (fst (typscheme_abs thy abs_thm)) sorts abs_thm), tyco);
+
+fun typscheme_of_cert thy (Equations (cert_thm, _)) =
+ fst (get_head thy cert_thm)
+ | typscheme_of_cert thy (Projection (proj, _)) =
+ typscheme_projection thy proj
+ | typscheme_of_cert thy (Abstract (abs_thm, _)) =
+ typscheme_abs thy abs_thm;
-fun equations_cert thy (cert as Cert (cert_thm, propers)) =
- let
- val tyscm = typscheme_cert thy cert;
- val equations = if null propers then [] else
- Thm.prop_of cert_thm
- |> Logic.dest_conjunction_balanced (length propers)
- |> map Logic.dest_equals
- |> (map o apfst) (snd o strip_comb)
- in (tyscm, equations) end;
+fun typargs_deps_of_cert thy (Equations (cert_thm, propers)) =
+ let
+ val vs = (fst o fst) (get_head thy cert_thm);
+ val equations = if null propers then [] else
+ Thm.prop_of cert_thm
+ |> Logic.dest_conjunction_balanced (length propers);
+ in (vs, fold (add_rhss_of_eqn thy) equations []) end
+ | typargs_deps_of_cert thy (Projection (t, tyco)) =
+ (fst (typscheme_projection thy t), add_rhss_of_eqn thy t [])
+ | typargs_deps_of_cert thy (Abstract (abs_thm, tyco)) =
+ let
+ val vs = fst (typscheme_abs thy abs_thm);
+ val (_, concrete_thm) = concretify_abs thy tyco abs_thm;
+ in (vs, add_rhss_of_eqn thy (Thm.prop_of abs_thm) []) end;
-fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
- let
- val (tyscm, equations) = equations_cert thy cert;
- val thms = if null propers then [] else
- cert_thm
- |> LocalDefs.expand [snd (get_head thy cert_thm)]
- |> Thm.varifyT
- |> Conjunction.elim_balanced (length propers)
- in (tyscm, equations ~~ (thms ~~ propers)) end;
+fun equations_of_cert thy (cert as Equations (cert_thm, propers)) =
+ let
+ val tyscm = typscheme_of_cert thy cert;
+ val thms = if null propers then [] else
+ cert_thm
+ |> LocalDefs.expand [snd (get_head thy cert_thm)]
+ |> Thm.varifyT
+ |> Conjunction.elim_balanced (length propers);
+ in (tyscm, map (pair NONE o dest_eqn thy o Thm.prop_of) thms ~~ (map SOME thms ~~ propers)) end
+ | equations_of_cert thy (Projection (t, tyco)) =
+ let
+ val (_, ((abs, _), _)) = get_abstype_spec thy tyco;
+ val tyscm = typscheme_projection thy t;
+ val t' = map_types Logic.varifyT t;
+ in (tyscm, [((SOME abs, dest_eqn thy t'), (NONE, true))]) end
+ | equations_of_cert thy (Abstract (abs_thm, tyco)) =
+ let
+ val tyscm = typscheme_abs thy abs_thm;
+ val (abs, concrete_thm) = concretify_abs thy tyco abs_thm;
+ val _ = fold_aterms (fn Const (c, _) => if c = abs
+ then error ("Abstraction violation in abstract code equation\n" ^ Display.string_of_thm_global thy abs_thm)
+ else I | _ => I) (Thm.prop_of abs_thm);
+ in (tyscm, [((SOME abs, dest_eqn thy (Thm.prop_of concrete_thm)), (SOME (Thm.varifyT abs_thm), true))]) end;
-fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
- o snd o equations_thms_cert thy;
+fun pretty_cert thy (cert as Equations _) =
+ (map_filter (Option.map (Display.pretty_thm_global thy o AxClass.overload thy) o fst o snd)
+ o snd o equations_of_cert thy) cert
+ | pretty_cert thy (Projection (t, _)) =
+ [Syntax.pretty_term_global thy (map_types Logic.varifyT t)]
+ | pretty_cert thy (Abstract (abs_thm, tyco)) =
+ [(Display.pretty_thm_global thy o AxClass.overload thy o Thm.varifyT) abs_thm];
+
+fun bare_thms_of_cert thy (cert as Equations _) =
+ (map_filter (fn (_, (some_thm, proper)) => if proper then some_thm else NONE)
+ o snd o equations_of_cert thy) cert
+ | bare_thms_of_cert thy _ = [];
end;
-(* code equation access *)
+(* code certificate access *)
+
+fun retrieve_raw thy c =
+ Symtab.lookup ((the_functions o the_exec) thy) c
+ |> Option.map (snd o fst)
+ |> the_default (Default [])
-fun get_cert thy f c =
- Symtab.lookup ((the_eqns o the_exec) thy) c
- |> Option.map (snd o snd o fst)
- |> these
- |> (map o apfst) (Thm.transfer thy)
- |> f
- |> (map o apfst) (AxClass.unoverload thy)
- |> cert_of_eqns thy c;
+fun get_cert thy f c = case retrieve_raw thy c
+ of Default eqns => eqns
+ |> (map o apfst) (Thm.transfer thy)
+ |> f
+ |> (map o apfst) (AxClass.unoverload thy)
+ |> cert_of_eqns thy c
+ | Eqns eqns => eqns
+ |> (map o apfst) (Thm.transfer thy)
+ |> f
+ |> (map o apfst) (AxClass.unoverload thy)
+ |> cert_of_eqns thy c
+ | Proj (_, tyco) =>
+ cert_of_proj thy c tyco
+ | Abstr (abs_thm, tyco) => abs_thm
+ |> Thm.transfer thy
+ |> AxClass.unoverload thy
+ |> cert_of_abs thy tyco c;
(* cases *)
@@ -729,48 +920,54 @@
let
val ctxt = ProofContext.init thy;
val exec = the_exec thy;
- fun pretty_eqns (s, (_, eqns)) =
+ fun pretty_equations const thms =
(Pretty.block o Pretty.fbreaks) (
- Pretty.str s :: map (Display.pretty_thm ctxt o fst) eqns
+ Pretty.str (string_of_const thy const) :: map (Display.pretty_thm ctxt) thms
);
- fun pretty_dtyp (s, []) =
- Pretty.str s
- | pretty_dtyp (s, cos) =
- (Pretty.block o Pretty.breaks) (
- Pretty.str s
- :: Pretty.str "="
- :: separate (Pretty.str "|") (map (fn (c, []) => Pretty.str (string_of_const thy c)
- | (c, tys) =>
- (Pretty.block o Pretty.breaks)
- (Pretty.str (string_of_const thy c)
- :: Pretty.str "of"
- :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
- );
+ fun pretty_function (const, Default eqns) = pretty_equations const (map fst eqns)
+ | pretty_function (const, Eqns eqns) = pretty_equations const (map fst eqns)
+ | pretty_function (const, Proj (proj, _)) = Pretty.block
+ [Pretty.str (string_of_const thy const), Pretty.fbrk, Syntax.pretty_term ctxt proj]
+ | pretty_function (const, Abstr (thm, _)) = pretty_equations const [thm];
+ fun pretty_typ (tyco, vs) = Pretty.str
+ (string_of_typ thy (Type (tyco, map TFree vs)));
+ fun pretty_typspec (typ, (cos, abstract)) = if null cos
+ then pretty_typ typ
+ else (Pretty.block o Pretty.breaks) (
+ pretty_typ typ
+ :: Pretty.str "="
+ :: (if abstract then [Pretty.str "(abstract)"] else [])
+ @ separate (Pretty.str "|") (map (fn (c, []) => Pretty.str (string_of_const thy c)
+ | (c, tys) =>
+ (Pretty.block o Pretty.breaks)
+ (Pretty.str (string_of_const thy c)
+ :: Pretty.str "of"
+ :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
+ );
fun pretty_case (const, (_, (_, []))) = Pretty.str (string_of_const thy const)
| pretty_case (const, (_, (_, cos))) = (Pretty.block o Pretty.breaks) [
Pretty.str (string_of_const thy const), Pretty.str "with",
(Pretty.block o Pretty.commas o map (Pretty.str o string_of_const thy)) cos];
- val eqns = the_eqns exec
+ val functions = the_functions exec
|> Symtab.dest
- |> (map o apfst) (string_of_const thy)
|> (map o apsnd) (snd o fst)
|> sort (string_ord o pairself fst);
- val dtyps = the_dtyps exec
+ val datatypes = the_datatypes exec
|> Symtab.dest
- |> map (fn (dtco, (_, (vs, cos)) :: _) =>
- (string_of_typ thy (Type (dtco, map TFree vs)), cos))
- |> sort (string_ord o pairself fst);
+ |> map (fn (tyco, (_, (vs, spec)) :: _) =>
+ ((tyco, vs), constructors_of spec))
+ |> sort (string_ord o pairself (fst o fst));
val cases = Symtab.dest ((fst o the_cases o the_exec) thy);
val undefineds = Symtab.keys ((snd o the_cases o the_exec) thy);
in
(Pretty.writeln o Pretty.chunks) [
Pretty.block (
Pretty.str "code equations:" :: Pretty.fbrk
- :: (Pretty.fbreaks o map pretty_eqns) eqns
+ :: (Pretty.fbreaks o map pretty_function) functions
),
Pretty.block (
Pretty.str "datatypes:" :: Pretty.fbrk
- :: (Pretty.fbreaks o map pretty_dtyp) dtyps
+ :: (Pretty.fbreaks o map pretty_typspec) datatypes
),
Pretty.block (
Pretty.str "cases:" :: Pretty.fbrk
@@ -816,11 +1013,27 @@
(* code equations *)
-fun gen_add_eqn default (thm, proper) thy =
+fun gen_add_eqn default (raw_thm, proper) thy =
let
- val thm' = Thm.close_derivation thm;
- val c = const_eqn thy thm';
- in change_eqns false c (add_thm thy default (thm', proper)) thy end;
+ val thm = Thm.close_derivation raw_thm;
+ val c = const_eqn thy thm;
+ fun add_eqn' true (Default eqns) = Default (eqns @ [(thm, proper)])
+ | add_eqn' _ (Eqns eqns) =
+ let
+ val args_of = snd o strip_comb o map_types Type.strip_sorts
+ o fst o Logic.dest_equals o Thm.plain_prop_of;
+ val args = args_of thm;
+ val incr_idx = Logic.incr_indexes ([], Thm.maxidx_of thm + 1);
+ fun matches_args args' = length args <= length args' andalso
+ Pattern.matchess thy (args, (map incr_idx o take (length args)) args');
+ fun drop (thm', proper') = if (proper orelse not proper')
+ andalso matches_args (args_of thm') then
+ (warning ("Code generator: dropping redundant code equation\n" ^
+ Display.string_of_thm_global thy thm'); true)
+ else false;
+ in Eqns ((thm, proper) :: filter_out drop eqns) end
+ | add_eqn' false _ = Eqns [(thm, proper)];
+ in change_fun_spec false c (add_eqn' default) thy end;
fun add_eqn thm thy =
gen_add_eqn false (mk_eqn thy (thm, true)) thy;
@@ -842,11 +1055,22 @@
(fn thm => Context.mapping (add_default_eqn thm) I);
val add_default_eqn_attrib = Attrib.internal (K add_default_eqn_attribute);
+fun add_abs_eqn raw_thm thy =
+ let
+ val (abs_thm, tyco) = (apfst Thm.close_derivation o mk_abs_eqn thy) raw_thm;
+ val c = const_abs_eqn thy abs_thm;
+ in change_fun_spec false c (K (Abstr (abs_thm, tyco))) thy end;
+
fun del_eqn thm thy = case mk_eqn_liberal thy thm
- of SOME (thm, _) => change_eqns true (const_eqn thy thm) (del_thm thm) thy
+ of SOME (thm, _) => let
+ fun del_eqn' (Default eqns) = empty_fun_spec
+ | del_eqn' (Eqns eqns) =
+ Eqns (filter_out (fn (thm', _) => Thm.eq_thm_prop (thm, thm')) eqns)
+ | del_eqn' spec = spec
+ in change_fun_spec true (const_eqn thy thm) del_eqn' thy end
| NONE => thy;
-fun del_eqns c = change_eqns true c (K (false, []));
+fun del_eqns c = change_fun_spec true c (K empty_fun_spec);
(* cases *)
@@ -869,32 +1093,69 @@
structure Type_Interpretation =
Interpretation(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
-fun add_datatype raw_cs thy =
+fun register_datatype (tyco, vs_spec) thy =
let
- val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs;
- val (tyco, vs_cos) = constrset_of_consts thy cs;
- val old_cs = (map fst o snd o get_datatype thy) tyco;
+ val (old_constrs, some_old_proj) =
+ case these (Symtab.lookup ((the_datatypes o the_exec) thy) tyco)
+ of (_, (_, Constructors cos)) :: _ => (map fst cos, NONE)
+ | (_, (_, Abstractor (_, (co, _)))) :: _ => ([], SOME co)
+ | [] => ([], NONE)
+ val outdated_funs = case some_old_proj
+ of NONE => []
+ | SOME old_proj => Symtab.fold
+ (fn (c, ((_, spec), _)) => if member (op =) (the_list (associated_abstype spec)) tyco
+ then insert (op =) c else I)
+ ((the_functions o the_exec) thy) [old_proj];
fun drop_outdated_cases cases = fold Symtab.delete_safe
(Symtab.fold (fn (c, (_, (_, cos))) =>
- if exists (member (op =) old_cs) cos
+ if exists (member (op =) old_constrs) cos
then insert (op =) c else I) cases []) cases;
in
thy
- |> fold (del_eqns o fst) cs
+ |> fold del_eqns outdated_funs
|> map_exec_purge
- ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos))
+ ((map_typs o Symtab.map_default (tyco, [])) (cons (serial (), vs_spec))
#> (map_cases o apfst) drop_outdated_cases)
|> Type_Interpretation.data (tyco, serial ())
end;
-fun type_interpretation f = Type_Interpretation.interpretation
+fun type_interpretation f = Type_Interpretation.interpretation
(fn (tyco, _) => fn thy => f (tyco, get_datatype thy tyco) thy);
-fun add_datatype_cmd raw_cs thy =
+fun unoverload_const_typ thy (c, ty) = (AxClass.unoverload_const thy (c, ty), ty);
+
+fun add_datatype proto_constrs thy =
+ let
+ val constrs = map (unoverload_const_typ thy) proto_constrs;
+ val (tyco, (vs, cos)) = constrset_of_consts thy constrs;
+ in
+ thy
+ |> fold (del_eqns o fst) constrs
+ |> register_datatype (tyco, (vs, Constructors cos))
+ end;
+
+fun add_datatype_cmd raw_constrs thy =
+ add_datatype (map (read_bare_const thy) raw_constrs) thy;
+
+fun add_abstype proto_abs proto_rep thy =
let
- val cs = map (read_bare_const thy) raw_cs;
- in add_datatype cs thy end;
+ val (abs, rep) = pairself (unoverload_const_typ thy) (proto_abs, proto_rep);
+ val (tyco, (vs, (abs_ty as (abs, ty), (rep, cert_prop)))) = abstype_cert thy abs (fst rep);
+ fun after_qed [[cert]] = ProofContext.theory
+ (register_datatype (tyco, (vs, Abstractor (abs_ty, (rep, cert))))
+ #> change_fun_spec false rep ((K o Proj)
+ (map_types Logic.varifyT (mk_proj tyco vs ty abs rep), tyco)));
+ in
+ thy
+ |> ProofContext.init
+ |> Proof.theorem_i NONE after_qed [[(cert_prop, [])]]
+ end;
+fun add_abstype_cmd raw_abs raw_rep thy =
+ add_abstype (read_bare_const thy raw_abs) (read_bare_const thy raw_rep) thy;
+
+
+(** infrastructure **)
(* c.f. src/HOL/Tools/recfun_codegen.ML *)
@@ -912,6 +1173,8 @@
let
val attr = the_default ((K o K) I) (Code_Target_Attr.get thy);
in thy |> add_warning_eqn thm |> attr prefix thm end;
+
+
(* setup *)
val _ = Context.>> (Context.map_theory
@@ -920,6 +1183,7 @@
val code_attribute_parser =
Args.del |-- Scan.succeed (mk_attribute del_eqn)
|| Args.$$$ "nbe" |-- Scan.succeed (mk_attribute add_nbe_eqn)
+ || Args.$$$ "abstract" |-- Scan.succeed (mk_attribute add_abs_eqn)
|| (Args.$$$ "target" |-- Args.colon |-- Args.name >>
(mk_attribute o code_target_attr))
|| Scan.succeed (mk_attribute add_warning_eqn);
@@ -932,7 +1196,7 @@
end; (*struct*)
-(** type-safe interfaces for data dependent on executable code **)
+(* type-safe interfaces for data dependent on executable code *)
functor Code_Data(Data: CODE_DATA_ARGS): CODE_DATA =
struct