src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53831 80423b9080cf
parent 53830 ed2eb7df2aac
child 53835 687116951569
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 24 17:54:09 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 24 18:07:09 2013 +0200
@@ -10,11 +10,11 @@
   val add_primrec_cmd: (binding * string option * mixfix) list ->
     (Attrib.binding * string) list -> local_theory -> local_theory;
   val add_primcorecursive_cmd: bool ->
-    (binding * string option * mixfix) list * (Attrib.binding * string) list -> Proof.context ->
-    Proof.state
+    (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
+    Proof.context -> Proof.state
   val add_primcorec_cmd: bool ->
-    (binding * string option * mixfix) list * (Attrib.binding * string) list -> local_theory ->
-    local_theory
+    (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
+    local_theory -> local_theory
 end;
 
 structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
@@ -479,7 +479,7 @@
     }, matchedsss')
   end;
 
-fun co_dissect_eqn_sel fun_names corec_specs eqn' eqn =
+fun co_dissect_eqn_sel fun_names corec_specs eqn' of_spec eqn =
   let
     val (lhs, rhs) = HOLogic.dest_eq eqn
       handle TERM _ =>
@@ -490,9 +490,11 @@
         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
     val corec_spec = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name)
       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
-    val (ctr_spec, sel) = #ctr_specs corec_spec
-      |> the o get_index (try (the o find_first (equal sel) o #sels))
-      |>> nth (#ctr_specs corec_spec);
+    val ctr_spec =
+      if is_some of_spec
+      then the (find_first (equal (the of_spec) o #ctr) (#ctr_specs corec_spec))
+      else #ctr_specs corec_spec |> filter (exists (equal sel) o #sels) |> the_single
+        handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn;
     val user_eqn = drop_All eqn';
   in
     Sel {
@@ -529,12 +531,12 @@
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
 
     val eqns_data_sel =
-      map (co_dissect_eqn_sel fun_names corec_specs eqn') sel_imp_rhss;
+      map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_imp_rhss;
   in
     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   end;
 
-fun co_dissect_eqn sequential fun_names corec_specs eqn' matchedsss =
+fun co_dissect_eqn sequential fun_names corec_specs eqn' of_spec matchedsss =
   let
     val eqn = drop_All eqn'
       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
@@ -557,7 +559,7 @@
       co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
       |>> single
     else if member (op =) sels head then
-      ([co_dissect_eqn_sel fun_names corec_specs eqn' imp_rhs], matchedsss)
+      ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec imp_rhs], matchedsss)
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
       co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
     else
@@ -693,7 +695,7 @@
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
     end;
 
-fun add_primcorec simple sequential fixes specs lthy =
+fun add_primcorec simple sequential fixes specs of_specs lthy =
   let
     val (bs, mxs) = map_split (apfst fst) fixes;
     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
@@ -708,9 +710,9 @@
     val fun_names = map Binding.name_of bs;
     val corec_specs = take actual_nn corec_specs'; (*###*)
 
-    val (eqns_data, _) =
-      fold_map (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) []
-      |>> flat;
+    val eqns_data =
+      fold_map2 (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) of_specs []
+      |> flat o fst;
 
     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
@@ -820,7 +822,7 @@
                 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
                 |> the o merge_options;
               val m = length prems;
-              val t = sel_eqns
+              val t = filter (equal ctr o #ctr) sel_eqns
                 |> fst o finds ((op =) o apsnd #sel) sels
                 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
                 |> curry list_comb ctr
@@ -899,11 +901,12 @@
       |> rpair NONE o SOME
   end;
 
-fun add_primcorec_ursive_cmd simple seq (raw_fixes, raw_specs) lthy =
+fun add_primcorec_ursive_cmd simple seq (raw_fixes, raw_specs') lthy =
   let
-    val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
+    val (raw_specs, of_specs) = split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy));
+    val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy;
   in
-    add_primcorec simple seq fixes specs lthy
+    add_primcorec simple seq fixes specs of_specs lthy
     handle ERROR str => primrec_error str
   end
   handle Primrec_Error (str, eqns) =>