src/HOL/Tools/inductive_codegen.ML
changeset 11537 e007d35359c3
child 11539 0f17da240450
equal deleted inserted replaced
11536:6adf4d532679 11537:e007d35359c3
       
     1 (*  Title:      Pure/HOL/inductive_codegen.ML
       
     2     ID:         $Id$
       
     3     Author:     Stefan Berghofer
       
     4     Copyright   2000  TU Muenchen
       
     5 
       
     6 Code generator for inductive predicates
       
     7 *)
       
     8 
       
     9 signature INDUCTIVE_CODEGEN =
       
    10 sig
       
    11   val setup : (theory -> theory) list
       
    12 end;
       
    13 
       
    14 structure InductiveCodegen : INDUCTIVE_CODEGEN =
       
    15 struct
       
    16 
       
    17 open Codegen;
       
    18 
       
    19 exception Modes of (string * int list list) list * (string * int list list) list;
       
    20 
       
    21 datatype indprem = Prem of string * term list * term list
       
    22                  | Sidecond of term;
       
    23 
       
    24 fun prod_factors p (Const ("Pair", _) $ t $ u) =
       
    25       p :: prod_factors (1::p) t @ prod_factors (2::p) u
       
    26   | prod_factors p _ = [];
       
    27 
       
    28 fun split_prod p ps t = if p mem ps then (case t of
       
    29        Const ("Pair", _) $ t $ u =>
       
    30          split_prod (1::p) ps t @ split_prod (2::p) ps u
       
    31      | _ => error "Inconsistent use of products") else [t];
       
    32 
       
    33 fun string_of_factors p ps = if p mem ps then
       
    34     "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")"
       
    35   else "_";
       
    36 
       
    37 (**** check if a term contains only constructor functions ****)
       
    38 
       
    39 fun is_constrt thy =
       
    40   let
       
    41     val cnstrs = flat (flat (map
       
    42       (map (fn (_, (_, _, cs)) => map (apsnd length) cs) o #descr o snd)
       
    43       (Symtab.dest (DatatypePackage.get_datatypes thy))));
       
    44     fun check t = (case strip_comb t of
       
    45         (Var _, []) => true
       
    46       | (Const (s, _), ts) => (case assoc (cnstrs, s) of
       
    47             None => false
       
    48           | Some i => length ts = i andalso forall check ts)
       
    49       | _ => false)
       
    50   in check end;
       
    51 
       
    52 (**** check if a type is an equality type (i.e. doesn't contain fun) ****)
       
    53 
       
    54 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
       
    55   | is_eqT _ = true;
       
    56 
       
    57 (**** mode inference ****)
       
    58 
       
    59 val term_vs = map (fst o fst o dest_Var) o term_vars;
       
    60 val terms_vs = distinct o flat o (map term_vs);
       
    61 
       
    62 (** collect all Vars in a term (with duplicates!) **)
       
    63 fun term_vTs t = map (apfst fst o dest_Var)
       
    64   (filter is_Var (foldl_aterms (op :: o Library.swap) ([], t)));
       
    65 
       
    66 fun known_args _ _ [] = []
       
    67   | known_args vs i (t::ts) = if term_vs t subset vs then i::known_args vs (i+1) ts
       
    68       else known_args vs (i+1) ts;
       
    69 
       
    70 fun get_args _ _ [] = ([], [])
       
    71   | get_args is i (x::xs) = (if i mem is then apfst else apsnd) (cons x)
       
    72       (get_args is (i+1) xs);
       
    73 
       
    74 fun merge xs [] = xs
       
    75   | merge [] ys = ys
       
    76   | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
       
    77       else y::merge (x::xs) ys;
       
    78 
       
    79 fun subsets i j = if i <= j then
       
    80        let val is = subsets (i+1) j
       
    81        in merge (map (fn ks => i::ks) is) is end
       
    82      else [[]];
       
    83 
       
    84 fun select_mode_prem thy modes vs ps =
       
    85   find_first (is_some o snd) (ps ~~ map
       
    86     (fn Prem (s, us, args) => find_first (fn is =>
       
    87           let
       
    88             val (_, out_ts) = get_args is 1 us;
       
    89             val vTs = flat (map term_vTs out_ts);
       
    90             val dupTs = map snd (duplicates vTs) @
       
    91               mapfilter (curry assoc vTs) vs;
       
    92           in
       
    93             is subset known_args vs 1 us andalso
       
    94             forall (is_constrt thy) (snd (get_args is 1 us)) andalso
       
    95             terms_vs args subset vs andalso
       
    96             forall is_eqT dupTs
       
    97           end)
       
    98             (the (assoc (modes, s)))
       
    99       | Sidecond t => if term_vs t subset vs then Some [] else None) ps);
       
   100 
       
   101 fun check_mode_clause thy arg_vs modes mode (ts, ps) =
       
   102   let
       
   103     fun check_mode_prems vs [] = Some vs
       
   104       | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of
       
   105           None => None
       
   106         | Some (x, _) => check_mode_prems
       
   107             (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs)
       
   108             (filter_out (equal x) ps));
       
   109     val (in_ts', _) = get_args mode 1 ts;
       
   110     val in_ts = filter (is_constrt thy) in_ts';
       
   111     val in_vs = terms_vs in_ts;
       
   112     val concl_vs = terms_vs ts
       
   113   in
       
   114     forall is_eqT (map snd (duplicates (flat (map term_vTs in_ts')))) andalso
       
   115     (case check_mode_prems (arg_vs union in_vs) ps of
       
   116        None => false
       
   117      | Some vs => concl_vs subset vs)
       
   118   end;
       
   119 
       
   120 fun check_modes_pred thy arg_vs preds modes (p, ms) =
       
   121   let val Some rs = assoc (preds, p)
       
   122   in (p, filter (fn m => forall (check_mode_clause thy arg_vs modes m) rs) ms) end
       
   123 
       
   124 fun fixp f x =
       
   125   let val y = f x
       
   126   in if x = y then x else fixp f y end;
       
   127 
       
   128 fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes =>
       
   129   map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
       
   130     (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds);
       
   131 
       
   132 (**** code generation ****)
       
   133 
       
   134 fun mk_eq (x::xs) =
       
   135   let fun mk_eqs _ [] = []
       
   136         | mk_eqs a (b::cs) = Pretty.str (a ^ " = " ^ b) :: mk_eqs b cs
       
   137   in mk_eqs x xs end;
       
   138 
       
   139 fun mk_tuple xs = Pretty.block (Pretty.str "(" ::
       
   140   flat (separate [Pretty.str ",", Pretty.brk 1] (map single xs)) @
       
   141   [Pretty.str ")"]);
       
   142 
       
   143 fun mk_v ((names, vs), s) = (case assoc (vs, s) of
       
   144       None => ((names, (s, [s])::vs), s)
       
   145     | Some xs => let val s' = variant names s in
       
   146         ((s'::names, overwrite (vs, (s, s'::xs))), s') end);
       
   147 
       
   148 fun distinct_v (nvs, Var ((s, 0), T)) =
       
   149       apsnd (Var o rpair T o rpair 0) (mk_v (nvs, s))
       
   150   | distinct_v (nvs, t $ u) =
       
   151       let
       
   152         val (nvs', t') = distinct_v (nvs, t);
       
   153         val (nvs'', u') = distinct_v (nvs', u);
       
   154       in (nvs'', t' $ u') end
       
   155   | distinct_v x = x;
       
   156 
       
   157 fun compile_match nvs eq_ps out_ps success_p fail_p =
       
   158   let val eqs = flat (separate [Pretty.str " andalso", Pretty.brk 1]
       
   159     (map single (flat (map (mk_eq o snd) nvs) @ eq_ps)));
       
   160   in
       
   161     Pretty.block
       
   162      ([Pretty.str "(fn ", mk_tuple out_ps, Pretty.str " =>", Pretty.brk 1] @
       
   163       (Pretty.block ((if eqs=[] then [] else Pretty.str "if " ::
       
   164          [Pretty.block eqs, Pretty.brk 1, Pretty.str "then "]) @
       
   165          (success_p ::
       
   166           (if eqs=[] then [] else [Pretty.brk 1, Pretty.str "else ", fail_p]))) ::
       
   167        [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"]))
       
   168   end;
       
   169 
       
   170 fun modename thy s mode = space_implode "_"
       
   171   (mk_const_id (sign_of thy) s :: map string_of_int mode);
       
   172 
       
   173 fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) =
       
   174   let
       
   175     fun check_constrt ((names, eqs), t) =
       
   176       if is_constrt thy t then ((names, eqs), t) else
       
   177         let val s = variant names "x";
       
   178         in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
       
   179 
       
   180     val (in_ts, out_ts) = get_args mode 1 ts;
       
   181     val ((all_vs', eqs), in_ts') =
       
   182       foldl_map check_constrt ((all_vs, []), in_ts);
       
   183 
       
   184     fun compile_prems out_ts' vs names gr [] =
       
   185           let
       
   186             val (gr2, out_ps) = foldl_map (fn (gr, t) =>
       
   187               invoke_codegen thy gr dep false t) (gr, out_ts);
       
   188             val (gr3, eq_ps) = foldl_map (fn (gr, (s, t)) =>
       
   189               apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single)
       
   190                 (invoke_codegen thy gr dep false t)) (gr2, eqs);
       
   191             val (nvs, out_ts'') = foldl_map distinct_v
       
   192               ((names, map (fn x => (x, [x])) vs), out_ts');
       
   193             val (gr4, out_ps') = foldl_map (fn (gr, t) =>
       
   194               invoke_codegen thy gr dep false t) (gr3, out_ts'');
       
   195           in
       
   196             (gr4, compile_match (snd nvs) eq_ps out_ps'
       
   197               (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps])
       
   198               (Pretty.str "Seq.empty"))
       
   199           end
       
   200       | compile_prems out_ts vs names gr ps =
       
   201           let
       
   202             val vs' = distinct (flat (vs :: map term_vs out_ts));
       
   203             val Some (p, Some mode') =
       
   204               select_mode_prem thy modes (arg_vs union vs') ps;
       
   205             val ps' = filter_out (equal p) ps;
       
   206           in
       
   207             (case p of
       
   208                Prem (s, us, args) =>
       
   209                  let
       
   210                    val (in_ts, out_ts') = get_args mode' 1 us;
       
   211                    val (gr1, in_ps) = foldl_map (fn (gr, t) =>
       
   212                      invoke_codegen thy gr dep false t) (gr, in_ts);
       
   213                    val (gr2, arg_ps) = foldl_map (fn (gr, t) =>
       
   214                      invoke_codegen thy gr dep true t) (gr1, args);
       
   215                    val (nvs, out_ts'') = foldl_map distinct_v
       
   216                      ((names, map (fn x => (x, [x])) vs), out_ts);
       
   217                    val (gr3, out_ps) = foldl_map (fn (gr, t) =>
       
   218                      invoke_codegen thy gr dep false t) (gr2, out_ts'')
       
   219                    val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps';
       
   220                  in
       
   221                    (gr4, compile_match (snd nvs) [] out_ps
       
   222                       (Pretty.block (separate (Pretty.brk 1)
       
   223                         (Pretty.str (modename thy s mode') :: arg_ps) @
       
   224                          [Pretty.brk 1, mk_tuple in_ps,
       
   225                           Pretty.str " :->", Pretty.brk 1, rest]))
       
   226                       (Pretty.str "Seq.empty"))
       
   227                  end
       
   228              | Sidecond t =>
       
   229                  let
       
   230                    val (gr1, side_p) = invoke_codegen thy gr dep true t;
       
   231                    val (nvs, out_ts') = foldl_map distinct_v
       
   232                      ((names, map (fn x => (x, [x])) vs), out_ts);
       
   233                    val (gr2, out_ps) = foldl_map (fn (gr, t) =>
       
   234                      invoke_codegen thy gr dep false t) (gr1, out_ts')
       
   235                    val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
       
   236                  in
       
   237                    (gr3, compile_match (snd nvs) [] out_ps
       
   238                       (Pretty.block [Pretty.str "?? ", side_p,
       
   239                         Pretty.str " :->", Pretty.brk 1, rest])
       
   240                       (Pretty.str "Seq.empty"))
       
   241                  end)
       
   242           end;
       
   243 
       
   244     val (gr', prem_p) = compile_prems in_ts' [] all_vs' gr ps;
       
   245   in
       
   246     (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p])
       
   247   end;
       
   248 
       
   249 fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode =
       
   250   let val (gr', cl_ps) = foldl_map (fn (gr, cl) =>
       
   251     compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls)
       
   252   in
       
   253     ((gr', "and "), Pretty.block
       
   254       ([Pretty.block (separate (Pretty.brk 1)
       
   255          (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @
       
   256          [Pretty.str " inp ="]),
       
   257         Pretty.brk 1] @
       
   258        flat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps))))
       
   259   end;
       
   260 
       
   261 fun compile_preds thy gr dep all_vs arg_vs modes preds =
       
   262   let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
       
   263     foldl_map (fn ((gr', prfx'), mode) =>
       
   264       compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode)
       
   265         ((gr, prfx), the (assoc (modes, s)))) ((gr, "fun "), preds)
       
   266   in
       
   267     (gr', space_implode "\n\n" (map Pretty.string_of (flat prs)) ^ ";\n\n")
       
   268   end;
       
   269 
       
   270 (**** processing of introduction rules ****)
       
   271 
       
   272 val string_of_mode = enclose "[" "]" o commas o map string_of_int;
       
   273 
       
   274 fun print_modes modes = message ("Inferred modes:\n" ^
       
   275   space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map
       
   276     string_of_mode ms)) modes));
       
   277 
       
   278 fun print_factors factors = message ("Factors:\n" ^
       
   279   space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors));
       
   280   
       
   281 fun get_modes (Some (Modes x), _) = x
       
   282   | get_modes _ = ([], []);
       
   283 
       
   284 fun mk_ind_def thy gr dep names intrs =
       
   285   let val ids = map (mk_const_id (sign_of thy)) names
       
   286   in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
       
   287     let
       
   288       fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) =
       
   289             (case strip_comb u of
       
   290                (Const (name, _), args) =>
       
   291                   (case InductivePackage.get_inductive thy name of
       
   292                      None => (gr, Sidecond t')
       
   293                    | Some ({names=names', ...}, {intrs=intrs', ...}) =>
       
   294                        (if names = names' then gr
       
   295                           else mk_ind_def thy gr (hd ids) names' intrs',
       
   296                         Prem (name, split_prod []
       
   297                           (the (assoc (factors, name))) t, args)))
       
   298              | _ => (gr, Sidecond t'))
       
   299         | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) =
       
   300             (gr, Prem ("eq", [t, u], []))
       
   301         | process_prem factors (gr, _ $ t) = (gr, Sidecond t);
       
   302 
       
   303       fun add_clause factors ((clauses, gr), intr) =
       
   304         let
       
   305           val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr;
       
   306           val (Const (name, _), args) = strip_comb u;
       
   307           val (gr', prems) = foldl_map (process_prem factors)
       
   308             (gr, Logic.strip_imp_prems intr);
       
   309         in
       
   310           (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @
       
   311              [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr')
       
   312         end;
       
   313 
       
   314       fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) =
       
   315             (case strip_comb u of
       
   316                (Const (name, _), _) =>
       
   317                  let val f = prod_factors [] t
       
   318                  in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end
       
   319              | _ => fs)
       
   320         | add_prod_factors (fs, _) = fs;
       
   321 
       
   322       val intrs' = map (rename_term o #prop o rep_thm o standard) intrs;
       
   323       val factors = foldl add_prod_factors ([], flat (map (fn t =>
       
   324         Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs'));
       
   325       val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep)
       
   326         (Graph.new_node (hd ids, (None, "")) gr)), intrs');
       
   327       val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs');
       
   328       val (_, args) = strip_comb u;
       
   329       val arg_vs = flat (map term_vs args);
       
   330       val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map
       
   331         (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids])));
       
   332       val modes = infer_modes thy extra_modes arg_vs clauses;
       
   333       val _ = print_modes modes;
       
   334       val _ = print_factors factors;
       
   335       val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs
       
   336         (modes @ extra_modes) clauses;
       
   337     in
       
   338       (Graph.map_node (hd ids) (K (Some (Modes (modes, factors)), s)) gr'')
       
   339     end      
       
   340   end;
       
   341 
       
   342 fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of
       
   343   (Const (s, _), args) => (case InductivePackage.get_inductive thy s of
       
   344        None => None
       
   345      | Some ({names, ...}, {intrs, ...}) =>
       
   346          let
       
   347           fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1)
       
   348             | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1)
       
   349             | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
       
   350 
       
   351            val gr1 = mk_ind_def thy gr dep names intrs;
       
   352            val (modes, factors) = pairself flat (ListPair.unzip
       
   353              (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep])));
       
   354            val ts = split_prod [] (the (assoc (factors, s))) t;
       
   355            val (ts', mode) = if is_query then
       
   356                fst (foldl mk_mode ((([], []), 1), ts))
       
   357              else (ts, 1 upto length ts);
       
   358            val _ = if mode mem the (assoc (modes, s)) then () else
       
   359              error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode);
       
   360            val (gr2, in_ps) = foldl_map (fn (gr, t) =>
       
   361              invoke_codegen thy gr dep false t) (gr1, ts');
       
   362            val (gr3, arg_ps) = foldl_map (fn (gr, t) =>
       
   363              invoke_codegen thy gr dep true t) (gr2, args);
       
   364          in
       
   365            Some (gr3, Pretty.block (separate (Pretty.brk 1)
       
   366              (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps])))
       
   367          end)
       
   368   | _ => None);
       
   369 
       
   370 fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) =
       
   371       (case mk_ind_call thy gr dep t u false of
       
   372          None => None
       
   373        | Some (gr', call_p) => Some (gr', (if brack then parens else I)
       
   374            (Pretty.block [Pretty.str "nonempty (", call_p, Pretty.str ")"])))
       
   375   | inductive_codegen thy gr dep brack (Free ("query", _) $ (Const ("op :", _) $ t $ u)) =
       
   376       mk_ind_call thy gr dep t u true
       
   377   | inductive_codegen thy gr dep brack _ = None;
       
   378 
       
   379 val setup = [add_codegen "inductive" inductive_codegen];
       
   380 
       
   381 end;