gen_add_inductive_i: treat abbrevs as local defs, expand by export;
authorwenzelm
Sun Oct 14 16:13:45 2007 +0200 (2007-10-14)
changeset 250293a72718c5ddd
parent 25028 e0f74efc210f
child 25030 7507f590486f
gen_add_inductive_i: treat abbrevs as local defs, expand by export;
tuned;
src/HOL/Tools/inductive_package.ML
     1.1 --- a/src/HOL/Tools/inductive_package.ML	Sun Oct 14 00:18:11 2007 +0200
     1.2 +++ b/src/HOL/Tools/inductive_package.ML	Sun Oct 14 16:13:45 2007 +0200
     1.3 @@ -777,59 +777,59 @@
     1.4  (* external interfaces *)
     1.5  
     1.6  fun gen_add_inductive_i mk_def (flags as {verbose, kind, alt_name, coind, no_elim, no_ind})
     1.7 -    cnames_syn pnames pre_intros monos ctxt =
     1.8 +    cnames_syn pnames spec monos lthy =
     1.9    let
    1.10 -    val thy = ProofContext.theory_of ctxt;
    1.11 +    val thy = ProofContext.theory_of lthy;
    1.12      val _ = Theory.requires thy "Inductive" (coind_prefix coind ^ "inductive definitions");
    1.13  
    1.14 -    fun is_abbrev ((name, atts), t) =
    1.15 -      can (Logic.strip_assums_concl #> Logic.dest_equals) t andalso
    1.16 -      (name = "" andalso null atts orelse
    1.17 -       error "Abbreviations may not have names or attributes");
    1.18  
    1.19 -    fun expand_atom tab (t as Free xT) =
    1.20 -          the_default t (AList.lookup op = tab xT)
    1.21 -      | expand_atom tab t = t;
    1.22 -    fun expand [] r = r
    1.23 -      | expand tab r = Envir.beta_norm (Term.map_aterms (expand_atom tab) r);
    1.24 +    (* abbrevs *)
    1.25 +
    1.26 +    val (_, ctxt1) = Variable.add_fixes (map (fst o fst) cnames_syn) lthy;
    1.27  
    1.28 -    val (_, ctxt') = Variable.add_fixes (map (fst o fst) cnames_syn) ctxt;
    1.29 +    fun get_abbrev ((name, atts), t) =
    1.30 +      if can (Logic.strip_assums_concl #> Logic.dest_equals) t then
    1.31 +        let
    1.32 +          val _ = name = "" andalso null atts orelse
    1.33 +            error "Abbreviations may not have names or attributes";
    1.34 +          val ((x, T), rhs) = LocalDefs.abs_def (snd (LocalDefs.cert_def ctxt1 t));
    1.35 +          val mx =
    1.36 +            (case find_first (fn ((c, _), _) => c = x) cnames_syn of
    1.37 +              NONE => error ("Undeclared head of abbreviation " ^ quote x)
    1.38 +            | SOME ((_, T'), mx) =>
    1.39 +                if T <> T' then error ("Bad type specification for abbreviation " ^ quote x)
    1.40 +                else mx);
    1.41 +        in SOME ((x, mx), rhs) end
    1.42 +      else NONE;
    1.43  
    1.44 -    fun prep_abbrevs [] abbrevs' abbrevs'' = (rev abbrevs', rev abbrevs'')
    1.45 -      | prep_abbrevs ((_, abbrev) :: abbrevs) abbrevs' abbrevs'' =
    1.46 -          let val ((s, T), t) =
    1.47 -            LocalDefs.abs_def (snd (LocalDefs.cert_def ctxt' abbrev))
    1.48 -          in case find_first (equal s o fst o fst) cnames_syn of
    1.49 -              NONE => error ("Head of abbreviation " ^ quote s ^ " undeclared")
    1.50 -            | SOME (_, mx) => prep_abbrevs abbrevs
    1.51 -                (((s, T), expand abbrevs' t) :: abbrevs')
    1.52 -                (((s, mx), expand abbrevs' t) :: abbrevs'') (* FIXME: do not expand *)
    1.53 -          end;
    1.54 +    val abbrevs = map_filter get_abbrev spec;
    1.55 +    val bs = map (fst o fst) abbrevs;
    1.56 +
    1.57  
    1.58 -    val (abbrevs, pre_intros') = List.partition is_abbrev pre_intros;
    1.59 -    val (abbrevs', abbrevs'') = prep_abbrevs abbrevs [] [];
    1.60 -    val _ = (case gen_inter (op = o apsnd fst)
    1.61 -      (fold (Term.add_frees o snd) abbrevs' [], abbrevs') of
    1.62 -        [] => ()
    1.63 -      | xs => error ("Bad abbreviation(s): " ^ commas (map fst xs)));
    1.64 +    (* predicates *)
    1.65  
    1.66 -    val params = map Free pnames;
    1.67 -    val cnames_syn' = filter_out (fn ((s, _), _) =>
    1.68 -      exists (equal s o fst o fst) abbrevs') cnames_syn;
    1.69 +    val pre_intros = filter_out (is_some o get_abbrev) spec;
    1.70 +    val cnames_syn' = filter_out (member (op =) bs o fst o fst) cnames_syn;
    1.71      val cs = map (Free o fst) cnames_syn';
    1.72 -    val cnames_syn'' = map (fn ((s, _), mx) => (s, mx)) cnames_syn';
    1.73 +    val ps = map Free pnames;
    1.74  
    1.75 -    fun close_rule (x, r) = (x, list_all_free (rev (fold_aterms
    1.76 +    val ctxt2 = lthy
    1.77 +      |> Variable.add_fixes (map (fst o fst) cnames_syn') |> snd
    1.78 +      |> fold (snd oo LocalDefs.add_def) abbrevs;
    1.79 +    val expand = Assumption.export_term ctxt2 lthy;
    1.80 +
    1.81 +    fun close_rule r = list_all_free (rev (fold_aterms
    1.82        (fn t as Free (v as (s, _)) =>
    1.83 -            if Variable.is_fixed ctxt' s orelse
    1.84 -              member op = params t then I else insert op = v
    1.85 -        | _ => I) r []), r));
    1.86 +          if Variable.is_fixed ctxt1 s orelse
    1.87 +            member (op =) ps t then I else insert (op =) v
    1.88 +        | _ => I) r []), r);
    1.89  
    1.90 -    val intros = map (close_rule ##> expand abbrevs') pre_intros';
    1.91 +    val intros = map (apsnd (close_rule #> expand)) pre_intros;
    1.92 +    val preds = map (fn ((c, _), mx) => (c, mx)) cnames_syn';
    1.93    in
    1.94 -    ctxt
    1.95 -    |> mk_def flags cs intros monos params cnames_syn''
    1.96 -    ||> fold (snd oo LocalTheory.abbrev Syntax.mode_default) abbrevs''
    1.97 +    lthy
    1.98 +    |> mk_def flags cs intros monos ps preds
    1.99 +    ||> fold (snd oo LocalTheory.abbrev Syntax.mode_default) abbrevs
   1.100    end;
   1.101  
   1.102  fun gen_add_inductive mk_def verbose coind cnames_syn pnames_syn intro_srcs raw_monos lthy =