src/HOL/Tools/inductive_codegen.ML
changeset 22642 bfdb29f11eb4
parent 22556 b067fdca022d
child 22661 f3ba63a2663e
equal deleted inserted replaced
22641:a5dc96fad632 22642:bfdb29f11eb4
    28   end;
    28   end;
    29 
    29 
    30 (**** theory data ****)
    30 (**** theory data ****)
    31 
    31 
    32 fun merge_rules tabs =
    32 fun merge_rules tabs =
    33   Symtab.join (fn _ => fn (ths, ths') =>
    33   Symtab.join (fn _ => AList.merge (Thm.eq_thm_prop) (K true)) tabs;
    34     gen_merge_lists (Thm.eq_thm_prop o pairself fst) ths ths') tabs;
       
    35 
    34 
    36 structure CodegenData = TheoryDataFun
    35 structure CodegenData = TheoryDataFun
    37 (struct
    36 (struct
    38   val name = "HOL/inductive_codegen";
    37   val name = "HOL/inductive_codegen";
    39   type T =
    38   type T =
    65       NONE => thyname_of_const s thy | SOME s => s);
    64       NONE => thyname_of_const s thy | SOME s => s);
    66   in (case Option.map strip_comb (try HOLogic.dest_Trueprop (concl_of thm)) of
    65   in (case Option.map strip_comb (try HOLogic.dest_Trueprop (concl_of thm)) of
    67       SOME (Const ("op =", _), [t, _]) => (case head_of t of
    66       SOME (Const ("op =", _), [t, _]) => (case head_of t of
    68         Const (s, _) =>
    67         Const (s, _) =>
    69           CodegenData.put {intros = intros, graph = graph,
    68           CodegenData.put {intros = intros, graph = graph,
    70              eqns = eqns |> Symtab.update
    69              eqns = eqns |> Symtab.map_default (s, [])
    71                (s, Symtab.lookup_list eqns s @ [(thm, thyname_of s)])} thy
    70                (AList.update Thm.eq_thm_prop (thm, thyname_of s))} thy
    72       | _ => (warn thm; thy))
    71       | _ => (warn thm; thy))
    73     | SOME (Const (s, _), _) =>
    72     | SOME (Const (s, _), _) =>
    74         let
    73         let
    75           val cs = foldr add_term_consts [] (prems_of thm);
    74           val cs = foldr add_term_consts [] (prems_of thm);
    76           val rules = Symtab.lookup_list intros s;
    75           val rules = Symtab.lookup_list intros s;
    81                  SOME (_, {raw_induct, ...}) => length (params_of raw_induct)
    80                  SOME (_, {raw_induct, ...}) => length (params_of raw_induct)
    82                | NONE => 0)
    81                | NONE => 0)
    83             | xs => snd (snd (snd (split_last xs)))))
    82             | xs => snd (snd (snd (split_last xs)))))
    84         in CodegenData.put
    83         in CodegenData.put
    85           {intros = intros |>
    84           {intros = intros |>
    86            Symtab.update (s, rules @ [(thm, (thyname_of s, nparms))]),
    85            Symtab.update (s, (AList.update Thm.eq_thm_prop
       
    86              (thm, (thyname_of s, nparms)) rules)),
    87            graph = foldr (uncurry (Graph.add_edge o pair s))
    87            graph = foldr (uncurry (Graph.add_edge o pair s))
    88              (Library.foldl add_node (graph, s :: cs)) cs,
    88              (Library.foldl add_node (graph, s :: cs)) cs,
    89            eqns = eqns} thy
    89            eqns = eqns} thy
    90         end
    90         end
    91     | _ => (warn thm; thy))
    91     | _ => (warn thm; thy))
    96   in case Symtab.lookup intros s of
    96   in case Symtab.lookup intros s of
    97       NONE => (case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
    97       NONE => (case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
    98         NONE => NONE
    98         NONE => NONE
    99       | SOME ({names, ...}, {intrs, raw_induct, ...}) =>
    99       | SOME ({names, ...}, {intrs, raw_induct, ...}) =>
   100           SOME (names, thyname_of_const s thy, length (params_of raw_induct),
   100           SOME (names, thyname_of_const s thy, length (params_of raw_induct),
   101             preprocess thy intrs))
   101             preprocess thy (rev intrs)))
   102     | SOME _ =>
   102     | SOME _ =>
   103         let
   103         let
   104           val SOME names = find_first
   104           val SOME names = find_first
   105             (fn xs => s mem xs) (Graph.strong_conn graph);
   105             (fn xs => member (op =) xs s) (Graph.strong_conn graph);
   106           val intrs = List.concat (map
   106           val intrs as (_, (thyname, nparms)) :: _ =
   107             (fn s => the (Symtab.lookup intros s)) names);
   107             maps (the o Symtab.lookup intros) names;
   108           val (_, (_, (thyname, nparms))) = split_last intrs
   108         in SOME (names, thyname, nparms, preprocess thy (map fst (rev intrs))) end
   109         in SOME (names, thyname, nparms, preprocess thy (map fst intrs)) end
       
   110   end;
   109   end;
   111 
   110 
   112 
   111 
   113 (**** check if a term contains only constructor functions ****)
   112 (**** check if a term contains only constructor functions ****)
   114 
   113 
   253       ~1 => true
   252       ~1 => true
   254     | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
   253     | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
   255       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   254       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   256   end;
   255   end;
   257 
   256 
   258 fun fixp f x =
   257 fun fixp f (x : (string * (int list option list * int list) list) list) =
   259   let val y = f x
   258   let val y = f x
   260   in if x = y then x else fixp f y end;
   259   in if x = y then x else fixp f y end;
   261 
   260 
   262 fun infer_modes thy extra_modes arities arg_vs preds = fixp (fn modes =>
   261 fun infer_modes thy extra_modes arities arg_vs preds = fixp (fn modes =>
   263   map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
   262   map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
   486         (ks @ [SOME k]))) arities));
   485         (ks @ [SOME k]))) arities));
   487 
   486 
   488 fun prep_intrs intrs = map (rename_term o #prop o rep_thm o standard) intrs;
   487 fun prep_intrs intrs = map (rename_term o #prop o rep_thm o standard) intrs;
   489 
   488 
   490 fun constrain cs [] = []
   489 fun constrain cs [] = []
   491   | constrain cs ((s, xs) :: ys) = (s, case AList.lookup (op =) cs s of
   490   | constrain cs ((s, xs) :: ys) = (s, case AList.lookup (op =) cs (s : string) of
   492       NONE => xs
   491       NONE => xs
   493     | SOME xs' => xs inter xs') :: constrain cs ys;
   492     | SOME xs' => xs inter xs') :: constrain cs ys;
   494 
   493 
   495 fun mk_extra_defs thy defs gr dep names module ts =
   494 fun mk_extra_defs thy defs gr dep names module ts =
   496   Library.foldl (fn (gr, name) =>
   495   Library.foldl (fn (gr, name) =>
   650              end handle TERM _ => mk_ind_call thy defs gr dep module true
   649              end handle TERM _ => mk_ind_call thy defs gr dep module true
   651                s T ts names thyname k intrs)
   650                s T ts names thyname k intrs)
   652       | _ => NONE)
   651       | _ => NONE)
   653     | SOME eqns =>
   652     | SOME eqns =>
   654         let
   653         let
   655           val (_, (_, thyname)) = split_last eqns;
   654           val (_, thyname) :: _ = eqns;
   656           val (gr', id) = mk_fun thy defs s (preprocess thy (map fst eqns))
   655           val (gr', id) = mk_fun thy defs s (preprocess thy (map fst (rev eqns)))
   657             dep module (if_library thyname module) gr;
   656             dep module (if_library thyname module) gr;
   658           val (gr'', ps) = foldl_map
   657           val (gr'', ps) = foldl_map
   659             (invoke_codegen thy defs dep module true) (gr', ts);
   658             (invoke_codegen thy defs dep module true) (gr', ts);
   660         in SOME (gr'', mk_app brack (Pretty.str id) ps)
   659         in SOME (gr'', mk_app brack (Pretty.str id) ps)
   661         end)
   660         end)