src/HOL/Tools/inductive_codegen.ML
changeset 12557 bb2e4689347e
parent 12453 806502073957
child 12562 323ce5a89695
--- a/src/HOL/Tools/inductive_codegen.ML	Thu Dec 20 14:57:54 2001 +0100
+++ b/src/HOL/Tools/inductive_codegen.ML	Thu Dec 20 14:58:18 2001 +0100
@@ -8,6 +8,7 @@
 
 signature INDUCTIVE_CODEGEN =
 sig
+  val add : theory attribute
   val setup : (theory -> theory) list
 end;
 
@@ -16,10 +17,48 @@
 
 open Codegen;
 
-exception Modes of (string * int list list) list * (string * int list list) list;
+(**** theory data ****)
+
+structure CodegenArgs =
+struct
+  val name = "HOL/inductive_codegen";
+  type T = thm list Symtab.table;
+  val empty = Symtab.empty;
+  val copy = I;
+  val prep_ext = I;
+  val merge = Symtab.merge_multi eq_thm;
+  fun print _ _ = ();
+end;
+
+structure CodegenData = TheoryDataFun(CodegenArgs);
+
+fun warn thm = warning ("InductiveCodegen: Not a proper clause:\n" ^
+  string_of_thm thm);
 
-datatype indprem = Prem of string * term list * term list
-                 | Sidecond of term;
+fun add (p as (thy, thm)) =
+  let
+    val tsig = Sign.tsig_of (sign_of thy);
+    val tab = CodegenData.get thy;
+    val matches = curry (Pattern.matches tsig o pairself concl_of);
+
+  in (case concl_of thm of
+      _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of
+        Const (s, _) => (CodegenData.put (Symtab.update ((s,
+          filter_out (matches thm) (if_none (Symtab.lookup (tab, s)) []) @
+            [thm]), tab)) thy, thm)
+      | _ => (warn thm; p))
+    | _ => (warn thm; p))
+  end handle Pattern.Pattern => (warn thm; p);
+
+fun get_clauses thy s =
+  (case Symtab.lookup (CodegenData.get thy, s) of
+     None => (case InductivePackage.get_inductive thy s of
+       None => None
+     | Some ({names, ...}, {intrs, ...}) => Some (names, intrs))
+   | Some thms => Some ([s], thms));
+
+
+(**** improper tuples ****)
 
 fun prod_factors p (Const ("Pair", _) $ t $ u) =
       p :: prod_factors (1::p) t @ prod_factors (2::p) u
@@ -30,10 +69,44 @@
          split_prod (1::p) ps t @ split_prod (2::p) ps u
      | _ => error "Inconsistent use of products") else [t];
 
