name derivations in 'primrec' for code extraction from proof terms
authorblanchet
Mon Feb 17 13:31:42 2014 +0100 (2014-02-17)
changeset 5553510194808430d
parent 55534 b18bdcbda41b
child 55536 56ebc4d4d008
name derivations in 'primrec' for code extraction from proof terms
src/HOL/Proofs/Extraction/Pigeonhole.thy
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML
     1.1 --- a/src/HOL/Proofs/Extraction/Pigeonhole.thy	Mon Feb 17 13:31:42 2014 +0100
     1.2 +++ b/src/HOL/Proofs/Extraction/Pigeonhole.thy	Mon Feb 17 13:31:42 2014 +0100
     1.3 @@ -253,4 +253,3 @@
     1.4  ML_val "timeit @{code test''}"
     1.5  
     1.6  end
     1.7 -
     2.1 --- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     2.2 +++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     2.3 @@ -1346,9 +1346,9 @@
     2.4              in
     2.5                (((map_thms, rel_inject_thms, rel_distinct_thms, set_thmss), ctr_sugar),
     2.6                 lthy
     2.7 -               |> Spec_Rules.add Spec_Rules.Equational (`(lhs_heads_of o hd) map_thms)
     2.8 -               |> Spec_Rules.add Spec_Rules.Equational (`(lhs_heads_of o hd) rel_eq_thms)
     2.9 -               |> Spec_Rules.add Spec_Rules.Equational (`(lhs_heads_of o hd) set_thms)
    2.10 +               |> Spec_Rules.add Spec_Rules.Equational (`(single o lhs_head_of o hd) map_thms)
    2.11 +               |> Spec_Rules.add Spec_Rules.Equational (`(single o lhs_head_of o hd) rel_eq_thms)
    2.12 +               |> Spec_Rules.add Spec_Rules.Equational (`(single o lhs_head_of o hd) set_thms)
    2.13                 |> Local_Theory.notes (anonymous_notes @ notes)
    2.14                 |> snd)
    2.15              end;
     3.1 --- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     3.2 +++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     3.3 @@ -466,13 +466,13 @@
     3.4      map_idents) THEN
     3.5    HEADGOAL (rtac refl);
     3.6  
     3.7 -fun prepare_primrec fixes specs lthy =
     3.8 +fun prepare_primrec fixes specs lthy0 =
     3.9    let
    3.10 -    val thy = Proof_Context.theory_of lthy;
    3.11 +    val thy = Proof_Context.theory_of lthy0;
    3.12  
    3.13      val (bs, mxs) = map_split (apfst fst) fixes;
    3.14      val fun_names = map Binding.name_of bs;
    3.15 -    val eqns_data = map (dissect_eqn lthy fun_names) specs;
    3.16 +    val eqns_data = map (dissect_eqn lthy0 fun_names) specs;
    3.17      val funs_data = eqns_data
    3.18        |> partition_eq ((op =) o pairself #fun_name)
    3.19        |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
    3.20 @@ -488,7 +488,7 @@
    3.21        |> map (maps (map_filter (find_rec_calls has_call)));
    3.22  
    3.23      fun is_only_old_datatype (Type (s, _)) =
    3.24 -        is_none (fp_sugar_of lthy s) andalso is_some (Datatype_Data.get_info thy s)
    3.25 +        is_none (fp_sugar_of lthy0 s) andalso is_some (Datatype_Data.get_info thy s)
    3.26        | is_only_old_datatype _ = false;
    3.27  
    3.28      val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
    3.29 @@ -496,35 +496,41 @@
    3.30          [] => ()
    3.31        | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
    3.32  
    3.33 -    val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
    3.34 -      rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
    3.35 +    val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
    3.36 +      rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy0;
    3.37  
    3.38      val actual_nn = length funs_data;
    3.39  
    3.40      val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
    3.41        map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
    3.42 -        primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
    3.43 +        primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
    3.44            " is not a constructor in left-hand side") user_eqn) eqns_data end;
    3.45  
    3.46 -    val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
    3.47 +    val defs = build_defs lthy bs mxs funs_data rec_specs has_call;
    3.48  
    3.49 -    fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
    3.50 +    fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
    3.51          (fun_data : eqn_data list) =
    3.52        let
    3.53 +        val js =
    3.54 +          find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
    3.55 +            fun_data eqns_data;
    3.56 +
    3.57          val def_thms = map (snd o snd) def_thms';
    3.58 -        val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
    3.59 +        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
    3.60            |> fst
    3.61            |> map_filter (try (fn (x, [y]) =>
    3.62 -            (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
    3.63 -          |> map (fn (user_eqn, num_extra_args, rec_thm) =>
    3.64 -            mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
    3.65 -            |> K |> Goal.prove_sorry lthy [] [] user_eqn
    3.66 -            |> Thm.close_derivation);
    3.67 -        val poss =
    3.68 -          find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
    3.69 -            fun_data eqns_data;
    3.70 +            (#fun_name x, #user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
    3.71 +          |> map2 (fn j => fn (fun_name, user_eqn, num_extra_args, rec_thm) =>
    3.72 +              mk_primrec_tac lthy' num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
    3.73 +              |> K |> Goal.prove_sorry lthy' [] [] user_eqn
    3.74 +              (* for code extraction from proof terms: *)
    3.75 +              |> singleton (Proof_Context.export lthy' lthy)
    3.76 +              |> Thm.name_derivation (Sign.full_name thy (Binding.name fun_name) ^
    3.77 +                Long_Name.separator ^ simpsN ^
    3.78 +                (if js = [0] then "" else "_" ^ string_of_int (j + 1))))
    3.79 +            js;
    3.80        in
    3.81 -        (poss, simp_thmss)
    3.82 +        (js, simp_thms)
    3.83        end;
    3.84  
    3.85      val notes =
    3.86 @@ -546,7 +552,7 @@
    3.87      (((fun_names, defs),
    3.88        fn lthy => fn defs =>
    3.89          split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
    3.90 -      lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
    3.91 +      lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
    3.92    end;
    3.93  
    3.94  fun add_primrec_simple fixes ts lthy =
    3.95 @@ -574,9 +580,9 @@
    3.96      val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
    3.97  
    3.98      val mk_notes =
    3.99 -      flat ooo map3 (fn poss => fn prefix => fn thms =>
   3.100 +      flat ooo map3 (fn js => fn prefix => fn thms =>
   3.101          let
   3.102 -          val (bs, attrss) = map_split (fst o nth specs) poss;
   3.103 +          val (bs, attrss) = map_split (fst o nth specs) js;
   3.104            val notes =
   3.105              map3 (fn b => fn attrs => fn thm =>
   3.106                  ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs),
   3.107 @@ -588,9 +594,9 @@
   3.108    in
   3.109      lthy
   3.110      |> add_primrec_simple fixes (map snd specs)
   3.111 -    |-> (fn (names, (ts, (posss, simpss))) =>
   3.112 +    |-> (fn (names, (ts, (jss, simpss))) =>
   3.113        Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   3.114 -      #> Local_Theory.notes (mk_notes posss names simpss)
   3.115 +      #> Local_Theory.notes (mk_notes jss names simpss)
   3.116        #>> pair ts o map snd)
   3.117    end
   3.118    handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   3.119 @@ -598,19 +604,15 @@
   3.120  val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec;
   3.121  val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   3.122  
   3.123 -fun add_primrec_global fixes specs thy =
   3.124 -  let
   3.125 -    val lthy = Named_Target.theory_init thy;
   3.126 -    val ((ts, simpss), lthy') = add_primrec fixes specs lthy;
   3.127 -    val simpss' = burrow (Proof_Context.export lthy' lthy) simpss;
   3.128 -  in ((ts, simpss'), Local_Theory.exit_global lthy') end;
   3.129 +fun add_primrec_global fixes specs =
   3.130 +  Named_Target.theory_init
   3.131 +  #> add_primrec fixes specs
   3.132 +  ##> Local_Theory.exit_global;
   3.133  
   3.134 -fun add_primrec_overloaded ops fixes specs thy =
   3.135 -  let
   3.136 -    val lthy = Overloading.overloading ops thy;
   3.137 -    val ((ts, simpss), lthy') = add_primrec fixes specs lthy;
   3.138 -    val simpss' = burrow (Proof_Context.export lthy' lthy) simpss;
   3.139 -  in ((ts, simpss'), Local_Theory.exit_global lthy') end;
   3.140 +fun add_primrec_overloaded ops fixes specs =
   3.141 +  Overloading.overloading ops
   3.142 +  #> add_primrec fixes specs
   3.143 +  ##> Local_Theory.exit_global;
   3.144  
   3.145  val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   3.146    "define primitive recursive functions"
     4.1 --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     4.2 +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Mon Feb 17 13:31:42 2014 +0100
     4.3 @@ -942,7 +942,7 @@
     4.4           lthy
     4.5           |> Spec_Rules.add Spec_Rules.Equational ([casex], case_thms)
     4.6           |> fold (Spec_Rules.add Spec_Rules.Equational)
     4.7 -           (AList.group (eq_list (op aconv)) (map (`lhs_heads_of) all_sel_thms))
     4.8 +           (AList.group (eq_list (op aconv)) (map (`(single o lhs_head_of)) all_sel_thms))
     4.9           |> fold (Spec_Rules.add Spec_Rules.Equational)
    4.10             (filter_out (null o snd) (map single discs ~~ nontriv_disc_eq_thmss))
    4.11           |> Local_Theory.declaration {syntax = false, pervasive = true}
     5.1 --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Mon Feb 17 13:31:42 2014 +0100
     5.2 +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Mon Feb 17 13:31:42 2014 +0100
     5.3 @@ -40,7 +40,7 @@
     5.4    val typ_subst_nonatomic: (typ * typ) list -> typ -> typ
     5.5    val subst_nonatomic_types: (typ * typ) list -> term -> term
     5.6  
     5.7 -  val lhs_heads_of : thm -> term list
     5.8 +  val lhs_head_of : thm -> term
     5.9  
    5.10    val mk_predT: typ list -> typ
    5.11    val mk_pred1T: typ -> typ
    5.12 @@ -182,8 +182,7 @@
    5.13  fun subst_nonatomic_types [] = I
    5.14    | subst_nonatomic_types inst = map_types (typ_subst_nonatomic inst);
    5.15  
    5.16 -fun lhs_heads_of thm =
    5.17 -  [Term.head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of thm))))];
    5.18 +fun lhs_head_of thm = Term.head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of thm))));
    5.19  
    5.20  fun mk_predT Ts = Ts ---> HOLogic.boolT;
    5.21  fun mk_pred1T T = mk_predT [T];