merged
authorbulwahn
Wed, 07 Sep 2011 14:58:40 +0200
changeset 44797 e0da66339e47
parent 44796 7f1f164696a4 (diff)
parent 44787 3c0741556e19 (current diff)
child 44798 9900c0069ae6
child 44821 a92f65e174cf
merged
--- a/src/HOL/Imperative_HOL/Heap_Monad.thy	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/HOL/Imperative_HOL/Heap_Monad.thy	Wed Sep 07 14:58:40 2011 +0200
@@ -628,7 +628,9 @@
     val dummy_case_term = IVar NONE;
     (*assumption: dummy values are not relevant for serialization*)
     val (unitt, unitT) = case lookup_const naming @{const_name Unity}
-     of SOME unit' => (IConst (unit', (([], []), [])), the (lookup_tyco naming @{type_name unit}) `%% [])
+     of SOME unit' =>
+        let val unitT = the (lookup_tyco naming @{type_name unit}) `%% []
+        in (IConst (unit', ((([], []), ([], unitT)), false)), unitT) end
       | NONE => error ("Must include " ^ @{const_name Unity} ^ " in generated constants.");
     fun dest_abs ((v, ty) `|=> t, _) = ((v, ty), t)
       | dest_abs (t, ty) =
@@ -645,13 +647,13 @@
         val ((v, ty), t) = dest_abs (t2, ty2);
       in ICase (((force t1, ty), [(IVar v, tr_bind' t)]), dummy_case_term) end
     and tr_bind' t = case unfold_app t
-     of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if is_bind c
+     of (IConst (c, ((_, (ty1 :: ty2 :: _, _)), _)), [x1, x2]) => if is_bind c
           then tr_bind'' [(x1, ty1), (x2, ty2)]
           else force t
       | _ => force t;
     fun imp_monad_bind'' ts = (SOME dummy_name, unitT) `|=> ICase (((IVar (SOME dummy_name), unitT),
       [(unitt, tr_bind'' ts)]), dummy_case_term)
-    fun imp_monad_bind' (const as (c, (_, tys))) ts = if is_bind c then case (ts, tys)
+    fun imp_monad_bind' (const as (c, ((_, (tys, _)), _))) ts = if is_bind c then case (ts, tys)
        of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)]
         | ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)] `$ t3
         | (ts, _) => imp_monad_bind (eta_expand 2 (const, ts))
--- a/src/Tools/Code/code_haskell.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/Code/code_haskell.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -25,7 +25,7 @@
 (** Haskell serializer **)
 
 fun print_haskell_stmt labelled_name class_syntax tyco_syntax const_syntax
-    reserved deresolve contr_classparam_typs deriving_show =
+    reserved deresolve deriving_show =
   let
     fun class_name class = case class_syntax class
      of NONE => deresolve class
@@ -75,20 +75,14 @@
                 then print_case tyvars some_thm vars fxy cases
                 else print_app tyvars some_thm vars fxy c_ts
             | NONE => print_case tyvars some_thm vars fxy cases)
-    and print_app_expr tyvars some_thm vars ((c, (_, function_typs)), ts) = case contr_classparam_typs c
-     of [] => (str o deresolve) c :: map (print_term tyvars some_thm vars BR) ts
-      | fingerprint => let
-          val ts_fingerprint = ts ~~ take (length ts) fingerprint;
-          val needs_annotation = forall (fn (_, NONE) => true | (t, SOME _) =>
-            (not o Code_Thingol.locally_monomorphic) t) ts_fingerprint;
-          fun print_term_anno (t, NONE) _ = print_term tyvars some_thm vars BR t
-            | print_term_anno (t, SOME _) ty =
-                brackets [print_term tyvars some_thm vars NOBR t, str "::", print_typ tyvars NOBR ty];
-        in
-          if needs_annotation then
-            (str o deresolve) c :: map2 print_term_anno ts_fingerprint (take (length ts) function_typs)
-          else (str o deresolve) c :: map (print_term tyvars some_thm vars BR) ts
-        end
+    and print_app_expr tyvars some_thm vars ((c, ((_, (function_typs, body_typ)), annotate)), ts) =
+      let
+        val ty = Library.foldr (fn (ty1, ty2) => Code_Thingol.fun_tyco `%% [ty1, ty2]) (function_typs, body_typ)
+        fun put_annotation c = brackets [c, str "::", print_typ tyvars NOBR ty]
+      in 
+        ((if annotate then put_annotation else I)
+          ((str o deresolve) c)) :: map (print_term tyvars some_thm vars BR) ts
+      end
     and print_app tyvars = gen_print_app (print_app_expr tyvars) (print_term tyvars) const_syntax
     and print_bind tyvars some_thm fxy p = gen_print_bind (print_term tyvars) some_thm fxy p
     and print_case tyvars some_thm vars fxy (cases as ((_, [_]), _)) =
