New code generators for HOL.
authorberghofe
Fri Aug 31 16:49:06 2001 +0200 (2001-08-31)
changeset 11537e007d35359c3
parent 11536 6adf4d532679
child 11538 f8588786cc9c
New code generators for HOL.
src/HOL/Tools/basic_codegen.ML
src/HOL/Tools/inductive_codegen.ML
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/basic_codegen.ML	Fri Aug 31 16:49:06 2001 +0200
     1.3 @@ -0,0 +1,249 @@
     1.4 +(*  Title:      Pure/HOL/basic_codegen.ML
     1.5 +    ID:         $Id$
     1.6 +    Author:     Stefan Berghofer
     1.7 +    Copyright   2000  TU Muenchen
     1.8 +
     1.9 +Code generator for inductive datatypes and recursive functions
    1.10 +*)
    1.11 +
    1.12 +signature BASIC_CODEGEN =
    1.13 +sig
    1.14 +  val setup: (theory -> theory) list
    1.15 +end;
    1.16 +
    1.17 +structure BasicCodegen : BASIC_CODEGEN =
    1.18 +struct
    1.19 +
    1.20 +open Codegen;
    1.21 +
    1.22 +fun mk_poly_id thy (s, T) = mk_const_id (sign_of thy) s ^
    1.23 +  (case get_defn thy s T of
    1.24 +     Some (_, Some i) => "_def" ^ string_of_int i
    1.25 +   | _ => "");
    1.26 +
    1.27 +fun mk_tuple [p] = p
    1.28 +  | mk_tuple ps = Pretty.block (Pretty.str "(" ::
    1.29 +      flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
    1.30 +        [Pretty.str ")"]);
    1.31 +
    1.32 +fun add_rec_funs thy dep (gr, eqs) =
    1.33 +  let
    1.34 +    fun dest_eq t =
    1.35 +      let val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop
    1.36 +            (Logic.strip_imp_concl (rename_term t)))
    1.37 +      in
    1.38 +        (mk_poly_id thy (dest_Const (head_of lhs)), (lhs, rhs))
    1.39 +      end;
    1.40 +    val eqs' = sort (string_ord o pairself fst) (map dest_eq eqs);
    1.41 +    val (dname, _) :: _ = eqs';
    1.42 +
    1.43 +    fun mk_fundef fname prfx gr [] = (gr, [])
    1.44 +      | mk_fundef fname prfx gr ((fname', (lhs, rhs))::xs) =
    1.45 +      let
    1.46 +        val (gr1, pl) = invoke_codegen thy gr dname false lhs;
    1.47 +        val (gr2, pr) = invoke_codegen thy gr1 dname false rhs;
    1.48 +        val (gr3, rest) = mk_fundef fname' "and " gr2 xs
    1.49 +      in
    1.50 +        (gr3, Pretty.blk (4, [Pretty.str (if fname=fname' then "  | " else prfx),
    1.51 +           pl, Pretty.str " =", Pretty.brk 1, pr]) :: rest)
    1.52 +      end
    1.53 +
    1.54 +  in
    1.55 +    (Graph.add_edge (dname, dep) gr handle Graph.UNDEF _ =>
    1.56 +       let
    1.57 +         val gr1 = Graph.add_edge (dname, dep)
    1.58 +           (Graph.new_node (dname, (None, "")) gr);
    1.59 +         val (gr2, fundef) = mk_fundef "" "fun " gr1 eqs'
    1.60 +       in
    1.61 +         Graph.map_node dname (K (None, Pretty.string_of (Pretty.blk (0,
    1.62 +           separate Pretty.fbrk fundef @ [Pretty.str ";"])) ^ "\n\n")) gr2
    1.63 +       end)
    1.64 +  end;
    1.65 +
    1.66 +
    1.67 +(**** generate functions for datatypes specified by descr ****)
    1.68 +(**** (i.e. constructors and case combinators)            ****)
    1.69 +
    1.70 +fun mk_typ _ _ (TVar ((s, i), _)) =
    1.71 +     Pretty.str (s ^ (if i=0 then "" else string_of_int i))
    1.72 +  | mk_typ _ _ (TFree (s, _)) = Pretty.str s
    1.73 +  | mk_typ sg types (Type ("fun", [T, U])) = Pretty.block [Pretty.str "(",
    1.74 +     mk_typ sg types T, Pretty.str " ->", Pretty.brk 1,
    1.75 +     mk_typ sg types U, Pretty.str ")"]
    1.76 +  | mk_typ sg types (Type (s, Ts)) = Pretty.block ((if null Ts then [] else
    1.77 +      [mk_tuple (map (mk_typ sg types) Ts), Pretty.str " "]) @
    1.78 +      [Pretty.str (if_none (assoc (types, s)) (mk_type_id sg s))]);
    1.79 +
    1.80 +fun add_dt_defs thy dep (gr, descr) =
    1.81 +  let
    1.82 +    val sg = sign_of thy;
    1.83 +    val tab = DatatypePackage.get_datatypes thy;
    1.84 +
    1.85 +    val descr' = filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
    1.86 +
    1.87 +    val (_, (_, _, (cname, _) :: _)) :: _ = descr';
    1.88 +    val dname = mk_const_id sg cname;
    1.89 +
    1.90 +    fun mk_dtdef gr prfx [] = (gr, [])
    1.91 +      | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
    1.92 +          let
    1.93 +            val types = get_assoc_types thy;
    1.94 +            val tvs = map DatatypeAux.dest_DtTFree dts;
    1.95 +            val sorts = map (rpair []) tvs;
    1.96 +            val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    1.97 +            val tycons = foldr add_typ_tycons (flat (map snd cs'), []) \\
    1.98 +              ("fun" :: map fst types);
    1.99 +            val descrs = map (fn s => case Symtab.lookup (tab, s) of
   1.100 +                None => error ("Not a datatype: " ^ s ^ "\nrequired by:\n" ^
   1.101 +                  commas (Graph.all_succs gr [dep]))
   1.102 +              | Some info => #descr info) tycons;
   1.103 +            val gr' = foldl (add_dt_defs thy dname) (gr, descrs);
   1.104 +            val (gr'', rest) = mk_dtdef gr' "and " xs
   1.105 +          in
   1.106 +            (gr'',
   1.107 +             Pretty.block (Pretty.str prfx ::
   1.108 +               (if null tvs then [] else
   1.109 +                  [mk_tuple (map Pretty.str tvs), Pretty.str " "]) @
   1.110 +               [Pretty.str (mk_type_id sg tname ^ " ="), Pretty.brk 1] @
   1.111 +               flat (separate [Pretty.brk 1, Pretty.str "| "]
   1.112 +                 (map (fn (cname, cargs) => [Pretty.block
   1.113 +                   (Pretty.str (mk_const_id sg cname) ::
   1.114 +                    (if null cargs then [] else
   1.115 +                     flat ([Pretty.str " of", Pretty.brk 1] ::
   1.116 +                       separate [Pretty.str " *", Pretty.brk 1]
   1.117 +                         (map (single o mk_typ sg types) cargs))))]) cs'))) :: rest)
   1.118 +          end
   1.119 +  in
   1.120 +    ((Graph.add_edge_acyclic (dname, dep) gr
   1.121 +        handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
   1.122 +         let
   1.123 +           val gr1 = Graph.add_edge (dname, dep)
   1.124 +             (Graph.new_node (dname, (None, "")) gr);
   1.125 +           val (gr2, dtdef) = mk_dtdef gr1 "datatype " descr';
   1.126 +         in
   1.127 +           Graph.map_node dname (K (None,
   1.128 +             Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
   1.129 +               [Pretty.str ";"])) ^ "\n\n")) gr2
   1.130 +         end)
   1.131 +  end;
   1.132 +
   1.133 +
   1.134 +(**** generate code for applications of constructors and case ****)
   1.135 +(**** combinators for datatypes                               ****)
   1.136 +
   1.137 +fun pretty_case thy gr dep brack constrs (c as Const (_, T)) ts =
   1.138 +  let val i = length constrs
   1.139 +  in if length ts <= i then
   1.140 +       invoke_codegen thy gr dep brack (eta_expand c ts (i+1))
   1.141 +    else
   1.142 +      let
   1.143 +        val ts1 = take (i, ts);
   1.144 +        val t :: ts2 = drop (i, ts);
   1.145 +        val names = foldr add_term_names (ts1,
   1.146 +          map (fst o fst o dest_Var) (foldr add_term_vars (ts1, [])));
   1.147 +        val (Ts, dT) = split_last (take (i+1, fst (strip_type T)));
   1.148 +
   1.149 +        fun pcase gr [] [] [] = ([], gr)
   1.150 +          | pcase gr ((cname, cargs)::cs) (t::ts) (U::Us) =
   1.151 +              let
   1.152 +                val j = length cargs;
   1.153 +                val (Ts, _) = strip_type (fastype_of t);
   1.154 +                val xs = variantlist (replicate j "x", names);
   1.155 +                val Us' = take (j, fst (strip_type U));
   1.156 +                val frees = map Free (xs ~~ Us');
   1.157 +                val (gr0, cp) = invoke_codegen thy gr dep false
   1.158 +                  (list_comb (Const (cname, Us' ---> dT), frees));
   1.159 +                val t' = Envir.beta_norm (list_comb (t, frees));
   1.160 +                val (gr1, p) = invoke_codegen thy gr0 dep false t';
   1.161 +                val (ps, gr2) = pcase gr1 cs ts Us;
   1.162 +              in
   1.163 +                ([Pretty.block [cp, Pretty.str " =>", Pretty.brk 1, p]] :: ps, gr2)
   1.164 +              end;
   1.165 +
   1.166 +        val (ps1, gr1) = pcase gr constrs ts1 Ts;
   1.167 +        val ps = flat (separate [Pretty.brk 1, Pretty.str "| "] ps1);
   1.168 +        val (gr2, p) = invoke_codegen thy gr1 dep false t;
   1.169 +        val (gr3, ps2) = foldl_map
   1.170 +         (fn (gr, t) => invoke_codegen thy gr dep true t) (gr2, ts2)
   1.171 +      in (gr3, (if not (null ts2) andalso brack then parens else I)
   1.172 +        (Pretty.block (separate (Pretty.brk 1)
   1.173 +          (Pretty.block ([Pretty.str "(case ", p, Pretty.str " of",
   1.174 +             Pretty.brk 1] @ ps @ [Pretty.str ")"]) :: ps2))))
   1.175 +      end
   1.176 +  end;
   1.177 +
   1.178 +
   1.179 +fun pretty_constr thy gr dep brack args (c as Const (s, _)) ts =
   1.180 +  let val i = length args
   1.181 +  in if length ts < i then
   1.182 +      invoke_codegen thy gr dep brack (eta_expand c ts i)
   1.183 +     else
   1.184 +       let
   1.185 +         val id = mk_const_id (sign_of thy) s;
   1.186 +         val (gr', ps) = foldl_map
   1.187 +           (fn (gr, t) => invoke_codegen thy gr dep (i = 1) t) (gr, ts);
   1.188 +       in (case args of
   1.189 +          [] => (gr', Pretty.str id)
   1.190 +        | [_] => (gr', mk_app brack (Pretty.str id) ps)
   1.191 +        | _ => (gr', (if brack then parens else I) (Pretty.block
   1.192 +            ([Pretty.str id, Pretty.brk 1, Pretty.str "("] @
   1.193 +             flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
   1.194 +             [Pretty.str ")"]))))
   1.195 +       end
   1.196 +  end;
   1.197 +
   1.198 +
   1.199 +fun mk_recfun thy gr dep brack s T ts eqns =
   1.200 +  let val (gr', ps) = foldl_map
   1.201 +    (fn (gr, t) => invoke_codegen thy gr dep true t) (gr, ts)
   1.202 +  in
   1.203 +    Some (add_rec_funs thy dep (gr', map (#prop o rep_thm) eqns),
   1.204 +      mk_app brack (Pretty.str (mk_poly_id thy (s, T))) ps)
   1.205 +  end;
   1.206 +
   1.207 +
   1.208 +fun datatype_codegen thy gr dep brack t = (case strip_comb t of
   1.209 +   (c as Const (s, T), ts) =>
   1.210 +       (case find_first (fn (_, {index, descr, case_name, rec_names, ...}) =>
   1.211 +         s = case_name orelse s mem rec_names orelse
   1.212 +           is_some (assoc (#3 (the (assoc (descr, index))), s)))
   1.213 +             (Symtab.dest (DatatypePackage.get_datatypes thy)) of
   1.214 +          None => None
   1.215 +        | Some (tname, {index, descr, case_name, rec_names, rec_rewrites, ...}) =>
   1.216 +           if is_some (get_assoc_code thy s T) then None else
   1.217 +           let
   1.218 +             val Some (_, _, constrs) = assoc (descr, index);
   1.219 +             val gr1 =
   1.220 +              if exists (equal tname o fst) (get_assoc_types thy) then gr
   1.221 +              else add_dt_defs thy dep (gr, descr);
   1.222 +           in
   1.223 +             (case assoc (constrs, s) of
   1.224 +                None => if s mem rec_names then
   1.225 +                    mk_recfun thy gr1 dep brack s T ts rec_rewrites
   1.226 +                  else Some (pretty_case thy gr1 dep brack constrs c ts)
   1.227 +              | Some args => Some (pretty_constr thy gr1 dep brack args c ts))
   1.228 +           end)
   1.229 + |  _ => None);
   1.230 +
   1.231 +
   1.232 +(**** generate code for primrec and recdef ****)
   1.233 +
   1.234 +fun recfun_codegen thy gr dep brack t = (case strip_comb t of
   1.235 +    (Const (s, T), ts) =>
   1.236 +      (case PrimrecPackage.get_primrec thy s of
   1.237 +         Some ps => (case find_first (fn (_, thm::_) =>
   1.238 +               is_instance thy T (snd (dest_Const (head_of
   1.239 +                 (fst (HOLogic.dest_eq
   1.240 +                   (HOLogic.dest_Trueprop (#prop (rep_thm thm))))))))) ps of
   1.241 +             Some (_, thms) => mk_recfun thy gr dep brack s T ts thms
   1.242 +           | None => None)
   1.243 +       | None => case RecdefPackage.get_recdef thy s of
   1.244 +            Some {simps, ...} => mk_recfun thy gr dep brack s T ts simps
   1.245 +          | None => None)
   1.246 +  | _ => None);
   1.247 +
   1.248 +
   1.249 +val setup = [add_codegen "datatype" datatype_codegen,
   1.250 +             add_codegen "primrec+recdef" recfun_codegen];
   1.251 +
   1.252 +end;
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/src/HOL/Tools/inductive_codegen.ML	Fri Aug 31 16:49:06 2001 +0200
     2.3 @@ -0,0 +1,381 @@
     2.4 +(*  Title:      Pure/HOL/inductive_codegen.ML
     2.5 +    ID:         $Id$
     2.6 +    Author:     Stefan Berghofer
     2.7 +    Copyright   2000  TU Muenchen
     2.8 +
     2.9 +Code generator for inductive predicates
    2.10 +*)
    2.11 +
    2.12 +signature INDUCTIVE_CODEGEN =
    2.13 +sig
    2.14 +  val setup : (theory -> theory) list
    2.15 +end;
    2.16 +
    2.17 +structure InductiveCodegen : INDUCTIVE_CODEGEN =
    2.18 +struct
    2.19 +
    2.20 +open Codegen;
    2.21 +
    2.22 +exception Modes of (string * int list list) list * (string * int list list) list;
    2.23 +
    2.24 +datatype indprem = Prem of string * term list * term list
    2.25 +                 | Sidecond of term;
    2.26 +
    2.27 +fun prod_factors p (Const ("Pair", _) $ t $ u) =
    2.28 +      p :: prod_factors (1::p) t @ prod_factors (2::p) u
    2.29 +  | prod_factors p _ = [];
    2.30 +
    2.31 +fun split_prod p ps t = if p mem ps then (case t of
    2.32 +       Const ("Pair", _) $ t $ u =>
    2.33 +         split_prod (1::p) ps t @ split_prod (2::p) ps u
    2.34 +     | _ => error "Inconsistent use of products") else [t];
    2.35 +
    2.36 +fun string_of_factors p ps = if p mem ps then
    2.37 +    "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")"
    2.38 +  else "_";
    2.39 +
    2.40 +(**** check if a term contains only constructor functions ****)
    2.41 +
    2.42 +fun is_constrt thy =
    2.43 +  let
    2.44 +    val cnstrs = flat (flat (map
    2.45 +      (map (fn (_, (_, _, cs)) => map (apsnd length) cs) o #descr o snd)
    2.46 +      (Symtab.dest (DatatypePackage.get_datatypes thy))));
    2.47 +    fun check t = (case strip_comb t of
    2.48 +        (Var _, []) => true
    2.49 +      | (Const (s, _), ts) => (case assoc (cnstrs, s) of
    2.50 +            None => false
    2.51 +          | Some i => length ts = i andalso forall check ts)
    2.52 +      | _ => false)
    2.53 +  in check end;
    2.54 +
    2.55 +(**** check if a type is an equality type (i.e. doesn't contain fun) ****)
    2.56 +
    2.57 +fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
    2.58 +  | is_eqT _ = true;
    2.59 +
    2.60 +(**** mode inference ****)
    2.61 +
    2.62 +val term_vs = map (fst o fst o dest_Var) o term_vars;
    2.63 +val terms_vs = distinct o flat o (map term_vs);
    2.64 +
    2.65 +(** collect all Vars in a term (with duplicates!) **)
    2.66 +fun term_vTs t = map (apfst fst o dest_Var)
    2.67 +  (filter is_Var (foldl_aterms (op :: o Library.swap) ([], t)));
    2.68 +
    2.69 +fun known_args _ _ [] = []
    2.70 +  | known_args vs i (t::ts) = if term_vs t subset vs then i::known_args vs (i+1) ts
    2.71 +      else known_args vs (i+1) ts;
    2.72 +
    2.73 +fun get_args _ _ [] = ([], [])
    2.74 +  | get_args is i (x::xs) = (if i mem is then apfst else apsnd) (cons x)
    2.75 +      (get_args is (i+1) xs);
    2.76 +
    2.77 +fun merge xs [] = xs
    2.78 +  | merge [] ys = ys
    2.79 +  | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
    2.80 +      else y::merge (x::xs) ys;
    2.81 +
    2.82 +fun subsets i j = if i <= j then
    2.83 +       let val is = subsets (i+1) j
    2.84 +       in merge (map (fn ks => i::ks) is) is end
    2.85 +     else [[]];
    2.86 +
    2.87 +fun select_mode_prem thy modes vs ps =
    2.88 +  find_first (is_some o snd) (ps ~~ map
    2.89 +    (fn Prem (s, us, args) => find_first (fn is =>
    2.90 +          let
    2.91 +            val (_, out_ts) = get_args is 1 us;
    2.92 +            val vTs = flat (map term_vTs out_ts);
    2.93 +            val dupTs = map snd (duplicates vTs) @
    2.94 +              mapfilter (curry assoc vTs) vs;
    2.95 +          in
    2.96 +            is subset known_args vs 1 us andalso
    2.97 +            forall (is_constrt thy) (snd (get_args is 1 us)) andalso
    2.98 +            terms_vs args subset vs andalso
    2.99 +            forall is_eqT dupTs
   2.100 +          end)
   2.101 +            (the (assoc (modes, s)))
   2.102 +      | Sidecond t => if term_vs t subset vs then Some [] else None) ps);
   2.103 +
   2.104 +fun check_mode_clause thy arg_vs modes mode (ts, ps) =
   2.105 +  let
   2.106 +    fun check_mode_prems vs [] = Some vs
   2.107 +      | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of
   2.108 +          None => None
   2.109 +        | Some (x, _) => check_mode_prems
   2.110 +            (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs)
   2.111 +            (filter_out (equal x) ps));
   2.112 +    val (in_ts', _) = get_args mode 1 ts;
   2.113 +    val in_ts = filter (is_constrt thy) in_ts';
   2.114 +    val in_vs = terms_vs in_ts;
   2.115 +    val concl_vs = terms_vs ts
   2.116 +  in
   2.117 +    forall is_eqT (map snd (duplicates (flat (map term_vTs in_ts')))) andalso
   2.118 +    (case check_mode_prems (arg_vs union in_vs) ps of
   2.119 +       None => false
   2.120 +     | Some vs => concl_vs subset vs)
   2.121 +  end;
   2.122 +
   2.123 +fun check_modes_pred thy arg_vs preds modes (p, ms) =
   2.124 +  let val Some rs = assoc (preds, p)
   2.125 +  in (p, filter (fn m => forall (check_mode_clause thy arg_vs modes m) rs) ms) end
   2.126 +
   2.127 +fun fixp f x =
   2.128 +  let val y = f x
   2.129 +  in if x = y then x else fixp f y end;
   2.130 +
   2.131 +fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes =>
   2.132 +  map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
   2.133 +    (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds);
   2.134 +
   2.135 +(**** code generation ****)
   2.136 +
   2.137 +fun mk_eq (x::xs) =
   2.138 +  let fun mk_eqs _ [] = []
   2.139 +        | mk_eqs a (b::cs) = Pretty.str (a ^ " = " ^ b) :: mk_eqs b cs
   2.140 +  in mk_eqs x xs end;
   2.141 +
   2.142 +fun mk_tuple xs = Pretty.block (Pretty.str "(" ::
   2.143 +  flat (separate [Pretty.str ",", Pretty.brk 1] (map single xs)) @
   2.144 +  [Pretty.str ")"]);
   2.145 +
   2.146 +fun mk_v ((names, vs), s) = (case assoc (vs, s) of
   2.147 +      None => ((names, (s, [s])::vs), s)
   2.148 +    | Some xs => let val s' = variant names s in
   2.149 +        ((s'::names, overwrite (vs, (s, s'::xs))), s') end);
   2.150 +
   2.151 +fun distinct_v (nvs, Var ((s, 0), T)) =
   2.152 +      apsnd (Var o rpair T o rpair 0) (mk_v (nvs, s))
   2.153 +  | distinct_v (nvs, t $ u) =
   2.154 +      let
   2.155 +        val (nvs', t') = distinct_v (nvs, t);
   2.156 +        val (nvs'', u') = distinct_v (nvs', u);
   2.157 +      in (nvs'', t' $ u') end
   2.158 +  | distinct_v x = x;
   2.159 +
   2.160 +fun compile_match nvs eq_ps out_ps success_p fail_p =
   2.161 +  let val eqs = flat (separate [Pretty.str " andalso", Pretty.brk 1]
   2.162 +    (map single (flat (map (mk_eq o snd) nvs) @ eq_ps)));
   2.163 +  in
   2.164 +    Pretty.block
   2.165 +     ([Pretty.str "(fn ", mk_tuple out_ps, Pretty.str " =>", Pretty.brk 1] @
   2.166 +      (Pretty.block ((if eqs=[] then [] else Pretty.str "if " ::
   2.167 +         [Pretty.block eqs, Pretty.brk 1, Pretty.str "then "]) @
   2.168 +         (success_p ::
   2.169 +          (if eqs=[] then [] else [Pretty.brk 1, Pretty.str "else ", fail_p]))) ::
   2.170 +       [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"]))
   2.171 +  end;
   2.172 +
   2.173 +fun modename thy s mode = space_implode "_"
   2.174 +  (mk_const_id (sign_of thy) s :: map string_of_int mode);
   2.175 +
   2.176 +fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) =
   2.177 +  let
   2.178 +    fun check_constrt ((names, eqs), t) =
   2.179 +      if is_constrt thy t then ((names, eqs), t) else
   2.180 +        let val s = variant names "x";
   2.181 +        in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
   2.182 +
   2.183 +    val (in_ts, out_ts) = get_args mode 1 ts;
   2.184 +    val ((all_vs', eqs), in_ts') =
   2.185 +      foldl_map check_constrt ((all_vs, []), in_ts);
   2.186 +
   2.187 +    fun compile_prems out_ts' vs names gr [] =
   2.188 +          let
   2.189 +            val (gr2, out_ps) = foldl_map (fn (gr, t) =>
   2.190 +              invoke_codegen thy gr dep false t) (gr, out_ts);
   2.191 +            val (gr3, eq_ps) = foldl_map (fn (gr, (s, t)) =>
   2.192 +              apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single)
   2.193 +                (invoke_codegen thy gr dep false t)) (gr2, eqs);
   2.194 +            val (nvs, out_ts'') = foldl_map distinct_v
   2.195 +              ((names, map (fn x => (x, [x])) vs), out_ts');
   2.196 +            val (gr4, out_ps') = foldl_map (fn (gr, t) =>
   2.197 +              invoke_codegen thy gr dep false t) (gr3, out_ts'');
   2.198 +          in
   2.199 +            (gr4, compile_match (snd nvs) eq_ps out_ps'
   2.200 +              (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps])
   2.201 +              (Pretty.str "Seq.empty"))
   2.202 +          end
   2.203 +      | compile_prems out_ts vs names gr ps =
   2.204 +          let
   2.205 +            val vs' = distinct (flat (vs :: map term_vs out_ts));
   2.206 +            val Some (p, Some mode') =
   2.207 +              select_mode_prem thy modes (arg_vs union vs') ps;
   2.208 +            val ps' = filter_out (equal p) ps;
   2.209 +          in
   2.210 +            (case p of
   2.211 +               Prem (s, us, args) =>
   2.212 +                 let
   2.213 +                   val (in_ts, out_ts') = get_args mode' 1 us;
   2.214 +                   val (gr1, in_ps) = foldl_map (fn (gr, t) =>
   2.215 +                     invoke_codegen thy gr dep false t) (gr, in_ts);
   2.216 +                   val (gr2, arg_ps) = foldl_map (fn (gr, t) =>
   2.217 +                     invoke_codegen thy gr dep true t) (gr1, args);
   2.218 +                   val (nvs, out_ts'') = foldl_map distinct_v
   2.219 +                     ((names, map (fn x => (x, [x])) vs), out_ts);
   2.220 +                   val (gr3, out_ps) = foldl_map (fn (gr, t) =>
   2.221 +                     invoke_codegen thy gr dep false t) (gr2, out_ts'')
   2.222 +                   val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps';
   2.223 +                 in
   2.224 +                   (gr4, compile_match (snd nvs) [] out_ps
   2.225 +                      (Pretty.block (separate (Pretty.brk 1)
   2.226 +                        (Pretty.str (modename thy s mode') :: arg_ps) @
   2.227 +                         [Pretty.brk 1, mk_tuple in_ps,
   2.228 +                          Pretty.str " :->", Pretty.brk 1, rest]))
   2.229 +                      (Pretty.str "Seq.empty"))
   2.230 +                 end
   2.231 +             | Sidecond t =>
   2.232 +                 let
   2.233 +                   val (gr1, side_p) = invoke_codegen thy gr dep true t;
   2.234 +                   val (nvs, out_ts') = foldl_map distinct_v
   2.235 +                     ((names, map (fn x => (x, [x])) vs), out_ts);
   2.236 +                   val (gr2, out_ps) = foldl_map (fn (gr, t) =>
   2.237 +                     invoke_codegen thy gr dep false t) (gr1, out_ts')
   2.238 +                   val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
   2.239 +                 in
   2.240 +                   (gr3, compile_match (snd nvs) [] out_ps
   2.241 +                      (Pretty.block [Pretty.str "?? ", side_p,
   2.242 +                        Pretty.str " :->", Pretty.brk 1, rest])
   2.243 +                      (Pretty.str "Seq.empty"))
   2.244 +                 end)
   2.245 +          end;
   2.246 +
   2.247 +    val (gr', prem_p) = compile_prems in_ts' [] all_vs' gr ps;
   2.248 +  in
   2.249 +    (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p])
   2.250 +  end;
   2.251 +
   2.252 +fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode =
   2.253 +  let val (gr', cl_ps) = foldl_map (fn (gr, cl) =>
   2.254 +    compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls)
   2.255 +  in
   2.256 +    ((gr', "and "), Pretty.block
   2.257 +      ([Pretty.block (separate (Pretty.brk 1)
   2.258 +         (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @
   2.259 +         [Pretty.str " inp ="]),
   2.260 +        Pretty.brk 1] @
   2.261 +       flat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps))))
   2.262 +  end;
   2.263 +
   2.264 +fun compile_preds thy gr dep all_vs arg_vs modes preds =
   2.265 +  let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
   2.266 +    foldl_map (fn ((gr', prfx'), mode) =>
   2.267 +      compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode)
   2.268 +        ((gr, prfx), the (assoc (modes, s)))) ((gr, "fun "), preds)
   2.269 +  in
   2.270 +    (gr', space_implode "\n\n" (map Pretty.string_of (flat prs)) ^ ";\n\n")
   2.271 +  end;
   2.272 +
   2.273 +(**** processing of introduction rules ****)
   2.274 +
   2.275 +val string_of_mode = enclose "[" "]" o commas o map string_of_int;
   2.276 +
   2.277 +fun print_modes modes = message ("Inferred modes:\n" ^
   2.278 +  space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map
   2.279 +    string_of_mode ms)) modes));
   2.280 +
   2.281 +fun print_factors factors = message ("Factors:\n" ^
   2.282 +  space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors));
   2.283 +  
   2.284 +fun get_modes (Some (Modes x), _) = x
   2.285 +  | get_modes _ = ([], []);
   2.286 +
   2.287 +fun mk_ind_def thy gr dep names intrs =
   2.288 +  let val ids = map (mk_const_id (sign_of thy)) names
   2.289 +  in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
   2.290 +    let
   2.291 +      fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) =
   2.292 +            (case strip_comb u of
   2.293 +               (Const (name, _), args) =>
   2.294 +                  (case InductivePackage.get_inductive thy name of
   2.295 +                     None => (gr, Sidecond t')
   2.296 +                   | Some ({names=names', ...}, {intrs=intrs', ...}) =>
   2.297 +                       (if names = names' then gr
   2.298 +                          else mk_ind_def thy gr (hd ids) names' intrs',
   2.299 +                        Prem (name, split_prod []
   2.300 +                          (the (assoc (factors, name))) t, args)))
   2.301 +             | _ => (gr, Sidecond t'))
   2.302 +        | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) =
   2.303 +            (gr, Prem ("eq", [t, u], []))
   2.304 +        | process_prem factors (gr, _ $ t) = (gr, Sidecond t);
   2.305 +
   2.306 +      fun add_clause factors ((clauses, gr), intr) =
   2.307 +        let
   2.308 +          val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr;
   2.309 +          val (Const (name, _), args) = strip_comb u;
   2.310 +          val (gr', prems) = foldl_map (process_prem factors)
   2.311 +            (gr, Logic.strip_imp_prems intr);
   2.312 +        in
   2.313 +          (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @
   2.314 +             [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr')
   2.315 +        end;
   2.316 +
   2.317 +      fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) =
   2.318 +            (case strip_comb u of
   2.319 +               (Const (name, _), _) =>
   2.320 +                 let val f = prod_factors [] t
   2.321 +                 in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end
   2.322 +             | _ => fs)
   2.323 +        | add_prod_factors (fs, _) = fs;
   2.324 +
   2.325 +      val intrs' = map (rename_term o #prop o rep_thm o standard) intrs;
   2.326 +      val factors = foldl add_prod_factors ([], flat (map (fn t =>
   2.327 +        Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs'));
   2.328 +      val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep)
   2.329 +        (Graph.new_node (hd ids, (None, "")) gr)), intrs');
   2.330 +      val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs');
   2.331 +      val (_, args) = strip_comb u;
   2.332 +      val arg_vs = flat (map term_vs args);
   2.333 +      val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map
   2.334 +        (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids])));
   2.335 +      val modes = infer_modes thy extra_modes arg_vs clauses;
   2.336 +      val _ = print_modes modes;
   2.337 +      val _ = print_factors factors;
   2.338 +      val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs
   2.339 +        (modes @ extra_modes) clauses;
   2.340 +    in
   2.341 +      (Graph.map_node (hd ids) (K (Some (Modes (modes, factors)), s)) gr'')
   2.342 +    end      
   2.343 +  end;
   2.344 +
   2.345 +fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of
   2.346 +  (Const (s, _), args) => (case InductivePackage.get_inductive thy s of
   2.347 +       None => None
   2.348 +     | Some ({names, ...}, {intrs, ...}) =>
   2.349 +         let
   2.350 +          fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1)
   2.351 +            | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1)
   2.352 +            | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
   2.353 +
   2.354 +           val gr1 = mk_ind_def thy gr dep names intrs;
   2.355 +           val (modes, factors) = pairself flat (ListPair.unzip
   2.356 +             (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep])));
   2.357 +           val ts = split_prod [] (the (assoc (factors, s))) t;
   2.358 +           val (ts', mode) = if is_query then
   2.359 +               fst (foldl mk_mode ((([], []), 1), ts))
   2.360 +             else (ts, 1 upto length ts);
   2.361 +           val _ = if mode mem the (assoc (modes, s)) then () else
   2.362 +             error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode);
   2.363 +           val (gr2, in_ps) = foldl_map (fn (gr, t) =>
   2.364 +             invoke_codegen thy gr dep false t) (gr1, ts');
   2.365 +           val (gr3, arg_ps) = foldl_map (fn (gr, t) =>
   2.366 +             invoke_codegen thy gr dep true t) (gr2, args);
   2.367 +         in
   2.368 +           Some (gr3, Pretty.block (separate (Pretty.brk 1)
   2.369 +             (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps])))
   2.370 +         end)
   2.371 +  | _ => None);
   2.372 +
   2.373 +fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) =
   2.374 +      (case mk_ind_call thy gr dep t u false of
   2.375 +         None => None
   2.376 +       | Some (gr', call_p) => Some (gr', (if brack then parens else I)
   2.377 +           (Pretty.block [Pretty.str "nonempty (", call_p, Pretty.str ")"])))
   2.378 +  | inductive_codegen thy gr dep brack (Free ("query", _) $ (Const ("op :", _) $ t $ u)) =
   2.379 +      mk_ind_call thy gr dep t u true
   2.380 +  | inductive_codegen thy gr dep brack _ = None;
   2.381 +
   2.382 +val setup = [add_codegen "inductive" inductive_codegen];
   2.383 +
   2.384 +end;