src/HOL/Matrix/eq_codegen.ML
changeset 15178 5f621aa35c25
child 15531 08c8dad8e399
equal deleted inserted replaced
15177:e7616269fdca 15178:5f621aa35c25
       
     1 fun inst_cterm inst ct = fst (Drule.dest_equals
       
     2   (Thm.cprop_of (Thm.instantiate inst (reflexive ct))));
       
     3 fun tyinst_cterm tyinst = inst_cterm (tyinst, []);
       
     4 
       
     5 val bla = ref ([] : term list);
       
     6 
       
     7 (******************************************************)
       
     8 (*        Code generator for equational proofs        *)
       
     9 (******************************************************)
       
    10 fun my_mk_meta_eq thm =
       
    11   let
       
    12     val (_, eq) = Thm.dest_comb (cprop_of thm);
       
    13     val (ct, rhs) = Thm.dest_comb eq;
       
    14     val (_, lhs) = Thm.dest_comb ct
       
    15   in Thm.implies_elim (Drule.instantiate' [Some (ctyp_of_term lhs)]
       
    16     [Some lhs, Some rhs] eq_reflection) thm
       
    17   end; 
       
    18 
       
    19 structure SimprocsCodegen =
       
    20 struct
       
    21 
       
    22 val simp_thms = ref ([] : thm list);
       
    23 
       
    24 fun parens b = if b then Pretty.enclose "(" ")" else Pretty.block;
       
    25 
       
    26 fun gen_mk_val f xs ps = Pretty.block ([Pretty.str "val ",
       
    27   f (length xs > 1) (flat
       
    28     (separate [Pretty.str ",", Pretty.brk 1] (map (single o Pretty.str) xs))),
       
    29   Pretty.str " =", Pretty.brk 1] @ ps @ [Pretty.str ";"]);
       
    30 
       
    31 val mk_val = gen_mk_val parens;
       
    32 val mk_vall = gen_mk_val (K (Pretty.enclose "[" "]"));
       
    33 
       
    34 fun rename s = if s mem ThmDatabase.ml_reserved then s ^ "'" else s;
       
    35 
       
    36 fun mk_decomp_name (Var ((s, i), _)) = rename (if i=0 then s else s ^ string_of_int i)
       
    37   | mk_decomp_name (Const (s, _)) = rename (Codegen.mk_id (Sign.base_name s))
       
    38   | mk_decomp_name _ = "ct";
       
    39 
       
    40 fun decomp_term_code cn ((vs, bs, ps), (v, t)) =
       
    41   if exists (equal t o fst) bs then (vs, bs, ps)
       
    42   else (case t of
       
    43       Var _ => (vs, bs @ [(t, v)], ps)
       
    44     | Const _ => (vs, if cn then bs @ [(t, v)] else bs, ps)
       
    45     | Bound _ => (vs, bs, ps)
       
    46     | Abs (s, T, t) =>
       
    47       let
       
    48         val v1 = variant vs s;
       
    49         val v2 = variant (v1 :: vs) (mk_decomp_name t)
       
    50       in
       
    51         decomp_term_code cn ((v1 :: v2 :: vs,
       
    52           bs @ [(Free (s, T), v1)],
       
    53           ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1,
       
    54             Pretty.str "None", Pretty.brk 1, Pretty.str v]]), (v2, t))
       
    55       end
       
    56     | t $ u =>
       
    57       let
       
    58         val v1 = variant vs (mk_decomp_name t);
       
    59         val v2 = variant (v1 :: vs) (mk_decomp_name u);
       
    60         val (vs', bs', ps') = decomp_term_code cn ((v1 :: v2 :: vs, bs,
       
    61           ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_comb", Pretty.brk 1,
       
    62             Pretty.str v]]), (v1, t));
       
    63         val (vs'', bs'', ps'') = decomp_term_code cn ((vs', bs', ps'), (v2, u))
       
    64       in
       
    65         if bs'' = bs then (vs, bs, ps) else (vs'', bs'', ps'')
       
    66       end);
       
    67 
       
    68 val strip_tv = implode o tl o explode;
       
    69 
       
    70 fun mk_decomp_tname (TVar ((s, i), _)) =
       
    71       strip_tv ((if i=0 then s else s ^ string_of_int i) ^ "T")
       
    72   | mk_decomp_tname (Type (s, _)) = Codegen.mk_id (Sign.base_name s) ^ "T"
       
    73   | mk_decomp_tname _ = "cT";
       
    74 
       
    75 fun decomp_type_code ((vs, bs, ps), (v, TVar (ixn, _))) =
       
    76       if exists (equal ixn o fst) bs then (vs, bs, ps)
       
    77       else (vs, bs @ [(ixn, v)], ps)
       
    78   | decomp_type_code ((vs, bs, ps), (v, Type (_, Ts))) =
       
    79       let
       
    80         val vs' = variantlist (map mk_decomp_tname Ts, vs);
       
    81         val (vs'', bs', ps') =
       
    82           foldl decomp_type_code ((vs @ vs', bs, ps @
       
    83             [mk_vall vs' [Pretty.str "Thm.dest_ctyp", Pretty.brk 1,
       
    84               Pretty.str v]]), vs' ~~ Ts)
       
    85       in
       
    86         if bs' = bs then (vs, bs, ps) else (vs'', bs', ps')
       
    87       end;
       
    88 
       
    89 fun gen_mk_bindings s dest decomp ((vs, bs, ps), (v, x)) =
       
    90   let
       
    91     val s' = variant vs s;
       
    92     val (vs', bs', ps') = decomp ((s' :: vs, bs, ps @
       
    93       [mk_val [s'] (dest v)]), (s', x))
       
    94   in
       
    95     if bs' = bs then (vs, bs, ps) else (vs', bs', ps')
       
    96   end;
       
    97 
       
    98 val mk_term_bindings = gen_mk_bindings "ct"
       
    99   (fn s => [Pretty.str "cprop_of", Pretty.brk 1, Pretty.str s])
       
   100   (decomp_term_code true);
       
   101 
       
   102 val mk_type_bindings = gen_mk_bindings "cT"
       
   103   (fn s => [Pretty.str "Thm.ctyp_of_term", Pretty.brk 1, Pretty.str s])
       
   104   decomp_type_code;
       
   105 
       
   106 fun pretty_pattern b (Const (s, _)) = Pretty.block [Pretty.str "Const",
       
   107       Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\", _)")]
       
   108   | pretty_pattern b (t as _ $ _) = parens b
       
   109       (flat (separate [Pretty.str " $", Pretty.brk 1]
       
   110         (map (single o pretty_pattern true) (op :: (strip_comb t)))))
       
   111   | pretty_pattern b _ = Pretty.str "_";
       
   112 
       
   113 fun term_consts' t = foldl_aterms
       
   114   (fn (cs, c as Const _) => c ins cs | (cs, _) => cs) ([], t);
       
   115 
       
   116 fun mk_apps s b p [] = p
       
   117   | mk_apps s b p (q :: qs) = 
       
   118       mk_apps s b (parens (b orelse not (null qs))
       
   119         [Pretty.str s, Pretty.brk 1, p, Pretty.brk 1, q]) qs;
       
   120 
       
   121 fun mk_refleq eq ct = mk_val [eq] [Pretty.str ("Thm.reflexive " ^ ct)];
       
   122 
       
   123 fun mk_tyinst ((s, i), s') =
       
   124   Pretty.block [Pretty.str ("((" ^ quote s ^ ","), Pretty.brk 1,
       
   125     Pretty.str (string_of_int i ^ "),"), Pretty.brk 1,
       
   126     Pretty.str (s' ^ ")")];
       
   127 
       
   128 fun inst_ty b ty_bs t s = (case term_tvars t of
       
   129     [] => Pretty.str s
       
   130   | Ts => parens b [Pretty.str "tyinst_cterm", Pretty.brk 1,
       
   131       Pretty.list "[" "]" (map (fn (ixn, _) => mk_tyinst
       
   132         (ixn, the (assoc (ty_bs, ixn)))) Ts),
       
   133       Pretty.brk 1, Pretty.str s]);
       
   134 
       
   135 fun mk_cterm_code b ty_bs ts xs (vals, t $ u) =
       
   136       let
       
   137         val (vals', p1) = mk_cterm_code true ty_bs ts xs (vals, t);
       
   138         val (vals'', p2) = mk_cterm_code true ty_bs ts xs (vals', u)
       
   139       in
       
   140         (vals'', parens b [Pretty.str "Thm.capply", Pretty.brk 1,
       
   141           p1, Pretty.brk 1, p2])
       
   142       end
       
   143   | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) =
       
   144       let
       
   145         val u = Free (s, T);
       
   146         val Some s' = assoc (ts, u);
       
   147         val p = Pretty.str s';
       
   148         val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs)
       
   149           (if null (typ_tvars T) then vals
       
   150            else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t)
       
   151       in (vals',
       
   152         parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p'])
       
   153       end
       
   154   | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs))
       
   155   | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of
       
   156         None =>
       
   157           let val Some s = assoc (ts, t)
       
   158           in (if is_Const t andalso not (null (term_tvars t)) then
       
   159               vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))]
       
   160             else vals, Pretty.str s)
       
   161           end
       
   162       | Some ((_, s), _) => (vals, Pretty.str s));
       
   163 
       
   164 fun get_cases sg =
       
   165   Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new
       
   166     ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop
       
   167       (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab))
       
   168         (Symtab.empty, DatatypePackage.get_datatypes_sg sg);
       
   169 
       
   170 fun decomp_case th =
       
   171   let
       
   172     val (lhs, _) = Logic.dest_equals (prop_of th);
       
   173     val (f, ts) = strip_comb lhs;
       
   174     val (us, u) = split_last ts;
       
   175     val (Const (s, _), vs) = strip_comb u
       
   176   in (us, s, vs, u) end;
       
   177 
       
   178 fun rename vs t =
       
   179   let
       
   180     fun mk_subst ((vs, subs), Var ((s, i), T)) =
       
   181       let val s' = variant vs s
       
   182       in if s = s' then (vs, subs)
       
   183         else (s' :: vs, ((s, i), Var ((s', i), T)) :: subs)
       
   184       end;
       
   185     val (vs', subs) = foldl mk_subst ((vs, []), term_vars t)
       
   186   in (vs', subst_Vars subs t) end;
       
   187 
       
   188 fun is_instance sg t u = t = subst_TVars_Vartab
       
   189   (Type.typ_match (Sign.tsig_of sg) (Vartab.empty,
       
   190     (fastype_of u, fastype_of t))) u handle Type.TYPE_MATCH => false;
       
   191 
       
   192 (*
       
   193 fun lookup sg fs t = apsome snd (Library.find_first
       
   194   (is_instance sg t o fst) fs);
       
   195 *)
       
   196 
       
   197 fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of
       
   198     None => (bla := (t ins !bla); None)
       
   199   | Some (_, x) => Some x);
       
   200 
       
   201 fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t);
       
   202 
       
   203 fun mk_let s i xs ys =
       
   204   Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)),
       
   205     Pretty.fbrk,
       
   206     Pretty.blk (i, ([Pretty.str "in", Pretty.fbrk] @ ys)),
       
   207     Pretty.fbrk, Pretty.str "end"]);
       
   208 
       
   209 (*****************************************************************************)
       
   210 (* Generate bindings for simplifying term t                                  *)
       
   211 (* mkeq: whether to generate reflexivity theorem for uninterpreted terms     *)
       
   212 (* fs:   interpreted functions                                               *)
       
   213 (* ts:   atomic terms                                                        *)
       
   214 (* vs:   used identifiers                                                    *)
       
   215 (* vals: list of bindings of the form ((eq, ct), ps) where                   *)
       
   216 (*       eq: name of equational theorem                                      *)
       
   217 (*       ct: name of simplified cterm                                        *)
       
   218 (*       ps: ML code for creating the above two items                        *)
       
   219 (*****************************************************************************)
       
   220 
       
   221 fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) =
       
   222   (case assoc (vals, t) of
       
   223     Some ((eq, ct), ps) =>  (* binding already generated *) 
       
   224       if mkeq andalso eq="" then
       
   225         let val eq' = variant vs "eq"
       
   226         in ((eq' :: vs, overwrite (vals,
       
   227           (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct))
       
   228         end
       
   229       else ((vs, vals), (eq, ct))
       
   230   | None => (case assoc (ts, t) of
       
   231       Some v =>  (* atomic term *)
       
   232         let val xs = if not (null (term_tvars t)) andalso is_Const t then
       
   233           [mk_val [v] [inst_ty false ty_bs t v]] else []
       
   234         in
       
   235           if mkeq then
       
   236             let val eq = variant vs "eq"
       
   237             in ((eq :: vs, vals @
       
   238               [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v))
       
   239             end
       
   240           else ((vs, if null xs then vals else vals @
       
   241             [(t, (("", v), xs))]), ("", v))
       
   242         end
       
   243     | None =>  (* complex term *)
       
   244         let val (f as Const (cname, _), us) = strip_comb t
       
   245         in case Symtab.lookup (case_tab, cname) of
       
   246             Some cases =>  (* case expression *)
       
   247               let
       
   248                 val (us', u) = split_last us;
       
   249                 val b = unint sg fs u;
       
   250                 val ((vs1, vals1), (eq, ct)) =
       
   251                   mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u);
       
   252                 val xs = variantlist (replicate (length us') "f", vs1);
       
   253                 val (vals2, ps) = foldl_map
       
   254                   (mk_cterm_code false ty_bs ts []) (vals1, us');
       
   255                 val fvals = map (fn (x, p) => mk_val [x] [p]) (xs ~~ ps);
       
   256                 val uT = fastype_of u;
       
   257                 val (us'', _, _, u') = decomp_case (hd cases);
       
   258                 val (vs2, ty_bs', ty_vals) = mk_type_bindings
       
   259                   (mk_type_bindings ((vs1 @ xs, [], []),
       
   260                     (hd xs, fastype_of (hd us''))), (ct, fastype_of u'));
       
   261                 val insts1 = map mk_tyinst ty_bs';
       
   262                 val i = length vals2;
       
   263    
       
   264                 fun mk_case_code ((vs, vals), (f, (name, eqn))) =
       
   265                   let
       
   266                     val (fvs, cname, cvs, _) = decomp_case eqn;
       
   267                     val Ts = binder_types (fastype_of f);
       
   268                     val ys = variantlist (map (fst o fst o dest_Var) cvs, vs);
       
   269                     val cvs' = map Var (map (rpair 0) ys ~~ Ts);
       
   270                     val rs = cvs' ~~ cvs;
       
   271                     val lhs = list_comb (Const (cname, Ts ---> uT), cvs');
       
   272                     val rhs = foldl betapply (f, cvs');
       
   273                     val (vs', tm_bs, tm_vals) = decomp_term_code false
       
   274                       ((vs @ ys, [], []), (ct, lhs));
       
   275                     val ((vs'', all_vals), (eq', ct')) = mk_simpl_code sg case_tab
       
   276                       false fs (tm_bs @ ts) ty_bs thm_bs ((vs', vals), rhs);
       
   277                     val (old_vals, eq_vals) = splitAt (i, all_vals);
       
   278                     val vs''' = vs @ filter (fn v => exists
       
   279                       (fn (_, ((v', _), _)) => v = v') old_vals) (vs'' \\ vs');
       
   280                     val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
       
   281                       inst_ty false ty_bs' t (the (assoc (thm_bs, t))), Pretty.str ",",
       
   282                       Pretty.brk 1, Pretty.str (s ^ ")")]) ((fvs ~~ xs) @
       
   283                         (map (fn (v, s) => (the (assoc (rs, v)), s)) tm_bs));
       
   284                     val eq'' = if null insts1 andalso null insts2 then Pretty.str name
       
   285                       else parens (eq' <> "") [Pretty.str
       
   286                           (if null cvs then "Thm.instantiate" else "Drule.instantiate"),
       
   287                         Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
       
   288                         Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
       
   289                         Pretty.str ")", Pretty.brk 1, Pretty.str name];
       
   290                     val eq''' = if eq' = "" then eq'' else
       
   291                       Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
       
   292                         eq'', Pretty.brk 1, Pretty.str eq']
       
   293                   in
       
   294                     ((vs''', old_vals), Pretty.block [pretty_pattern false lhs,
       
   295                       Pretty.str " =>",
       
   296                       Pretty.brk 1, mk_let "let" 2 (tm_vals @ flat (map (snd o snd) eq_vals))
       
   297                         [Pretty.str ("(" ^ ct' ^ ","), Pretty.brk 1, eq''', Pretty.str ")"]])
       
   298                   end;
       
   299 
       
   300                 val case_names = map (fn i => Sign.base_name cname ^ "_" ^
       
   301                   string_of_int i) (1 upto length cases);
       
   302                 val ((vs3, vals3), case_ps) = foldl_map mk_case_code
       
   303                   ((vs2, vals2), us' ~~ (case_names ~~ cases));
       
   304                 val eq' = variant vs3 "eq";
       
   305                 val ct' = variant (eq' :: vs3) "ct";
       
   306                 val eq'' = variant (eq' :: ct' :: vs3) "eq";
       
   307                 val case_vals =
       
   308                   fvals @ ty_vals @
       
   309                   [mk_val [ct', eq'] ([Pretty.str "(case", Pretty.brk 1,
       
   310                     Pretty.str ("term_of " ^ ct ^ " of"), Pretty.brk 1] @
       
   311                     flat (separate [Pretty.brk 1, Pretty.str "| "]
       
   312                       (map single case_ps)) @ [Pretty.str ")"])]
       
   313               in
       
   314                 if b then
       
   315                   ((eq' :: ct' :: vs3, vals3 @
       
   316                      [(t, ((eq', ct'), case_vals))]), (eq', ct'))
       
   317                 else
       
   318                   let val ((vs4, vals4), (_, ctcase)) = mk_simpl_code sg case_tab false
       
   319                     fs ts ty_bs thm_bs ((eq' :: eq'' :: ct' :: vs3, vals3), f)
       
   320                   in
       
   321                     ((vs4, vals4 @ [(t, ((eq'', ct'), case_vals @
       
   322                        [mk_val [eq''] [Pretty.str "Thm.transitive", Pretty.brk 1,
       
   323                           Pretty.str "(Thm.combination", Pretty.brk 1,
       
   324                           Pretty.str "(Thm.reflexive", Pretty.brk 1,
       
   325                           mk_apps "Thm.capply" true (Pretty.str ctcase)
       
   326                             (map Pretty.str xs),
       
   327                           Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"),
       
   328                           Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct'))
       
   329                   end
       
   330               end
       
   331           
       
   332           | None =>
       
   333             let
       
   334               val b = forall (unint sg fs) us;
       
   335               val (q, eqs) = foldl_map
       
   336                 (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us);
       
   337               val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b
       
   338                 then (q, ("", ""))
       
   339                 else mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs (q, f);
       
   340               val ct = variant vs' "ct";
       
   341               val eq = variant (ct :: vs') "eq";
       
   342               val ctv = mk_val [ct] [mk_apps "Thm.capply" false
       
   343                 (Pretty.str ctf) (map (Pretty.str o snd) eqs)];
       
   344               fun combp b = mk_apps "Thm.combination" b
       
   345                 (Pretty.str eqf) (map (Pretty.str o fst) eqs)
       
   346             in
       
   347               case (lookup sg fs f, b) of
       
   348                 (None, true) =>  (* completely uninterpreted *)
       
   349                   if mkeq then ((ct :: eq :: vs', vals' @
       
   350                     [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct))
       
   351                   else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct))
       
   352               | (None, false) =>  (* function uninterpreted *)
       
   353                   ((eq :: ct :: vs', vals' @
       
   354                      [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct))
       
   355               | (Some (s, _, _), true) =>  (* arguments uninterpreted *)
       
   356                   ((eq :: ct :: vs', vals' @
       
   357                      [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1)
       
   358                        (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct))
       
   359               | (Some (s, _, _), false) =>  (* function and arguments interpreted *)
       
   360                   let val eq' = variant (eq :: ct :: vs') "eq"
       
   361                   in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct),
       
   362                     [mk_val [ct, eq] (separate (Pretty.brk 1)
       
   363                        (Pretty.str s :: map (Pretty.str o snd) eqs)),
       
   364                      mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1,
       
   365                        combp true, Pretty.brk 1, Pretty.str eq]]))]), (eq', ct))
       
   366                   end
       
   367             end
       
   368         end));
       
   369 
       
   370 fun lhs_of thm = fst (Logic.dest_equals (prop_of thm));
       
   371 fun rhs_of thm = snd (Logic.dest_equals (prop_of thm));
       
   372 
       
   373 fun mk_funs_code sg case_tab fs fs' =
       
   374   let
       
   375     val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of
       
   376         None => None
       
   377       | Some thms => Some (unsuffix "_case" (Sign.base_name s) ^ ".cases",
       
   378           map (fn i => Sign.base_name s ^ "_" ^ string_of_int i)
       
   379             (1 upto length thms) ~~ thms)))
       
   380       (foldr add_term_consts (map (prop_of o snd)
       
   381         (flat (map (#3 o snd) fs')), []));
       
   382     val case_vals = map (fn (s, cs) => mk_vall (map fst cs)
       
   383       [Pretty.str "map my_mk_meta_eq", Pretty.brk 1,
       
   384        Pretty.str ("(thms \"" ^ s ^ "\")")]) case_thms;
       
   385     val (vs, thm_bs, thm_vals) = foldl mk_term_bindings (([], [], []),
       
   386       flat (map (map (apsnd prop_of) o #3 o snd) fs') @
       
   387       map (apsnd prop_of) (flat (map snd case_thms)));
       
   388 
       
   389     fun mk_fun_code (prfx, (fname, d, eqns)) =
       
   390       let
       
   391         val (f, ts) = strip_comb (lhs_of (snd (hd eqns)));
       
   392         val args = variantlist (replicate (length ts) "ct", vs);
       
   393         val (vs', ty_bs, ty_vals) = foldl mk_type_bindings
       
   394           ((vs @ args, [], []), args ~~ map fastype_of ts);
       
   395         val insts1 = map mk_tyinst ty_bs;
       
   396 
       
   397         fun mk_eqn_code (name, eqn) =
       
   398           let
       
   399             val (_, argts) = strip_comb (lhs_of eqn);
       
   400             val (vs'', tm_bs, tm_vals) = foldl (decomp_term_code false)
       
   401               ((vs', [], []), args ~~ argts);
       
   402             val ((vs''', eq_vals), (eq, ct)) = mk_simpl_code sg case_tab false fs
       
   403               (tm_bs @ filter_out (is_Var o fst) thm_bs) ty_bs thm_bs
       
   404               ((vs'', []), rhs_of eqn);
       
   405             val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
       
   406               inst_ty false ty_bs t (the (assoc (thm_bs, t))), Pretty.str ",", Pretty.brk 1,
       
   407               Pretty.str (s ^ ")")]) tm_bs
       
   408             val eq' = if null insts1 andalso null insts2 then Pretty.str name
       
   409               else parens (eq <> "") [Pretty.str "Thm.instantiate",
       
   410                 Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
       
   411                 Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
       
   412                 Pretty.str ")", Pretty.brk 1, Pretty.str name];
       
   413             val eq'' = if eq = "" then eq' else
       
   414               Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
       
   415                 eq', Pretty.brk 1, Pretty.str eq]
       
   416           in
       
   417             Pretty.block [parens (length argts > 1)
       
   418                 (Pretty.commas (map (pretty_pattern false) argts)),
       
   419               Pretty.str " =>",
       
   420               Pretty.brk 1, mk_let "let" 2 (ty_vals @ tm_vals @ flat (map (snd o snd) eq_vals))
       
   421                 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]]
       
   422           end;
       
   423 
       
   424         val default = if d then
       
   425             let
       
   426               val Some s = assoc (thm_bs, f);
       
   427               val ct = variant vs' "ct"
       
   428             in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1,
       
   429               Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2
       
   430                 (ty_vals @ (if null (term_tvars f) then [] else
       
   431                    [mk_val [s] [inst_ty false ty_bs f s]]) @
       
   432                  [mk_val [ct] [mk_apps "Thm.capply" false (Pretty.str s)
       
   433                     (map Pretty.str args)]])
       
   434                 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1,
       
   435                  Pretty.str "Thm.reflexive", Pretty.brk 1, Pretty.str (ct ^ ")")]]
       
   436             end
       
   437           else []
       
   438       in
       
   439         ("and ", Pretty.block (separate (Pretty.brk 1)
       
   440             (Pretty.str (prfx ^ fname) :: map Pretty.str args) @
       
   441           [Pretty.str " =", Pretty.brk 1, Pretty.str "(case", Pretty.brk 1,
       
   442            Pretty.list "(" ")" (map (fn s => Pretty.str ("term_of " ^ s)) args),
       
   443            Pretty.str " of", Pretty.brk 1] @
       
   444           flat (separate [Pretty.brk 1, Pretty.str "| "]
       
   445             (map (single o mk_eqn_code) eqns)) @ [Pretty.str ")"] @ default))
       
   446       end;
       
   447 
       
   448     val (_, decls) = foldl_map mk_fun_code ("fun ", map snd fs')
       
   449   in
       
   450     mk_let "local" 2 (case_vals @ thm_vals) (separate Pretty.fbrk decls)
       
   451   end;
       
   452 
       
   453 fun mk_simprocs_code sg eqns =
       
   454   let
       
   455     val case_tab = get_cases sg;
       
   456     fun get_head th = head_of (fst (Logic.dest_equals (prop_of th)));
       
   457     fun attach_term (x as (_, _, (_, th) :: _)) = (get_head th, x);
       
   458     val eqns' = map attach_term eqns;
       
   459     fun mk_node (s, _, (_, th) :: _) = (s, get_head th);
       
   460     fun mk_edges (s, _, ths) = map (pair s) (distinct
       
   461       (mapfilter (fn t => apsome #1 (lookup sg eqns' t))
       
   462         (flat (map (term_consts' o prop_of o snd) ths))));
       
   463     val gr = foldr (uncurry Graph.add_edge)
       
   464       (map (pair "" o #1) eqns @ flat (map mk_edges eqns),
       
   465        foldr (uncurry Graph.new_node)
       
   466          (("", Bound 0) :: map mk_node eqns, Graph.empty));
       
   467     val keys = rev (Graph.all_succs gr [""] \ "");
       
   468     fun gr_ord (x :: _, y :: _) =
       
   469       int_ord (find_index (equal x) keys, find_index (equal y) keys);
       
   470     val scc = map (fn xs => filter (fn (_, (s, _, _)) => s mem xs) eqns')
       
   471       (sort gr_ord (Graph.strong_conn gr \ [""]));
       
   472   in
       
   473     flat (separate [Pretty.str ";", Pretty.fbrk, Pretty.str " ", Pretty.fbrk]
       
   474       (map (fn eqns'' => [mk_funs_code sg case_tab eqns' eqns'']) scc)) @
       
   475     [Pretty.str ";", Pretty.fbrk]
       
   476   end;
       
   477 
       
   478 fun use_simprocs_code sg eqns =
       
   479   let
       
   480     fun attach_name (i, x) = (i+1, ("simp_thm_" ^ string_of_int i, x));
       
   481     fun attach_names (i, (s, b, eqs)) =
       
   482       let val (i', eqs') = foldl_map attach_name (i, eqs)
       
   483       in (i', (s, b, eqs')) end;
       
   484     val (_, eqns') = foldl_map attach_names (1, eqns);
       
   485     val (names, thms) = split_list (flat (map #3 eqns'));
       
   486     val s = setmp print_mode [] Pretty.string_of
       
   487       (mk_let "local" 2 [mk_vall names [Pretty.str "!SimprocsCodegen.simp_thms"]]
       
   488         (mk_simprocs_code sg eqns'))
       
   489   in
       
   490     (simp_thms := thms; use_text Context.ml_output false s)
       
   491   end;
       
   492 
       
   493 end;