src/HOL/Tools/Function/mutual.ML
changeset 63011 301e631666a0
parent 63004 f507e6fe1d77
child 63012 75f488e15479
--- a/src/HOL/Tools/Function/mutual.ML	Mon Apr 18 14:26:42 2016 +0200
+++ b/src/HOL/Tools/Function/mutual.ML	Mon Apr 18 14:30:24 2016 +0200
@@ -8,7 +8,7 @@
 sig
   val prepare_function_mutual : Function_Common.function_config
     -> binding (* defname *)
-    -> ((string * typ) * mixfix) list
+    -> ((binding * typ) * mixfix) list
     -> term list
     -> local_theory
     -> ((thm (* goalstate *)
@@ -27,7 +27,8 @@
 datatype mutual_part = MutualPart of
  {i : int,
   i' : int,
-  fvar : string * typ,
+  fname : binding,
+  fT : typ,
   cargTs: typ list,
   f_def: term,
 
@@ -52,8 +53,8 @@
   if n < 5 then fst (chop n ["P","Q","R","S"])
   else map (fn i => "P" ^ string_of_int i) (1 upto n)
 
-fun get_part fname =
-  the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
+fun get_part f =
+  the o find_first (fn (MutualPart {fname, ...}) => Binding.name_of fname = f)
 
 (* FIXME *)
 fun mk_prod_abs e (t1, t2) =
@@ -69,15 +70,13 @@
   let
     val num = length fs
     val fqgars = map (split_def ctxt (K true)) eqs
-    val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
-      |> AList.lookup (op =) #> the
+    fun arity_of fname =
+      the (get_first (fn (f, _, _, args, _) =>
+        if f = Binding.name_of fname then SOME (length args) else NONE) fqgars)
 
     fun curried_types (fname, fT) =
-      let
-        val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
-      in
-        (caTs, uaTs ---> body_type fT)
-      end
+      let val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
+      in (caTs, uaTs ---> body_type fT) end
 
     val (caTss, resultTs) = split_list (map curried_types fs)
     val argTs = map (foldr1 HOLogic.mk_prodT) caTss
@@ -94,7 +93,7 @@
       Variable.add_fixes_binding [Binding.map_name (suffix "_sum") defname] ctxt
     val fsum_var = (fsum_var_name, fsum_type)
 
-    fun define (fvar as (n, _)) caTs resultT i =
+    fun define (fname, fT) caTs resultT i =
       let
         val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
         val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
@@ -102,9 +101,11 @@
         val f_exp = Sum_Tree.mk_proj RST n' i' (Free fsum_var $ Sum_Tree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
         val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
 
-        val rew = (n, fold_rev lambda vars f_exp)
+        val rew = (Binding.name_of fname, fold_rev lambda vars f_exp)
       in
-        (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
+        (MutualPart
+          {i = i, i' = i', fname = fname, fT = fT, cargTs = caTs,
+            f_def = def, f = NONE, f_defthm = NONE}, rew)
       end
 
     val (parts, rews) = split_list (@{map 4} define fs caTss resultTs (1 upto num))
@@ -127,19 +128,15 @@
 
 fun define_projections fixes mutual fsum lthy =
   let
-    fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
+    fun def ((MutualPart {i=i, i'=i', fname, fT, cargTs, f_def, ...}), (_, mixfix)) lthy =
       let
-        val def_binding =
-          if Config.get lthy function_internals then (Binding.name (Thm.def_name fname), [])
-          else Attrib.empty_binding
+        val def_binding = Thm.make_def_binding (Config.get lthy function_internals) fname
         val ((f, (_, f_defthm)), lthy') =
           Local_Theory.define
-            ((Binding.name fname, mixfix),
-              (def_binding, Term.subst_bound (fsum, f_def))) lthy
+            ((fname, mixfix), ((def_binding, []), Term.subst_bound (fsum, f_def))) lthy
       in
-        (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
-           f=SOME f, f_defthm=SOME f_defthm },
-         lthy')
+        (MutualPart {i = i, i' = i', fname = fname, fT = fT, cargTs = cargTs,
+            f_def = f_def, f = SOME f, f_defthm = SOME f_defthm}, lthy')
       end
 
     val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
@@ -309,7 +306,7 @@
       analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
 
     val ((fsum, goalstate, cont), lthy') =
-      Function_Core.prepare_function config defname [((n, T), NoSyn)] qglrs lthy
+      Function_Core.prepare_function config defname [((Binding.name n, T), NoSyn)] qglrs lthy
 
     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'