src/HOL/Tools/primrec_package.ML
changeset 31262 580510315dda
parent 31177 c39994cb152a
child 31269 dcbe1f9fe2cd
equal deleted inserted replaced
31261:900ebbc35e30 31262:580510315dda
    14   val add_primrec_global: (binding * typ option * mixfix) list ->
    14   val add_primrec_global: (binding * typ option * mixfix) list ->
    15     (Attrib.binding * term) list -> theory -> thm list * theory
    15     (Attrib.binding * term) list -> theory -> thm list * theory
    16   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    16   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    17     (binding * typ option * mixfix) list ->
    17     (binding * typ option * mixfix) list ->
    18     (Attrib.binding * term) list -> theory -> thm list * theory
    18     (Attrib.binding * term) list -> theory -> thm list * theory
       
    19   val add_primrec_simple: ((binding * typ) * mixfix) list -> (binding * term) list ->
       
    20     local_theory -> (string * thm list list) * local_theory
    19 end;
    21 end;
    20 
    22 
    21 structure PrimrecPackage : PRIMREC_PACKAGE =
    23 structure PrimrecPackage : PRIMREC_PACKAGE =
    22 struct
    24 struct
    23 
    25 
   209             if tnames' subset (map (#1 o snd) (#descr dt)) then
   211             if tnames' subset (map (#1 o snd) (#descr dt)) then
   210               (tname, dt)::(find_dts dt_info tnames' tnames)
   212               (tname, dt)::(find_dts dt_info tnames' tnames)
   211             else find_dts dt_info tnames' tnames);
   213             else find_dts dt_info tnames' tnames);
   212 
   214 
   213 
   215 
   214 (* primrec definition *)
   216 (* distill primitive definition(s) from primrec specification *)
   215 
   217 
   216 local
   218 fun distill lthy fixes eqs = 
   217 
   219   let
   218 fun prove_spec ctxt names rec_rewrites defs eqs =
       
   219   let
       
   220     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
       
   221     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
       
   222     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
       
   223   in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end;
       
   224 
       
   225 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
       
   226   let
       
   227     val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
       
   228     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   220     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   229       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec [];
   221       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   230     val tnames = distinct (op =) (map (#1 o snd) eqns);
   222     val tnames = distinct (op =) (map (#1 o snd) eqns);
   231     val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames;
   223     val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames;
   232     val main_fns = map (fn (tname, {index, ...}) =>
   224     val main_fns = map (fn (tname, {index, ...}) =>
   233       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   225       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   234     val {descr, rec_names, rec_rewrites, ...} =
   226     val {descr, rec_names, rec_rewrites, ...} =
   235       if null dts then primrec_error
   227       if null dts then primrec_error
   236         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   228         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   237       else snd (hd dts);
   229       else snd (hd dts);
   238     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   230     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   239     val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   231     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   240     val names1 = map snd fnames;
   232     val defs = map (make_def lthy fixes fs) raw_defs;
   241     val names2 = map fst eqns;
   233     val names = map snd fnames;
   242     val _ = if gen_eq_set (op =) (names1, names2) then ()
   234     val names_eqns = map fst eqns;
   243       else primrec_error ("functions " ^ commas_quote names2 ^
   235     val _ = if gen_eq_set (op =) (names, names_eqns) then ()
       
   236       else primrec_error ("functions " ^ commas_quote names_eqns ^
   244         "\nare not mutually recursive");
   237         "\nare not mutually recursive");
   245     val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs);
   238     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   246     val qualify = Binding.qualify false prefix;
   239     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   247     val spec' = (map o apfst)
   240     fun prove lthy defs =
   248       (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
   241       let
   249     val simp_atts = map (Attrib.internal o K)
   242         val rewrites = rec_rewrites' @ map (snd o snd) defs;
   250       [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
   243         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
       
   244         val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
       
   245       in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end;
       
   246   in ((prefix, (fs, defs)), prove) end
       
   247   handle PrimrecError (msg, some_eqn) =>
       
   248     error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
       
   249      of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
       
   250       | NONE => ""));
       
   251 
       
   252 
       
   253 (* primrec definition *)
       
   254 
       
   255 fun add_primrec_simple fixes spec lthy =
       
   256   let
       
   257     val ((prefix, (fs, defs)), prove) = distill lthy fixes (map snd spec);
       
   258   in
       
   259     lthy
       
   260     |> fold_map (LocalTheory.define Thm.definitionK) defs
       
   261     |-> (fn defs => `(fn lthy => (prefix, prove lthy defs)))
       
   262   end;
       
   263 
       
   264 local
       
   265 
       
   266 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
       
   267   let
       
   268     val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
       
   269     fun attr_bindings prefix = map (fn ((b, attrs), _) =>
       
   270       (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
       
   271     fun simp_attr_binding prefix = (Binding.qualify false prefix (Binding.name "simps"),
       
   272       map (Attrib.internal o K)
       
   273         [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]);
   251   in
   274   in
   252     lthy
   275     lthy
   253     |> set_group ? LocalTheory.set_group (serial_string ())
   276     |> set_group ? LocalTheory.set_group (serial_string ())
   254     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   277     |> add_primrec_simple fixes spec
   255     |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec'))
   278     |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK)
   256     |-> (fn simps => fold_map (LocalTheory.note Thm.generatedK) simps)
   279           (attr_bindings prefix ~~ simps)
   257     |-> (fn simps' => LocalTheory.note Thm.generatedK
   280     #-> (fn simps' => LocalTheory.note Thm.generatedK
   258           ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps'))
   281           (simp_attr_binding prefix, maps snd simps')))
   259     |>> snd
   282     |>> snd
   260   end handle PrimrecError (msg, some_eqn) =>
   283   end;
   261     error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
       
   262      of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
       
   263       | NONE => ""));
       
   264 
   284 
   265 in
   285 in
   266 
   286 
   267 val add_primrec = gen_primrec false Specification.check_spec;
   287 val add_primrec = gen_primrec false Specification.check_spec;
   268 val add_primrec_cmd = gen_primrec true Specification.read_spec;
   288 val add_primrec_cmd = gen_primrec true Specification.read_spec;