src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 55575 a5e33e18fb5c
parent 55574 4a940ebceef8
child 55576 315dd5920114
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Wed Feb 19 08:34:32 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Wed Feb 19 08:34:33 2014 +0100
@@ -19,12 +19,15 @@
      rec_thms: thm list};
 
   type lfp_rec_extension =
-    {is_new_datatype: Proof.context -> string -> bool,
+    {nested_simps: thm list,
+     is_new_datatype: Proof.context -> string -> bool,
      get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) ->
       (term * term list list) list list -> local_theory ->
       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
-     massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
-       typ list -> term -> term -> term -> term};
+     rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list -> term ->
+       term -> term -> term};
+
+  exception PRIMREC of string * term list;
 
   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
 
@@ -45,11 +48,13 @@
 struct
 
 open Ctr_Sugar
+open Ctr_Sugar_Util
 open Ctr_Sugar_General_Tactics
-open BNF_Util
-open BNF_FP_Util
 open BNF_FP_Rec_Sugar_Util
 
+val inductN = "induct"
+val simpsN = "simps"
+
 val nitpicksimp_attrs = @{attributes [nitpick_simp]};
 val simp_attrs = @{attributes [simp]};
 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
@@ -57,10 +62,6 @@
 exception OLD_PRIMREC of unit;
 exception PRIMREC of string * term list;
 
-fun primrec_error str = raise PRIMREC (str, []);
-fun primrec_error_eqn str eqn = raise PRIMREC (str, [eqn]);
-fun primrec_error_eqns str eqns = raise PRIMREC (str, eqns);
-
 datatype rec_call =
   No_Rec of int * typ |
   Mutual_Rec of (int * typ) * (int * typ) |
@@ -89,12 +90,13 @@
    rec_thms: thm list};
 
 type lfp_rec_extension =
-  {is_new_datatype: Proof.context -> string -> bool,
+  {nested_simps: thm list,
+   is_new_datatype: Proof.context -> string -> bool,
    get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) ->
     (term * term list list) list list -> local_theory ->
     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
-   massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
-     typ list -> term -> term -> term -> term};
+   rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list -> term ->
+     term -> term -> term};
 
 structure Data = Theory_Data
 (
@@ -106,6 +108,11 @@
 
 val register_lfp_rec_extension = Data.put o SOME;
 
+fun nested_simps ctxt =
+  (case Data.get (Proof_Context.theory_of ctxt) of
+    SOME {nested_simps, ...} => nested_simps
+  | NONE => []);
+
 fun is_new_datatype ctxt =
   (case Data.get (Proof_Context.theory_of ctxt) of
     SOME {is_new_datatype, ...} => is_new_datatype ctxt
@@ -116,9 +123,9 @@
     SOME {get_basic_lfp_sugars, ...} => get_basic_lfp_sugars bs arg_Ts get_indices callssss lthy
   | NONE => error "Not implemented yet");
 
-fun massage_nested_rec_call ctxt =
+fun rewrite_nested_rec_call ctxt =
   (case Data.get (Proof_Context.theory_of ctxt) of
-    SOME {massage_nested_rec_call, ...} => massage_nested_rec_call ctxt);
+    SOME {rewrite_nested_rec_call, ...} => rewrite_nested_rec_call ctxt);
 
 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy0 =
   let
@@ -193,9 +200,6 @@
 
 val undef_const = Const (@{const_name undefined}, dummyT);
 
-fun permute_args n t =
-  list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
-
 type eqn_data = {
   fun_name: string,
   rec_type: typ,
@@ -212,30 +216,30 @@
   let
     val eqn = drop_all eqn' |> HOLogic.dest_Trueprop
       handle TERM _ =>
-             primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
+             raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
     val (lhs, rhs) = HOLogic.dest_eq eqn
         handle TERM _ =>
-               primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
+               raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
     val (fun_name, args) = strip_comb lhs
       |>> (fn x => if is_Free x then fst (dest_Free x)
-          else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
+          else raise PRIMREC ("malformed function equation (does not start with free)", [eqn]));
     val (left_args, rest) = take_prefix is_Free args;
     val (nonfrees, right_args) = take_suffix is_Free rest;
     val num_nonfrees = length nonfrees;
     val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
-      primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
-      primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
+      raise PRIMREC ("constructor pattern missing in left-hand side", [eqn]) else
+      raise PRIMREC ("more than one non-variable argument in left-hand side", [eqn]);
     val _ = member (op =) fun_names fun_name orelse
-      primrec_error_eqn "malformed function equation (does not start with function name)" eqn
+      raise PRIMREC ("malformed function equation (does not start with function name)", [eqn]);
 
     val (ctr, ctr_args) = strip_comb (the_single nonfrees);
     val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
-      primrec_error_eqn "partially applied constructor in pattern" eqn;
+      raise PRIMREC ("partially applied constructor in pattern", [eqn]);
     val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
-      primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
-        "\" in left-hand side") eqn end;
+      raise PRIMREC ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
+        "\" in left-hand side", [eqn]) end;
     val _ = forall is_Free ctr_args orelse
