src/HOL/Tools/inductive_codegen.ML
changeset 28537 1e84256d1a8a
parent 27809 a1e409db516b
child 29265 5b4247055bd7
--- a/src/HOL/Tools/inductive_codegen.ML	Thu Oct 09 08:47:26 2008 +0200
+++ b/src/HOL/Tools/inductive_codegen.ML	Thu Oct 09 08:47:27 2008 +0200
@@ -302,11 +302,11 @@
   end;
 
 fun modename module s (iss, is) gr =
-  let val (gr', id) = if s = "op =" then (gr, ("", "equal"))
+  let val (id, gr') = if s = @{const_name "op ="} then (("", "equal"), gr)
     else mk_const_id module s gr
-  in (gr', space_implode "__"
+  in (space_implode "__"
     (mk_qual_id module id ::
-      map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is])))
+      map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is])), gr')
   end;
 
 fun mk_funcomp brack s k p = (if brack then parens else I)
@@ -314,34 +314,34 @@
     separate (Pretty.brk 1) (str s :: replicate k (str "|> ???")) @
     (if k = 0 then [] else [str ")"])), Pretty.brk 1, p]);
 
-fun compile_expr thy defs dep module brack modes (gr, (NONE, t)) =
-      apsnd single (invoke_codegen thy defs dep module brack (gr, t))
-  | compile_expr _ _ _ _ _ _ (gr, (SOME _, Var ((name, _), _))) =
-      (gr, [str name])
-  | compile_expr thy defs dep module brack modes (gr, (SOME (Mode (mode, _, ms)), t)) =
+fun compile_expr thy defs dep module brack modes (NONE, t) gr =
+      apfst single (invoke_codegen thy defs dep module brack t gr)
+  | compile_expr _ _ _ _ _ _ (SOME _, Var ((name, _), _)) gr =
+      ([str name], gr)
+  | compile_expr thy defs dep module brack modes (SOME (Mode (mode, _, ms)), t) gr =
       (case strip_comb t of
          (Const (name, _), args) =>
-           if name = "op =" orelse AList.defined op = modes name then
+           if name = @{const_name "op ="} orelse AList.defined op = modes name then
              let
                val (args1, args2) = chop (length ms) args;
-               val (gr', (ps, mode_id)) = foldl_map
-                   (compile_expr thy defs dep module true modes) (gr, ms ~~ args1) |>>>
-                 modename module name mode;
-               val (gr'', ps') = (case mode of
-                   ([], []) => (gr', [str "()"])
-                 | _ => foldl_map
-                     (invoke_codegen thy defs dep module true) (gr', args2))
-             in (gr', (if brack andalso not (null ps andalso null ps') then
+               val ((ps, mode_id), gr') = gr |> fold_map
+                   (compile_expr thy defs dep module true modes) (ms ~~ args1)
+                   ||>> modename module name mode;
+               val (ps', gr'') = (case mode of
+                   ([], []) => ([str "()"], gr')
+                 | _ => fold_map
+                     (invoke_codegen thy defs dep module true) args2 gr')
+             in ((if brack andalso not (null ps andalso null ps') then
                single o parens o Pretty.block else I)
                  (List.concat (separate [Pretty.brk 1]
-                   ([str mode_id] :: ps @ map single ps'))))
+                   ([str mode_id] :: ps @ map single ps'))), gr')
              end
-           else apsnd (single o mk_funcomp brack "??" (length (binder_types (fastype_of t))))
-             (invoke_codegen thy defs dep module true (gr, t))
-       | _ => apsnd (single o mk_funcomp brack "??" (length (binder_types (fastype_of t))))
-           (invoke_codegen thy defs dep module true (gr, t)));
+           else apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t))))
+             (invoke_codegen thy defs dep module true t gr)
+       | _ => apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t))))
+           (invoke_codegen thy defs dep module true t gr));
 
-fun compile_clause thy defs gr dep module all_vs arg_vs modes (iss, is) (ts, ps) inp =
+fun compile_clause thy defs dep module all_vs arg_vs modes (iss, is) (ts, ps) inp gr =
   let
     val modes' = modes @ List.mapPartial
       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
@@ -352,32 +352,32 @@
         let val s = Name.variant names "x";
         in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
 
-    fun compile_eq (gr, (s, t)) =
-      apsnd (Pretty.block o cons (str (s ^ " = ")) o single)
-        (invoke_codegen thy defs dep module false (gr, t));
+    fun compile_eq (s, t) gr =
+      apfst (Pretty.block o cons (str (s ^ " = ")) o single)
+        (invoke_codegen thy defs dep module false t gr);
 
     val (in_ts, out_ts) = get_args is 1 ts;
     val ((all_vs', eqs), in_ts') =
       foldl_map check_constrt ((all_vs, []), in_ts);
 
-    fun compile_prems out_ts' vs names gr [] =
+    fun compile_prems out_ts' vs names [] gr =
           let
-            val (gr2, out_ps) = foldl_map
-              (invoke_codegen thy defs dep module false) (gr, out_ts);
-            val (gr3, eq_ps) = foldl_map compile_eq (gr2, eqs);
+            val (out_ps, gr2) = fold_map
+              (invoke_codegen thy defs dep module false) out_ts gr;
+            val (eq_ps, gr3) = fold_map compile_eq eqs gr2;
             val ((names', eqs'), out_ts'') =
               foldl_map check_constrt ((names, []), out_ts');
             val (nvs, out_ts''') = foldl_map distinct_v
               ((names', map (fn x => (x, [x])) vs), out_ts'');
-            val (gr4, out_ps') = foldl_map
-              (invoke_codegen thy defs dep module false) (gr3, out_ts''');
-            val (gr5, eq_ps') = foldl_map compile_eq (gr4, eqs')
+            val (out_ps', gr4) = fold_map
+              (invoke_codegen thy defs dep module false) (out_ts''') gr3;
+            val (eq_ps', gr5) = fold_map compile_eq eqs' gr4;
           in
-            (gr5, compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
+            (compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
               (Pretty.block [str "DSeq.single", Pretty.brk 1, mk_tuple out_ps])
-              (exists (not o is_exhaustive) out_ts'''))
+              (exists (not o is_exhaustive) out_ts'''), gr5)
           end
-      | compile_prems out_ts vs names gr ps =
+      | compile_prems out_ts vs names ps gr =
           let
             val vs' = distinct (op =) (List.concat (vs :: map term_vs out_ts));
             val SOME (p, mode as SOME (Mode (_, js, _))) =
@@ -387,77 +387,77 @@
               foldl_map check_constrt ((names, []), out_ts);
             val (nvs, out_ts'') = foldl_map distinct_v
               ((names', map (fn x => (x, [x])) vs), out_ts');
-            val (gr0, out_ps) = foldl_map
-              (invoke_codegen thy defs dep module false) (gr, out_ts'');
-            val (gr1, eq_ps) = foldl_map compile_eq (gr0, eqs)
+            val (out_ps, gr0) = fold_map
+              (invoke_codegen thy defs dep module false) out_ts'' gr;
+            val (eq_ps, gr1) = fold_map compile_eq eqs gr0;
           in
             (case p of
                Prem (us, t, is_set) =>
                  let
                    val (in_ts, out_ts''') = get_args js 1 us;
-                   val (gr2, in_ps) = foldl_map
-                     (invoke_codegen thy defs dep module true) (gr1, in_ts);
-                   val (gr3, ps) =
+                   val (in_ps, gr2) = fold_map
+                     (invoke_codegen thy defs dep module true) in_ts gr1;
+                   val (ps, gr3) =
                      if not is_set then
-                       apsnd (fn ps => ps @
+                       apfst (fn ps => ps @
                            (if null in_ps then [] else [Pretty.brk 1]) @
                            separate (Pretty.brk 1) in_ps)
                          (compile_expr thy defs dep module false modes
-                           (gr2, (mode, t)))
+                           (mode, t) gr2)
                      else
-                       apsnd (fn p => [str "DSeq.of_list", Pretty.brk 1, p])
-                           (invoke_codegen thy defs dep module true (gr2, t));
-                   val (gr4, rest) = compile_prems out_ts''' vs' (fst nvs) gr3 ps';
+                       apfst (fn p => [str "DSeq.of_list", Pretty.brk 1, p])
+                           (invoke_codegen thy defs dep module true t gr2);
+                   val (rest, gr4) = compile_prems out_ts''' vs' (fst nvs) ps' gr3;
                  in
-                   (gr4, compile_match (snd nvs) eq_ps out_ps
+                   (compile_match (snd nvs) eq_ps out_ps
                       (Pretty.block (ps @
                          [str " :->", Pretty.brk 1, rest]))
-                      (exists (not o is_exhaustive) out_ts''))
+                      (exists (not o is_exhaustive) out_ts''), gr4)
                  end
              | Sidecond t =>
                  let
-                   val (gr2, side_p) = invoke_codegen thy defs dep module true (gr1, t);
-                   val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
+                   val (side_p, gr2) = invoke_codegen thy defs dep module true t gr1;
+                   val (rest, gr3) = compile_prems [] vs' (fst nvs) ps' gr2;
                  in
-                   (gr3, compile_match (snd nvs) eq_ps out_ps
+                   (compile_match (snd nvs) eq_ps out_ps
                       (Pretty.block [str "?? ", side_p,
                         str " :->", Pretty.brk 1, rest])
-                      (exists (not o is_exhaustive) out_ts''))
+                      (exists (not o is_exhaustive) out_ts''), gr3)
                  end)
           end;
 
-    val (gr', prem_p) = compile_prems in_ts' arg_vs all_vs' gr ps;
+    val (prem_p, gr') = compile_prems in_ts' arg_vs all_vs' ps gr ;
   in
-    (gr', Pretty.block [str "DSeq.single", Pretty.brk 1, inp,
-       str " :->", Pretty.brk 1, prem_p])
+    (Pretty.block [str "DSeq.single", Pretty.brk 1, inp,
+       str " :->", Pretty.brk 1, prem_p], gr')
   end;
 
-fun compile_pred thy defs gr dep module prfx all_vs arg_vs modes s cls mode =
+fun compile_pred thy defs dep module prfx all_vs arg_vs modes s cls mode gr =
   let
     val xs = map str (Name.variant_list arg_vs
       (map (fn i => "x" ^ string_of_int i) (snd mode)));
-    val (gr', (cl_ps, mode_id)) =
-      foldl_map (fn (gr, cl) => compile_clause thy defs
-        gr dep module all_vs arg_vs modes mode cl (mk_tuple xs)) (gr, cls) |>>>
+    val ((cl_ps, mode_id), gr') = gr |>
+      fold_map (fn cl => compile_clause thy defs
+        dep module all_vs arg_vs modes mode cl (mk_tuple xs)) cls ||>>
       modename module s mode
   in
-    ((gr', "and "), Pretty.block
+    (Pretty.block
       ([Pretty.block (separate (Pretty.brk 1)
          (str (prfx ^ mode_id) ::
            map str arg_vs @
            (case mode of ([], []) => [str "()"] | _ => xs)) @
          [str " ="]),
         Pretty.brk 1] @
-       List.concat (separate [str " ++", Pretty.brk 1] (map single cl_ps))))
+       List.concat (separate [str " ++", Pretty.brk 1] (map single cl_ps))), (gr', "and "))
   end;
 
-fun compile_preds thy defs gr dep module all_vs arg_vs modes preds =
-  let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
-    foldl_map (fn ((gr', prfx'), mode) => compile_pred thy defs gr'
-      dep module prfx' all_vs arg_vs modes s cls mode)
-        ((gr, prfx), ((the o AList.lookup (op =) modes) s))) ((gr, "fun "), preds)
+fun compile_preds thy defs dep module all_vs arg_vs modes preds gr =
+  let val (prs, (gr', _)) = fold_map (fn (s, cls) =>
+    fold_map (fn mode => fn (gr', prfx') => compile_pred thy defs
+      dep module prfx' all_vs arg_vs modes s cls mode gr')
+        (((the o AList.lookup (op =) modes) s))) preds (gr, "fun ")
   in
-    (gr', space_implode "\n\n" (map string_of (List.concat prs)) ^ ";\n\n")
+    (space_implode "\n\n" (map string_of (List.concat prs)) ^ ";\n\n", gr')
   end;
 
 (**** processing of introduction rules ****)
@@ -543,8 +543,8 @@
         (infer_modes thy extra_modes arities arg_vs clauses);
       val _ = print_arities arities;
       val _ = print_modes modes;
-      val (gr'', s) = compile_preds thy defs gr' (hd names) module (terms_vs intrs)
-        arg_vs (modes @ extra_modes) clauses;
+      val (s, gr'') = compile_preds thy defs (hd names) module (terms_vs intrs)
+        arg_vs (modes @ extra_modes) clauses gr';
     in
       (map_node (hd names)
         (K (SOME (Modes (modes, arities)), module, s)) gr'')
@@ -556,7 +556,7 @@
        ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is))
    | mode => mode);
 
-fun mk_ind_call thy defs gr dep module is_query s T ts names thyname k intrs =
+fun mk_ind_call thy defs dep module is_query s T ts names thyname k intrs gr =
   let
     val (ts1, ts2) = chop k ts;
     val u = list_comb (Const (s, T), ts1);
@@ -574,13 +574,11 @@
         fst (Library.foldl mk_mode ((([], []), 1), ts2))
       else (ts2, 1 upto length (binder_types T) - k);
     val mode = find_mode gr1 dep s u modes is;
-    val (gr2, in_ps) = foldl_map
-      (invoke_codegen thy defs dep module true) (gr1, ts');
-    val (gr3, ps) =
-      compile_expr thy defs dep module false modes (gr2, (mode, u))
+    val (in_ps, gr2) = fold_map (invoke_codegen thy defs dep module true) ts' gr1;
+    val (ps, gr3) = compile_expr thy defs dep module false modes (mode, u) gr2;
   in
-    (gr3, Pretty.block (ps @ (if null in_ps then [] else [Pretty.brk 1]) @
-       separate (Pretty.brk 1) in_ps))
+    (Pretty.block (ps @ (if null in_ps then [] else [Pretty.brk 1]) @
+       separate (Pretty.brk 1) in_ps), gr3)
   end;
 
 fun clause_of_eqn eqn =
@@ -602,8 +600,8 @@
       val arity = length (snd (strip_comb (fst (HOLogic.dest_eq
         (HOLogic.dest_Trueprop (concl_of (hd eqns)))))));
       val mode = 1 upto arity;
-      val (gr', (fun_id, mode_id)) = gr |>
-        mk_const_id module' name |>>>
+      val ((fun_id, mode_id), gr') = gr |>
+        mk_const_id module' name ||>>
         modename module' pname ([], mode);
       val vars = map (fn i => str ("x" ^ string_of_int i)) mode;
       val s = string_of (Pretty.block
@@ -617,9 +615,9 @@
       val (modes, _) = lookup_modes gr'' dep;
       val _ = find_mode gr'' dep pname (head_of (HOLogic.dest_Trueprop
         (Logic.strip_imp_concl (hd clauses)))) modes mode
-    in (gr'', mk_qual_id module fun_id) end
+    in (mk_qual_id module fun_id, gr'') end
   | SOME _ =>
-      (add_edge (name, dep) gr, mk_qual_id module (get_const_id name gr));
+      (mk_qual_id module (get_const_id gr name), add_edge (name, dep) gr);
 
 (* convert n-tuple to nested pairs *)
 
@@ -644,7 +642,7 @@
     else p
   end;
 
-fun inductive_codegen thy defs gr dep module brack t = (case strip_comb t of
+fun inductive_codegen thy defs dep module brack t gr  = (case strip_comb t of
     (Const ("Collect", _), [u]) =>
       let val (r, Ts, fs) = HOLogic.strip_split u
       in case strip_comb r of
@@ -661,11 +659,11 @@
                   if null (duplicates op = ots) andalso
                     no_loose ts1 andalso no_loose its
                   then
-                    let val (gr', call_p) = mk_ind_call thy defs gr dep module true
-                      s T (ts1 @ ts2') names thyname k intrs
-                    in SOME (gr', (if brack then parens else I) (Pretty.block
+                    let val (call_p, gr') = mk_ind_call thy defs dep module true
+                      s T (ts1 @ ts2') names thyname k intrs gr 
+                    in SOME ((if brack then parens else I) (Pretty.block
                       [str "DSeq.list_of", Pretty.brk 1, str "(",
-                       conv_ntuple fs ots call_p, str ")"]))
+                       conv_ntuple fs ots call_p, str ")"]), gr')
                     end
                   else NONE
                 end
@@ -676,21 +674,21 @@
       NONE => (case (get_clauses thy s, get_assoc_code thy (s, T)) of
         (SOME (names, thyname, k, intrs), NONE) =>
           if length ts < k then NONE else SOME
-            (let val (gr', call_p) = mk_ind_call thy defs gr dep module false
-               s T (map Term.no_dummy_patterns ts) names thyname k intrs
-             in (gr', mk_funcomp brack "?!"
-               (length (binder_types T) - length ts) (parens call_p))
-             end handle TERM _ => mk_ind_call thy defs gr dep module true
-               s T ts names thyname k intrs)
+            (let val (call_p, gr') = mk_ind_call thy defs dep module false
+               s T (map Term.no_dummy_patterns ts) names thyname k intrs gr
+             in (mk_funcomp brack "?!"
+               (length (binder_types T) - length ts) (parens call_p), gr')
+             end handle TERM _ => mk_ind_call thy defs dep module true
+               s T ts names thyname k intrs gr )
       | _ => NONE)
     | SOME eqns =>
         let
           val (_, thyname) :: _ = eqns;
-          val (gr', id) = mk_fun thy defs s (preprocess thy (map fst (rev eqns)))
+          val (id, gr') = mk_fun thy defs s (preprocess thy (map fst (rev eqns)))
             dep module (if_library thyname module) gr;
-          val (gr'', ps) = foldl_map
-            (invoke_codegen thy defs dep module true) (gr', ts);
-        in SOME (gr'', mk_app brack (str id) ps)
+          val (ps, gr'') = fold_map
+            (invoke_codegen thy defs dep module true) ts gr';
+        in SOME (mk_app brack (str id) ps, gr'')
         end)
   | _ => NONE);