src/HOL/Tools/inductive_codegen.ML
changeset 16645 a152d6b21c31
parent 16424 18a07ad8fea8
child 16861 7446b4be013b
--- a/src/HOL/Tools/inductive_codegen.ML	Fri Jul 01 14:11:06 2005 +0200
+++ b/src/HOL/Tools/inductive_codegen.ML	Fri Jul 01 14:13:40 2005 +0200
@@ -7,7 +7,7 @@
 
 signature INDUCTIVE_CODEGEN =
 sig
-  val add : theory attribute
+  val add : string option -> theory attribute
   val setup : (theory -> theory) list
 end;
 
@@ -22,18 +22,20 @@
 (struct
   val name = "HOL/inductive_codegen";
   type T =
-    {intros : thm list Symtab.table,
+    {intros : (thm * string) list Symtab.table,
      graph : unit Graph.T,
-     eqns : thm list Symtab.table};
+     eqns : (thm * string) list Symtab.table};
   val empty =
     {intros = Symtab.empty, graph = Graph.empty, eqns = Symtab.empty};
   val copy = I;
   val extend = I;
   fun merge _ ({intros=intros1, graph=graph1, eqns=eqns1},
     {intros=intros2, graph=graph2, eqns=eqns2}) =
-    {intros = Symtab.merge_multi Drule.eq_thm_prop (intros1, intros2),
+    {intros = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst)
+       (intros1, intros2),
      graph = Graph.merge (K true) (graph1, graph2),
-     eqns = Symtab.merge_multi Drule.eq_thm_prop (eqns1, eqns2)};
+     eqns = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst)
+       (eqns1, eqns2)};
   fun print _ _ = ();
 end);
 
@@ -43,15 +45,19 @@
 
 fun add_node (g, x) = Graph.new_node (x, ()) g handle Graph.DUP _ => g;
 
-fun add (p as (thy, thm)) =
-  let val {intros, graph, eqns} = CodegenData.get thy;
+fun add optmod (p as (thy, thm)) =
+  let
+    val {intros, graph, eqns} = CodegenData.get thy;
+    fun thyname_of s = (case optmod of
+      NONE => thyname_of_const s thy | SOME s => s);
   in (case concl_of thm of
       _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of
         Const (s, _) =>
           let val cs = foldr add_term_consts [] (prems_of thm)
           in (CodegenData.put
             {intros = Symtab.update ((s,
-               getOpt (Symtab.lookup (intros, s), []) @ [thm]), intros),
+               getOpt (Symtab.lookup (intros, s), []) @
+                 [(thm, thyname_of s)]), intros),
              graph = foldr (uncurry (Graph.add_edge o pair s))
                (Library.foldl add_node (graph, s :: cs)) cs,
              eqns = eqns} thy, thm)
@@ -61,7 +67,8 @@
         Const (s, _) =>
           (CodegenData.put {intros = intros, graph = graph,
              eqns = Symtab.update ((s,
-               getOpt (Symtab.lookup (eqns, s), []) @ [thm]), eqns)} thy, thm)
+               getOpt (Symtab.lookup (eqns, s), []) @
+                 [(thm, thyname_of s)]), eqns)} thy, thm)
       | _ => (warn thm; p))
     | _ => (warn thm; p))
   end;
@@ -71,13 +78,17 @@
   in case Symtab.lookup (intros, s) of
       NONE => (case InductivePackage.get_inductive thy s of
         NONE => NONE
-      | SOME ({names, ...}, {intrs, ...}) => SOME (names, preprocess thy intrs))
+      | SOME ({names, ...}, {intrs, ...}) =>
+          SOME (names, thyname_of_const s thy,
+            preprocess thy intrs))
     | SOME _ =>
-        let val SOME names = find_first
-          (fn xs => s mem xs) (Graph.strong_conn graph)
-        in SOME (names, preprocess thy
-          (List.concat (map (fn s => valOf (Symtab.lookup (intros, s))) names)))
-        end
+        let
+          val SOME names = find_first
+            (fn xs => s mem xs) (Graph.strong_conn graph);
+          val intrs = List.concat (map
+            (fn s => valOf (Symtab.lookup (intros, s))) names);
+          val (_, (_, thyname)) = split_last intrs
+        in SOME (names, thyname, preprocess thy (map fst intrs)) end
   end;
 
 
@@ -364,26 +375,30 @@
         else [Pretty.str ")"])))
   end;
 
