src/HOL/Tools/inductive_codegen.ML
changeset 16645 a152d6b21c31
parent 16424 18a07ad8fea8
child 16861 7446b4be013b
equal deleted inserted replaced
16644:701218c1301c 16645:a152d6b21c31
     5 Code generator for inductive predicates.
     5 Code generator for inductive predicates.
     6 *)
     6 *)
     7 
     7 
     8 signature INDUCTIVE_CODEGEN =
     8 signature INDUCTIVE_CODEGEN =
     9 sig
     9 sig
    10   val add : theory attribute
    10   val add : string option -> theory attribute
    11   val setup : (theory -> theory) list
    11   val setup : (theory -> theory) list
    12 end;
    12 end;
    13 
    13 
    14 structure InductiveCodegen : INDUCTIVE_CODEGEN =
    14 structure InductiveCodegen : INDUCTIVE_CODEGEN =
    15 struct
    15 struct
    20 
    20 
    21 structure CodegenData = TheoryDataFun
    21 structure CodegenData = TheoryDataFun
    22 (struct
    22 (struct
    23   val name = "HOL/inductive_codegen";
    23   val name = "HOL/inductive_codegen";
    24   type T =
    24   type T =
    25     {intros : thm list Symtab.table,
    25     {intros : (thm * string) list Symtab.table,
    26      graph : unit Graph.T,
    26      graph : unit Graph.T,
    27      eqns : thm list Symtab.table};
    27      eqns : (thm * string) list Symtab.table};
    28   val empty =
    28   val empty =
    29     {intros = Symtab.empty, graph = Graph.empty, eqns = Symtab.empty};
    29     {intros = Symtab.empty, graph = Graph.empty, eqns = Symtab.empty};
    30   val copy = I;
    30   val copy = I;
    31   val extend = I;
    31   val extend = I;
    32   fun merge _ ({intros=intros1, graph=graph1, eqns=eqns1},
    32   fun merge _ ({intros=intros1, graph=graph1, eqns=eqns1},
    33     {intros=intros2, graph=graph2, eqns=eqns2}) =
    33     {intros=intros2, graph=graph2, eqns=eqns2}) =
    34     {intros = Symtab.merge_multi Drule.eq_thm_prop (intros1, intros2),
    34     {intros = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst)
       
    35        (intros1, intros2),
    35      graph = Graph.merge (K true) (graph1, graph2),
    36      graph = Graph.merge (K true) (graph1, graph2),
    36      eqns = Symtab.merge_multi Drule.eq_thm_prop (eqns1, eqns2)};
    37      eqns = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst)
       
    38        (eqns1, eqns2)};
    37   fun print _ _ = ();
    39   fun print _ _ = ();
    38 end);
    40 end);
    39 
    41 
    40 
    42 
    41 fun warn thm = warning ("InductiveCodegen: Not a proper clause:\n" ^
    43 fun warn thm = warning ("InductiveCodegen: Not a proper clause:\n" ^
    42   string_of_thm thm);
    44   string_of_thm thm);
    43 
    45 
    44 fun add_node (g, x) = Graph.new_node (x, ()) g handle Graph.DUP _ => g;
    46 fun add_node (g, x) = Graph.new_node (x, ()) g handle Graph.DUP _ => g;
    45 
    47 
    46 fun add (p as (thy, thm)) =
    48 fun add optmod (p as (thy, thm)) =
    47   let val {intros, graph, eqns} = CodegenData.get thy;
    49   let
       
    50     val {intros, graph, eqns} = CodegenData.get thy;
       
    51     fun thyname_of s = (case optmod of
       
    52       NONE => thyname_of_const s thy | SOME s => s);
    48   in (case concl_of thm of
    53   in (case concl_of thm of
    49       _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of
    54       _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of
    50         Const (s, _) =>
    55         Const (s, _) =>
    51           let val cs = foldr add_term_consts [] (prems_of thm)
    56           let val cs = foldr add_term_consts [] (prems_of thm)
    52           in (CodegenData.put
    57           in (CodegenData.put
    53             {intros = Symtab.update ((s,
    58             {intros = Symtab.update ((s,
    54                getOpt (Symtab.lookup (intros, s), []) @ [thm]), intros),
    59                getOpt (Symtab.lookup (intros, s), []) @
       
    60                  [(thm, thyname_of s)]), intros),
    55              graph = foldr (uncurry (Graph.add_edge o pair s))
    61              graph = foldr (uncurry (Graph.add_edge o pair s))
    56                (Library.foldl add_node (graph, s :: cs)) cs,
    62                (Library.foldl add_node (graph, s :: cs)) cs,
    57              eqns = eqns} thy, thm)
    63              eqns = eqns} thy, thm)
    58           end
    64           end
    59       | _ => (warn thm; p))
    65       | _ => (warn thm; p))
    60     | _ $ (Const ("op =", _) $ t $ _) => (case head_of t of
    66     | _ $ (Const ("op =", _) $ t $ _) => (case head_of t of
    61         Const (s, _) =>
    67         Const (s, _) =>
    62           (CodegenData.put {intros = intros, graph = graph,
    68           (CodegenData.put {intros = intros, graph = graph,
    63              eqns = Symtab.update ((s,
    69              eqns = Symtab.update ((s,
    64                getOpt (Symtab.lookup (eqns, s), []) @ [thm]), eqns)} thy, thm)
    70                getOpt (Symtab.lookup (eqns, s), []) @
       
    71                  [(thm, thyname_of s)]), eqns)} thy, thm)
    65       | _ => (warn thm; p))
    72       | _ => (warn thm; p))
    66     | _ => (warn thm; p))
    73     | _ => (warn thm; p))
    67   end;
    74   end;
    68 
    75 
    69 fun get_clauses thy s =
    76 fun get_clauses thy s =
    70   let val {intros, graph, ...} = CodegenData.get thy
    77   let val {intros, graph, ...} = CodegenData.get thy
    71   in case Symtab.lookup (intros, s) of
    78   in case Symtab.lookup (intros, s) of
    72       NONE => (case InductivePackage.get_inductive thy s of
    79       NONE => (case InductivePackage.get_inductive thy s of
    73         NONE => NONE
    80         NONE => NONE
    74       | SOME ({names, ...}, {intrs, ...}) => SOME (names, preprocess thy intrs))
    81       | SOME ({names, ...}, {intrs, ...}) =>
       
    82           SOME (names, thyname_of_const s thy,
       
    83             preprocess thy intrs))
    75     | SOME _ =>
    84     | SOME _ =>
    76         let val SOME names = find_first
    85         let
    77           (fn xs => s mem xs) (Graph.strong_conn graph)
    86           val SOME names = find_first
    78         in SOME (names, preprocess thy
    87             (fn xs => s mem xs) (Graph.strong_conn graph);
    79           (List.concat (map (fn s => valOf (Symtab.lookup (intros, s))) names)))
    88           val intrs = List.concat (map
    80         end
    89             (fn s => valOf (Symtab.lookup (intros, s))) names);
       
    90           val (_, (_, thyname)) = split_last intrs
       
    91         in SOME (names, thyname, preprocess thy (map fst intrs)) end
    81   end;
    92   end;
    82 
    93 
    83 
    94 
    84 (**** improper tuples ****)
    95 (**** improper tuples ****)
    85 
    96 
   362        (if can_fail then
   373        (if can_fail then
   363           [Pretty.brk 1, Pretty.str "| _ => Seq.empty)"]
   374           [Pretty.brk 1, Pretty.str "| _ => Seq.empty)"]
   364         else [Pretty.str ")"])))
   375         else [Pretty.str ")"])))
   365   end;
   376   end;
   366 
   377 
   367 fun modename thy s (iss, is) = space_implode "__"
   378 fun strip_spaces s = implode (fst (take_suffix (equal " ") (explode s)));
   368   (mk_const_id (sign_of thy) s ::
   379 
       
   380 fun modename thy thyname thyname' s (iss, is) = space_implode "__"
       
   381   (mk_const_id (sign_of thy) thyname thyname' (strip_spaces s) ::
   369     map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is]));
   382     map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is]));
   370 
   383 
   371 fun compile_expr thy dep brack (gr, (NONE, t)) =
   384 fun compile_expr thy defs dep thyname brack thynames (gr, (NONE, t)) =
   372       apsnd single (invoke_codegen thy dep brack (gr, t))
   385       apsnd single (invoke_codegen thy defs dep thyname brack (gr, t))
   373   | compile_expr _ _ _ (gr, (SOME _, Var ((name, _), _))) =
   386   | compile_expr _ _ _ _ _ _ (gr, (SOME _, Var ((name, _), _))) =
   374       (gr, [Pretty.str name])
   387       (gr, [Pretty.str name])
   375   | compile_expr thy dep brack (gr, (SOME (Mode (mode, ms)), t)) =
   388   | compile_expr thy defs dep thyname brack thynames (gr, (SOME (Mode (mode, ms)), t)) =
   376       let
   389       let
   377         val (Const (name, _), args) = strip_comb t;
   390         val (Const (name, _), args) = strip_comb t;
   378         val (gr', ps) = foldl_map
   391         val (gr', ps) = foldl_map
   379           (compile_expr thy dep true) (gr, ms ~~ args);
   392           (compile_expr thy defs dep thyname true thynames) (gr, ms ~~ args);
   380       in (gr', (if brack andalso not (null ps) then
   393       in (gr', (if brack andalso not (null ps) then
   381         single o parens o Pretty.block else I)
   394         single o parens o Pretty.block else I)
   382           (List.concat (separate [Pretty.brk 1]
   395           (List.concat (separate [Pretty.brk 1]
   383             ([Pretty.str (modename thy name mode)] :: ps))))
   396             ([Pretty.str (modename thy thyname
       
   397                 (if name = "op =" then ""
       
   398                  else the (assoc (thynames, name))) name mode)] :: ps))))
   384       end;
   399       end;
   385 
   400 
   386 fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) =
   401 fun compile_clause thy defs gr dep thyname all_vs arg_vs modes thynames (iss, is) (ts, ps) =
   387   let
   402   let
   388     val modes' = modes @ List.mapPartial
   403     val modes' = modes @ List.mapPartial
   389       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   404       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   390         (arg_vs ~~ iss);
   405         (arg_vs ~~ iss);
   391 
   406 
   394         let val s = variant names "x";
   409         let val s = variant names "x";
   395         in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
   410         in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
   396 
   411 
   397     fun compile_eq (gr, (s, t)) =
   412     fun compile_eq (gr, (s, t)) =
   398       apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single)
   413       apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single)
   399         (invoke_codegen thy dep false (gr, t));
   414         (invoke_codegen thy defs dep thyname false (gr, t));
   400 
   415 
   401     val (in_ts, out_ts) = get_args is 1 ts;
   416     val (in_ts, out_ts) = get_args is 1 ts;
   402     val ((all_vs', eqs), in_ts') =
   417     val ((all_vs', eqs), in_ts') =
   403       foldl_map check_constrt ((all_vs, []), in_ts);
   418       foldl_map check_constrt ((all_vs, []), in_ts);
   404 
   419 
   407         | Var ((s, _), _) => s mem arg_vs);
   422         | Var ((s, _), _) => s mem arg_vs);
   408 
   423 
   409     fun compile_prems out_ts' vs names gr [] =
   424     fun compile_prems out_ts' vs names gr [] =
   410           let
   425           let
   411             val (gr2, out_ps) = foldl_map
   426             val (gr2, out_ps) = foldl_map
   412               (invoke_codegen thy dep false) (gr, out_ts);
   427               (invoke_codegen thy defs dep thyname false) (gr, out_ts);
   413             val (gr3, eq_ps) = foldl_map compile_eq (gr2, eqs);
   428             val (gr3, eq_ps) = foldl_map compile_eq (gr2, eqs);
   414             val ((names', eqs'), out_ts'') =
   429             val ((names', eqs'), out_ts'') =
   415               foldl_map check_constrt ((names, []), out_ts');
   430               foldl_map check_constrt ((names, []), out_ts');
   416             val (nvs, out_ts''') = foldl_map distinct_v
   431             val (nvs, out_ts''') = foldl_map distinct_v
   417               ((names', map (fn x => (x, [x])) vs), out_ts'');
   432               ((names', map (fn x => (x, [x])) vs), out_ts'');
   418             val (gr4, out_ps') = foldl_map
   433             val (gr4, out_ps') = foldl_map
   419               (invoke_codegen thy dep false) (gr3, out_ts''');
   434               (invoke_codegen thy defs dep thyname false) (gr3, out_ts''');
   420             val (gr5, eq_ps') = foldl_map compile_eq (gr4, eqs')
   435             val (gr5, eq_ps') = foldl_map compile_eq (gr4, eqs')
   421           in
   436           in
   422             (gr5, compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
   437             (gr5, compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
   423               (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps])
   438               (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps])
   424               (exists (not o is_exhaustive) out_ts'''))
   439               (exists (not o is_exhaustive) out_ts'''))
   432             val ((names', eqs), out_ts') =
   447             val ((names', eqs), out_ts') =
   433               foldl_map check_constrt ((names, []), out_ts);
   448               foldl_map check_constrt ((names, []), out_ts);
   434             val (nvs, out_ts'') = foldl_map distinct_v
   449             val (nvs, out_ts'') = foldl_map distinct_v
   435               ((names', map (fn x => (x, [x])) vs), out_ts');
   450               ((names', map (fn x => (x, [x])) vs), out_ts');
   436             val (gr0, out_ps) = foldl_map
   451             val (gr0, out_ps) = foldl_map
   437               (invoke_codegen thy dep false) (gr, out_ts'');
   452               (invoke_codegen thy defs dep thyname false) (gr, out_ts'');
   438             val (gr1, eq_ps) = foldl_map compile_eq (gr0, eqs)
   453             val (gr1, eq_ps) = foldl_map compile_eq (gr0, eqs)
   439           in
   454           in
   440             (case p of
   455             (case p of
   441                Prem (us, t) =>
   456                Prem (us, t) =>
   442                  let
   457                  let
   443                    val (in_ts, out_ts''') = get_args js 1 us;
   458                    val (in_ts, out_ts''') = get_args js 1 us;
   444                    val (gr2, in_ps) = foldl_map
   459                    val (gr2, in_ps) = foldl_map
   445                      (invoke_codegen thy dep false) (gr1, in_ts);
   460                      (invoke_codegen thy defs dep thyname false) (gr1, in_ts);
   446                    val (gr3, ps) = if is_ind t then
   461                    val (gr3, ps) = if is_ind t then
   447                        apsnd (fn ps => ps @ [Pretty.brk 1, mk_tuple in_ps])
   462                        apsnd (fn ps => ps @ [Pretty.brk 1, mk_tuple in_ps])
   448                          (compile_expr thy dep false (gr2, (mode, t)))
   463                          (compile_expr thy defs dep thyname false thynames
       
   464                            (gr2, (mode, t)))
   449                      else
   465                      else
   450                        apsnd (fn p => conv_ntuple us t
   466                        apsnd (fn p => conv_ntuple us t
   451                          [Pretty.str "Seq.of_list", Pretty.brk 1, p])
   467                          [Pretty.str "Seq.of_list", Pretty.brk 1, p])
   452                            (invoke_codegen thy dep true (gr2, t));
   468                            (invoke_codegen thy defs dep thyname true (gr2, t));
   453                    val (gr4, rest) = compile_prems out_ts''' vs' (fst nvs) gr3 ps';
   469                    val (gr4, rest) = compile_prems out_ts''' vs' (fst nvs) gr3 ps';
   454                  in
   470                  in
   455                    (gr4, compile_match (snd nvs) eq_ps out_ps
   471                    (gr4, compile_match (snd nvs) eq_ps out_ps
   456                       (Pretty.block (ps @
   472                       (Pretty.block (ps @
   457                          [Pretty.str " :->", Pretty.brk 1, rest]))
   473                          [Pretty.str " :->", Pretty.brk 1, rest]))
   458                       (exists (not o is_exhaustive) out_ts''))
   474                       (exists (not o is_exhaustive) out_ts''))
   459                  end
   475                  end
   460              | Sidecond t =>
   476              | Sidecond t =>
   461                  let
   477                  let
   462                    val (gr2, side_p) = invoke_codegen thy dep true (gr1, t);
   478                    val (gr2, side_p) = invoke_codegen thy defs dep thyname true (gr1, t);
   463                    val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
   479                    val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps';
   464                  in
   480                  in
   465                    (gr3, compile_match (snd nvs) eq_ps out_ps
   481                    (gr3, compile_match (snd nvs) eq_ps out_ps
   466                       (Pretty.block [Pretty.str "?? ", side_p,
   482                       (Pretty.block [Pretty.str "?? ", side_p,
   467                         Pretty.str " :->", Pretty.brk 1, rest])
   483                         Pretty.str " :->", Pretty.brk 1, rest])
   472     val (gr', prem_p) = compile_prems in_ts' arg_vs all_vs' gr ps;
   488     val (gr', prem_p) = compile_prems in_ts' arg_vs all_vs' gr ps;
   473   in
   489   in
   474     (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p])
   490     (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p])
   475   end;
   491   end;
   476 
   492 
   477 fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode =
   493 fun compile_pred thy defs gr dep thyname prfx all_vs arg_vs modes thynames s cls mode =
   478   let val (gr', cl_ps) = foldl_map (fn (gr, cl) =>
   494   let val (gr', cl_ps) = foldl_map (fn (gr, cl) => compile_clause thy defs
   479     compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls)
   495     gr dep thyname all_vs arg_vs modes thynames mode cl) (gr, cls)
   480   in
   496   in
   481     ((gr', "and "), Pretty.block
   497     ((gr', "and "), Pretty.block
   482       ([Pretty.block (separate (Pretty.brk 1)
   498       ([Pretty.block (separate (Pretty.brk 1)
   483          (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @
   499          (Pretty.str (prfx ^ modename thy thyname thyname s mode) ::
       
   500            map Pretty.str arg_vs) @
   484          [Pretty.str " inp ="]),
   501          [Pretty.str " inp ="]),
   485         Pretty.brk 1] @
   502         Pretty.brk 1] @
   486        List.concat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps))))
   503        List.concat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps))))
   487   end;
   504   end;
   488 
   505 
   489 fun compile_preds thy gr dep all_vs arg_vs modes preds =
   506 fun compile_preds thy defs gr dep thyname all_vs arg_vs modes thynames preds =
   490   let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
   507   let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) =>
   491     foldl_map (fn ((gr', prfx'), mode) =>
   508     foldl_map (fn ((gr', prfx'), mode) => compile_pred thy defs gr'
   492       compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode)
   509       dep thyname prfx' all_vs arg_vs modes thynames s cls mode)
   493         ((gr, prfx), valOf (assoc (modes, s)))) ((gr, "fun "), preds)
   510         ((gr, prfx), valOf (assoc (modes, s)))) ((gr, "fun "), preds)
   494   in
   511   in
   495     (gr', space_implode "\n\n" (map Pretty.string_of (List.concat prs)) ^ ";\n\n")
   512     (gr', space_implode "\n\n" (map Pretty.string_of (List.concat prs)) ^ ";\n\n")
   496   end;
   513   end;
   497 
   514 
   498 (**** processing of introduction rules ****)
   515 (**** processing of introduction rules ****)
   499 
   516 
   500 exception Modes of
   517 exception Modes of
   501   (string * (int list option list * int list) list) list *
   518   (string * (int list option list * int list) list) list *
   502   (string * (int list list option list * int list list)) list;
   519   (string * (int list list option list * int list list)) list *
   503 
   520   string;
   504 fun lookup_modes gr dep = apfst List.concat (apsnd List.concat (ListPair.unzip
   521 
   505   (map ((fn (SOME (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr)
   522 fun lookup_modes gr dep = foldl (fn ((xs, ys, z), (xss, yss, zss)) =>
   506     (Graph.all_preds gr [dep]))));
   523     (xss @ xs, yss @ ys, zss @ map (rpair z o fst) ys)) ([], [], [])
       
   524   (map ((fn (SOME (Modes x), _, _) => x | _ => ([], [], "")) o Graph.get_node gr)
       
   525     (Graph.all_preds gr [dep]));
   507 
   526 
   508 fun print_factors factors = message ("Factors:\n" ^
   527 fun print_factors factors = message ("Factors:\n" ^
   509   space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^
   528   space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^
   510     space_implode " -> " (map
   529     space_implode " -> " (map
   511       (fn NONE => "X" | SOME f' => string_of_factors [] f')
   530       (fn NONE => "X" | SOME f' => string_of_factors [] f')
   516 fun constrain cs [] = []
   535 fun constrain cs [] = []
   517   | constrain cs ((s, xs) :: ys) = (s, case assoc (cs, s) of
   536   | constrain cs ((s, xs) :: ys) = (s, case assoc (cs, s) of
   518       NONE => xs
   537       NONE => xs
   519     | SOME xs' => xs inter xs') :: constrain cs ys;
   538     | SOME xs' => xs inter xs') :: constrain cs ys;
   520 
   539 
   521 fun mk_extra_defs thy gr dep names ts =
   540 fun mk_extra_defs thy defs gr dep names ts =
   522   Library.foldl (fn (gr, name) =>
   541   Library.foldl (fn (gr, name) =>
   523     if name mem names then gr
   542     if name mem names then gr
   524     else (case get_clauses thy name of
   543     else (case get_clauses thy name of
   525         NONE => gr
   544         NONE => gr
   526       | SOME (names, intrs) =>
   545       | SOME (names, thyname, intrs) =>
   527           mk_ind_def thy gr dep names [] [] (prep_intrs intrs)))
   546           mk_ind_def thy defs gr dep names thyname [] [] (prep_intrs intrs)))
   528             (gr, foldr add_term_consts [] ts)
   547             (gr, foldr add_term_consts [] ts)
   529 
   548 
   530 and mk_ind_def thy gr dep names modecs factorcs intrs =
   549 and mk_ind_def thy defs gr dep names thyname modecs factorcs intrs =
   531   let val ids = map (mk_const_id (sign_of thy)) names
   550   Graph.add_edge (hd names, dep) gr handle Graph.UNDEF _ =>
   532   in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
       
   533     let
   551     let
   534       val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs);
   552       val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs);
   535       val (_, args) = strip_comb u;
   553       val (_, args) = strip_comb u;
   536       val arg_vs = List.concat (map term_vs args);
   554       val arg_vs = List.concat (map term_vs args);
   537 
   555 
   563             then infer_factors (sign_of thy) extra_fs
   581             then infer_factors (sign_of thy) extra_fs
   564               (fs, (SOME (FVar (prod_factors [] t)), u))
   582               (fs, (SOME (FVar (prod_factors [] t)), u))
   565             else fs
   583             else fs
   566         | add_prod_factors _ (fs, _) = fs;
   584         | add_prod_factors _ (fs, _) = fs;
   567 
   585 
   568       val gr' = mk_extra_defs thy
   586       val gr' = mk_extra_defs thy defs
   569         (Graph.add_edge (hd ids, dep)
   587         (Graph.add_edge (hd names, dep)
   570           (Graph.new_node (hd ids, (NONE, "")) gr)) (hd ids) names intrs;
   588           (Graph.new_node (hd names, (NONE, "", "")) gr)) (hd names) names intrs;
   571       val (extra_modes, extra_factors) = lookup_modes gr' (hd ids);
   589       val (extra_modes, extra_factors, extra_thynames) = lookup_modes gr' (hd names);
   572       val fs = constrain factorcs (map (apsnd dest_factors)
   590       val fs = constrain factorcs (map (apsnd dest_factors)
   573         (Library.foldl (add_prod_factors extra_factors) ([], List.concat (map (fn t =>
   591         (Library.foldl (add_prod_factors extra_factors) ([], List.concat (map (fn t =>
   574           Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs))));
   592           Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs))));
   575       val factors = List.mapPartial (fn (name, f) =>
   593       val factors = List.mapPartial (fn (name, f) =>
   576         if name mem arg_vs then NONE
   594         if name mem arg_vs then NONE
   579         Library.foldl (add_clause (fs @ map (apsnd snd) extra_factors)) ([], intrs);
   597         Library.foldl (add_clause (fs @ map (apsnd snd) extra_factors)) ([], intrs);
   580       val modes = constrain modecs
   598       val modes = constrain modecs
   581         (infer_modes thy extra_modes factors arg_vs clauses);
   599         (infer_modes thy extra_modes factors arg_vs clauses);
   582       val _ = print_factors factors;
   600       val _ = print_factors factors;
   583       val _ = print_modes modes;
   601       val _ = print_modes modes;
   584       val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs) arg_vs
   602       val (gr'', s) = compile_preds thy defs gr' (hd names) thyname (terms_vs intrs)
   585         (modes @ extra_modes) clauses;
   603         arg_vs (modes @ extra_modes)
       
   604         (map (rpair thyname o fst) factors @ extra_thynames) clauses;
   586     in
   605     in
   587       (Graph.map_node (hd ids) (K (SOME (Modes (modes, factors)), s)) gr'')
   606       (Graph.map_node (hd names)
   588     end      
   607         (K (SOME (Modes (modes, factors, thyname)), thyname, s)) gr'')
   589   end;
   608     end;
   590 
   609 
   591 fun find_mode s u modes is = (case find_first (fn Mode ((_, js), _) => is=js)
   610 fun find_mode s u modes is = (case find_first (fn Mode ((_, js), _) => is=js)
   592   (modes_of modes u handle Option => []) of
   611   (modes_of modes u handle Option => []) of
   593      NONE => error ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is))
   612      NONE => error ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is))
   594    | mode => mode);
   613    | mode => mode);
   595 
   614 
   596 fun mk_ind_call thy gr dep t u is_query = (case head_of u of
   615 fun mk_ind_call thy defs gr dep thyname t u is_query = (case head_of u of
   597   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
   616   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
   598        (NONE, _) => NONE
   617        (NONE, _) => NONE
   599      | (SOME (names, intrs), NONE) =>
   618      | (SOME (names, thyname', intrs), NONE) =>
   600          let
   619          let
   601           fun mk_mode (((ts, mode), i), Const ("dummy_pattern", _)) =
   620           fun mk_mode (((ts, mode), i), Const ("dummy_pattern", _)) =
   602                 ((ts, mode), i+1)
   621                 ((ts, mode), i+1)
   603             | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
   622             | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
   604 
   623 
   605            val gr1 = mk_extra_defs thy
   624            val gr1 = mk_extra_defs thy defs
   606              (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u];
   625              (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u];
   607            val (modes, factors) = lookup_modes gr1 dep;
   626            val (modes, factors, thynames) = lookup_modes gr1 dep;
   608            val ts = split_prod [] (snd (valOf (assoc (factors, s)))) t;
   627            val ts = split_prod [] (snd (valOf (assoc (factors, s)))) t;
   609            val (ts', is) = if is_query then
   628            val (ts', is) = if is_query then
   610                fst (Library.foldl mk_mode ((([], []), 1), ts))
   629                fst (Library.foldl mk_mode ((([], []), 1), ts))
   611              else (ts, 1 upto length ts);
   630              else (ts, 1 upto length ts);
   612            val mode = find_mode s u modes is;
   631            val mode = find_mode s u modes is;
   613            val (gr2, in_ps) = foldl_map
   632            val (gr2, in_ps) = foldl_map
   614              (invoke_codegen thy dep false) (gr1, ts');
   633              (invoke_codegen thy defs dep thyname false) (gr1, ts');
   615            val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u))
   634            val (gr3, ps) =
       
   635              compile_expr thy defs dep thyname false thynames (gr2, (mode, u))
   616          in
   636          in
   617            SOME (gr3, Pretty.block
   637            SOME (gr3, Pretty.block
   618              (ps @ [Pretty.brk 1, mk_tuple in_ps]))
   638              (ps @ [Pretty.brk 1, mk_tuple in_ps]))
   619          end
   639          end
   620      | _ => NONE)
   640      | _ => NONE)
   621   | _ => NONE);
   641   | _ => NONE);
   622 
   642 
   623 fun list_of_indset thy gr dep brack u = (case head_of u of
   643 fun list_of_indset thy defs gr dep thyname brack u = (case head_of u of
   624   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
   644   Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of
   625        (NONE, _) => NONE
   645        (NONE, _) => NONE
   626      | (SOME (names, intrs), NONE) =>
   646      | (SOME (names, thyname', intrs), NONE) =>
   627          let
   647          let
   628            val gr1 = mk_extra_defs thy
   648            val gr1 = mk_extra_defs thy defs
   629              (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u];
   649              (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u];
   630            val (modes, factors) = lookup_modes gr1 dep;
   650            val (modes, factors, thynames) = lookup_modes gr1 dep;
   631            val mode = find_mode s u modes [];
   651            val mode = find_mode s u modes [];
   632            val (gr2, ps) = compile_expr thy dep false (gr1, (mode, u))
   652            val (gr2, ps) =
       
   653              compile_expr thy defs dep thyname false thynames (gr1, (mode, u))
   633          in
   654          in
   634            SOME (gr2, (if brack then parens else I)
   655            SOME (gr2, (if brack then parens else I)
   635              (Pretty.block ([Pretty.str "Seq.list_of", Pretty.brk 1,
   656              (Pretty.block ([Pretty.str "Seq.list_of", Pretty.brk 1,
   636                Pretty.str "("] @
   657                Pretty.str "("] @
   637                conv_ntuple' (snd (valOf (assoc (factors, s))))
   658                conv_ntuple' (snd (valOf (assoc (factors, s))))
   648     val (Const (s, T), ts) = strip_comb t;
   669     val (Const (s, T), ts) = strip_comb t;
   649     val (Ts, U) = strip_type T
   670     val (Ts, U) = strip_type T
   650   in
   671   in
   651     rename_term
   672     rename_term
   652       (Logic.list_implies (prems_of eqn, HOLogic.mk_Trueprop (HOLogic.mk_mem
   673       (Logic.list_implies (prems_of eqn, HOLogic.mk_Trueprop (HOLogic.mk_mem
   653         (foldr1 HOLogic.mk_prod (ts @ [u]), Const (Sign.base_name s ^ "_aux",
   674         (foldr1 HOLogic.mk_prod (ts @ [u]), Const (s ^ " ",
   654           HOLogic.mk_setT (foldr1 HOLogic.mk_prodT (Ts @ [U])))))))
   675           HOLogic.mk_setT (foldr1 HOLogic.mk_prodT (Ts @ [U])))))))
   655   end;
   676   end;
   656 
   677 
   657 fun mk_fun thy name eqns dep gr = 
   678 fun mk_fun thy defs name eqns dep thyname thyname' gr =
   658   let val id = mk_const_id (sign_of thy) name
   679   let
   659   in Graph.add_edge (id, dep) gr handle Graph.UNDEF _ =>
   680     val fun_id = mk_const_id (sign_of thy) thyname' thyname' name;
       
   681     val call_id = mk_const_id (sign_of thy) thyname thyname' name
       
   682   in (Graph.add_edge (name, dep) gr handle Graph.UNDEF _ =>
   660     let
   683     let
   661       val clauses = map clause_of_eqn eqns;
   684       val clauses = map clause_of_eqn eqns;
   662       val pname = mk_const_id (sign_of thy) (Sign.base_name name ^ "_aux");
   685       val pname = name ^ " ";
   663       val arity = length (snd (strip_comb (fst (HOLogic.dest_eq
   686       val arity = length (snd (strip_comb (fst (HOLogic.dest_eq
   664         (HOLogic.dest_Trueprop (concl_of (hd eqns)))))));
   687         (HOLogic.dest_Trueprop (concl_of (hd eqns)))))));
   665       val mode = 1 upto arity;
   688       val mode = 1 upto arity;
   666       val vars = map (fn i => Pretty.str ("x" ^ string_of_int i)) mode;
   689       val vars = map (fn i => Pretty.str ("x" ^ string_of_int i)) mode;
   667       val s = Pretty.string_of (Pretty.block
   690       val s = Pretty.string_of (Pretty.block
   668         [mk_app false (Pretty.str ("fun " ^ id)) vars, Pretty.str " =",
   691         [mk_app false (Pretty.str ("fun " ^ fun_id)) vars, Pretty.str " =",
   669          Pretty.brk 1, Pretty.str "Seq.hd", Pretty.brk 1,
   692          Pretty.brk 1, Pretty.str "Seq.hd", Pretty.brk 1,
   670          parens (Pretty.block [Pretty.str (modename thy pname ([], mode)),
   693          parens (Pretty.block [Pretty.str (modename thy thyname' thyname' pname ([], mode)),
   671            Pretty.brk 1, mk_tuple vars])]) ^ ";\n\n";
   694            Pretty.brk 1, mk_tuple vars])]) ^ ";\n\n";
   672       val gr' = mk_ind_def thy (Graph.add_edge (id, dep)
   695       val gr' = mk_ind_def thy defs (Graph.add_edge (name, dep)
   673         (Graph.new_node (id, (NONE, s)) gr)) id [pname]
   696         (Graph.new_node (name, (NONE, thyname', s)) gr)) name [pname] thyname'
   674         [(pname, [([], mode)])]
   697         [(pname, [([], mode)])]
   675         [(pname, map (fn i => replicate i 2) (0 upto arity-1))]
   698         [(pname, map (fn i => replicate i 2) (0 upto arity-1))]
   676         clauses;
   699         clauses;
   677       val (modes, _) = lookup_modes gr' dep;
   700       val (modes, _, _) = lookup_modes gr' dep;
   678       val _ = find_mode pname (snd (HOLogic.dest_mem (HOLogic.dest_Trueprop
   701       val _ = find_mode pname (snd (HOLogic.dest_mem (HOLogic.dest_Trueprop
   679         (Logic.strip_imp_concl (hd clauses))))) modes mode
   702         (Logic.strip_imp_concl (hd clauses))))) modes mode
   680     in gr' end
   703     in gr' end, call_id)
   681   end;
   704   end;
   682 
   705 
   683 fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) =
   706 fun inductive_codegen thy defs gr dep thyname brack (Const ("op :", _) $ t $ u) =
   684       ((case mk_ind_call thy gr dep (Term.no_dummy_patterns t) u false of
   707       ((case mk_ind_call thy defs gr dep thyname (Term.no_dummy_patterns t) u false of
   685          NONE => NONE
   708          NONE => NONE
   686        | SOME (gr', call_p) => SOME (gr', (if brack then parens else I)
   709        | SOME (gr', call_p) => SOME (gr', (if brack then parens else I)
   687            (Pretty.block [Pretty.str "?! (", call_p, Pretty.str ")"])))
   710            (Pretty.block [Pretty.str "?! (", call_p, Pretty.str ")"])))
   688         handle TERM _ => mk_ind_call thy gr dep t u true)
   711         handle TERM _ => mk_ind_call thy defs gr dep thyname t u true)
   689   | inductive_codegen thy gr dep brack t = (case strip_comb t of
   712   | inductive_codegen thy defs gr dep thyname brack t = (case strip_comb t of
   690       (Const (s, _), ts) => (case Symtab.lookup (#eqns (CodegenData.get thy), s) of
   713       (Const (s, _), ts) => (case Symtab.lookup (#eqns (CodegenData.get thy), s) of
   691         NONE => list_of_indset thy gr dep brack t
   714         NONE => list_of_indset thy defs gr dep thyname brack t
   692       | SOME eqns =>
   715       | SOME eqns =>
   693           let
   716           let
   694             val gr' = mk_fun thy s (preprocess thy eqns) dep gr
   717             val (_, (_, thyname')) = split_last eqns;
   695             val (gr'', ps) = foldl_map (invoke_codegen thy dep true) (gr', ts);
   718             val (gr', id) = mk_fun thy defs s (preprocess thy (map fst eqns))
   696           in SOME (gr'', mk_app brack (Pretty.str (mk_const_id
   719               dep thyname thyname' gr;
   697             (sign_of thy) s)) ps)
   720             val (gr'', ps) = foldl_map
       
   721               (invoke_codegen thy defs dep thyname true) (gr', ts);
       
   722           in SOME (gr'', mk_app brack (Pretty.str id) ps)
   698           end)
   723           end)
   699     | _ => NONE);
   724     | _ => NONE);
   700 
   725 
   701 val setup =
   726 val setup =
   702   [add_codegen "inductive" inductive_codegen,
   727   [add_codegen "inductive" inductive_codegen,
   703    CodegenData.init,
   728    CodegenData.init,
   704    add_attribute "ind" (Scan.succeed add)];
   729    add_attribute "ind"
       
   730      (Scan.option (Args.$$$ "target" |-- Args.colon |-- Args.name) >> add)];
   705 
   731 
   706 end;
   732 end;
   707 
   733 
   708 
   734 
   709 (**** combinators for code generated from inductive predicates ****)
   735 (**** combinators for code generated from inductive predicates ****)