@@ -230,14 +224,14 @@
                     ]
                 | SOME k =>
                     let
-                      val (c, (_, tys)) = const;
+                      val (c, ((_, tys), _)) = const; (* FIXME: pass around the need annotation flag here? *)
                       val (vs, rhs) = (apfst o map) fst
                         (Code_Thingol.unfold_abs (Code_Thingol.eta_expand k (const, [])));
                       val s = if (is_some o const_syntax) c
                         then NONE else (SOME o Long_Name.base_name o deresolve) c;
                       val vars = reserved
                         |> intro_vars (map_filter I (s :: vs));
-                      val lhs = IConst (classparam, (([], []), tys)) `$$ map IVar vs;
+                      val lhs = IConst (classparam, ((([], []), tys), false (* FIXME *))) `$$ map IVar vs;
                         (*dictionaries are not relevant at this late stage*)
                     in
                       semicolon [
@@ -304,7 +298,6 @@
       labelled_name module_alias module_prefix (Name.make_context reserved) program;
 
     (* print statements *)
-    val contr_classparam_typs = Code_Thingol.contr_classparam_typs program;
     fun deriving_show tyco =
       let
         fun deriv _ "fun" = false
@@ -320,7 +313,7 @@
       in deriv [] tyco end;
     fun print_stmt deresolve = print_haskell_stmt labelled_name
       class_syntax tyco_syntax const_syntax (make_vars reserved)
-      deresolve contr_classparam_typs
+      deresolve
       (if string_classes then deriving_show else K false);
 
     (* print modules *)
--- a/src/Tools/Code/code_ml.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/Code/code_ml.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -117,7 +117,7 @@
                 then print_case is_pseudo_fun some_thm vars fxy cases
                 else print_app is_pseudo_fun some_thm vars fxy c_ts
             | NONE => print_case is_pseudo_fun some_thm vars fxy cases)
-    and print_app_expr is_pseudo_fun some_thm vars (app as ((c, ((_, iss), function_typs)), ts)) =
+    and print_app_expr is_pseudo_fun some_thm vars (app as ((c, (((_, iss), (function_typs, _)), _)), ts)) =
       if is_cons c then
         let val k = length function_typs in
           if k < 2 orelse length ts = k
@@ -417,7 +417,7 @@
                 then print_case is_pseudo_fun some_thm vars fxy cases
                 else print_app is_pseudo_fun some_thm vars fxy c_ts
             | NONE => print_case is_pseudo_fun some_thm vars fxy cases)
-    and print_app_expr is_pseudo_fun some_thm vars (app as ((c, ((_, iss), tys)), ts)) =
+    and print_app_expr is_pseudo_fun some_thm vars (app as ((c, (((_, iss), (tys, _)), _)), ts)) =
       if is_cons c then
         let val k = length tys in
           if length ts = k
--- a/src/Tools/Code/code_printer.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/Code/code_printer.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -315,7 +315,7 @@
       |-> (fn cs' => pair (Complex_const_syntax (n, f literals cs')));
 
 fun gen_print_app print_app_expr print_term const_syntax some_thm vars fxy
-    (app as ((c, (_, function_typs)), ts)) =
+    (app as ((c, ((_, (function_typs, _)), _)), ts)) =
   case const_syntax c of
     NONE => brackify fxy (print_app_expr some_thm vars app)
   | SOME (Plain_const_syntax (_, s)) =>
