gen_add_inductive_i: treat abbrevs as local defs, expand by export;
authorwenzelm
Sun, 14 Oct 2007 16:13:45 +0200
changeset 25029 3a72718c5ddd
parent 25028 e0f74efc210f
child 25030 7507f590486f
gen_add_inductive_i: treat abbrevs as local defs, expand by export; tuned;
src/HOL/Tools/inductive_package.ML
--- a/src/HOL/Tools/inductive_package.ML	Sun Oct 14 00:18:11 2007 +0200
+++ b/src/HOL/Tools/inductive_package.ML	Sun Oct 14 16:13:45 2007 +0200
@@ -777,59 +777,59 @@
 (* external interfaces *)
 
 fun gen_add_inductive_i mk_def (flags as {verbose, kind, alt_name, coind, no_elim, no_ind})
-    cnames_syn pnames pre_intros monos ctxt =
+    cnames_syn pnames spec monos lthy =
   let
-    val thy = ProofContext.theory_of ctxt;
+    val thy = ProofContext.theory_of lthy;
     val _ = Theory.requires thy "Inductive" (coind_prefix coind ^ "inductive definitions");
 
-    fun is_abbrev ((name, atts), t) =
-      can (Logic.strip_assums_concl #> Logic.dest_equals) t andalso
-      (name = "" andalso null atts orelse
-       error "Abbreviations may not have names or attributes");
 
-    fun expand_atom tab (t as Free xT) =
-          the_default t (AList.lookup op = tab xT)
-      | expand_atom tab t = t;
-    fun expand [] r = r
-      | expand tab r = Envir.beta_norm (Term.map_aterms (expand_atom tab) r);
+    (* abbrevs *)
+
+    val (_, ctxt1) = Variable.add_fixes (map (fst o fst) cnames_syn) lthy;
 
-    val (_, ctxt') = Variable.add_fixes (map (fst o fst) cnames_syn) ctxt;
+    fun get_abbrev ((name, atts), t) =
+      if can (Logic.strip_assums_concl #> Logic.dest_equals) t then
+        let
+          val _ = name = "" andalso null atts orelse
+            error "Abbreviations may not have names or attributes";
+          val ((x, T), rhs) = LocalDefs.abs_def (snd (LocalDefs.cert_def ctxt1 t));
+          val mx =
+            (case find_first (fn ((c, _), _) => c = x) cnames_syn of
+              NONE => error ("Undeclared head of abbreviation " ^ quote x)
+            | SOME ((_, T'), mx) =>
+                if T <> T' then error ("Bad type specification for abbreviation " ^ quote x)
+                else mx);
+        in SOME ((x, mx), rhs) end
+      else NONE;
 
-    fun prep_abbrevs [] abbrevs' abbrevs'' = (rev abbrevs', rev abbrevs'')
-      | prep_abbrevs ((_, abbrev) :: abbrevs) abbrevs' abbrevs'' =
-          let val ((s, T), t) =
-            LocalDefs.abs_def (snd (LocalDefs.cert_def ctxt' abbrev))
-          in case find_first (equal s o fst o fst) cnames_syn of
-              NONE => error ("Head of abbreviation " ^ quote s ^ " undeclared")
-            | SOME (_, mx) => prep_abbrevs abbrevs
-                (((s, T), expand abbrevs' t) :: abbrevs')
-                (((s, mx), expand abbrevs' t) :: abbrevs'') (* FIXME: do not expand *)
-          end;
+    val abbrevs = map_filter get_abbrev spec;
+    val bs = map (fst o fst) abbrevs;
+
 
-    val (abbrevs, pre_intros') = List.partition is_abbrev pre_intros;
-    val (abbrevs', abbrevs'') = prep_abbrevs abbrevs [] [];
-    val _ = (case gen_inter (op = o apsnd fst)
-      (fold (Term.add_frees o snd) abbrevs' [], abbrevs') of
-        [] => ()
-      | xs => error ("Bad abbreviation(s): " ^ commas (map fst xs)));
+    (* predicates *)
 
-    val params = map Free pnames;
-    val cnames_syn' = filter_out (fn ((s, _), _) =>
-      exists (equal s o fst o fst) abbrevs') cnames_syn;
+    val pre_intros = filter_out (is_some o get_abbrev) spec;
+    val cnames_syn' = filter_out (member (op =) bs o fst o fst) cnames_syn;
     val cs = map (Free o fst) cnames_syn';
-    val cnames_syn'' = map (fn ((s, _), mx) => (s, mx)) cnames_syn';
+    val ps = map Free pnames;
 
-    fun close_rule (x, r) = (x, list_all_free (rev (fold_aterms
+    val ctxt2 = lthy
+      |> Variable.add_fixes (map (fst o fst) cnames_syn') |> snd
+      |> fold (snd oo LocalDefs.add_def) abbrevs;
+    val expand = Assumption.export_term ctxt2 lthy;
+
+    fun close_rule r = list_all_free (rev (fold_aterms
       (fn t as Free (v as (s, _)) =>
-            if Variable.is_fixed ctxt' s orelse
-              member op = params t then I else insert op = v
-        | _ => I) r []), r));
+          if Variable.is_fixed ctxt1 s orelse
+            member (op =) ps t then I else insert (op =) v
+        | _ => I) r []), r);
 
-    val intros = map (close_rule ##> expand abbrevs') pre_intros';
+    val intros = map (apsnd (close_rule #> expand)) pre_intros;
+    val preds = map (fn ((c, _), mx) => (c, mx)) cnames_syn';
   in
-    ctxt
-    |> mk_def flags cs intros monos params cnames_syn''
-    ||> fold (snd oo LocalTheory.abbrev Syntax.mode_default) abbrevs''
+    lthy
+    |> mk_def flags cs intros monos ps preds
+    ||> fold (snd oo LocalTheory.abbrev Syntax.mode_default) abbrevs
   end;
 
 fun gen_add_inductive mk_def verbose coind cnames_syn pnames_syn intro_srcs raw_monos lthy =