-fun modename thy s (iss, is) = space_implode "__"
-  (mk_const_id (sign_of thy) s ::
+fun strip_spaces s = implode (fst (take_suffix (equal " ") (explode s)));
+
+fun modename thy thyname thyname' s (iss, is) = space_implode "__"
+  (mk_const_id (sign_of thy) thyname thyname' (strip_spaces s) ::
     map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is]));
 
-fun compile_expr thy dep brack (gr, (NONE, t)) =
-      apsnd single (invoke_codegen thy dep brack (gr, t))
-  | compile_expr _ _ _ (gr, (SOME _, Var ((name, _), _))) =
+fun compile_expr thy defs dep thyname brack thynames (gr, (NONE, t)) =
+      apsnd single (invoke_codegen thy defs dep thyname brack (gr, t))
+  | compile_expr _ _ _ _ _ _ (gr, (SOME _, Var ((name, _), _))) =
       (gr, [Pretty.str name])
-  | compile_expr thy dep brack (gr, (SOME (Mode (mode, ms)), t)) =
+  | compile_expr thy defs dep thyname brack thynames (gr, (SOME (Mode (mode, ms)), t)) =
       let
         val (Const (name, _), args) = strip_comb t;
         val (gr', ps) = foldl_map
-          (compile_expr thy dep true) (gr, ms ~~ args);
+          (compile_expr thy defs dep thyname true thynames) (gr, ms ~~ args);
       in (gr', (if brack andalso not (null ps) then
         single o parens o Pretty.block else I)
           (List.concat (separate [Pretty.brk 1]
-            ([Pretty.str (modename thy name mode)] :: ps))))
+            ([Pretty.str (modename thy thyname
+                (if name = "op =" then ""
+                 else the (assoc (thynames, name))) name mode)] :: ps))))
       end;
 
-fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) =
+fun compile_clause thy defs gr dep thyname all_vs arg_vs modes thynames (iss, is) (ts, ps) =
   let
     val modes' = modes @ List.mapPartial
       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
@@ -396,7 +411,7 @@
 
     fun compile_eq (gr, (s, t)) =
       apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single)
