src/HOL/Tools/datatype_package.ML
changeset 21045 66d6d1b0ddfa
parent 20820 58693343905f
child 21127 c8e862897d13
--- a/src/HOL/Tools/datatype_package.ML	Mon Oct 16 14:07:20 2006 +0200
+++ b/src/HOL/Tools/datatype_package.ML	Mon Oct 16 14:07:21 2006 +0200
@@ -167,8 +167,8 @@
   usually signals a mistake.  But calls the tactic either way!*)
 fun occs_in_prems tacf vars =
   SUBGOAL (fn (Bi, i) =>
-           (if  exists (fn Free (a, _) => a mem vars)
-                      (foldr add_term_frees [] (#2 (strip_context Bi)))
+           (if exists (fn (a, _) => member (op =) vars a)
+                      (fold Term.add_frees (#2 (strip_context Bi)) [])
              then warning "Induction variable occurs also among premises!"
              else ();
             tacf i));
@@ -181,9 +181,9 @@
 fun prep_var (Var (ixn, _), SOME x) = SOME (ixn, x)
   | prep_var _ = NONE;
 
-fun prep_inst (concl, xs) =	(*exception Library.UnequalLengths *)
+fun prep_inst (concl, xs) = (*exception Library.UnequalLengths*)
   let val vs = InductAttrib.vars_of concl
-  in List.mapPartial prep_var (Library.drop (length vs - length xs, vs) ~~ xs) end;
+  in map_filter prep_var (Library.drop (length vs - length xs, vs) ~~ xs) end;
 
 in
 
@@ -192,21 +192,21 @@
   let
     val (rule, rule_name) =
       case opt_rule of
-	  SOME r => (r, "Induction rule")
-	| NONE =>
-	    let val tn = find_tname (hd (List.mapPartial I (List.concat varss))) Bi
+          SOME r => (r, "Induction rule")
+        | NONE =>
+            let val tn = find_tname (hd (map_filter I (flat varss))) Bi
                 val {sign, ...} = Thm.rep_thm state
-	    in (#induction (the_datatype sign tn), "Induction rule for type " ^ tn) 
-	    end
+            in (#induction (the_datatype sign tn), "Induction rule for type " ^ tn) 
+            end
     val concls = HOLogic.dest_concls (Thm.concl_of rule);
-    val insts = List.concat (map prep_inst (concls ~~ varss)) handle Library.UnequalLengths =>
+    val insts = maps prep_inst (concls ~~ varss) handle Library.UnequalLengths =>
       error (rule_name ^ " has different numbers of variables");
   in occs_in_prems (inst_tac insts rule) (map #2 insts) i end)
   i state;
 
 fun induct_tac s =
   gen_induct_tac Tactic.res_inst_tac'
-    (map (Library.single o SOME) (Syntax.read_idents s), NONE);
+    (map (single o SOME) (Syntax.read_idents s), NONE);
 
 fun induct_thm_tac th s =
   gen_induct_tac Tactic.res_inst_tac'
@@ -273,20 +273,20 @@
 local
 
 fun dt_recs (DtTFree _) = []
-  | dt_recs (DtType (_, dts)) = List.concat (map dt_recs dts)
+  | dt_recs (DtType (_, dts)) = maps dt_recs dts
   | dt_recs (DtRec i) = [i];
 
 fun dt_cases (descr: descr) (_, args, constrs) =
   let
-    fun the_bname i = Sign.base_name (#1 (valOf (AList.lookup (op =) descr i)));
-    val bnames = map the_bname (distinct (op =) (List.concat (map dt_recs args)));
+    fun the_bname i = Sign.base_name (#1 (the (AList.lookup (op =) descr i)));
+    val bnames = map the_bname (distinct (op =) (maps dt_recs args));
   in map (fn (c, _) => space_implode "_" (Sign.base_name c :: bnames)) constrs end;
 
 
 fun induct_cases descr =
-  DatatypeProp.indexify_names (List.concat (map (dt_cases descr) (map #2 descr)));
+  DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr));
 
-fun exhaust_cases descr i = dt_cases descr (valOf (AList.lookup (op =) descr i));
+fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
 
 in
 
@@ -294,7 +294,7 @@
 
 fun mk_case_names_exhausts descr new =
   map (RuleCases.case_names o exhaust_cases descr o #1)
-    (List.filter (fn ((_, (name, _, _))) => name mem_string new) descr);
+    (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
 
 end;
 
@@ -306,11 +306,11 @@
 fun add_rules simps case_thms size_thms rec_thms inject distinct
                   weak_case_congs cong_att =
   (snd o PureThy.add_thmss [(("simps", simps), []),
-    (("", List.concat case_thms @ size_thms @ 
-          List.concat distinct  @ rec_thms), [Simplifier.simp_add]),
+    (("", flat case_thms @ size_thms @ 
+          flat distinct @ rec_thms), [Simplifier.simp_add]),
     (("", size_thms @ rec_thms),             [setmp CodegenData.strict_functyp false (RecfunCodegen.add NONE)]),
-    (("", List.concat inject),               [iff_add]),
-    (("", map name_notE (List.concat distinct)),  [Classical.safe_elim NONE]),
+    (("", flat inject),               [iff_add]),
+    (("", map name_notE (flat distinct)),  [Classical.safe_elim NONE]),
     (("", weak_case_congs),                  [cong_att])]);
 
 
@@ -327,7 +327,7 @@
       (("", nth inducts i), [PureThy.kind_internal, InductAttrib.induct_type ""]);
   in
     thy |> PureThy.add_thms
-      (List.concat (map named_rules infos) @
+      (maps named_rules infos @
         map unnamed_rule (length infos upto length inducts - 1)) |> snd
     |> PureThy.add_thmss [(("inducts", inducts), [])] |> snd
   end;
@@ -396,8 +396,6 @@
 
 (**** translation rules for case ****)
 
-fun find_first f = Library.find_first f;
-
 fun case_tr context [t, u] =
     let
       val thy = Context.theory_of context;
@@ -415,34 +413,34 @@
       val (cases', default) = (case split_last cases of
           (cases', (("dummy_pattern", []), t)) => (cases', SOME t)
         | _ => (cases, NONE))
-      fun abstr (Free (x, T), body) = Term.absfree (x, T, body)
-        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT, body) =
+      fun abstr (Free (x, T)) body = Term.absfree (x, T, body)
+        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT) body =
             Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
-        | abstr (Const ("Pair", _) $ x $ y, body) =
-            Syntax.const "split" $ abstr (x, abstr (y, body))
-        | abstr (t, _) = case_error "Illegal pattern" NONE [t];
+        | abstr (Const ("Pair", _) $ x $ y) body =
+            Syntax.const "split" $ (abstr x o abstr y) body
+        | abstr t _ = case_error "Illegal pattern" NONE [t];
     in case find_first (fn (_, {descr, index, ...}) =>
-      exists (equal cname o fst) (#3 (snd (List.nth (descr, index))))) tab of
+      exists (equal cname o fst) (#3 (snd (nth descr index)))) tab of
         NONE => case_error ("Not a datatype constructor: " ^ cname) NONE [u]
       | SOME (tname, {descr, sorts, case_name, index, ...}) =>
         let
           val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
             case_error "Illegal occurrence of '_' dummy pattern" (SOME tname) [u] else ();
-          val (_, (_, dts, constrs)) = List.nth (descr, index);
-          fun find_case (cases, (s, dt)) =
+          val (_, (_, dts, constrs)) = nth descr index;
+          fun find_case (s, dt) cases =
             (case find_first (equal s o fst o fst) cases' of
-               NONE => (cases, list_abs (map (rpair dummyT)
+               NONE => (list_abs (map (rpair dummyT)
                  (DatatypeProp.make_tnames (map (typ_of_dtyp descr sorts) dt)),
                  case default of
                    NONE => (warning ("No clause for constructor " ^ s ^
                      " in case expression"); Const ("undefined", dummyT))
-                 | SOME t => t))
+                 | SOME t => t), cases)
              | SOME (c as ((_, vs), t)) =>
                  if length dt <> length vs then
                     case_error ("Wrong number of arguments for constructor " ^ s)
                       (SOME tname) vs
-                 else (cases \ c, foldr abstr t vs))
-          val (cases'', fs) = foldl_map find_case (cases', constrs)
+                 else (fold_rev abstr vs t, remove (op =) c cases))
+          val (fs, cases'') = fold_map find_case constrs cases'
         in case (cases'', length constrs = length cases', default) of
             ([], true, SOME _) =>
               case_error "Extra '_' dummy pattern" (SOME tname) [u]
@@ -478,13 +476,12 @@
       (Sign.extern_const (Context.theory_of context) cname,
        strip_abs (length dts) t, is_dependent (length dts) t))
       (constrs ~~ fs);
-    fun count_cases (cs, (_, _, true)) = cs
-      | count_cases (cs, (cname, (_, body), false)) =
-          case AList.lookup (op = : term * term -> bool) cs body
-           of NONE => (body, [cname]) :: cs
-            | SOME cnames => AList.update (op =) (body, cnames @ [cname]) cs;
-    val cases' = sort (int_ord o Library.swap o pairself (length o snd))
-      (Library.foldl count_cases ([], cases));
+    fun count_cases (_, _, true) = I
+      | count_cases (cname, (_, body), false) =
+          AList.map_default (op = : term * term -> bool)
+            (body, []) (cons cname)
+    val cases' = sort (int_ord o swap o pairself (length o snd))
+      (fold_rev count_cases cases []);
     fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
       list_comb (Syntax.const cname, vs) $ body;
     fun is_undefined (Const ("undefined", _)) = true
@@ -503,13 +500,14 @@
            else if length cnames = length constrs then
              [hd cases, ("dummy_pattern", ([], default), false)]
            else
-             filter_out (fn (cname, _, _) => cname mem cnames) cases @
+             filter_out (fn (cname, _, _) => member (op =) cnames cname) cases @
              [("dummy_pattern", ([], default), false)]))
   end;
 
-fun make_case_tr' case_names descr = List.concat (map
-  (fn ((_, (_, _, constrs)), case_name) => map (rpair (case_tr' constrs))
-    (NameSpace.accesses' case_name)) (descr ~~ case_names));
+fun make_case_tr' case_names descr = maps
+  (fn ((_, (_, _, constrs)), case_name) =>
+    map (rpair (case_tr' constrs)) (NameSpace.accesses' case_name))
+      (descr ~~ case_names);
 
 val trfun_setup =
   Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], []);
@@ -599,9 +597,9 @@
 fun add_datatype_axm flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
     case_names_induct case_names_exhausts thy =
   let
-    val descr' = List.concat descr;
+    val descr' = flat descr;
     val recTs = get_rec_types descr' sorts;
-    val used = foldr add_typ_tfree_names [] recTs;
+    val used = map fst (fold Term.add_tfreesT recTs []);
     val newTs = Library.take (length (hd descr), recTs);
 
     val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists
@@ -635,46 +633,44 @@
 
     val case_names = map (fn s => (s ^ "_case")) new_type_names;
 
-    val thy2' = thy |>
+    val thy2' = thy
 
       (** new types **)
-
-      curry (Library.foldr (fn (((name, mx), tvs), thy') => thy' |>
-          TypedefPackage.add_typedecls [(name, tvs, mx)]))
-        (types_syntax ~~ tyvars) |>
-      add_path flat_names (space_implode "_" new_type_names) |>
+      |> fold (fn ((name, mx), tvs) => TypedefPackage.add_typedecls [(name, tvs, mx)])
+           (types_syntax ~~ tyvars)
+      |> add_path flat_names (space_implode "_" new_type_names)
 
       (** primrec combinators **)
 
-      specify_consts (map (fn ((name, T), T') =>
-        (name, reccomb_fn_Ts @ [T] ---> T', NoSyn)) (reccomb_names ~~ recTs ~~ rec_result_Ts)) |>
+      |> specify_consts (map (fn ((name, T), T') =>
+           (name, reccomb_fn_Ts @ [T] ---> T', NoSyn)) (reccomb_names ~~ recTs ~~ rec_result_Ts))
 
       (** case combinators **)
 
-      specify_consts (map (fn ((name, T), Ts) =>
-        (name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts));
+      |> specify_consts (map (fn ((name, T), Ts) =>
+           (name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts));
 
     val reccomb_names' = map (Sign.full_name thy2') reccomb_names;
     val case_names' = map (Sign.full_name thy2') case_names;
 
-    val thy2 = thy2' |>
+    val thy2 = thy2'
 
       (** size functions **)
 
-      (if no_size then I else specify_consts (map (fn (s, T) =>
+      |> (if no_size then I else specify_consts (map (fn (s, T) =>
         (Sign.base_name s, T --> HOLogic.natT, NoSyn))
-          (size_names ~~ Library.drop (length (hd descr), recTs)))) |>
+          (size_names ~~ Library.drop (length (hd descr), recTs))))
 
       (** constructors **)
 
-      parent_path flat_names |>
-      curry (Library.foldr (fn (((((_, (_, _, constrs)), T), tname),
-        constr_syntax'), thy') => thy' |>
-          add_path flat_names tname |>
+      |> parent_path flat_names
+      |> fold (fn ((((_, (_, _, constrs)), T), tname),
+        constr_syntax') =>
+          add_path flat_names tname #>
             specify_consts (map (fn ((_, cargs), (cname, mx)) =>
               (cname, map (typ_of_dtyp descr' sorts) cargs ---> T, mx))
-                (constrs ~~ constr_syntax')) |>
-          parent_path flat_names))
+                (constrs ~~ constr_syntax')) #>
+          parent_path flat_names)
             (hd descr ~~ newTs ~~ new_type_names ~~ constr_syntax);
 
     (**** introduction of axioms ****)
@@ -717,7 +713,7 @@
         exhaustion ~~ replicate (length (hd descr)) QuickAndDirty ~~ inject ~~
           nchotomys ~~ case_congs ~~ weak_case_congs);
 
-    val simps = List.concat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
+    val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
     val split_thms = split ~~ split_asm;
 
     val thy12 =
@@ -775,11 +771,11 @@
     val (size_thms, thy11) = DatatypeAbsProofs.prove_size_thms flat_names new_type_names
       descr sorts reccomb_names rec_thms thy10;
 
-    val dt_infos = map (make_dt_info (List.concat descr) sorts induct reccomb_names rec_thms)
+    val dt_infos = map (make_dt_info (flat descr) sorts induct reccomb_names rec_thms)
       ((0 upto length (hd descr) - 1) ~~ (hd descr) ~~ case_names ~~ case_thms ~~
         casedist_thms ~~ simproc_dists ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
 
-    val simps = List.concat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
+    val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
 
     val thy12 =
       thy11
@@ -852,7 +848,7 @@
     val (case_names_induct, case_names_exhausts) = case RuleCases.get induction
      of (("1", _) :: _, _) => (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames))
       | (cases, _) => (RuleCases.case_names (map fst cases),
-          replicate (length ((List.filter (fn ((_, (name, _, _))) => member (op =)
+          replicate (length ((filter (fn ((_, (name, _, _))) => member (op =)
             (map #1 dtnames) name) descr)))
             (RuleCases.case_names (map fst cases)));
     
@@ -892,7 +888,7 @@
       ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_thms ~~ casedist_thms ~~
         map FewConstrs distinct ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
 
-    val simps = List.concat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
+    val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
 
     val thy11 =
       thy10
@@ -949,12 +945,12 @@
     val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of
       [] => () | dups => error ("Duplicate datatypes: " ^ commas dups));
 
-    fun prep_dt_spec ((dts', constr_syntax, sorts, i), (tvs, tname, mx, constrs)) =
+    fun prep_dt_spec (tvs, tname, mx, constrs) (dts', constr_syntax, sorts, i) =
       let
-        fun prep_constr ((constrs, constr_syntax', sorts'), (cname, cargs, mx')) =
+        fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') =
           let
             val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs);
-            val _ = (case foldr add_typ_tfree_names [] cargs' \\ tvs of
+            val _ = (case fold (curry add_typ_tfree_names) cargs' [] \\ tvs of
                 [] => ()
               | vs => error ("Extra type variables on rhs: " ^ commas vs))
           in (constrs @ [((if flat_names then Sign.full_name tmp_thy else
@@ -966,7 +962,7 @@
               " of datatype " ^ tname);
 
         val (constrs', constr_syntax', sorts') =
-          Library.foldl prep_constr (([], [], sorts), constrs)
+          fold prep_constr constrs ([], [], sorts)
 
       in
         case duplicates (op =) (map fst constrs') of
@@ -978,7 +974,7 @@
              " in datatype " ^ tname)
       end;
 
-    val (dts', constr_syntax, sorts', i) = Library.foldl prep_dt_spec (([], [], [], 0), dts);
+    val (dts', constr_syntax, sorts', i) = fold prep_dt_spec dts ([], [], [], 0);
     val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts'));
     val dt_info = get_datatypes thy;
     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
@@ -986,7 +982,7 @@
       if err then error ("Nonemptiness check failed for datatype " ^ s)
       else raise exn;
 
-    val descr' = List.concat descr;
+    val descr' = flat descr;
     val case_names_induct = mk_case_names_induct descr';
     val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   in