src/Pure/variable.ML
author boehmes
Wed, 12 May 2010 23:53:59 +0200
changeset 36895 a96f9793d9c5
parent 36620 e6bb250402b5
child 37145 01aa36932739
permissions -rw-r--r--
split monolithic Z3 proof reconstruction structure into separate structures, use one set of schematic theorems for all uncertain proof rules (to extend proof reconstruction by missing cases), added several schematic theorems, improved abstraction of goals (abstract all uninterpreted sub-terms, only leave builtin symbols)

(*  Title:      Pure/variable.ML
    Author:     Makarius

Fixed type/term variables and polymorphic term abbreviations.
*)

signature VARIABLE =
sig
  val is_body: Proof.context -> bool
  val set_body: bool -> Proof.context -> Proof.context
  val restore_body: Proof.context -> Proof.context -> Proof.context
  val names_of: Proof.context -> Name.context
  val fixes_of: Proof.context -> (string * string) list
  val binds_of: Proof.context -> (typ * term) Vartab.table
  val maxidx_of: Proof.context -> int
  val sorts_of: Proof.context -> sort list
  val constraints_of: Proof.context -> typ Vartab.table * sort Vartab.table
  val is_declared: Proof.context -> string -> bool
  val is_fixed: Proof.context -> string -> bool
  val newly_fixed: Proof.context -> Proof.context -> string -> bool
  val add_fixed: Proof.context -> term -> (string * typ) list -> (string * typ) list
  val default_type: Proof.context -> string -> typ option
  val def_type: Proof.context -> bool -> indexname -> typ option
  val def_sort: Proof.context -> indexname -> sort option
  val declare_names: term -> Proof.context -> Proof.context
  val declare_constraints: term -> Proof.context -> Proof.context
  val declare_term: term -> Proof.context -> Proof.context
  val declare_typ: typ -> Proof.context -> Proof.context
  val declare_prf: Proofterm.proof -> Proof.context -> Proof.context
  val declare_thm: thm -> Proof.context -> Proof.context
  val global_thm_context: thm -> Proof.context
  val variant_frees: Proof.context -> term list -> (string * 'a) list -> (string * 'a) list
  val bind_term: indexname * term option -> Proof.context -> Proof.context
  val expand_binds: Proof.context -> term -> term
  val lookup_const: Proof.context -> string -> string option
  val is_const: Proof.context -> string -> bool
  val declare_const: string * string -> Proof.context -> Proof.context
  val add_fixes: string list -> Proof.context -> string list * Proof.context
  val add_fixes_direct: string list -> Proof.context -> Proof.context
  val auto_fixes: term -> Proof.context -> Proof.context
  val variant_fixes: string list -> Proof.context -> string list * Proof.context
  val invent_types: sort list -> Proof.context -> (string * sort) list * Proof.context
  val export_terms: Proof.context -> Proof.context -> term list -> term list
  val exportT_terms: Proof.context -> Proof.context -> term list -> term list
  val exportT: Proof.context -> Proof.context -> thm list -> thm list
  val export_prf: Proof.context -> Proof.context -> Proofterm.proof -> Proofterm.proof
  val export: Proof.context -> Proof.context -> thm list -> thm list
  val export_morphism: Proof.context -> Proof.context -> morphism
  val importT_inst: term list -> Proof.context -> ((indexname * sort) * typ) list * Proof.context
  val import_inst: bool -> term list -> Proof.context ->
    (((indexname * sort) * typ) list * ((indexname * typ) * term) list) * Proof.context
  val importT_terms: term list -> Proof.context -> term list * Proof.context
  val import_terms: bool -> term list -> Proof.context -> term list * Proof.context
  val importT: thm list -> Proof.context -> ((ctyp * ctyp) list * thm list) * Proof.context
  val import_prf: bool -> Proofterm.proof -> Proof.context -> Proofterm.proof * Proof.context
  val import: bool -> thm list -> Proof.context ->
    (((ctyp * ctyp) list * (cterm * cterm) list) * thm list) * Proof.context
  val tradeT: (Proof.context -> thm list -> thm list) -> Proof.context -> thm list -> thm list
  val trade: (Proof.context -> thm list -> thm list) -> Proof.context -> thm list -> thm list
  val focus: cterm -> Proof.context -> ((string * cterm) list * cterm) * Proof.context
  val focus_subgoal: int -> thm -> Proof.context -> ((string * cterm) list * cterm) * Proof.context
  val warn_extra_tfrees: Proof.context -> Proof.context -> unit
  val polymorphic_types: Proof.context -> term list -> (indexname * sort) list * term list
  val polymorphic: Proof.context -> term list -> term list