-        (invoke_codegen thy dep false (gr, t));
+        (invoke_codegen thy defs dep thyname false (gr, t));
 
     val (in_ts, out_ts) = get_args is 1 ts;
     val ((all_vs', eqs), in_ts') =
@@ -409,14 +424,14 @@
     fun compile_prems out_ts' vs names gr [] =
           let
             val (gr2, out_ps) = foldl_map
-              (invoke_codegen thy dep false) (gr, out_ts);
+              (invoke_codegen thy defs dep thyname false) (gr, out_ts);
             val (gr3, eq_ps) = foldl_map compile_eq (gr2, eqs);
             val ((names', eqs'), out_ts'') =
               foldl_map check_constrt ((names, []), out_ts');
             val (nvs, out_ts''') = foldl_map distinct_v
               ((names', map (fn x => (x, [x])) vs), out_ts'');
             val (gr4, out_ps') = foldl_map
-              (invoke_codegen thy dep false) (gr3, out_ts''');
+              (invoke_codegen thy defs dep thyname false) (gr3, out_ts''');
             val (gr5, eq_ps') = foldl_map compile_eq (gr4, eqs')
           in
             (gr5, compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
@@ -434,7 +449,7 @@
             val (nvs, out_ts'') = foldl_map distinct_v
               ((names', map (fn x => (x, [x])) vs), out_ts');
             val (gr0, out_ps) = foldl_map
-              (invoke_codegen thy dep false) (gr, out_ts'');
+              (invoke_codegen thy defs dep thyname false) (gr, out_ts'');
             val (gr1, eq_ps) = foldl_map compile_eq (gr0, eqs)
           in
             (case p of
@@ -442,14 +457,15 @@
                  let
                    val (in_ts, out_ts''') = get_args js 1 us;
                    val (gr2, in_ps) = foldl_map
-                     (invoke_codegen thy dep false) (gr1, in_ts);
+                     (invoke_codegen thy defs dep thyname false) (gr1, in_ts);
                    val (gr3, ps) = if is_ind t then
                        apsnd (fn ps => ps @ [Pretty.brk 1, mk_tuple in_ps])
-                         (compile_expr thy dep false (gr2, (mode, t)))
+                         (compile_expr thy defs dep thyname false thynames
+                           (gr2, (mode, t)))
                      else
                        apsnd (fn p => conv_ntuple us t
                          [Pretty.str "Seq.of_list", Pretty.brk 1, p])
-                           (invoke_codegen thy dep true (gr2, t));
+                           (invoke_codegen thy defs dep thyname true (gr2, t));
                    val (gr4, rest) = compile_prems out_ts''' vs' (fst nvs) gr3 ps';
                  in
                    (gr4, compile_match (snd nvs) eq_ps out_ps
@@ -459,7 +475,7 @@
                  end
              | Sidecond t =>
                  let
-                   val (gr2, side_p) = invoke_codegen thy dep true (gr1, t);
+                   val (gr2, side_p) = invoke_codegen thy defs dep thyname true (gr1, t);
                    val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
                  in
                    (gr3, compile_match (snd nvs) eq_ps out_ps
@@ -474,22 +490,23 @@
     (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p])
   end;
 
-fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode =
-  let val (gr', cl_ps) = foldl_map (fn (gr, cl) =>
-    compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls)
+fun compile_pred thy defs gr dep thyname prfx all_vs arg_vs modes thynames s cls mode =
+  let val (gr', cl_ps) = foldl_map (fn (gr, cl) => compile_clause thy defs
+    gr dep thyname all_vs arg_vs modes thynames mode cl) (gr, cls)
   in
     ((gr', "and "), Pretty.block
       ([Pretty.block (separate (Pretty.brk 1)
-         (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @
+         (Pretty.str (prfx ^ modename thy thyname thyname s mode) ::
+           map Pretty.str arg_vs) @
          [Pretty.str " inp ="]),
         Pretty.brk 1] @
        List.concat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps))))
   end;
 
-fun compile_preds thy gr dep all_vs arg_vs modes preds =
+fun compile_preds thy defs gr dep thyname all_vs arg_vs modes thynames preds =
   let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
-    foldl_map (fn ((gr', prfx'), mode) =>
-      compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode)
+    foldl_map (fn ((gr', prfx'), mode) => compile_pred thy defs gr'
+      dep thyname prfx' all_vs arg_vs modes thynames s cls mode)
         ((gr, prfx), valOf (assoc (modes, s)))) ((gr, "fun "), preds)
   in
     (gr', space_implode "\n\n" (map Pretty.string_of (List.concat prs)) ^ ";\n\n")
@@ -499,11 +516,13 @@
 
 exception Modes of
   (string * (int list option list * int list) list) list *
-  (string * (int list list option list * int list list)) list;
+  (string * (int list list option list * int list list)) list *
+  string;
 
-fun lookup_modes gr dep = apfst List.concat (apsnd List.concat (ListPair.unzip
-  (map ((fn (SOME (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr)
-    (Graph.all_preds gr [dep]))));
+fun lookup_modes gr dep = foldl (fn ((xs, ys, z), (xss, yss, zss)) =>
+    (xss @ xs, yss @ ys, zss @ map (rpair z o fst) ys)) ([], [], [])
+  (map ((fn (SOME (Modes x), _, _) => x | _ => ([], [], "")) o Graph.get_node gr)
+    (Graph.all_preds gr [dep]));
 
 fun print_factors factors = message ("Factors:\n" ^
   space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^
@@ -518,18 +537,17 @@
       NONE => xs
     | SOME xs' => xs inter xs') :: constrain cs ys;
 
-fun mk_extra_defs thy gr dep names ts =
+fun mk_extra_defs thy defs gr dep names ts =
   Library.foldl (fn (gr, name) =>
     if name mem names then gr
     else (case get_clauses thy name of
         NONE => gr
-      | SOME (names, intrs) =>
-          mk_ind_def thy gr dep names [] [] (prep_intrs intrs)))
+      | SOME (names, thyname, intrs) =>
+          mk_ind_def thy defs gr dep names thyname [] [] (prep_intrs intrs)))
             (gr, foldr add_term_consts [] ts)
 
-and mk_ind_def thy gr dep names modecs factorcs intrs =
-  let val ids = map (mk_const_id (sign_of thy)) names
-  in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
+and mk_ind_def thy defs gr dep names thyname modecs factorcs intrs =
+  Graph.add_edge (hd names, dep) gr handle Graph.UNDEF _ =>
     let
       val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs);
       val (_, args) = strip_comb u;
@@ -565,10 +583,10 @@
             else fs
         | add_prod_factors _ (fs, _) = fs;
 
-      val gr' = mk_extra_defs thy
-        (Graph.add_edge (hd ids, dep)
-          (Graph.new_node (hd ids, (NONE, "")) gr)) (hd ids) names intrs;
-      val (extra_modes, extra_factors) = lookup_modes gr' (hd ids);
+      val gr' = mk_extra_defs thy defs
+        (Graph.add_edge (hd names, dep)
+          (Graph.new_node (hd names, (NONE, "", "")) gr)) (hd names) names intrs;
+      val (extra_modes, extra_factors, extra_thynames) = lookup_modes gr' (hd names);
       val fs = constrain factorcs (map (apsnd dest_factors)
         (Library.foldl (add_prod_factors extra_factors) ([], List.concat (map (fn t =>
           Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs))));
@@ -581,38 +599,40 @@
         (infer_modes thy extra_modes factors arg_vs clauses);
       val _ = print_factors factors;
       val _ = print_modes modes;
-      val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs) arg_vs
-        (modes @ extra_modes) clauses;
+      val (gr'', s) = compile_preds thy defs gr' (hd names) thyname (terms_vs intrs)
+        arg_vs (modes @ extra_modes)
+        (map (rpair thyname o fst) factors @ extra_thynames) clauses;
     in
-      (Graph.map_node (hd ids) (K (SOME (Modes (modes, factors)), s)) gr'')
-    end      
-  end;
+      (Graph.map_node (hd names)
+        (K (SOME (Modes (modes, factors, thyname)), thyname, s)) gr'')
+    end;
 
 fun find_mode s u modes is = (case find_first (fn Mode ((_, js), _) => is=js)
   (modes_of modes u handle Option => []) of
      NONE => error ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is))
    | mode => mode);
 
-fun mk_ind_call thy gr dep t u is_query = (case head_of u of
+fun mk_ind_call thy defs gr dep thyname t u is_query = (case head_of u of
   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
        (NONE, _) => NONE
-     | (SOME (names, intrs), NONE) =>
+     | (SOME (names, thyname', intrs), NONE) =>
          let
           fun mk_mode (((ts, mode), i), Const ("dummy_pattern", _)) =
                 ((ts, mode), i+1)
             | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
 
-           val gr1 = mk_extra_defs thy
-             (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u];
-           val (modes, factors) = lookup_modes gr1 dep;
+           val gr1 = mk_extra_defs thy defs
+             (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u];
+           val (modes, factors, thynames) = lookup_modes gr1 dep;
            val ts = split_prod [] (snd (valOf (assoc (factors, s)))) t;
            val (ts', is) = if is_query then
                fst (Library.foldl mk_mode ((([], []), 1), ts))
              else (ts, 1 upto length ts);
            val mode = find_mode s u modes is;
            val (gr2, in_ps) = foldl_map
-             (invoke_codegen thy dep false) (gr1, ts');
-           val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u))
+             (invoke_codegen thy defs dep thyname false) (gr1, ts');
+           val (gr3, ps) =
+             compile_expr thy defs dep thyname false thynames (gr2, (mode, u))
          in
            SOME (gr3, Pretty.block
              (ps @ [Pretty.brk 1, mk_tuple in_ps]))
@@ -620,16 +640,17 @@
      | _ => NONE)
   | _ => NONE);
 
-fun list_of_indset thy gr dep brack u = (case head_of u of
+fun list_of_indset thy defs gr dep thyname brack u = (case head_of u of
   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
        (NONE, _) => NONE
-     | (SOME (names, intrs), NONE) =>
+     | (SOME (names, thyname', intrs), NONE) =>
          let
-           val gr1 = mk_extra_defs thy
-             (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u];
-           val (modes, factors) = lookup_modes gr1 dep;
+           val gr1 = mk_extra_defs thy defs
+             (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u];
+           val (modes, factors, thynames) = lookup_modes gr1 dep;
            val mode = find_mode s u modes [];
-           val (gr2, ps) = compile_expr thy dep false (gr1, (mode, u))
+           val (gr2, ps) =
+             compile_expr thy defs dep thyname false thynames (gr1, (mode, u))
          in
            SOME (gr2, (if brack then parens else I)
              (Pretty.block ([Pretty.str "Seq.list_of", Pretty.brk 1,
@@ -650,58 +671,63 @@
   in
     rename_term
       (Logic.list_implies (prems_of eqn, HOLogic.mk_Trueprop (HOLogic.mk_mem
-        (foldr1 HOLogic.mk_prod (ts @ [u]), Const (Sign.base_name s ^ "_aux",
+        (foldr1 HOLogic.mk_prod (ts @ [u]), Const (s ^ " ",
           HOLogic.mk_setT (foldr1 HOLogic.mk_prodT (Ts @ [U])))))))
   end;
 
-fun mk_fun thy name eqns dep gr = 
-  let val id = mk_const_id (sign_of thy) name
-  in Graph.add_edge (id, dep) gr handle Graph.UNDEF _ =>
+fun mk_fun thy defs name eqns dep thyname thyname' gr =
+  let
+    val fun_id = mk_const_id (sign_of thy) thyname' thyname' name;
+    val call_id = mk_const_id (sign_of thy) thyname thyname' name
+  in (Graph.add_edge (name, dep) gr handle Graph.UNDEF _ =>
     let
       val clauses = map clause_of_eqn eqns;
-      val pname = mk_const_id (sign_of thy) (Sign.base_name name ^ "_aux");
+      val pname = name ^ " ";
       val arity = length (snd (strip_comb (fst (HOLogic.dest_eq
         (HOLogic.dest_Trueprop (concl_of (hd eqns)))))));
       val mode = 1 upto arity;
       val vars = map (fn i => Pretty.str ("x" ^ string_of_int i)) mode;
       val s = Pretty.string_of (Pretty.block
-        [mk_app false (Pretty.str ("fun " ^ id)) vars, Pretty.str " =",
+        [mk_app false (Pretty.str ("fun " ^ fun_id)) vars, Pretty.str " =",
          Pretty.brk 1, Pretty.str "Seq.hd", Pretty.brk 1,
-         parens (Pretty.block [Pretty.str (modename thy pname ([], mode)),
+         parens (Pretty.block [Pretty.str (modename thy thyname' thyname' pname ([], mode)),
            Pretty.brk 1, mk_tuple vars])]) ^ ";\n\n";
-      val gr' = mk_ind_def thy (Graph.add_edge (id, dep)
-        (Graph.new_node (id, (NONE, s)) gr)) id [pname]
+      val gr' = mk_ind_def thy defs (Graph.add_edge (name, dep)
+        (Graph.new_node (name, (NONE, thyname', s)) gr)) name [pname] thyname'
         [(pname, [([], mode)])]
         [(pname, map (fn i => replicate i 2) (0 upto arity-1))]
         clauses;
-      val (modes, _) = lookup_modes gr' dep;
+      val (modes, _, _) = lookup_modes gr' dep;
       val _ = find_mode pname (snd (HOLogic.dest_mem (HOLogic.dest_Trueprop
         (Logic.strip_imp_concl (hd clauses))))) modes mode
-    in gr' end
+    in gr' end, call_id)
   end;
 
-fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) =
-      ((case mk_ind_call thy gr dep (Term.no_dummy_patterns t) u false of
+fun inductive_codegen thy defs gr dep thyname brack (Const ("op :", _) $ t $ u) =
+      ((case mk_ind_call thy defs gr dep thyname (Term.no_dummy_patterns t) u false of
          NONE => NONE
        | SOME (gr', call_p) => SOME (gr', (if brack then parens else I)
            (Pretty.block [Pretty.str "?! (", call_p, Pretty.str ")"])))
-        handle TERM _ => mk_ind_call thy gr dep t u true)
-  | inductive_codegen thy gr dep brack t = (case strip_comb t of
+        handle TERM _ => mk_ind_call thy defs gr dep thyname t u true)
+  | inductive_codegen thy defs gr dep thyname brack t = (case strip_comb t of
       (Const (s, _), ts) => (case Symtab.lookup (#eqns (CodegenData.get thy), s) of
-        NONE => list_of_indset thy gr dep brack t
+        NONE => list_of_indset thy defs gr dep thyname brack t
       | SOME eqns =>
           let
-            val gr' = mk_fun thy s (preprocess thy eqns) dep gr
-            val (gr'', ps) = foldl_map (invoke_codegen thy dep true) (gr', ts);
-          in SOME (gr'', mk_app brack (Pretty.str (mk_const_id
-            (sign_of thy) s)) ps)
+            val (_, (_, thyname')) = split_last eqns;
+            val (gr', id) = mk_fun thy defs s (preprocess thy (map fst eqns))
+              dep thyname thyname' gr;
+            val (gr'', ps) = foldl_map
+              (invoke_codegen thy defs dep thyname true) (gr', ts);
+          in SOME (gr'', mk_app brack (Pretty.str id) ps)
           end)
     | _ => NONE);
 
 val setup =
   [add_codegen "inductive" inductive_codegen,
    CodegenData.init,
-   add_attribute "ind" (Scan.succeed add)];
+   add_attribute "ind"
+     (Scan.option (Args.$$$ "target" |-- Args.colon |-- Args.name) >> add)];
 
 end;