src/HOL/Tools/primrec_package.ML
changeset 25559 f14305fb698c
parent 25557 ea6b11021e79
child 25562 f0fc8531c909
equal deleted inserted replaced
25558:5c317e8f5673 25559:f14305fb698c
    25 fun message s = if ! Toplevel.debug then () else writeln s;
    25 fun message s = if ! Toplevel.debug then () else writeln s;
    26 
    26 
    27 
    27 
    28 (* preprocessing of equations *)
    28 (* preprocessing of equations *)
    29 
    29 
    30 fun process_eqn is_fixed is_const spec rec_fns =
    30 fun process_eqn is_fixed spec rec_fns =
    31   let
    31   let
    32     val vars = strip_qnt_vars "all" spec;
    32     val vars = strip_qnt_vars "all" spec;
    33     val body = strip_qnt_body "all" spec;
    33     val body = strip_qnt_body "all" spec;
       
    34     (*FIXME not necessarily correct*)
    34     val eqn = curry subst_bounds (map Free (rev vars)) body;
    35     val eqn = curry subst_bounds (map Free (rev vars)) body;
    35     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    36     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    36       handle TERM _ => primrec_error "not a proper equation";
    37       handle TERM _ => primrec_error "not a proper equation";
    37     val (recfun, args) = strip_comb lhs;
    38     val (recfun, args) = strip_comb lhs;
    38     val fname = case recfun of Free (v, _) => if is_fixed v then v
    39     val fname = case recfun of Free (v, _) => if is_fixed v then v
    62       primrec_error "more than one non-variable in pattern"
    63       primrec_error "more than one non-variable in pattern"
    63     else
    64     else
    64      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    65      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    65       check_vars "extra variables on rhs: "
    66       check_vars "extra variables on rhs: "
    66         (map dest_Free (term_frees rhs) |> subtract (op =) lfrees
    67         (map dest_Free (term_frees rhs) |> subtract (op =) lfrees
    67           |> filter_out (is_const o fst) |> filter_out (is_fixed o fst));
    68           |> filter_out (is_fixed o fst));
    68       case AList.lookup (op =) rec_fns fname of
    69       case AList.lookup (op =) rec_fns fname of
    69         NONE =>
    70         NONE =>
    70           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
    71           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
    71       | SOME (_, rpos', eqns) =>
    72       | SOME (_, rpos', eqns) =>
    72           if AList.defined (op =) eqns cname then
    73           if AList.defined (op =) eqns cname then
   220 
   221 
   221 local
   222 local
   222 
   223 
   223 fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
   224 fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
   224   let
   225   let
   225     val ((fixes, spec), _) = prep_spec raw_fixes [(map o apsnd) single raw_spec] ctxt
   226     val ((fixes, spec), _) = prep_spec
   226   in (fixes, (map o apsnd) the_single spec) end;
   227       raw_fixes (map (single o apsnd single) raw_spec) ctxt
       
   228   in (fixes, map (apsnd the_single) spec) end;
   227 
   229 
   228 fun prove_spec ctxt rec_rewrites defs =
   230 fun prove_spec ctxt rec_rewrites defs =
   229   let
   231   let
   230     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   232     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   231     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   233     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   233   in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   235   in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   234 
   236 
   235 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   237 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   236   let
   238   let
   237     val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   239     val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   238     val eqns = fold_rev (process_eqn (member (op =) (map (fst o fst) fixes))
   240     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   239       (Variable.is_const lthy) o snd) spec [];
   241       orelse exists (fn ((w, _), _) => v = w) fixes) o snd) spec [];
   240     val tnames = distinct (op =) (map (#1 o snd) eqns);
   242     val tnames = distinct (op =) (map (#1 o snd) eqns);
   241     val dts = find_dts (DatatypePackage.get_datatypes
   243     val dts = find_dts (DatatypePackage.get_datatypes
   242       (ProofContext.theory_of lthy)) tnames tnames;
   244       (ProofContext.theory_of lthy)) tnames tnames;
   243     val main_fns = map (fn (tname, {index, ...}) =>
   245     val main_fns = map (fn (tname, {index, ...}) =>
   244       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   246       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   253     val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
   255     val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
   254       else primrec_error ("functions " ^ commas_quote nameTs2 ^
   256       else primrec_error ("functions " ^ commas_quote nameTs2 ^
   255         "\nare not mutually recursive");
   257         "\nare not mutually recursive");
   256     val qualify = NameSpace.qualified
   258     val qualify = NameSpace.qualified
   257       (space_implode "_" (map (Sign.base_name o #1) defs));
   259       (space_implode "_" (map (Sign.base_name o #1) defs));
   258     val simp_atts = [Attrib.internal (K Simplifier.simp_add),
   260     val simp_atts = map (Attrib.internal o K) [Simplifier.simp_add, RecfunCodegen.add NONE];
   259       Code.add_default_func_attr (*FIXME*)];
       
   260   in
   261   in
   261     lthy
   262     lthy
   262     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   263     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   263     |-> (fn defs => `(fn ctxt => prove_spec ctxt rec_rewrites defs spec))
   264     |-> (fn defs => `(fn ctxt => prove_spec ctxt rec_rewrites defs spec))
   264     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   265     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)