end;

structure Variable: VARIABLE =
struct

(** local context data **)

datatype data = Data of
 {is_body: bool,                        (*inner body mode*)
  names: Name.context,                  (*type/term variable names*)
  consts: string Symtab.table,          (*consts within the local scope*)
  fixes: (string * string) list,        (*term fixes -- extern/intern*)
  binds: (typ * term) Vartab.table,     (*term bindings*)
  type_occs: string list Symtab.table,  (*type variables -- possibly within term variables*)
  maxidx: int,                          (*maximum var index*)
  sorts: sort OrdList.T,                (*declared sort occurrences*)
  constraints:
    typ Vartab.table *                  (*type constraints*)
    sort Vartab.table};                 (*default sorts*)

fun make_data (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =
  Data {is_body = is_body, names = names, consts = consts, fixes = fixes, binds = binds,
    type_occs = type_occs, maxidx = maxidx, sorts = sorts, constraints = constraints};

structure Data = Proof_Data
(
  type T = data;
  fun init _ =
    make_data (false, Name.context, Symtab.empty, [], Vartab.empty, Symtab.empty,
      ~1, [], (Vartab.empty, Vartab.empty));
);

fun map_data f =
  Data.map (fn Data {is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints} =>
    make_data (f (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints)));

fun map_names f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, f names, consts, fixes, binds, type_occs, maxidx, sorts, constraints));

fun map_consts f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, f consts, fixes, binds, type_occs, maxidx, sorts, constraints));

fun map_fixes f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, f fixes, binds, type_occs, maxidx, sorts, constraints));

fun map_binds f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, fixes, f binds, type_occs, maxidx, sorts, constraints));

fun map_type_occs f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, fixes, binds, f type_occs, maxidx, sorts, constraints));

fun map_maxidx f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, fixes, binds, type_occs, f maxidx, sorts, constraints));

fun map_sorts f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, fixes, binds, type_occs, maxidx, f sorts, constraints));