--- a/src/Tools/Code/code_scala.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/Code/code_scala.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -72,7 +72,7 @@
                 else print_app tyvars is_pat some_thm vars fxy c_ts
             | NONE => print_case tyvars some_thm vars fxy cases)
     and print_app tyvars is_pat some_thm vars fxy
-        (app as ((c, ((arg_typs, _), function_typs)), ts)) =
+        (app as ((c, (((arg_typs, _), (function_typs, _)), _)), ts)) =
       let
         val k = length ts;
         val arg_typs' = if is_pat orelse
@@ -265,7 +265,7 @@
           let
             val tyvars = intro_tyvars vs reserved;
             val classtyp = (class, tyco `%% map (ITyVar o fst) vs);
-            fun print_classparam_instance ((classparam, const as (_, (_, tys))), (thm, _)) =
+            fun print_classparam_instance ((classparam, const as (_, ((_, (tys, _)), _))), (thm, _)) =
               let
                 val aux_tys = Name.invent_names (snd reserved) "a" tys;
                 val auxs = map fst aux_tys;
--- a/src/Tools/Code/code_thingol.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/Code/code_thingol.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -22,7 +22,7 @@
   datatype itype =
       `%% of string * itype list
     | ITyVar of vname;
-  type const = string * ((itype list * dict list list) * itype list)
+  type const = string * (((itype list * dict list list) * (itype list * itype)) * bool)
     (* f [T1..Tn] {dicts} (_::S1) .. (_..Sm) =^= (f, (([T1..Tn], dicts), [S1..Sm]) *)
   datatype iterm =
       IConst of const
@@ -55,7 +55,6 @@
   val is_IAbs: iterm -> bool
   val eta_expand: int -> const * iterm list -> iterm
   val contains_dict_var: iterm -> bool
-  val locally_monomorphic: iterm -> bool
   val add_constnames: iterm -> string list -> string list
   val add_tyconames: iterm -> string list -> string list
   val fold_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a
@@ -88,7 +87,6 @@
   val map_terms_stmt: (iterm -> iterm) -> stmt -> stmt
   val is_cons: program -> string -> bool
   val is_case: stmt -> bool
-  val contr_classparam_typs: program -> string -> itype option list
   val labelled_name: theory -> program -> string -> string
   val group_stmts: theory -> program
     -> ((string * stmt) list * (string * stmt) list
@@ -145,7 +143,8 @@
     `%% of string * itype list
   | ITyVar of vname;
 
-type const = string * ((itype list * dict list list) * itype list (*types of arguments*))
+type const = string * (((itype list * dict list list) *
+  (itype list (*types of arguments*) * itype (*body type*))) * bool (*requires type annotation*))
 
 datatype iterm =
     IConst of const
@@ -198,7 +197,7 @@
 fun add_tycos (tyco `%% tys) = insert (op =) tyco #> fold add_tycos tys
   | add_tycos (ITyVar _) = I;
 
-val add_tyconames = fold_constexprs (fn (_, ((tys, _), _)) => fold add_tycos tys);
+val add_tyconames = fold_constexprs (fn (_, (((tys, _), _), _)) => fold add_tycos tys);
 
 fun fold_varnames f =
   let
@@ -240,7 +239,7 @@
         val vs_tys = (map o apfst) SOME (Name.invent_names ctxt "a" tys);
       in (vs_tys, t `$$ map (IVar o fst) vs_tys) end;
 
-fun eta_expand k (c as (name, (_, tys)), ts) =
+fun eta_expand k (c as (name, ((_, (tys, _)), _)), ts) =
   let
     val j = length ts;
     val l = k - j;
@@ -256,18 +255,12 @@
     fun cont_dict (Dict (_, d)) = cont_plain_dict d
     and cont_plain_dict (Dict_Const (_, dss)) = (exists o exists) cont_dict dss
       | cont_plain_dict (Dict_Var _) = true;
-    fun cont_term (IConst (_, ((_, dss), _))) = (exists o exists) cont_dict dss
+    fun cont_term (IConst (_, (((_, dss), _), _))) = (exists o exists) cont_dict dss
       | cont_term (IVar _) = false
       | cont_term (t1 `$ t2) = cont_term t1 orelse cont_term t2
       | cont_term (_ `|=> t) = cont_term t
       | cont_term (ICase (_, t)) = cont_term t;
   in cont_term t end;
-  
-fun locally_monomorphic (IConst _) = false
-  | locally_monomorphic (IVar _) = true
-  | locally_monomorphic (t `$ _) = locally_monomorphic t
-  | locally_monomorphic (_ `|=> t) = locally_monomorphic t
-  | locally_monomorphic (ICase ((_, ds), _)) = exists (locally_monomorphic o snd) ds;
 
 
 (** namings **)
@@ -480,28 +473,6 @@
   (fn (_, (Classinst ((class, _), (_, (param_insts, _))), _)) =>
     Option.map (fn ((const, _), _) => (class, const))
       (find_first (fn ((_, (inst_const, _)), _) => inst_const = name) param_insts) | _ => NONE)
-
-fun contr_classparam_typs program name = 
-  let
-    fun contr_classparam_typs' (class, name) =
-      let
-        val Class (_, (_, (_, params))) = Graph.get_node program class;
-        val SOME ty = AList.lookup (op =) params name;
-        val (tys, res_ty) = unfold_fun ty;
-        fun no_tyvar (_ `%% tys) = forall no_tyvar tys
-          | no_tyvar (ITyVar _) = false;
-      in if no_tyvar res_ty
-        then map (fn ty => if no_tyvar ty then NONE else SOME ty) tys
-        else []
-      end
- in 
-   case Graph.get_node program name
-   of Classparam (_, class) => contr_classparam_typs' (class, name)
-    | Fun (c, _) => (case lookup_classparam_instance program name
-      of NONE => []
-       | SOME (class, name) => the_default [] (try contr_classparam_typs' (class, name)))
-    | _ => []
-  end;
   
 fun labelled_name thy program name =
   let val ctxt = Proof_Context.init_global thy in
@@ -608,6 +579,42 @@
       (err_typ ^ "\n" ^ err_class)
   end;
 
+(* inference of type annotations for disambiguation with type classes *)
+
+fun annotate_term (Const (c', T'), Const (c, T)) tvar_names =
+    let
+      val tvar_names' = Term.add_tvar_namesT T' tvar_names
+    in
+      (Const (c, if eq_set (op =) (tvar_names, tvar_names') then T else Type("", [T])), tvar_names')
+    end
+  | annotate_term (t1 $ u1, t $ u) tvar_names =
+    let
+      val (u', tvar_names') = annotate_term (u1, u) tvar_names
+      val (t', tvar_names'') = annotate_term (t1, t) tvar_names'    
+    in
+      (t' $ u', tvar_names'')
+    end
+  | annotate_term (Abs (_, _, t1) , Abs (x, T, t)) tvar_names =
+    apfst (fn t => Abs (x, T, t)) (annotate_term (t1, t) tvar_names)
+  | annotate_term (_, t) tvar_names = (t, tvar_names)
+
+fun annotate_eqns thy eqns = 
+  let
+    val ctxt = ProofContext.init_global thy |> Config.put Type_Infer_Context.const_sorts false
+    val erase = map_types (fn _ => Type_Infer.anyT [])
+    val reinfer = singleton (Type_Infer_Context.infer_types ctxt)
+    fun add_annotations ((args, (rhs, some_abs)), (SOME th, proper)) =
+      let
+        val (lhs, drhs) = Logic.dest_equals (prop_of (Thm.unvarify_global th))
+        val drhs' = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase drhs))))
+        val (rhs', _) = annotate_term (drhs', rhs) []
+     in
+        ((args, (rhs', some_abs)), (SOME th, proper))
+     end
+     | add_annotations eqn = eqn
+  in
+    map add_annotations eqns
+  end;
 
 (* translation *)
 
@@ -633,11 +640,12 @@
     fun stmt_fun cert =
       let
         val ((vs, ty), eqns) = Code.equations_of_cert thy cert;
+        val eqns' = annotate_eqns thy eqns
         val some_case_cong = Code.get_case_cong thy c;
       in
         fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs
         ##>> translate_typ thy algbr eqngr permissive ty
-        ##>> translate_eqns thy algbr eqngr permissive eqns
+        ##>> translate_eqns thy algbr eqngr permissive eqns'
         #>> (fn info => Fun (c, (info, some_case_cong)))
       end;
     val stmt_const = case Code.get_type_of_constr_or_abstr thy c
@@ -748,15 +756,17 @@
         then translation_error thy permissive some_thm
           "Abstraction violation" ("constant " ^ Code.string_of_const thy c)
       else ()
-    val arg_typs = Sign.const_typargs thy (c, ty);
+    val (annotate, ty') = (case ty of Type("", [ty']) => (true, ty') | ty' => (false, ty'))
+    val arg_typs = Sign.const_typargs thy (c, ty');
     val sorts = Code_Preproc.sortargs eqngr c;
-    val function_typs = Term.binder_types ty;
+    val (function_typs, body_typ) = Term.strip_type ty';
   in
     ensure_const thy algbr eqngr permissive c
     ##>> fold_map (translate_typ thy algbr eqngr permissive) arg_typs
     ##>> fold_map (translate_dicts thy algbr eqngr permissive some_thm) (arg_typs ~~ sorts)
-    ##>> fold_map (translate_typ thy algbr eqngr permissive) function_typs
-    #>> (fn (((c, arg_typs), dss), function_typs) => IConst (c, ((arg_typs, dss), function_typs)))
+    ##>> fold_map (translate_typ thy algbr eqngr permissive) (body_typ :: function_typs)
+    #>> (fn (((c, arg_typs), dss), body_typ :: function_typs) =>
+      IConst (c, (((arg_typs, dss), (function_typs, body_typ)), annotate)))
   end
 and translate_app_const thy algbr eqngr permissive some_thm ((c_ty, ts), some_abs) =
   translate_const thy algbr eqngr permissive some_thm (c_ty, some_abs)
@@ -801,7 +811,7 @@
         val ts_clause = nth_drop t_pos ts;
         val clauses = if null case_pats
           then mk_clause (fn ([t], body) => (t, body)) [ty] (the_single ts_clause)
-          else maps (fn ((constr as IConst (_, (_, tys)), n), t) =>
+          else maps (fn ((constr as IConst (_, ((_, (tys, _)), _)), n), t) =>
             mk_clause (fn (ts, body) => (constr `$$ ts, body)) (take n tys) t)
               (constrs ~~ ts_clause);
       in ((t, ty), clauses) end;
--- a/src/Tools/nbe.ML	Wed Sep 07 13:50:17 2011 +0200
+++ b/src/Tools/nbe.ML	Wed Sep 07 14:58:40 2011 +0200
@@ -315,7 +315,7 @@
           let
             val (t', ts) = Code_Thingol.unfold_app t
           in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
-        and of_iapp match_cont (IConst (c, ((_, dss), _))) ts = constapp c dss ts
+        and of_iapp match_cont (IConst (c, (((_, dss), _), _))) ts = constapp c dss ts
           | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
           | of_iapp match_cont ((v, _) `|=> t) ts =
               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts
@@ -425,7 +425,7 @@
         val params = Name.invent Name.context "d" (length names);
         fun mk (k, name) =
           (name, ([(v, [])],
-            [([IConst (class, (([], []), [])) `$$ map (IVar o SOME) params],
+            [([IConst (class, ((([], []), ([], ITyVar "")), false)) `$$ map (IVar o SOME) params],
               IVar (SOME (nth params k)))]));
       in map_index mk names end
   | eqns_of_stmt (_, Code_Thingol.Classrel _) =
@@ -433,8 +433,8 @@
   | eqns_of_stmt (_, Code_Thingol.Classparam _) =
       []
   | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arity_args)), (super_instances, (classparam_instances, _)))) =
-      [(inst, (arity_args, [([], IConst (class, (([], []), [])) `$$
-        map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
+      [(inst, (arity_args, [([], IConst (class, ((([], []), ([], ITyVar "")), false)) `$$
+        map (fn (_, (_, (inst, dss))) => IConst (inst, ((([], dss), ([], ITyVar "")), false))) super_instances
         @ map (IConst o snd o fst) classparam_instances)]))];