src/HOL/Matrix/eq_codegen.ML
changeset 15178 5f621aa35c25
child 15531 08c8dad8e399
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Matrix/eq_codegen.ML	Fri Sep 03 17:10:36 2004 +0200
     1.3 @@ -0,0 +1,493 @@
     1.4 +fun inst_cterm inst ct = fst (Drule.dest_equals
     1.5 +  (Thm.cprop_of (Thm.instantiate inst (reflexive ct))));
     1.6 +fun tyinst_cterm tyinst = inst_cterm (tyinst, []);
     1.7 +
     1.8 +val bla = ref ([] : term list);
     1.9 +
    1.10 +(******************************************************)
    1.11 +(*        Code generator for equational proofs        *)
    1.12 +(******************************************************)
    1.13 +fun my_mk_meta_eq thm =
    1.14 +  let
    1.15 +    val (_, eq) = Thm.dest_comb (cprop_of thm);
    1.16 +    val (ct, rhs) = Thm.dest_comb eq;
    1.17 +    val (_, lhs) = Thm.dest_comb ct
    1.18 +  in Thm.implies_elim (Drule.instantiate' [Some (ctyp_of_term lhs)]
    1.19 +    [Some lhs, Some rhs] eq_reflection) thm
    1.20 +  end; 
    1.21 +
    1.22 +structure SimprocsCodegen =
    1.23 +struct
    1.24 +
    1.25 +val simp_thms = ref ([] : thm list);
    1.26 +
    1.27 +fun parens b = if b then Pretty.enclose "(" ")" else Pretty.block;
    1.28 +
    1.29 +fun gen_mk_val f xs ps = Pretty.block ([Pretty.str "val ",
    1.30 +  f (length xs > 1) (flat
    1.31 +    (separate [Pretty.str ",", Pretty.brk 1] (map (single o Pretty.str) xs))),
    1.32 +  Pretty.str " =", Pretty.brk 1] @ ps @ [Pretty.str ";"]);
    1.33 +
    1.34 +val mk_val = gen_mk_val parens;
    1.35 +val mk_vall = gen_mk_val (K (Pretty.enclose "[" "]"));
    1.36 +
    1.37 +fun rename s = if s mem ThmDatabase.ml_reserved then s ^ "'" else s;
    1.38 +
    1.39 +fun mk_decomp_name (Var ((s, i), _)) = rename (if i=0 then s else s ^ string_of_int i)
    1.40 +  | mk_decomp_name (Const (s, _)) = rename (Codegen.mk_id (Sign.base_name s))
    1.41 +  | mk_decomp_name _ = "ct";
    1.42 +
    1.43 +fun decomp_term_code cn ((vs, bs, ps), (v, t)) =
    1.44 +  if exists (equal t o fst) bs then (vs, bs, ps)
    1.45 +  else (case t of
    1.46 +      Var _ => (vs, bs @ [(t, v)], ps)
    1.47 +    | Const _ => (vs, if cn then bs @ [(t, v)] else bs, ps)
    1.48 +    | Bound _ => (vs, bs, ps)
    1.49 +    | Abs (s, T, t) =>
    1.50 +      let
    1.51 +        val v1 = variant vs s;
    1.52 +        val v2 = variant (v1 :: vs) (mk_decomp_name t)
    1.53 +      in
    1.54 +        decomp_term_code cn ((v1 :: v2 :: vs,
    1.55 +          bs @ [(Free (s, T), v1)],
    1.56 +          ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1,
    1.57 +            Pretty.str "None", Pretty.brk 1, Pretty.str v]]), (v2, t))
    1.58 +      end
    1.59 +    | t $ u =>
    1.60 +      let
    1.61 +        val v1 = variant vs (mk_decomp_name t);
    1.62 +        val v2 = variant (v1 :: vs) (mk_decomp_name u);
    1.63 +        val (vs', bs', ps') = decomp_term_code cn ((v1 :: v2 :: vs, bs,
    1.64 +          ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_comb", Pretty.brk 1,
    1.65 +            Pretty.str v]]), (v1, t));
    1.66 +        val (vs'', bs'', ps'') = decomp_term_code cn ((vs', bs', ps'), (v2, u))
    1.67 +      in
    1.68 +        if bs'' = bs then (vs, bs, ps) else (vs'', bs'', ps'')
    1.69 +      end);
    1.70 +
    1.71 +val strip_tv = implode o tl o explode;
    1.72 +
    1.73 +fun mk_decomp_tname (TVar ((s, i), _)) =
    1.74 +      strip_tv ((if i=0 then s else s ^ string_of_int i) ^ "T")
    1.75 +  | mk_decomp_tname (Type (s, _)) = Codegen.mk_id (Sign.base_name s) ^ "T"
    1.76 +  | mk_decomp_tname _ = "cT";
    1.77 +
    1.78 +fun decomp_type_code ((vs, bs, ps), (v, TVar (ixn, _))) =
    1.79 +      if exists (equal ixn o fst) bs then (vs, bs, ps)
    1.80 +      else (vs, bs @ [(ixn, v)], ps)
    1.81 +  | decomp_type_code ((vs, bs, ps), (v, Type (_, Ts))) =
    1.82 +      let
    1.83 +        val vs' = variantlist (map mk_decomp_tname Ts, vs);
    1.84 +        val (vs'', bs', ps') =
    1.85 +          foldl decomp_type_code ((vs @ vs', bs, ps @
    1.86 +            [mk_vall vs' [Pretty.str "Thm.dest_ctyp", Pretty.brk 1,
    1.87 +              Pretty.str v]]), vs' ~~ Ts)
    1.88 +      in
    1.89 +        if bs' = bs then (vs, bs, ps) else (vs'', bs', ps')
    1.90 +      end;
    1.91 +
    1.92 +fun gen_mk_bindings s dest decomp ((vs, bs, ps), (v, x)) =
    1.93 +  let
    1.94 +    val s' = variant vs s;
    1.95 +    val (vs', bs', ps') = decomp ((s' :: vs, bs, ps @
    1.96 +      [mk_val [s'] (dest v)]), (s', x))
    1.97 +  in
    1.98 +    if bs' = bs then (vs, bs, ps) else (vs', bs', ps')
    1.99 +  end;
   1.100 +
   1.101 +val mk_term_bindings = gen_mk_bindings "ct"
   1.102 +  (fn s => [Pretty.str "cprop_of", Pretty.brk 1, Pretty.str s])
   1.103 +  (decomp_term_code true);
   1.104 +
   1.105 +val mk_type_bindings = gen_mk_bindings "cT"
   1.106 +  (fn s => [Pretty.str "Thm.ctyp_of_term", Pretty.brk 1, Pretty.str s])
   1.107 +  decomp_type_code;
   1.108 +
   1.109 +fun pretty_pattern b (Const (s, _)) = Pretty.block [Pretty.str "Const",
   1.110 +      Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\", _)")]
   1.111 +  | pretty_pattern b (t as _ $ _) = parens b
   1.112 +      (flat (separate [Pretty.str " $", Pretty.brk 1]
   1.113 +        (map (single o pretty_pattern true) (op :: (strip_comb t)))))
   1.114 +  | pretty_pattern b _ = Pretty.str "_";
   1.115 +
   1.116 +fun term_consts' t = foldl_aterms
   1.117 +  (fn (cs, c as Const _) => c ins cs | (cs, _) => cs) ([], t);
   1.118 +
   1.119 +fun mk_apps s b p [] = p
   1.120 +  | mk_apps s b p (q :: qs) = 
   1.121 +      mk_apps s b (parens (b orelse not (null qs))
   1.122 +        [Pretty.str s, Pretty.brk 1, p, Pretty.brk 1, q]) qs;
   1.123 +
   1.124 +fun mk_refleq eq ct = mk_val [eq] [Pretty.str ("Thm.reflexive " ^ ct)];
   1.125 +
   1.126 +fun mk_tyinst ((s, i), s') =
   1.127 +  Pretty.block [Pretty.str ("((" ^ quote s ^ ","), Pretty.brk 1,
   1.128 +    Pretty.str (string_of_int i ^ "),"), Pretty.brk 1,
   1.129 +    Pretty.str (s' ^ ")")];
   1.130 +
   1.131 +fun inst_ty b ty_bs t s = (case term_tvars t of
   1.132 +    [] => Pretty.str s
   1.133 +  | Ts => parens b [Pretty.str "tyinst_cterm", Pretty.brk 1,
   1.134 +      Pretty.list "[" "]" (map (fn (ixn, _) => mk_tyinst
   1.135 +        (ixn, the (assoc (ty_bs, ixn)))) Ts),
   1.136 +      Pretty.brk 1, Pretty.str s]);
   1.137 +
   1.138 +fun mk_cterm_code b ty_bs ts xs (vals, t $ u) =
   1.139 +      let
   1.140 +        val (vals', p1) = mk_cterm_code true ty_bs ts xs (vals, t);
   1.141 +        val (vals'', p2) = mk_cterm_code true ty_bs ts xs (vals', u)
   1.142 +      in
   1.143 +        (vals'', parens b [Pretty.str "Thm.capply", Pretty.brk 1,
   1.144 +          p1, Pretty.brk 1, p2])
   1.145 +      end
   1.146 +  | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) =
   1.147 +      let
   1.148 +        val u = Free (s, T);
   1.149 +        val Some s' = assoc (ts, u);
   1.150 +        val p = Pretty.str s';
   1.151 +        val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs)
   1.152 +          (if null (typ_tvars T) then vals
   1.153 +           else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t)
   1.154 +      in (vals',
   1.155 +        parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p'])
   1.156 +      end
   1.157 +  | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs))
   1.158 +  | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of
   1.159 +        None =>
   1.160 +          let val Some s = assoc (ts, t)
   1.161 +          in (if is_Const t andalso not (null (term_tvars t)) then
   1.162 +              vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))]
   1.163 +            else vals, Pretty.str s)
   1.164 +          end
   1.165 +      | Some ((_, s), _) => (vals, Pretty.str s));
   1.166 +
   1.167 +fun get_cases sg =
   1.168 +  Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new
   1.169 +    ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop
   1.170 +      (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab))
   1.171 +        (Symtab.empty, DatatypePackage.get_datatypes_sg sg);
   1.172 +
   1.173 +fun decomp_case th =
   1.174 +  let
   1.175 +    val (lhs, _) = Logic.dest_equals (prop_of th);
   1.176 +    val (f, ts) = strip_comb lhs;
   1.177 +    val (us, u) = split_last ts;
   1.178 +    val (Const (s, _), vs) = strip_comb u
   1.179 +  in (us, s, vs, u) end;
   1.180 +
   1.181 +fun rename vs t =
   1.182 +  let
   1.183 +    fun mk_subst ((vs, subs), Var ((s, i), T)) =
   1.184 +      let val s' = variant vs s
   1.185 +      in if s = s' then (vs, subs)
   1.186 +        else (s' :: vs, ((s, i), Var ((s', i), T)) :: subs)
   1.187 +      end;
   1.188 +    val (vs', subs) = foldl mk_subst ((vs, []), term_vars t)
   1.189 +  in (vs', subst_Vars subs t) end;
   1.190 +
   1.191 +fun is_instance sg t u = t = subst_TVars_Vartab
   1.192 +  (Type.typ_match (Sign.tsig_of sg) (Vartab.empty,
   1.193 +    (fastype_of u, fastype_of t))) u handle Type.TYPE_MATCH => false;
   1.194 +
   1.195 +(*
   1.196 +fun lookup sg fs t = apsome snd (Library.find_first
   1.197 +  (is_instance sg t o fst) fs);
   1.198 +*)
   1.199 +
   1.200 +fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of
   1.201 +    None => (bla := (t ins !bla); None)
   1.202 +  | Some (_, x) => Some x);
   1.203 +
   1.204 +fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t);
   1.205 +
   1.206 +fun mk_let s i xs ys =
   1.207 +  Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)),
   1.208 +    Pretty.fbrk,
   1.209 +    Pretty.blk (i, ([Pretty.str "in", Pretty.fbrk] @ ys)),
   1.210 +    Pretty.fbrk, Pretty.str "end"]);
   1.211 +
   1.212 +(*****************************************************************************)
   1.213 +(* Generate bindings for simplifying term t                                  *)
   1.214 +(* mkeq: whether to generate reflexivity theorem for uninterpreted terms     *)
   1.215 +(* fs:   interpreted functions                                               *)
   1.216 +(* ts:   atomic terms                                                        *)
   1.217 +(* vs:   used identifiers                                                    *)
   1.218 +(* vals: list of bindings of the form ((eq, ct), ps) where                   *)
   1.219 +(*       eq: name of equational theorem                                      *)
   1.220 +(*       ct: name of simplified cterm                                        *)
   1.221 +(*       ps: ML code for creating the above two items                        *)
   1.222 +(*****************************************************************************)
   1.223 +
   1.224 +fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) =
   1.225 +  (case assoc (vals, t) of
   1.226 +    Some ((eq, ct), ps) =>  (* binding already generated *) 
   1.227 +      if mkeq andalso eq="" then
   1.228 +        let val eq' = variant vs "eq"
   1.229 +        in ((eq' :: vs, overwrite (vals,
   1.230 +          (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct))
   1.231 +        end
   1.232 +      else ((vs, vals), (eq, ct))
   1.233 +  | None => (case assoc (ts, t) of
   1.234 +      Some v =>  (* atomic term *)
   1.235 +        let val xs = if not (null (term_tvars t)) andalso is_Const t then
   1.236 +          [mk_val [v] [inst_ty false ty_bs t v]] else []
   1.237 +        in
   1.238 +          if mkeq then
   1.239 +            let val eq = variant vs "eq"
   1.240 +            in ((eq :: vs, vals @
   1.241 +              [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v))
   1.242 +            end
   1.243 +          else ((vs, if null xs then vals else vals @
   1.244 +            [(t, (("", v), xs))]), ("", v))
   1.245 +        end
   1.246 +    | None =>  (* complex term *)
   1.247 +        let val (f as Const (cname, _), us) = strip_comb t
   1.248 +        in case Symtab.lookup (case_tab, cname) of
   1.249 +            Some cases =>  (* case expression *)
   1.250 +              let
   1.251 +                val (us', u) = split_last us;
   1.252 +                val b = unint sg fs u;
   1.253 +                val ((vs1, vals1), (eq, ct)) =
   1.254 +                  mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u);
   1.255 +                val xs = variantlist (replicate (length us') "f", vs1);
   1.256 +                val (vals2, ps) = foldl_map
   1.257 +                  (mk_cterm_code false ty_bs ts []) (vals1, us');
   1.258 +                val fvals = map (fn (x, p) => mk_val [x] [p]) (xs ~~ ps);
   1.259 +                val uT = fastype_of u;
   1.260 +                val (us'', _, _, u') = decomp_case (hd cases);
   1.261 +                val (vs2, ty_bs', ty_vals) = mk_type_bindings
   1.262 +                  (mk_type_bindings ((vs1 @ xs, [], []),
   1.263 +                    (hd xs, fastype_of (hd us''))), (ct, fastype_of u'));
   1.264 +                val insts1 = map mk_tyinst ty_bs';
   1.265 +                val i = length vals2;
   1.266 +   
   1.267 +                fun mk_case_code ((vs, vals), (f, (name, eqn))) =
   1.268 +                  let
   1.269 +                    val (fvs, cname, cvs, _) = decomp_case eqn;
   1.270 +                    val Ts = binder_types (fastype_of f);
   1.271 +                    val ys = variantlist (map (fst o fst o dest_Var) cvs, vs);
   1.272 +                    val cvs' = map Var (map (rpair 0) ys ~~ Ts);
   1.273 +                    val rs = cvs' ~~ cvs;
   1.274 +                    val lhs = list_comb (Const (cname, Ts ---> uT), cvs');
   1.275 +                    val rhs = foldl betapply (f, cvs');
   1.276 +                    val (vs', tm_bs, tm_vals) = decomp_term_code false
   1.277 +                      ((vs @ ys, [], []), (ct, lhs));
   1.278 +                    val ((vs'', all_vals), (eq', ct')) = mk_simpl_code sg case_tab
   1.279 +                      false fs (tm_bs @ ts) ty_bs thm_bs ((vs', vals), rhs);
   1.280 +                    val (old_vals, eq_vals) = splitAt (i, all_vals);
   1.281 +                    val vs''' = vs @ filter (fn v => exists
   1.282 +                      (fn (_, ((v', _), _)) => v = v') old_vals) (vs'' \\ vs');
   1.283 +                    val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
   1.284 +                      inst_ty false ty_bs' t (the (assoc (thm_bs, t))), Pretty.str ",",
   1.285 +                      Pretty.brk 1, Pretty.str (s ^ ")")]) ((fvs ~~ xs) @
   1.286 +                        (map (fn (v, s) => (the (assoc (rs, v)), s)) tm_bs));
   1.287 +                    val eq'' = if null insts1 andalso null insts2 then Pretty.str name
   1.288 +                      else parens (eq' <> "") [Pretty.str
   1.289 +                          (if null cvs then "Thm.instantiate" else "Drule.instantiate"),
   1.290 +                        Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
   1.291 +                        Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
   1.292 +                        Pretty.str ")", Pretty.brk 1, Pretty.str name];
   1.293 +                    val eq''' = if eq' = "" then eq'' else
   1.294 +                      Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
   1.295 +                        eq'', Pretty.brk 1, Pretty.str eq']
   1.296 +                  in
   1.297 +                    ((vs''', old_vals), Pretty.block [pretty_pattern false lhs,
   1.298 +                      Pretty.str " =>",
   1.299 +                      Pretty.brk 1, mk_let "let" 2 (tm_vals @ flat (map (snd o snd) eq_vals))
   1.300 +                        [Pretty.str ("(" ^ ct' ^ ","), Pretty.brk 1, eq''', Pretty.str ")"]])
   1.301 +                  end;
   1.302 +
   1.303 +                val case_names = map (fn i => Sign.base_name cname ^ "_" ^
   1.304 +                  string_of_int i) (1 upto length cases);
   1.305 +                val ((vs3, vals3), case_ps) = foldl_map mk_case_code
   1.306 +                  ((vs2, vals2), us' ~~ (case_names ~~ cases));
   1.307 +                val eq' = variant vs3 "eq";
   1.308 +                val ct' = variant (eq' :: vs3) "ct";
   1.309 +                val eq'' = variant (eq' :: ct' :: vs3) "eq";
   1.310 +                val case_vals =
   1.311 +                  fvals @ ty_vals @
   1.312 +                  [mk_val [ct', eq'] ([Pretty.str "(case", Pretty.brk 1,
   1.313 +                    Pretty.str ("term_of " ^ ct ^ " of"), Pretty.brk 1] @
   1.314 +                    flat (separate [Pretty.brk 1, Pretty.str "| "]
   1.315 +                      (map single case_ps)) @ [Pretty.str ")"])]
   1.316 +              in
   1.317 +                if b then
   1.318 +                  ((eq' :: ct' :: vs3, vals3 @
   1.319 +                     [(t, ((eq', ct'), case_vals))]), (eq', ct'))
   1.320 +                else
   1.321 +                  let val ((vs4, vals4), (_, ctcase)) = mk_simpl_code sg case_tab false
   1.322 +                    fs ts ty_bs thm_bs ((eq' :: eq'' :: ct' :: vs3, vals3), f)
   1.323 +                  in
   1.324 +                    ((vs4, vals4 @ [(t, ((eq'', ct'), case_vals @
   1.325 +                       [mk_val [eq''] [Pretty.str "Thm.transitive", Pretty.brk 1,
   1.326 +                          Pretty.str "(Thm.combination", Pretty.brk 1,
   1.327 +                          Pretty.str "(Thm.reflexive", Pretty.brk 1,
   1.328 +                          mk_apps "Thm.capply" true (Pretty.str ctcase)
   1.329 +                            (map Pretty.str xs),
   1.330 +                          Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"),
   1.331 +                          Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct'))
   1.332 +                  end
   1.333 +              end
   1.334 +          
   1.335 +          | None =>
   1.336 +            let
   1.337 +              val b = forall (unint sg fs) us;
   1.338 +              val (q, eqs) = foldl_map
   1.339 +                (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us);
   1.340 +              val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b
   1.341 +                then (q, ("", ""))
   1.342 +                else mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs (q, f);
   1.343 +              val ct = variant vs' "ct";
   1.344 +              val eq = variant (ct :: vs') "eq";
   1.345 +              val ctv = mk_val [ct] [mk_apps "Thm.capply" false
   1.346 +                (Pretty.str ctf) (map (Pretty.str o snd) eqs)];
   1.347 +              fun combp b = mk_apps "Thm.combination" b
   1.348 +                (Pretty.str eqf) (map (Pretty.str o fst) eqs)
   1.349 +            in
   1.350 +              case (lookup sg fs f, b) of
   1.351 +                (None, true) =>  (* completely uninterpreted *)
   1.352 +                  if mkeq then ((ct :: eq :: vs', vals' @
   1.353 +                    [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct))
   1.354 +                  else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct))
   1.355 +              | (None, false) =>  (* function uninterpreted *)
   1.356 +                  ((eq :: ct :: vs', vals' @
   1.357 +                     [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct))
   1.358 +              | (Some (s, _, _), true) =>  (* arguments uninterpreted *)
   1.359 +                  ((eq :: ct :: vs', vals' @
   1.360 +                     [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1)
   1.361 +                       (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct))
   1.362 +              | (Some (s, _, _), false) =>  (* function and arguments interpreted *)
   1.363 +                  let val eq' = variant (eq :: ct :: vs') "eq"
   1.364 +                  in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct),
   1.365 +                    [mk_val [ct, eq] (separate (Pretty.brk 1)
   1.366 +                       (Pretty.str s :: map (Pretty.str o snd) eqs)),
   1.367 +                     mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1,
   1.368 +                       combp true, Pretty.brk 1, Pretty.str eq]]))]), (eq', ct))
   1.369 +                  end
   1.370 +            end
   1.371 +        end));
   1.372 +
   1.373 +fun lhs_of thm = fst (Logic.dest_equals (prop_of thm));
   1.374 +fun rhs_of thm = snd (Logic.dest_equals (prop_of thm));
   1.375 +
   1.376 +fun mk_funs_code sg case_tab fs fs' =
   1.377 +  let
   1.378 +    val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of
   1.379 +        None => None
   1.380 +      | Some thms => Some (unsuffix "_case" (Sign.base_name s) ^ ".cases",
   1.381 +          map (fn i => Sign.base_name s ^ "_" ^ string_of_int i)
   1.382 +            (1 upto length thms) ~~ thms)))
   1.383 +      (foldr add_term_consts (map (prop_of o snd)
   1.384 +        (flat (map (#3 o snd) fs')), []));
   1.385 +    val case_vals = map (fn (s, cs) => mk_vall (map fst cs)
   1.386 +      [Pretty.str "map my_mk_meta_eq", Pretty.brk 1,
   1.387 +       Pretty.str ("(thms \"" ^ s ^ "\")")]) case_thms;
   1.388 +    val (vs, thm_bs, thm_vals) = foldl mk_term_bindings (([], [], []),
   1.389 +      flat (map (map (apsnd prop_of) o #3 o snd) fs') @
   1.390 +      map (apsnd prop_of) (flat (map snd case_thms)));
   1.391 +
   1.392 +    fun mk_fun_code (prfx, (fname, d, eqns)) =
   1.393 +      let
   1.394 +        val (f, ts) = strip_comb (lhs_of (snd (hd eqns)));
   1.395 +        val args = variantlist (replicate (length ts) "ct", vs);
   1.396 +        val (vs', ty_bs, ty_vals) = foldl mk_type_bindings
   1.397 +          ((vs @ args, [], []), args ~~ map fastype_of ts);
   1.398 +        val insts1 = map mk_tyinst ty_bs;
   1.399 +
   1.400 +        fun mk_eqn_code (name, eqn) =
   1.401 +          let
   1.402 +            val (_, argts) = strip_comb (lhs_of eqn);
   1.403 +            val (vs'', tm_bs, tm_vals) = foldl (decomp_term_code false)
   1.404 +              ((vs', [], []), args ~~ argts);
   1.405 +            val ((vs''', eq_vals), (eq, ct)) = mk_simpl_code sg case_tab false fs
   1.406 +              (tm_bs @ filter_out (is_Var o fst) thm_bs) ty_bs thm_bs
   1.407 +              ((vs'', []), rhs_of eqn);
   1.408 +            val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
   1.409 +              inst_ty false ty_bs t (the (assoc (thm_bs, t))), Pretty.str ",", Pretty.brk 1,
   1.410 +              Pretty.str (s ^ ")")]) tm_bs
   1.411 +            val eq' = if null insts1 andalso null insts2 then Pretty.str name
   1.412 +              else parens (eq <> "") [Pretty.str "Thm.instantiate",
   1.413 +                Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
   1.414 +                Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
   1.415 +                Pretty.str ")", Pretty.brk 1, Pretty.str name];
   1.416 +            val eq'' = if eq = "" then eq' else
   1.417 +              Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
   1.418 +                eq', Pretty.brk 1, Pretty.str eq]
   1.419 +          in
   1.420 +            Pretty.block [parens (length argts > 1)
   1.421 +                (Pretty.commas (map (pretty_pattern false) argts)),
   1.422 +              Pretty.str " =>",
   1.423 +              Pretty.brk 1, mk_let "let" 2 (ty_vals @ tm_vals @ flat (map (snd o snd) eq_vals))
   1.424 +                [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]]
   1.425 +          end;
   1.426 +
   1.427 +        val default = if d then
   1.428 +            let
   1.429 +              val Some s = assoc (thm_bs, f);
   1.430 +              val ct = variant vs' "ct"
   1.431 +            in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1,
   1.432 +              Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2
   1.433 +                (ty_vals @ (if null (term_tvars f) then [] else
   1.434 +                   [mk_val [s] [inst_ty false ty_bs f s]]) @
   1.435 +                 [mk_val [ct] [mk_apps "Thm.capply" false (Pretty.str s)
   1.436 +                    (map Pretty.str args)]])
   1.437 +                [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1,
   1.438 +                 Pretty.str "Thm.reflexive", Pretty.brk 1, Pretty.str (ct ^ ")")]]
   1.439 +            end
   1.440 +          else []
   1.441 +      in
   1.442 +        ("and ", Pretty.block (separate (Pretty.brk 1)
   1.443 +            (Pretty.str (prfx ^ fname) :: map Pretty.str args) @
   1.444 +          [Pretty.str " =", Pretty.brk 1, Pretty.str "(case", Pretty.brk 1,
   1.445 +           Pretty.list "(" ")" (map (fn s => Pretty.str ("term_of " ^ s)) args),
   1.446 +           Pretty.str " of", Pretty.brk 1] @
   1.447 +          flat (separate [Pretty.brk 1, Pretty.str "| "]
   1.448 +            (map (single o mk_eqn_code) eqns)) @ [Pretty.str ")"] @ default))
   1.449 +      end;
   1.450 +
   1.451 +    val (_, decls) = foldl_map mk_fun_code ("fun ", map snd fs')
   1.452 +  in
   1.453 +    mk_let "local" 2 (case_vals @ thm_vals) (separate Pretty.fbrk decls)
   1.454 +  end;
   1.455 +
   1.456 +fun mk_simprocs_code sg eqns =
   1.457 +  let
   1.458 +    val case_tab = get_cases sg;
   1.459 +    fun get_head th = head_of (fst (Logic.dest_equals (prop_of th)));
   1.460 +    fun attach_term (x as (_, _, (_, th) :: _)) = (get_head th, x);
   1.461 +    val eqns' = map attach_term eqns;
   1.462 +    fun mk_node (s, _, (_, th) :: _) = (s, get_head th);
   1.463 +    fun mk_edges (s, _, ths) = map (pair s) (distinct
   1.464 +      (mapfilter (fn t => apsome #1 (lookup sg eqns' t))
   1.465 +        (flat (map (term_consts' o prop_of o snd) ths))));
   1.466 +    val gr = foldr (uncurry Graph.add_edge)
   1.467 +      (map (pair "" o #1) eqns @ flat (map mk_edges eqns),
   1.468 +       foldr (uncurry Graph.new_node)
   1.469 +         (("", Bound 0) :: map mk_node eqns, Graph.empty));
   1.470 +    val keys = rev (Graph.all_succs gr [""] \ "");
   1.471 +    fun gr_ord (x :: _, y :: _) =
   1.472 +      int_ord (find_index (equal x) keys, find_index (equal y) keys);
   1.473 +    val scc = map (fn xs => filter (fn (_, (s, _, _)) => s mem xs) eqns')
   1.474 +      (sort gr_ord (Graph.strong_conn gr \ [""]));
   1.475 +  in
   1.476 +    flat (separate [Pretty.str ";", Pretty.fbrk, Pretty.str " ", Pretty.fbrk]
   1.477 +      (map (fn eqns'' => [mk_funs_code sg case_tab eqns' eqns'']) scc)) @
   1.478 +    [Pretty.str ";", Pretty.fbrk]
   1.479 +  end;
   1.480 +
   1.481 +fun use_simprocs_code sg eqns =
   1.482 +  let
   1.483 +    fun attach_name (i, x) = (i+1, ("simp_thm_" ^ string_of_int i, x));
   1.484 +    fun attach_names (i, (s, b, eqs)) =
   1.485 +      let val (i', eqs') = foldl_map attach_name (i, eqs)
   1.486 +      in (i', (s, b, eqs')) end;
   1.487 +    val (_, eqns') = foldl_map attach_names (1, eqns);
   1.488 +    val (names, thms) = split_list (flat (map #3 eqns'));
   1.489 +    val s = setmp print_mode [] Pretty.string_of
   1.490 +      (mk_let "local" 2 [mk_vall names [Pretty.str "!SimprocsCodegen.simp_thms"]]
   1.491 +        (mk_simprocs_code sg eqns'))
   1.492 +  in
   1.493 +    (simp_thms := thms; use_text Context.ml_output false s)
   1.494 +  end;
   1.495 +
   1.496 +end;