fun map_constraints f =
  map_data (fn (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (is_body, names, consts, fixes, binds, type_occs, maxidx, sorts, f constraints));

fun rep_data ctxt = Data.get ctxt |> (fn Data args => args);

val is_body = #is_body o rep_data;

fun set_body b =
  map_data (fn (_, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints) =>
    (b, names, consts, fixes, binds, type_occs, maxidx, sorts, constraints));

fun restore_body ctxt = set_body (is_body ctxt);

val names_of = #names o rep_data;
val fixes_of = #fixes o rep_data;
val binds_of = #binds o rep_data;
val type_occs_of = #type_occs o rep_data;
val maxidx_of = #maxidx o rep_data;
val sorts_of = #sorts o rep_data;
val constraints_of = #constraints o rep_data;

val is_declared = Name.is_declared o names_of;
fun is_fixed ctxt x = exists (fn (_, y) => x = y) (fixes_of ctxt);
fun newly_fixed inner outer x = is_fixed inner x andalso not (is_fixed outer x);

fun add_fixed ctxt = Term.fold_aterms
  (fn Free (x, T) => if is_fixed ctxt x then insert (op =) (x, T) else I | _ => I);



(** declarations **)

(* default sorts and types *)

fun default_type ctxt x = Vartab.lookup (#1 (constraints_of ctxt)) (x, ~1);

fun def_type ctxt pattern xi =
  let val {binds, constraints = (types, _), ...} = rep_data ctxt in
    (case Vartab.lookup types xi of
      NONE =>
        if pattern then NONE
        else Vartab.lookup binds xi |> Option.map (TypeInfer.polymorphicT o #1)
    | some => some)
  end;

val def_sort = Vartab.lookup o #2 o constraints_of;


(* names *)

fun declare_type_names t =
  map_names (fold_types (fold_atyps Term.declare_typ_names) t) #>
  map_maxidx (fold_types Term.maxidx_typ t);

fun declare_names t =
  declare_type_names t #>
  map_names (fold_aterms Term.declare_term_frees t) #>
  map_maxidx (Term.maxidx_term t);


(* type occurrences *)

fun decl_type_occsT T = fold_atyps (fn TFree (a, _) => Symtab.default (a, []) | _ => I) T;

val decl_type_occs = fold_term_types
  (fn Free (x, _) => fold_atyps (fn TFree (a, _) => Symtab.insert_list (op =) (a, x) | _ => I)
    | _ => decl_type_occsT);

val declare_type_occsT = map_type_occs o fold_types decl_type_occsT;
val declare_type_occs = map_type_occs o decl_type_occs;


(* constraints *)

fun constrain_tvar (xi, S) =
  if S = dummyS then Vartab.delete_safe xi else Vartab.update (xi, S);

fun declare_constraints t = map_constraints (fn (types, sorts) =>
  let
    val types' = fold_aterms
      (fn Free (x, T) => Vartab.update ((x, ~1), T)
        | Var v => Vartab.update v
        | _ => I) t types;
    val sorts' = fold_types (fold_atyps
      (fn TFree (x, S) => constrain_tvar ((x, ~1), S)
        | TVar v => constrain_tvar v
        | _ => I)) t sorts;
  in (types', sorts') end)
  #> declare_type_occsT t
  #> declare_type_names t;


(* common declarations *)

fun declare_internal t =
  declare_names t #>
  declare_type_occs t #>
  map_sorts (Sorts.insert_term t);

fun declare_term t =
  declare_internal t #>
  declare_constraints t;

val declare_typ = declare_term o Logic.mk_type;

val declare_prf = Proofterm.fold_proof_terms declare_internal (declare_internal o Logic.mk_type);

val declare_thm = Thm.fold_terms declare_internal;
fun global_thm_context th = declare_thm th (ProofContext.init_global (Thm.theory_of_thm th));


(* renaming term/type frees *)

fun variant_frees ctxt ts frees =
  let
    val names = names_of (fold declare_names ts ctxt);
    val xs = fst (Name.variants (map #1 frees) names);
  in xs ~~ map snd frees end;



(** term bindings **)

fun bind_term (xi, NONE) = map_binds (Vartab.delete_safe xi)
  | bind_term ((x, i), SOME t) =
      let
        val u = Term.close_schematic_term t;
        val U = Term.fastype_of u;
      in declare_term u #> map_binds (Vartab.update ((x, i), (U, u))) end;

fun expand_binds ctxt =
  let
    val binds = binds_of ctxt;
    val get = fn Var (xi, _) => Vartab.lookup binds xi | _ => NONE;
  in Envir.beta_norm o Envir.expand_term get end;



(** consts **)

val lookup_const = Symtab.lookup o #consts o rep_data;
val is_const = is_some oo lookup_const;

val declare_fixed = map_consts o Symtab.delete_safe;
val declare_const = map_consts o Symtab.update;



(** fixes **)

local

fun no_dups [] = ()
  | no_dups dups = error ("Duplicate fixed variable(s): " ^ commas_quote dups);

fun new_fixes names' xs xs' =
  map_names (K names') #>
  fold declare_fixed xs #>
  map_fixes (fn fixes => (rev (xs ~~ xs') @ fixes)) #>
  fold (declare_constraints o Syntax.free) xs' #>
  pair xs';

in

fun add_fixes xs ctxt =
  let
    val _ =
      (case filter (can Name.dest_skolem) xs of [] => ()
      | bads => error ("Illegal internal Skolem constant(s): " ^ commas_quote bads));
    val _ = no_dups (duplicates (op =) xs);
    val (ys, zs) = split_list (fixes_of ctxt);
    val names = names_of ctxt;
    val (xs', names') =
      if is_body ctxt then Name.variants xs names |>> map Name.skolem
      else (no_dups (inter (op =) xs ys); no_dups (inter (op =) xs zs);
        (xs, fold Name.declare xs names));
  in ctxt |> new_fixes names' xs xs' end;

fun variant_fixes raw_xs ctxt =
  let
    val names = names_of ctxt;
    val xs = map (fn x => Name.clean x |> can Name.dest_internal x ? Name.internal) raw_xs;
    val (xs', names') = Name.variants xs names |>> (is_body ctxt ? map Name.skolem);
  in ctxt |> new_fixes names' xs xs' end;

end;


fun add_fixes_direct xs ctxt = ctxt
  |> set_body false
  |> (snd o add_fixes xs)
  |> restore_body ctxt;

fun fix_frees t ctxt = ctxt
  |> add_fixes_direct
      (rev (fold_aterms (fn Free (x, _) =>
        if is_fixed ctxt x then I else insert (op =) x | _ => I) t []));

fun auto_fixes t ctxt =
  (if is_body ctxt then ctxt else fix_frees t ctxt)
  |> declare_term t;

fun invent_types Ss ctxt =
  let
    val tfrees = Name.invents (names_of ctxt) Name.aT (length Ss) ~~ Ss;
    val ctxt' = fold (declare_constraints o Logic.mk_type o TFree) tfrees ctxt;
  in (tfrees, ctxt') end;



(** export -- generalize type/term variables (beware of closure sizes) **)

fun export_inst inner outer =
  let
    val declared_outer = is_declared outer;
    val fixes_inner = fixes_of inner;
    val fixes_outer = fixes_of outer;

    val gen_fixes = map #2 (take (length fixes_inner - length fixes_outer) fixes_inner);
    val still_fixed = not o member (op =) gen_fixes;

    val type_occs_inner = type_occs_of inner;
    fun gen_fixesT ts =
      Symtab.fold (fn (a, xs) =>
        if declared_outer a orelse exists still_fixed xs
        then I else cons a) (fold decl_type_occs ts type_occs_inner) [];
  in (gen_fixesT, gen_fixes) end;

fun exportT_inst inner outer = #1 (export_inst inner outer);

fun exportT_terms inner outer =
  let val mk_tfrees = exportT_inst inner outer in
    fn ts => ts |> map
      (Term_Subst.generalize (mk_tfrees ts, [])
        (fold (Term.fold_types Term.maxidx_typ) ts ~1 + 1))
  end;

fun export_terms inner outer =
  let val (mk_tfrees, tfrees) = export_inst inner outer in
    fn ts => ts |> map
      (Term_Subst.generalize (mk_tfrees ts, tfrees)
        (fold Term.maxidx_term ts ~1 + 1))
  end;

fun export_prf inner outer prf =
  let
    val (mk_tfrees, frees) = export_inst (declare_prf prf inner) outer;
    val tfrees = mk_tfrees [];
    val idx = Proofterm.maxidx_proof prf ~1 + 1;
    val gen_term = Term_Subst.generalize_same (tfrees, frees) idx;
    val gen_typ = Term_Subst.generalizeT_same tfrees idx;
  in Same.commit (Proofterm.map_proof_terms_same gen_term gen_typ) prf end;


fun gen_export (mk_tfrees, frees) ths =
  let
    val tfrees = mk_tfrees (map Thm.full_prop_of ths);
    val maxidx = fold Thm.maxidx_thm ths ~1;
  in map (Thm.generalize (tfrees, frees) (maxidx + 1)) ths end;

fun exportT inner outer = gen_export (exportT_inst inner outer, []);
fun export inner outer = gen_export (export_inst inner outer);

fun export_morphism inner outer =
  let
    val fact = export inner outer;
    val term = singleton (export_terms inner outer);
    val typ = Logic.type_map term;
  in Morphism.morphism {binding = I, typ = typ, term = term, fact = fact} end;



(** import -- fix schematic type/term variables **)

fun importT_inst ts ctxt =
  let
    val tvars = rev (fold Term.add_tvars ts []);
    val (tfrees, ctxt') = invent_types (map #2 tvars) ctxt;
  in (tvars ~~ map TFree tfrees, ctxt') end;

fun import_inst is_open ts ctxt =
  let
    val ren = Name.clean #> (if is_open then I else Name.internal);
    val (instT, ctxt') = importT_inst ts ctxt;
    val vars = map (apsnd (Term_Subst.instantiateT instT)) (rev (fold Term.add_vars ts []));
    val (xs, ctxt'') = variant_fixes (map (ren o #1 o #1) vars) ctxt';
    val inst = vars ~~ map Free (xs ~~ map #2 vars);
  in ((instT, inst), ctxt'') end;

fun importT_terms ts ctxt =
  let val (instT, ctxt') = importT_inst ts ctxt
  in (map (Term_Subst.instantiate (instT, [])) ts, ctxt') end;

fun import_terms is_open ts ctxt =
  let val (inst, ctxt') = import_inst is_open ts ctxt
  in (map (Term_Subst.instantiate inst) ts, ctxt') end;

fun importT ths ctxt =
  let
    val thy = ProofContext.theory_of ctxt;
    val (instT, ctxt') = importT_inst (map Thm.full_prop_of ths) ctxt;
    val insts' as (instT', _) = Thm.certify_inst thy (instT, []);
    val ths' = map (Thm.instantiate insts') ths;
  in ((instT', ths'), ctxt') end;

fun import_prf is_open prf ctxt =
  let
    val ts = rev (Proofterm.fold_proof_terms cons (cons o Logic.mk_type) prf []);
    val (insts, ctxt') = import_inst is_open ts ctxt;
  in (Proofterm.instantiate insts prf, ctxt') end;

fun import is_open ths ctxt =
  let
    val thy = ProofContext.theory_of ctxt;
    val (insts, ctxt') = import_inst is_open (map Thm.full_prop_of ths) ctxt;
    val insts' = Thm.certify_inst thy insts;
    val ths' = map (Thm.instantiate insts') ths;
  in ((insts', ths'), ctxt') end;


(* import/export *)

fun gen_trade imp exp f ctxt ths =
  let val ((_, ths'), ctxt') = imp ths ctxt
  in exp ctxt' ctxt (f ctxt' ths') end;

val tradeT = gen_trade importT exportT;
val trade = gen_trade (import true) export;


(* focus on outermost parameters *)

fun forall_elim_prop t prop =
  Thm.beta_conversion false (Thm.capply (Thm.dest_arg prop) t)
  |> Thm.cprop_of |> Thm.dest_arg;

fun focus goal ctxt =
  let
    val cert = Thm.cterm_of (Thm.theory_of_cterm goal);
    val t = Thm.term_of goal;
    val ps = Term.variant_frees t (Term.strip_all_vars t);   (*as they are printed :-*)
    val (xs, Ts) = split_list ps;
    val (xs', ctxt') = variant_fixes xs ctxt;
    val ps' = ListPair.map (cert o Free) (xs', Ts);
    val goal' = fold forall_elim_prop ps' goal;
    val ctxt'' = ctxt' |> fold (declare_constraints o Thm.term_of) ps';
  in ((xs ~~ ps', goal'), ctxt'') end;

fun focus_subgoal i st =
  let
    val all_vars = Thm.fold_terms Term.add_vars st [];
    val no_binds = map (fn (xi, _) => (xi, NONE)) all_vars;
  in
    fold bind_term no_binds #>
    fold (declare_constraints o Var) all_vars #>
    focus (Thm.cprem_of st i)
  end;



(** implicit polymorphism **)

(* warn_extra_tfrees *)

fun warn_extra_tfrees ctxt1 ctxt2 =
  let
    fun occs_typ a = Term.exists_subtype (fn TFree (b, _) => a = b | _ => false);
    fun occs_free a x =
      (case def_type ctxt1 false (x, ~1) of
        SOME T => if occs_typ a T then I else cons (a, x)
      | NONE => cons (a, x));

    val occs1 = type_occs_of ctxt1;
    val occs2 = type_occs_of ctxt2;
    val extras = Symtab.fold (fn (a, xs) =>
      if Symtab.defined occs1 a then I else fold (occs_free a) xs) occs2 [];
    val tfrees = map #1 extras |> sort_distinct string_ord;
    val frees = map #2 extras |> sort_distinct string_ord;
  in
    if null extras then ()
    else warning ("Introduced fixed type variable(s): " ^ commas tfrees ^ " in " ^
      space_implode " or " (map quote frees))
  end;


(* polymorphic terms *)

fun polymorphic_types ctxt ts =
  let
    val ctxt' = fold declare_term ts ctxt;
    val occs = type_occs_of ctxt;
    val occs' = type_occs_of ctxt';
    val types = Symtab.fold (fn (a, _) => if Symtab.defined occs a then I else cons a) occs' [];
    val idx = maxidx_of ctxt' + 1;
    val Ts' = (fold o fold_types o fold_atyps)
      (fn T as TFree _ =>
          (case Term_Subst.generalizeT types idx T of TVar v => insert (op =) v | _ => I)
        | _ => I) ts [];
    val ts' = map (Term_Subst.generalize (types, []) idx) ts;
  in (rev Ts', ts') end;

fun polymorphic ctxt ts = snd (polymorphic_types ctxt ts);

end;