src/HOL/Tools/primrec_package.ML
changeset 25566 33f740c5e022
parent 25562 f0fc8531c909
child 25570 fdfbbb92dadf
equal deleted inserted replaced
25565:33d30a53fae7 25566:33f740c5e022
    27 
    27 
    28 (* preprocessing of equations *)
    28 (* preprocessing of equations *)
    29 
    29 
    30 fun process_eqn is_fixed spec rec_fns =
    30 fun process_eqn is_fixed spec rec_fns =
    31   let
    31   let
    32     val vars = strip_qnt_vars "all" spec;
    32     val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    33     val body = strip_qnt_body "all" spec;
    33     val body = strip_qnt_body "all" spec;
    34     (*FIXME not necessarily correct*)
    34     val (vs', _) = Name.variants vs (Name.make_context (fold_aterms
    35     val eqn = curry subst_bounds (map Free (rev vars)) body;
    35       (fn Free (v, _) => insert (op =) v | _ => I) body []));
       
    36     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    36     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    37     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    37       handle TERM _ => primrec_error "not a proper equation";
    38       handle TERM _ => primrec_error "not a proper equation";
    38     val (recfun, args) = strip_comb lhs;
    39     val (recfun, args) = strip_comb lhs;
    39     val fname = case recfun of Free (v, _) => if is_fixed v then v
    40     val fname = case recfun of Free (v, _) => if is_fixed v then v
    40           else primrec_error "illegal head of function equation"
    41           else primrec_error "illegal head of function equation"
   225   let
   226   let
   226     val ((fixes, spec), _) = prep_spec
   227     val ((fixes, spec), _) = prep_spec
   227       raw_fixes (map (single o apsnd single) raw_spec) ctxt
   228       raw_fixes (map (single o apsnd single) raw_spec) ctxt
   228   in (fixes, map (apsnd the_single) spec) end;
   229   in (fixes, map (apsnd the_single) spec) end;
   229 
   230 
   230 fun prove_spec ctxt rec_rewrites defs =
   231 fun prove_spec ctxt names rec_rewrites defs =
   231   let
   232   let
   232     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   233     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   233     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   234     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   234     val _ = message "Proving equations for primrec function";
   235     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
   235   in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   236   in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   236 
   237 
   237 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   238 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   238   let
   239   let
   239     val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   240     val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   248       if null dts then primrec_error
   249       if null dts then primrec_error
   249         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   250         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   250       else snd (hd dts);
   251       else snd (hd dts);
   251     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   252     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   252     val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   253     val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   253     val nameTs1 = map snd fnames;
   254     val names1 = map snd fnames;
   254     val nameTs2 = map fst eqns;
   255     val names2 = map fst eqns;
   255     val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
   256     val _ = if gen_eq_set (op =) (names1, names2) then ()
   256       else primrec_error ("functions " ^ commas_quote nameTs2 ^
   257       else primrec_error ("functions " ^ commas_quote names2 ^
   257         "\nare not mutually recursive");
   258         "\nare not mutually recursive");
   258     val qualify = NameSpace.qualified
   259     val qualify = NameSpace.qualified
   259       (space_implode "_" (map (Sign.base_name o #1) defs));
   260       (space_implode "_" (map (Sign.base_name o #1) defs));
   260     val simp_atts = map (Attrib.internal o K) [Simplifier.simp_add]
   261     val simp_atts = map (Attrib.internal o K) [Simplifier.simp_add]
   261       @ [Code.add_default_func_attr (*FIXME*)] (*RecfunCodegen.add NONE*);
   262       @ [Code.add_default_func_attr (*FIXME*)] (*RecfunCodegen.add NONE*);
   262   in
   263   in
   263     lthy
   264     lthy
   264     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   265     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   265     |-> (fn defs => `(fn ctxt => prove_spec ctxt rec_rewrites defs spec))
   266     |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec))
   266     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   267     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   267     |-> (fn simps' => LocalTheory.note Thm.theoremK
   268     |-> (fn simps' => LocalTheory.note Thm.theoremK
   268           ((qualify "simps", simp_atts), maps snd simps'))
   269           ((qualify "simps", simp_atts), maps snd simps'))
   269     ||>> LocalTheory.note Thm.theoremK
   270     ||>> LocalTheory.note Thm.theoremK
   270           ((qualify "induct", []), [prepare_induct (#2 (hd dts)) eqns])
   271           ((qualify "induct", []), [prepare_induct (#2 (hd dts)) eqns])