# HG changeset patch # User haftmann # Date 1204894387 -3600 # Node ID c30bb8182da29fa53e7e12ae9d3149628f6084ea # Parent 4bc6e3ff8b7814543fc44533a8cf85d4279857fe generic improvable syntax for targets diff -r 4bc6e3ff8b78 -r c30bb8182da2 src/Pure/Isar/class.ML --- a/src/Pure/Isar/class.ML Fri Mar 07 13:53:06 2008 +0100 +++ b/src/Pure/Isar/class.ML Fri Mar 07 13:53:07 2008 +0100 @@ -70,15 +70,6 @@ (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE) #> ProofContext.theory_of; -fun get_remove_global_constraint c thy = - let - val ty = Sign.the_const_constraint thy c; - in - thy - |> Sign.add_const_constraint (c, NONE) - |> pair (c, Logic.unvarifyT ty) - end; - (** primitive axclass and instance commands **) @@ -345,15 +336,20 @@ fun class_interpretation class facts defs thy = let val params = these_params thy [class]; + val consts = map (fst o snd) params; + val constraints = map (fn c => map_atyps (K (TFree (Name.aT, + [the (AxClass.class_of_param thy c)]))) (Sign.the_const_type thy c)) consts; + val no_constraints = map (map_atyps (K (TFree (Name.aT, [])))) constraints; + fun add_constraint c T = Sign.add_const_constraint (c, SOME T); val inst = (#inst o the_class_data thy) class; val tac = ALLGOALS (ProofContext.fact_tac facts); val prfx = class_prefix class; in thy - |> fold_map (get_remove_global_constraint o fst o snd) params - ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class) + |> fold2 add_constraint consts no_constraints + |> prove_interpretation tac ((false, prfx), []) (Locale.Locale class) (inst, map (fn def => (("", []), def)) defs) - |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs) + |> fold2 add_constraint consts constraints end; fun prove_subclass (sub, sup) thm thy = @@ -398,26 +394,7 @@ (* class context syntax *) -structure ClassSyntax = ProofDataFun( - type T = { - local_constraints: (string * typ) list, - global_constraints: (string * typ) list, - base_sort: sort, - operations: (string * (typ * term)) list, - unchecks: (term * term) list, - passed: bool - }; - fun init _ = { - local_constraints = [], - global_constraints = [], - base_sort = [], - operations = [], - unchecks = [], - passed = true - };; -); - -fun synchronize_syntax sups base_sort ctxt = +fun synchronize_class_syntax sups base_sort ctxt = let val thy = ProofContext.theory_of ctxt; fun subst_class_typ sort = map_atyps @@ -430,80 +407,44 @@ fun declare_const (c, _) = let val b = Sign.base_name c in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end; + fun improve (c, ty) = (case AList.lookup (op =) local_constraints c + of SOME ty' => (case try (Type.raw_match (ty', ty)) Vartab.empty + of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0) + of SOME (_, ty' as TVar (tvar as (vi, _))) => + if TypeInfer.is_param vi + then SOME (ty', TFree (Name.aT, base_sort)) + else NONE + | _ => NONE) + | NONE => NONE) + | NONE => NONE) + fun subst (c, ty) = Option.map snd (AList.lookup (op =) operations c); val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations; in ctxt |> fold declare_const local_constraints |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints - |> ClassSyntax.put { + |> Overloading.map_improvable_syntax (K { local_constraints = local_constraints, global_constraints = global_constraints, - base_sort = base_sort, - operations = (map o apsnd) snd operations, + improve = improve, + subst = subst, unchecks = unchecks, passed = false - } + }) end; fun refresh_syntax class ctxt = let val thy = ProofContext.theory_of ctxt; val base_sort = (#base_sort o the_class_data thy) class; - in synchronize_syntax [class] base_sort ctxt end; - -val mark_passed = ClassSyntax.map - (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } => - { local_constraints = local_constraints, global_constraints = global_constraints, - base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true }); - -fun sort_term_check ts ctxt = - let - val { local_constraints, global_constraints, base_sort, operations, passed, ... } = - ClassSyntax.get ctxt; - fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c - of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty - of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0) - of SOME (_, TVar (tvar as (vi, _))) => - if TypeInfer.is_param vi then cons tvar else I - | _ => I) - | NONE => I) - | NONE => I) - | check_improve _ = I; - val improvements = (fold o fold_aterms) check_improve ts []; - val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar => - if member (op =) improvements tvar - then TFree (Name.aT, base_sort) else ty | ty => ty) ts; - fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c - of SOME (ty0, t) => - if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0) - then SOME (ty0, check t) else NONE - | NONE => NONE) - | _ => NONE) t0; - val ts'' = map check ts'; - in if eq_list (op aconv) (ts, ts'') andalso passed then NONE - else - ctxt - |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints - |> mark_passed - |> pair ts'' - |> SOME - end; - -fun sort_term_uncheck ts ctxt = - let - val thy = ProofContext.theory_of ctxt; - val unchecks = (#unchecks o ClassSyntax.get) ctxt; - val ts' = map (Pattern.rewrite_term thy unchecks []) ts; - in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end; + in synchronize_class_syntax [class] base_sort ctxt end; fun init_ctxt sups base_sort ctxt = ctxt |> Variable.declare_term (Logic.mk_type (TFree (Name.aT, base_sort))) - |> synchronize_syntax sups base_sort - |> Context.proof_map ( - Syntax.add_term_check 0 "class" sort_term_check - #> Syntax.add_term_uncheck 0 "class" sort_term_uncheck) + |> synchronize_class_syntax sups base_sort + |> Overloading.add_improvable_syntax; fun init class thy = thy @@ -710,46 +651,52 @@ |> find_first (fn (_, (v', _)) => v = v') |> Option.map (fst o fst); -fun confirm_declaration c = (map_instantiation o apsnd) - (filter_out (fn (_, (c', _)) => c' = c)); - (* syntax *) -fun inst_term_check ts lthy = +fun synchronize_inst_syntax ctxt = let - val params = instantiation_params lthy; - val tsig = ProofContext.tsig_of lthy; - val thy = ProofContext.theory_of lthy; - - fun check_improve (Const (c, ty)) = (case AxClass.inst_tyco_of thy (c, ty) - of SOME tyco => (case AList.lookup (op =) params (c, tyco) - of SOME (_, ty') => perhaps (try (Type.typ_match tsig (ty, ty'))) - | NONE => I) - | NONE => I) - | check_improve _ = I; - val subst_param = map_aterms (fn t as Const (c, ty) => - (case AxClass.inst_tyco_of thy (c, ty) + val Instantiation { arities = (_, _, sorts), params = params } = Instantiation.get ctxt; + val thy = ProofContext.theory_of ctxt; + val operations = these_operations thy sorts; + fun subst_class_typ sort = map_atyps + (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty'); + val local_constraints = + (map o apsnd) (subst_class_typ [] o fst o snd) operations; + val global_constraints = map_filter (fn (c, (class, (ty, _))) => + if exists (fn ((c', _), _) => c = c') params + then SOME (c, subst_class_typ [class] ty) + else NONE) operations; + fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty) of SOME tyco => (case AList.lookup (op =) params (c, tyco) - of SOME v_ty => Free v_ty - | NONE => t) - | NONE => t) - | t => t); - - val improvement = (fold o fold_aterms) check_improve ts Vartab.empty; - val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts; - val ts'' = map subst_param ts'; - in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end; - -fun inst_term_uncheck ts lthy = - let - val params = instantiation_params lthy; - val ts' = (map o map_aterms) (fn t as Free (v, ty) => - (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params - of SOME c => Const (c, ty) - | NONE => t) - | t => t) ts; - in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end; + of SOME (_, ty') => SOME (ty, ty') + | NONE => NONE) + | NONE => NONE; + (*| NONE => (case map_filter + (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params + of [ty'] => (case Sign.const_typargs thy (c, ty) + of [TVar (vi, _)] => if TypeInfer.is_param vi then SOME (ty, ty') else NONE + | _ => NONE) + | _ => NONE*); + fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty) + of SOME tyco => (case AList.lookup (op =) params (c, tyco) + of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty) + | NONE => NONE) + | NONE => NONE; + val unchecks = + map (fn ((c, _), v_ty as (_, ty)) => (Free v_ty, Const (c, ty))) params; + in + ctxt + |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints + |> Overloading.map_improvable_syntax (K { + local_constraints = local_constraints, + global_constraints = global_constraints, + improve = improve, + subst = subst, + unchecks = unchecks, + passed = false + }) + end; (* target *) @@ -786,12 +733,14 @@ |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params)) |> fold (Variable.declare_term o Logic.mk_type o TFree) vs |> fold (Variable.declare_names o Free o snd) params - |> fold (fn tyco => ProofContext.add_arity (tyco, map snd vs, sort)) tycos - |> Context.proof_map ( - Syntax.add_term_check 0 "instance" inst_term_check - #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck) + |> synchronize_inst_syntax + |> Overloading.add_improvable_syntax end; +fun confirm_declaration c = (map_instantiation o apsnd) + (filter_out (fn (_, (c', _)) => c' = c)) + #> LocalTheory.target synchronize_inst_syntax + fun gen_instantiation_instance do_proof after_qed lthy = let val (tycos, vs, sort) = (#arities o the_instantiation) lthy; diff -r 4bc6e3ff8b78 -r c30bb8182da2 src/Pure/Isar/overloading.ML --- a/src/Pure/Isar/overloading.ML Fri Mar 07 13:53:06 2008 +0100 +++ b/src/Pure/Isar/overloading.ML Fri Mar 07 13:53:07 2008 +0100 @@ -14,6 +14,11 @@ val define: bool -> string -> string * term -> theory -> thm * theory val operation: Proof.context -> string -> (string * bool) option val pretty: Proof.context -> Pretty.T + + type improvable_syntax + val add_improvable_syntax: Proof.context -> Proof.context + val map_improvable_syntax: (improvable_syntax -> improvable_syntax) + -> Proof.context -> Proof.context end; structure Overloading: OVERLOADING = @@ -44,45 +49,101 @@ Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t)); -(* syntax *) +(* generic check/uncheck combinators for improvable constants *) + +type improvable_syntax = { + local_constraints: (string * typ) list, + global_constraints: (string * typ) list, + improve: string * typ -> (typ * typ) option, + subst: string * typ -> (typ * term) option, + unchecks: (term * term) list, + passed: bool +}; -fun subst_operation overloading = map_aterms (fn t as Const (c, ty) => - (case AList.lookup (op =) overloading (c, ty) - of SOME (v, _) => Free (v, ty) - | NONE => t) - | t => t); +structure ImprovableSyntax = ProofDataFun( + type T = improvable_syntax; + fun init _ = { + local_constraints = [], + global_constraints = [], + improve = K NONE, + subst = K NONE, + unchecks = [], + passed = true + }; +); -fun term_check ts lthy = +val map_improvable_syntax = ImprovableSyntax.map; + +val mark_passed = map_improvable_syntax + (fn { local_constraints, global_constraints, improve, subst, unchecks, passed } => + { local_constraints = local_constraints, global_constraints = global_constraints, + improve = improve, subst = subst, unchecks = unchecks, passed = true }); + +fun improve_term_check ts ctxt = let - val overloading = get_overloading lthy; - val ts' = map (subst_operation overloading) ts; - in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end; + val { local_constraints, global_constraints, improve, subst, passed, ... } = + ImprovableSyntax.get ctxt; + val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt; + fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty) + of SOME ty_ty' => (perhaps o try o Type.typ_match tsig) ty_ty' + | _ => I) + | accumulate_improvements _ = I; + val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty; + val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts; + fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty) + of SOME (ty', t') => + if Type.typ_instance tsig (ty, ty') + then SOME (ty', apply_subst t') else NONE + | NONE => NONE) + | _ => NONE) t; + val ts'' = map apply_subst ts'; + in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else + if passed then SOME (ts'', ctxt) + else SOME (ts'', ctxt + |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints + |> mark_passed) + end; -fun term_uncheck ts lthy = +fun improve_term_uncheck ts ctxt = let - val overloading = get_overloading lthy; - fun subst (t as Free (v, ty)) = (case get_first (fn ((c, _), (v', _)) => - if v = v' then SOME c else NONE) overloading - of SOME c => Const (c, ty) - | NONE => t) - | subst t = t; - val ts' = (map o map_aterms) subst ts; - in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end; + val thy = ProofContext.theory_of ctxt; + val unchecks = (#unchecks o ImprovableSyntax.get) ctxt; + val ts' = map (Pattern.rewrite_term thy unchecks []) ts; + in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end; + +fun add_improvable_syntax ctxt = ctxt + |> Context.proof_map + (Syntax.add_term_check 0 "improvement" improve_term_check + #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck) + |> fold (ProofContext.add_const_constraint o apsnd SOME) + ((#local_constraints o ImprovableSyntax.get) ctxt); (* target *) -fun init overloading thy = +fun init raw_overloading thy = let - val _ = if null overloading then error "At least one parameter must be given" else (); + val _ = if null raw_overloading then error "At least one parameter must be given" else (); + val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading; + fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty) + of SOME (v, _) => SOME (ty, Free (v, ty)) + | NONE => NONE; + val unchecks = + map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading; in thy |> ProofContext.init - |> OverloadingData.put (map (fn (v, c_ty, checked) => (c_ty, (v, checked))) overloading) - |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) overloading - |> Context.proof_map ( - Syntax.add_term_check 0 "overloading" term_check - #> Syntax.add_term_uncheck 0 "overloading" term_uncheck) + |> OverloadingData.put overloading + |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading + |> map_improvable_syntax (K { + local_constraints = [], + global_constraints = [], + improve = K NONE, + subst = subst, + unchecks = unchecks, + passed = false + }) + |> add_improvable_syntax end; fun conclude lthy =