+datatype factors = FVar of int list list | FFix of int list list;
+
+exception Factors;
+
+fun mg_factor (FVar f) (FVar f') = FVar (f inter f')
+  | mg_factor (FVar f) (FFix f') =
+      if f' subset f then FFix f' else raise Factors
+  | mg_factor (FFix f) (FVar f') =
+      if f subset f' then FFix f else raise Factors
+  | mg_factor (FFix f) (FFix f') =
+      if f subset f' andalso f' subset f then FFix f else raise Factors;
+
+fun dest_factors (FVar f) = f
+  | dest_factors (FFix f) = f;
+
+fun infer_factors sg extra_fs (fs, (optf, t)) =
+  let fun err s = error (s ^ "\n" ^ Sign.string_of_term sg t)
+  in (case (optf, strip_comb t) of
+      (Some f, (Const (name, _), args)) =>
+        (case assoc (extra_fs, name) of
+           None => overwrite (fs, (name, if_none
+             (apsome (mg_factor f) (assoc (fs, name))) f))
+         | Some (fs', f') => (mg_factor f (FFix f');
+             foldl (infer_factors sg extra_fs)
+               (fs, map (apsome FFix) fs' ~~ args)))
+    | (Some f, (Var ((name, _), _), [])) =>
+        overwrite (fs, (name, if_none
+          (apsome (mg_factor f) (assoc (fs, name))) f))
+    | (None, _) => fs
+    | _ => err "Illegal term")
+      handle Factors => err "Product factor mismatch in"
+  end;
+
 fun string_of_factors p ps = if p mem ps then
     "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")"
   else "_";
 
+
 (**** check if a term contains only constructor functions ****)
 
 fun is_constrt thy =
@@ -81,9 +154,32 @@
        in merge (map (fn ks => i::ks) is) is end
      else [[]];
 
+fun cprod ([], ys) = []
+  | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
+
+fun cprods xss = foldr (map op :: o cprod) (xss, [[]]);
+
+datatype mode = Mode of (int list option list * int list) * mode option list;
+
+fun modes_of modes t =
+  let
+    fun mk_modes name args = flat
+      (map (fn (m as (iss, is)) => map (Mode o pair m) (cprods (map
+        (fn (None, _) => [None]
+          | (Some js, arg) => map Some
+              (filter (fn Mode ((_, js'), _) => js=js') (modes_of modes arg)))
+                (iss ~~ args)))) (the (assoc (modes, name))))
+
+  in (case strip_comb t of
+      (Const (name, _), args) => mk_modes name args
+    | (Var ((name, _), _), args) => mk_modes name args)
+  end;
+
+datatype indprem = Prem of term list * term | Sidecond of term;
+
 fun select_mode_prem thy modes vs ps =
   find_first (is_some o snd) (ps ~~ map
-    (fn Prem (s, us, args) => find_first (fn is =>
+    (fn Prem (us, t) => find_first (fn Mode ((_, is), _) =>
           let
             val (_, out_ts) = get_args is 1 us;
             val vTs = flat (map term_vTs out_ts);
@@ -92,21 +188,25 @@
           in
             is subset known_args vs 1 us andalso
             forall (is_constrt thy) (snd (get_args is 1 us)) andalso
-            terms_vs args subset vs andalso
+            term_vs t subset vs andalso
             forall is_eqT dupTs
           end)
-            (the (assoc (modes, s)))
-      | Sidecond t => if term_vs t subset vs then Some [] else None) ps);
+            (modes_of modes t)
+      | Sidecond t => if term_vs t subset vs then Some (Mode (([], []), []))
+          else None) ps);
 
-fun check_mode_clause thy arg_vs modes mode (ts, ps) =
+fun check_mode_clause thy arg_vs modes (iss, is) (ts, ps) =
   let
+    val modes' = modes @ mapfilter
+      (fn (_, None) => None | (v, Some js) => Some (v, [([], js)]))
+        (arg_vs ~~ iss);
     fun check_mode_prems vs [] = Some vs
-      | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of
+      | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
           None => None
         | Some (x, _) => check_mode_prems
-            (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs)
+            (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
             (filter_out (equal x) ps));
-    val (in_ts', _) = get_args mode 1 ts;
+    val (in_ts', _) = get_args is 1 ts;
     val in_ts = filter (is_constrt thy) in_ts';
     val in_vs = terms_vs in_ts;
     val concl_vs = terms_vs ts
@@ -125,9 +225,12 @@
   let val y = f x
   in if x = y then x else fixp f y end;
 
-fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes =>
+fun infer_modes thy extra_modes factors arg_vs preds = fixp (fn modes =>
   map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
-    (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds);
+    (map (fn (s, (fs, f)) => (s, cprod (cprods (map
+      (fn None => [None]
+        | Some f' => map Some (subsets 1 (length f' + 1))) fs),
+      subsets 1 (length f + 1)))) factors);
 
 (**** code generation ****)
 
@@ -167,17 +270,37 @@
        [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"]))
   end;
 
-fun modename thy s mode = space_implode "_"
-  (mk_const_id (sign_of thy) s :: map string_of_int mode);
+fun modename thy s (iss, is) = space_implode "__"
+  (mk_const_id (sign_of thy) s ::
+    map (space_implode "_" o map string_of_int) (mapfilter I iss @ [is]));
 
-fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) =
+fun compile_expr thy dep brack (gr, (None, t)) =
+      apsnd single (invoke_codegen thy dep brack (gr, t))
+  | compile_expr _ _ _ (gr, (Some _, Var ((name, _), _))) =
+      (gr, [Pretty.str name])
+  | compile_expr thy dep brack (gr, (Some (Mode (mode, ms)), t)) =
+      let
+        val (Const (name, _), args) = strip_comb t;
+        val (gr', ps) = foldl_map
+          (compile_expr thy dep true) (gr, ms ~~ args);
+      in (gr', (if brack andalso not (null ps) then
+        single o parens o Pretty.block else I)
+          (flat (separate [Pretty.brk 1]
+            ([Pretty.str (modename thy name mode)] :: ps))))
+      end;
+
+fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) =
   let
+    val modes' = modes @ mapfilter
+      (fn (_, None) => None | (v, Some js) => Some (v, [([], js)]))
+        (arg_vs ~~ iss);
+
     fun check_constrt ((names, eqs), t) =
       if is_constrt thy t then ((names, eqs), t) else
         let val s = variant names "x";
         in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
 
-    val (in_ts, out_ts) = get_args mode 1 ts;
+    val (in_ts, out_ts) = get_args is 1 ts;
     val ((all_vs', eqs), in_ts') =
       foldl_map check_constrt ((all_vs, []), in_ts);
 
@@ -200,27 +323,25 @@
       | compile_prems out_ts vs names gr ps =
           let
             val vs' = distinct (flat (vs :: map term_vs out_ts));
-            val Some (p, Some mode') =
-              select_mode_prem thy modes (arg_vs union vs') ps;
+            val Some (p, mode as Some (Mode ((_, js), _))) =
+              select_mode_prem thy modes' (arg_vs union vs') ps;
             val ps' = filter_out (equal p) ps;
           in
             (case p of
-               Prem (s, us, args) =>
+               Prem (us, t) =>
                  let
-                   val (in_ts, out_ts') = get_args mode' 1 us;
+                   val (in_ts, out_ts') = get_args js 1 us;
                    val (gr1, in_ps) = foldl_map
                      (invoke_codegen thy dep false) (gr, in_ts);
-                   val (gr2, arg_ps) = foldl_map
-                     (invoke_codegen thy dep true) (gr1, args);
                    val (nvs, out_ts'') = foldl_map distinct_v
                      ((names, map (fn x => (x, [x])) vs), out_ts);
-                   val (gr3, out_ps) = foldl_map
-                     (invoke_codegen thy dep false) (gr2, out_ts'')
+                   val (gr2, out_ps) = foldl_map
+                     (invoke_codegen thy dep false) (gr1, out_ts'');
+                   val (gr3, ps) = compile_expr thy dep false (gr2, (mode, t));
                    val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps';
                  in
                    (gr4, compile_match (snd nvs) [] out_ps
-                      (Pretty.block (separate (Pretty.brk 1)
-                        (Pretty.str (modename thy s mode') :: arg_ps) @
+                      (Pretty.block (ps @
                          [Pretty.brk 1, mk_tuple in_ps,
                           Pretty.str " :->", Pretty.brk 1, rest]))
                       (Pretty.str "Seq.empty"))
@@ -269,69 +390,91 @@
 
 (**** processing of introduction rules ****)
 
-val string_of_mode = enclose "[" "]" o commas o map string_of_int;
+exception Modes of
+  (string * (int list option list * int list) list) list *
+  (string * (int list list option list * int list list)) list;
+
+fun lookup_modes gr dep = apfst flat (apsnd flat (ListPair.unzip
+  (map ((fn (Some (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr)
+    (Graph.all_preds gr [dep]))));
+
+fun string_of_mode (iss, is) = space_implode " -> " (map
+  (fn None => "X"
+    | Some js => enclose "[" "]" (commas (map string_of_int js)))
+       (iss @ [Some is]));
 
 fun print_modes modes = message ("Inferred modes:\n" ^
   space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map
     string_of_mode ms)) modes));
 
 fun print_factors factors = message ("Factors:\n" ^
-  space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors));
-  
-fun get_modes (Some (Modes x), _) = x
-  | get_modes _ = ([], []);
+  space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^
+    space_implode " -> " (map
+      (fn None => "X" | Some f' => string_of_factors [] f')
+        (fs @ [Some f]))) factors));
 
-fun mk_ind_def thy gr dep names intrs =
+fun mk_extra_defs thy gr dep names ts =
+  foldl (fn (gr, name) =>
+    if name mem names then gr
+    else (case get_clauses thy name of
+        None => gr
+      | Some (names, intrs) =>
+          mk_ind_def thy gr dep names intrs))
+            (gr, foldr add_term_consts (ts, []))
+
+and mk_ind_def thy gr dep names intrs =
   let val ids = map (mk_const_id (sign_of thy)) names
   in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
     let
-      fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) =
-            (case strip_comb u of
-               (Const (name, _), args) =>
-                  (case InductivePackage.get_inductive thy name of
-                     None => (gr, Sidecond t')
-                   | Some ({names=names', ...}, {intrs=intrs', ...}) =>
-                       (if names = names' then gr
-                          else mk_ind_def thy gr (hd ids) names' intrs',
-                        Prem (name, split_prod []
-                          (the (assoc (factors, name))) t, args)))
-             | _ => (gr, Sidecond t'))
-        | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) =
-            (gr, Prem ("eq", [t, u], []))
-        | process_prem factors (gr, _ $ t) = (gr, Sidecond t);
+      fun dest_prem factors (_ $ (Const ("op :", _) $ t $ u)) =
+            (case head_of u of
+               Const (name, _) => Prem (split_prod []
+                 (the (assoc (factors, name))) t, u)
+             | Var ((name, _), _) => Prem (split_prod []
+                 (the (assoc (factors, name))) t, u))
+        | dest_prem factors (_ $ ((eq as Const ("op =", _)) $ t $ u)) =
+            Prem ([t, u], eq)
+        | dest_prem factors (_ $ t) = Sidecond t;
 
-      fun add_clause factors ((clauses, gr), intr) =
+      fun add_clause factors (clauses, intr) =
         let
           val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr;
-          val (Const (name, _), args) = strip_comb u;
-          val (gr', prems) = foldl_map (process_prem factors)
-            (gr, Logic.strip_imp_prems intr);
+          val Const (name, _) = head_of u;
+          val prems = map (dest_prem factors) (Logic.strip_imp_prems intr);
         in
           (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @
-             [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr')
+             [(split_prod [] (the (assoc (factors, name))) t, prems)])))
         end;
 
-      fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) =
-            (case strip_comb u of
-               (Const (name, _), _) =>
-                 let val f = prod_factors [] t
-                 in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end
-             | _ => fs)
-        | add_prod_factors (fs, _) = fs;
+      fun add_prod_factors extra_fs (fs, _ $ (Const ("op :", _) $ t $ u)) =
+            infer_factors (sign_of thy) extra_fs
+              (fs, (Some (FVar (prod_factors [] t)), u))
+        | add_prod_factors _ (fs, _) = fs;
 
       val intrs' = map (rename_term o #prop o rep_thm o standard) intrs;
-      val factors = foldl add_prod_factors ([], flat (map (fn t =>
-        Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs'));
-      val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep)
-        (Graph.new_node (hd ids, (None, "")) gr)), intrs');
       val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs');
       val (_, args) = strip_comb u;
       val arg_vs = flat (map term_vs args);
-      val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map
-        (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids])));
-      val modes = infer_modes thy extra_modes arg_vs clauses;
+      val gr' = mk_extra_defs thy
+        (Graph.add_edge (hd ids, dep)
+          (Graph.new_node (hd ids, (None, "")) gr)) (hd ids) names intrs';
+      val (extra_modes', extra_factors) = lookup_modes gr' (hd ids);
+      val extra_modes =
+        ("op =", [([], [1]), ([], [2]), ([], [1, 2])]) :: extra_modes';
+      val fs = map (apsnd dest_factors)
+        (foldl (add_prod_factors extra_factors) ([], flat (map (fn t =>
+          Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs')));
+      val _ = (case map fst fs \\ names \\ arg_vs of
+          [] => ()
+        | xs => error ("Non-inductive sets: " ^ commas_quote xs));
+      val factors = mapfilter (fn (name, f) =>
+        if name mem arg_vs then None
+        else Some (name, (map (curry assoc fs) arg_vs, f))) fs;
+      val clauses =
+        foldl (add_clause (fs @ map (apsnd snd) extra_factors)) ([], intrs');
+      val modes = infer_modes thy extra_modes factors arg_vs clauses;
+      val _ = print_factors factors;
       val _ = print_modes modes;
-      val _ = print_factors factors;
       val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs
         (modes @ extra_modes) clauses;
     in
@@ -339,31 +482,33 @@
     end      
   end;
 
-fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of
-  (Const (s, _), args) => (case InductivePackage.get_inductive thy s of
+fun mk_ind_call thy gr dep t u is_query = (case head_of u of
+  Const (s, _) => (case get_clauses thy s of
        None => None
-     | Some ({names, ...}, {intrs, ...}) =>
+     | Some (names, intrs) =>
          let
           fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1)
             | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1)
             | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
 
-           val gr1 = mk_ind_def thy gr dep names intrs;
-           val (modes, factors) = pairself flat (ListPair.unzip
-             (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep])));
-           val ts = split_prod [] (the (assoc (factors, s))) t;
-           val (ts', mode) = if is_query then
+           val gr1 = mk_extra_defs thy
+             (mk_ind_def thy gr dep names intrs) dep names [u];
+           val (modes, factors) = lookup_modes gr1 dep;
+           val ts = split_prod [] (snd (the (assoc (factors, s)))) t;
+           val (ts', is) = if is_query then
                fst (foldl mk_mode ((([], []), 1), ts))
              else (ts, 1 upto length ts);
-           val _ = if mode mem the (assoc (modes, s)) then () else
-             error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode);
+           val mode = (case find_first (fn Mode ((_, js), _) => is=js)
+                  (modes_of modes u) of
+                None => error ("No such mode for " ^ s ^ ": " ^
+                  string_of_mode ([], is))
+              | mode => mode);
            val (gr2, in_ps) = foldl_map
              (invoke_codegen thy dep false) (gr1, ts');
-           val (gr3, arg_ps) = foldl_map
-             (invoke_codegen thy dep true) (gr2, args);
+           val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u))
          in
-           Some (gr3, Pretty.block (separate (Pretty.brk 1)
-             (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps])))
+           Some (gr3, Pretty.block
+             (ps @ [Pretty.brk 1, mk_tuple in_ps]))
          end)
   | _ => None);
 
@@ -376,7 +521,10 @@
       mk_ind_call thy gr dep t u true
   | inductive_codegen thy gr dep brack _ = None;
 
-val setup = [add_codegen "inductive" inductive_codegen];
+val setup =
+  [add_codegen "inductive" inductive_codegen,
+   CodegenData.init,
+   add_attribute "ind" add];
 
 end;
 
@@ -394,6 +542,8 @@
 
 fun ?! s = is_some (Seq.pull s);    
 
-fun eq_1 x = Seq.single x;
+fun op__61__1 x = Seq.single x;
 
-val eq_2 = eq_1;
+val op__61__2 = op__61__1;
+
+fun op__61__1_2 (x, y) = ?? (x = y);