-      primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
+      raise PRIMREC ("non-primitive pattern in left-hand side", [eqn]);
     val _ =
       let val b = fold_aterms (fn x as Free (v, _) =>
         if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
@@ -243,8 +247,8 @@
         not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
       in
         null b orelse
-        primrec_error_eqn ("extra variable(s) in right-hand side: " ^
-          commas (map (Syntax.string_of_term lthy) b)) eqn
+        raise PRIMREC ("extra variable(s) in right-hand side: " ^
+          commas (map (Syntax.string_of_term lthy) b), [eqn])
       end;
   in
     {fun_name = fun_name,
@@ -258,39 +262,11 @@
      user_eqn = eqn'}
   end;
 
-fun rewrite_map_arg get_ctr_pos rec_type res_type =
-  let
-    val pT = HOLogic.mk_prodT (rec_type, res_type);
-
-    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
-      | subst d (Abs (v, T, b)) =
-        Abs (v, if d = SOME ~1 then pT else T, subst (Option.map (Integer.add 1) d) b)
-      | subst d t =
-        let
-          val (u, vs) = strip_comb t;
-          val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
-        in
-          if ctr_pos >= 0 then
-            if d = SOME ~1 andalso length vs = ctr_pos then
-              list_comb (permute_args ctr_pos (snd_const pT), vs)
-            else if length vs > ctr_pos andalso is_some d
-                andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
-              list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
-            else
-              primrec_error_eqn ("recursive call not directly applied to constructor argument") t
-          else
-            list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
-        end
-  in
-    subst (SOME ~1)
-  end;
-
 fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
   let
     fun try_nested_rec bound_Ts y t =
       AList.lookup (op =) nested_calls y
-      |> Option.map (fn y' =>
-        massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
+      |> Option.map (fn y' => rewrite_nested_rec_call lthy has_call get_ctr_pos bound_Ts y y' t);
 
     fun subst bound_Ts (t as g' $ y) =
         let
@@ -307,7 +283,7 @@
                 (case try (get_ctr_pos o fst o dest_Free) g of
                   SOME ctr_pos =>
                   (length g_args >= ctr_pos orelse
-                   primrec_error_eqn "too few arguments in recursive call" t;
+                   raise PRIMREC ("too few arguments in recursive call", [t]);
                    (case AList.lookup (op =) mutual_calls y of
                      SOME y' => list_comb (y', g_args)
                    | NONE => subst_rec ()))
@@ -320,7 +296,7 @@
     fun subst' t =
       if has_call t then
         (* FIXME detect this case earlier? *)
-        primrec_error_eqn "recursive call not directly applied to constructor argument" t
+        raise PRIMREC ("recursive call not directly applied to constructor argument", [t])
       else
         try_nested_rec [] (head_of t) t |> the_default t
   in
@@ -378,9 +354,9 @@
       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
           ##> (fn x => null x orelse
-            primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
+            raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
     val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
-      primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
+      raise PRIMREC ("multiple equations for constructor", map #user_eqn x));
 
     val ctr_spec_eqn_data_list =
       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
@@ -391,8 +367,8 @@
       |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
     val ctr_poss = map (fn x =>
       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
-        primrec_error ("inconstant constructor pattern position for function " ^
-          quote (#fun_name (hd x)))
+        raise PRIMREC ("inconstant constructor pattern position for function " ^
+          quote (#fun_name (hd x)), [])
       else
         hd x |> #left_args |> length) funs_data;
   in
@@ -435,8 +411,7 @@
 fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
   unfold_thms_tac ctxt fun_defs THEN
   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
-  unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @
-    map_idents) THEN
+  unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
   HEADGOAL (rtac refl);
 
 fun prepare_primrec fixes specs lthy0 =
@@ -450,7 +425,7 @@
       |> partition_eq ((op =) o pairself #fun_name)
       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
       |> map (fn (x, y) => the_single y
-          handle List.Empty => primrec_error ("missing equations for function " ^ quote x));
+          handle List.Empty => raise PRIMREC ("missing equations for function " ^ quote x, []));
 
     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val arg_Ts = map (#rec_type o hd) funs_data;
@@ -466,7 +441,7 @@
     val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
     val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
         [] => ()
-      | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
+      | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
 
     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
       rec_specs_of bs arg_Ts res_Ts (get_free_indices fixes) callssss lthy0;
@@ -476,8 +451,8 @@
     val ctrs = maps (map #ctr o #ctr_specs) rec_specs;
     val _ =
       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
-        primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
-          " is not a constructor in left-hand side") user_eqn) eqns_data;
+        raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
+          " is not a constructor in left-hand side", [user_eqn])) eqns_data;
 
     val defs = build_defs lthy bs mxs funs_data rec_specs has_call;
 
@@ -508,8 +483,8 @@
 
     val notes =
       (if n2m then
-         map2 (fn name => fn thm =>
-           (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms)
+         map2 (fn name => fn thm => (name, inductN, [thm], [])) fun_names
+           (take actual_nn induct_thms)
        else
          [])
       |> map (fn (prefix, thmN, thms, attrs) =>
@@ -531,7 +506,7 @@
 fun add_primrec_simple fixes ts lthy =
   let
     val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy
-      handle ERROR str => primrec_error str;
+      handle ERROR str => raise PRIMREC (str, []);
   in
     lthy'
     |> fold_map Local_Theory.define defs
@@ -548,7 +523,7 @@
     lthy =
   let
     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
-    val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
+    val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
 
     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);