improvement in devarifications
authorhaftmann
Thu, 02 Feb 2006 18:04:10 +0100
changeset 18912 dd168daf172d
parent 18911 74edab16166f
child 18913 57f19fad8c2a
improvement in devarifications
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/Tools/codegen_thingol.ML
--- a/src/Pure/Tools/codegen_package.ML	Thu Feb 02 18:03:35 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Thu Feb 02 18:04:10 2006 +0100
@@ -36,12 +36,7 @@
   val set_defgen_datatype: defgen -> theory -> theory;
   val set_int_tyco: string -> theory -> theory;
 
-  val exprgen_type: theory -> auxtab
-    -> typ -> CodegenThingol.transact -> CodegenThingol.itype * CodegenThingol.transact;
-  val exprgen_term: theory -> auxtab
-    -> term -> CodegenThingol.transact -> CodegenThingol.iexpr * CodegenThingol.transact;
   val appgen_default: appgen;
-
   val appgen_let: (int -> term -> term list * term)
     -> appgen;
   val appgen_split: (int -> term -> term list * term)
@@ -85,19 +80,6 @@
 infixr 3 `|->;
 infixr 3 `|-->;
 
-(* auxiliary *)
-
-fun devarify_type ty = (fst o Type.freeze_thaw_type o Term.zero_var_indexesT) ty;
-fun devarify_term t = (fst o Type.freeze_thaw o Term.zero_var_indexes) t;
-
-val is_number = is_some o Int.fromString;
-
-fun merge_opt _ (x1, NONE) = x1
-  | merge_opt _ (NONE, x2) = x2
-  | merge_opt eq (SOME x1, SOME x2) =
-      if eq (x1, x2) then SOME x1 else error ("incompatible options during merge");
-
-
 (* shallow name spaces *)
 
 val nsp_module = ""; (* a dummy by convention *)
@@ -205,6 +187,11 @@
 
 (* theory data for code generator *)
 
+fun merge_opt _ (x1, NONE) = x1
+  | merge_opt _ (NONE, x2) = x2
+  | merge_opt eq (SOME x1, SOME x2) =
+      if eq (x1, x2) then SOME x1 else error ("incompatible options during merge");
+
 type gens = {
   appconst: ((int * int) * (appgen * stamp)) Symtab.table,
   eqextrs: (string * (eqextr * stamp)) list
@@ -506,6 +493,75 @@
     ); thy);
 
 
+(* sophisticated devarification *)
+
+fun assert f msg x =
+  if f x then x
+    else error msg;
+
+val _ : ('a -> bool) -> string -> 'a -> 'a = assert;
+
+fun devarify_typs tys =
+  let
+    fun add_rename (var as ((v, _), sort)) used = 
+      let
+        val v' = variant used v
+      in (((var, TFree (v', sort)), (v', TVar var)), v' :: used) end;
+    fun typ_names (Type (tyco, tys)) (vars, names) =
+          (vars, names |> insert (op =) (NameSpace.base tyco))
+          |> fold typ_names tys
+      | typ_names (TFree (v, _)) (vars, names) =
+          (vars, names |> insert (op =) v)
+      | typ_names (TVar (v, sort)) (vars, names) =
+          (vars |> AList.update (op =) (v, sort), names);
+    val (vars, used) = fold typ_names tys ([], []);
+    val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
+  in
+    (reverse, (map o map_atyps) (Term.instantiateT renames) tys)
+  end;
+
+fun burrow_typs_yield f ts =
+  let
+    val typtab =
+      fold (fold_types (fn ty => Typtab.update (ty, dummyT)))
+        ts Typtab.empty;
+    val typs = Typtab.keys typtab;
+    val (x, typs') = f typs;
+    val typtab' = fold2 (Typtab.update oo pair) typs typs' typtab;
+  in
+    (x, (map o map_term_types) (the o Typtab.lookup typtab') ts)
+  end;
+
+fun devarify_terms ts =
+  let
+    fun add_rename (var as ((v, _), ty)) used = 
+      let
+        val v' = variant used v
+      in (((var, Free (v', ty)), (v', Var var)), v' :: used) end;
+    fun term_names (Const (c, _)) (vars, names) =
+          (vars, names |> insert (op =) (NameSpace.base c))
+      | term_names (Free (v, _)) (vars, names) =
+          (vars, names |> insert (op =) v)
+      | term_names (Var (v, sort)) (vars, names) =
+          (vars |> AList.update (op =) (v, sort), names)
+      | term_names (Bound _) vars_names =
+          vars_names
+      | term_names (Abs (v, _, _)) (vars, names) =
+          (vars, names |> insert (op =) v)
+      | term_names (t1 $ t2) vars_names =
+          vars_names |> term_names t1 |> term_names t2
+    val (vars, used) = fold term_names ts ([], []);
+    val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
+  in
+    (reverse, (map o map_aterms) (Term.instantiate ([], renames)) ts)
+  end;
+
+fun devarify_term_typs ts =
+  ts
+  |> devarify_terms
+  |-> (fn reverse => burrow_typs_yield devarify_typs
+  #-> (fn reverseT => pair (reverseT, reverse)));
+
 (* definition and expression generators *)
 
 fun ensure_def_class thy tabs cls trns =
@@ -521,7 +577,7 @@
               trns
               |> debug 5 (fn _ => "trying defgen class declaration for " ^ quote cls)
               |> fold_map (ensure_def_class thy tabs) (ClassPackage.the_superclasses thy cls)
-              ||>> fold_map (exprgen_type thy tabs o devarify_type o snd) cs
+              ||>> (codegen_type thy tabs o map snd) cs
               ||>> (fold_map o fold_map) (exprgen_tyvar_sort thy tabs) sortctxts
               |-> (fn ((supcls, memtypes), sortctxts) => succeed
                 (Class ((supcls, ("a", idfs ~~ (sortctxts ~~ memtypes))), [])))
@@ -564,7 +620,9 @@
       trns
       |> ensure_def_tyco thy tabs tyco
       ||>> fold_map (exprgen_type thy tabs) tys
-      |-> (fn (tyco, tys) => pair (tyco `%% tys));
+      |-> (fn (tyco, tys) => pair (tyco `%% tys))
+and codegen_type thy tabs =
+  fold_map (exprgen_type thy tabs) o snd o devarify_typs;
 
 fun exprgen_classlookup thy tabs (ClassPackage.Instance (inst, ls)) trns =
       trns
@@ -590,17 +648,12 @@
                    ^ ", actually defining " ^ quote c')
               | _ => error ("illegal function equation for " ^ quote c)
             end;
-          fun mk_eq (args, rhs) trns =
-            trns
-            |> fold_map (exprgen_term thy tabs o devarify_term) args
-            ||>> (exprgen_term thy tabs o devarify_term) rhs
-            |-> (fn (args, rhs) => pair (args, rhs))
         in
           trns
-          |> fold_map (mk_eq o dest_eqthm) eq_thms
-          ||>> (exprgen_type thy tabs o devarify_type) ty
+          |> (codegen_eqs thy tabs o map dest_eqthm) eq_thms
+          ||>> codegen_type thy tabs [ty]
           ||>> fold_map (exprgen_tyvar_sort thy tabs) sortctxt
-          |-> (fn ((eqs, ty), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty)))
+          |-> (fn ((eqs, [ty]), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty)))
         end
     | NONE => (NONE, trns)
 and ensure_def_inst thy (tabs as (_, (insttab, _, _))) (cls, tyco) trns =
@@ -690,19 +743,21 @@
       trns
       |> appgen thy tabs ((f, ty), [])
       |-> (fn e => pair e)
-  | exprgen_term thy tabs (Var ((v, 0), ty)) trns =
+  (* | exprgen_term thy tabs (Var ((v, 0), ty)) trns =
       trns
-      |> (exprgen_type thy tabs o devarify_type) ty
+      |> (exprgen_type thy tabs) ty
       |-> (fn ty => pair (IVarE (v, ty)))
   | exprgen_term thy tabs (Var ((_, _), _)) trns =
-      error "Var with index greater 0 encountered during code generation"
+      error "Var with index greater 0 encountered during code generation" *)
+  | exprgen_term thy tabs (Var _) trns =
+      error "Var encountered during code generation"
   | exprgen_term thy tabs (Free (v, ty)) trns =
       trns
-      |> (exprgen_type thy tabs o devarify_type) ty
+      |> exprgen_type thy tabs ty
       |-> (fn ty => pair (IVarE (v, ty)))
   | exprgen_term thy tabs (Abs (v, ty, t)) trns =
       trns
-      |> (exprgen_type thy tabs o devarify_type) ty
+      |> exprgen_type thy tabs ty
       ||>> exprgen_term thy tabs (subst_bound (Free (v, ty), t))
       |-> (fn (ty, e) => pair ((v, ty) `|-> e))
   | exprgen_term thy tabs (t as t1 $ t2) trns =
@@ -719,14 +774,20 @@
             ||>> fold_map (exprgen_term thy tabs) ts
             |-> (fn (e, es) => pair (e `$$ es))
       end
+and codegen_term thy tabs =
+  fold_map (exprgen_term thy tabs) o snd o devarify_term_typs
+and codegen_eqs thy tabs =
+  apfst (map (fn (rhs::args) => (args, rhs)))
+    oo fold_burrow (codegen_term thy tabs)
+    o map (fn (args, rhs) => (rhs :: args))
 and appgen_default thy tabs ((c, ty), ts) trns =
   trns
   |> ensure_def_const thy tabs (c, ty)
   ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
          (ClassPackage.extract_classlookup thy (c, ty))
-  ||>> (exprgen_type thy tabs o devarify_type) ty
-  ||>> fold_map (exprgen_term thy tabs o devarify_term) ts
-  |-> (fn (((c, ls), ty), es) =>
+  ||>> codegen_type thy tabs [ty]
+  ||>> fold_map (exprgen_term thy tabs) ts
+  |-> (fn (((c, ls), [ty]), es) =>
          pair (IConst ((c, ty), ls) `$$ es))
 and appgen thy tabs ((f, ty), ts) trns =
   case Symtab.lookup ((#appconst o #gens o CodegenData.get) thy) f
@@ -739,7 +800,7 @@
           in
             trns
             |> debug 10 (fn _ => "eta-expanding")
-            |> fold_map (exprgen_type thy tabs o devarify_type) tys
+            |> fold_map (exprgen_type thy tabs) tys
             ||>> ag thy tabs ((f, ty), ts @ map2 (curry Free) vs tys)
             |-> (fn (tys, e) => pair ((vs ~~ tys) `|--> e))
           end
@@ -842,7 +903,7 @@
   Type (_, [_, ty as Type (tyco, [])])), [bin]) trns =
     if tyco = tyco_int then
       trns
-      |> (exprgen_type thy tabs o devarify_type) ty
+      |> exprgen_type thy tabs ty
       |-> (fn ty => pair (CodegenThingol.IConst (((IntInf.toString o bin_to_int) bin, ty), [])))
     else if tyco = tyco_nat then
       trns
@@ -902,7 +963,7 @@
                 trns
                 |> debug 5 (fn _ => "trying defgen datatype for " ^ quote dtco)
                 |> fold_map (exprgen_tyvar_sort thy tabs) vars
-                ||>> (fold_map o fold_map) (exprgen_type thy tabs o devarify_type) cotys
+                ||>> fold_map (codegen_type thy tabs) cotys
                 |-> (fn (sorts, tys) => succeed (Datatype
                      ((sorts, coidfs ~~ tys), [])))
               end
@@ -1056,9 +1117,10 @@
   in (c, ty) end;
 
 fun read_quote reader gen raw thy =
-  expand_module
-    (fn thy => fn tabs => gen thy tabs (reader thy raw))
-    thy;
+  thy
+  |> expand_module
+       (fn thy => fn tabs => (gen thy tabs o single o reader thy) raw)
+  |-> (fn [x] => pair x);
 
 fun gen_add_prim prep_name prep_primdef raw_name deps (target, raw_primdef) thy =
   let
@@ -1133,8 +1195,7 @@
               logic_data)))
       end;
   in
-    CodegenSerializer.parse_syntax
-      (read_quote read_typ (fn thy => fn tabs => exprgen_type thy tabs o devarify_type))
+    CodegenSerializer.parse_syntax (read_quote read_typ codegen_type)
     #-> (fn reader => pair (mk reader))
   end;
 
@@ -1165,7 +1226,7 @@
         |-> (fn pretty => add_pretty_syntax_const c target pretty)
       end;
   in
-    CodegenSerializer.parse_syntax (read_quote Sign.read_term exprgen_term)
+    CodegenSerializer.parse_syntax (read_quote Sign.read_term codegen_term)
     #-> (fn reader => pair (mk reader))
   end;
 
--- a/src/Pure/Tools/codegen_serializer.ML	Thu Feb 02 18:03:35 2006 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML	Thu Feb 02 18:04:10 2006 +0100
@@ -499,7 +499,7 @@
               :: (lss
               @ map (ml_from_expr BR) es)
             );
-    fun ml_from_funs (ds as d::ds_tl) =
+    fun ml_from_funs (defs as def::defs_tl) =
       let
         fun mk_definer [] = "val"
           | mk_definer _ = "fun";
@@ -511,37 +511,33 @@
               else error ("mixing simultaneous vals and funs not implemented")
           | check_args _ _ =
               error ("function definition block containing other definitions than functions")
-        val definer = the (fold check_args ds NONE);
-        fun mk_eq definer sortctxt f ty (pats, expr) =
+        fun mk_fun definer (name, Fun (eqs as eq::eq_tl, (sortctxt, ty))) =
           let
-            val args = map (str o fst) sortctxt @ map (ml_from_expr BR) pats;
-            val lhs = [str (definer ^ " " ^ f)]
-                       @ (if null args
-                          then [str ":", ml_from_type NOBR ty]
-                          else args)
-            val rhs = [str "=", ml_from_expr NOBR expr]
+            val shift = if null eq_tl then I else map (Pretty.block o single);
+            fun mk_eq definer (pats, expr) =
+              (Pretty.block o Pretty.breaks) (
+                [str definer, (str o resolv) name]
+                @ (if null pats
+                   then [str ":", ml_from_type NOBR ty]
+                   else map (str o fst) sortctxt @ map (ml_from_expr BR) pats)
+                @ [str "=", ml_from_expr NOBR expr]
+              )
           in
-            Pretty.block (separate (Pretty.brk 1) (lhs @ rhs))
-          end
-        fun mk_fun definer (f, Fun (eqs as eq::eq_tl, (sortctxt, ty))) =
-          let
-            val (pats_hd::pats_tl) = (fst o split_list) eqs;
-            val shift = if null eq_tl then I else map (Pretty.block o single);
-          in (Pretty.block o Pretty.fbreaks o shift) (
-               mk_eq definer sortctxt f ty eq
-               :: map (mk_eq "|" sortctxt f ty) eq_tl
-             )
+            (Pretty.block o Pretty.fbreaks o shift) (
+              mk_eq definer eq
+              :: map (mk_eq "|") eq_tl
+            )
           end;
       in
         chunk_defs (
-          mk_fun definer d
-          :: map (mk_fun "and") ds_tl
+          mk_fun (the (fold check_args defs NONE)) def
+          :: map (mk_fun "and") defs_tl
         ) |> SOME
       end;
     fun ml_from_datatypes defs =
       let
         val defs' = List.mapPartial
-          (fn (name, Datatype info) => SOME (name, info)
+          (fn (name, Datatype info) => SOME (resolv name, info)
             | (name, Datatypecons _) => NONE
             | (name, def) => error ("datatype block containing illegal def: "
                 ^ (Pretty.output o pretty_def) def)
@@ -557,19 +553,18 @@
                      (map (ml_from_type NOBR) tys)
               )
         fun mk_datatype definer (t, ((vs, cs), _)) =
-          Pretty.block (
+          (Pretty.block o Pretty.breaks) (
             str definer
             :: ml_from_type NOBR (t `%% map IVarT vs)
-            :: str " ="
-            :: Pretty.brk 1
-            :: separate (Pretty.block [Pretty.brk 1, str "| "]) (map mk_cons cs)
+            :: str "="
+            :: separate (str "|") (map mk_cons cs)
           )
       in
         case defs'
-         of d::ds_tl =>
+         of (def::defs_tl) =>
             chunk_defs (
-              mk_datatype "datatype " d
-              :: map (mk_datatype "and ") ds_tl
+              mk_datatype "datatype " def
+              :: map (mk_datatype "and ") defs_tl
             ) |> SOME
           | _ => NONE
       end
@@ -661,7 +656,7 @@
             Pretty.block [
               (Pretty.block o Pretty.breaks) (
                 str definer
-                :: str name
+                :: (str o resolv) name
                 :: map (str o fst) arity
               ),
               Pretty.brk 1,
@@ -867,7 +862,7 @@
       let
         fun from_eq name (args, rhs) =
           Pretty.block [
-            str (lower_first name),
+            (str o lower_first o resolv) name,
             Pretty.block (map (fn p => Pretty.block [Pretty.brk 1, hs_from_expr BR p]) args),
             Pretty.brk 1,
             str ("="),
@@ -880,26 +875,14 @@
       | hs_from_def (name, Prim prim) =
           from_prim (name, prim)
       | hs_from_def (name, Fun (eqs, (sctxt, ty))) =
-          let
-            fun from_eq name (args, rhs) =
-              Pretty.block [
-                str (lower_first name),
-                Pretty.block (map (fn p => Pretty.block [Pretty.brk 1, hs_from_expr BR p]) args),
-                Pretty.brk 1,
-                str ("="),
-                Pretty.brk 1,
-                hs_from_expr NOBR rhs
-              ]
-          in
-            Pretty.chunks [
-              Pretty.block [
-                str (lower_first name ^ " ::"),
-                Pretty.brk 1,
-                hs_from_sctxt_type (sctxt, ty)
-              ],
-              hs_from_funeqs (name, eqs)
-            ] |> SOME
-          end
+          Pretty.chunks [
+            Pretty.block [
+              (str o lower_first o resolv) (name ^ " ::"),
+              Pretty.brk 1,
+              hs_from_sctxt_type (sctxt, ty)
+            ],
+            hs_from_funeqs (name, eqs)
+          ] |> SOME
       | hs_from_def (name, Typesyn (vs, ty)) =
           Pretty.block [
             str "type ",
@@ -941,7 +924,7 @@
             Pretty.block [
               str "class ",
               hs_from_sctxt (map (fn class => (v, [class])) supclasss),
-              str ((upper_first name) ^ " " ^ v),
+              str ((upper_first o resolv) name ^ " " ^ v),
               str " where",
               Pretty.fbrk,
               Pretty.chunks (map mk_member membrs)
@@ -957,7 +940,7 @@
             hs_from_sctxt_type (arity, IType (tyco, map (IVarT o rpair [] o fst) arity)),
             str " where",
             Pretty.fbrk,
-            Pretty.chunks (map (fn (m, (eqs, _)) => hs_from_funeqs (resolv m, eqs)) memdefs)
+            Pretty.chunks (map (fn (m, (eqs, _)) => hs_from_funeqs (m, eqs)) memdefs)
           ] |> SOME
   in
     case List.mapPartial (fn (name, def) => hs_from_def (name, def)) defs
--- a/src/Pure/Tools/codegen_thingol.ML	Thu Feb 02 18:03:35 2006 +0100
+++ b/src/Pure/Tools/codegen_thingol.ML	Thu Feb 02 18:04:10 2006 +0100
@@ -33,6 +33,7 @@
   val unfold_let: iexpr -> (iexpr * iexpr) list * iexpr;
   val unfold_const_app: iexpr ->
     (((string * itype) * classlookup list list) * iexpr list) option;
+  val ensure_pat: iexpr -> iexpr;
   val itype_of_iexpr: iexpr -> itype;
 
   val `%% : string * itype list -> itype;
@@ -379,6 +380,13 @@
   | itype_of_iexpr (IAbs ((_, ty1), e2)) = ty1 `-> itype_of_iexpr e2
   | itype_of_iexpr (ICase ((_, [(_, e)]))) = itype_of_iexpr e;
 
+fun ensure_pat (e as IConst (_, [])) = e
+  | ensure_pat (e as IVarE _) = e
+  | ensure_pat (e as IApp (e1, e2)) =
+      (ensure_pat e1 `$ ensure_pat e2; e)
+  | ensure_pat e =
+      error ("illegal expression for pattern: " ^ (Pretty.output o pretty_iexpr) e);
+
 fun type_vnames ty = 
   let
     fun extr (IVarT (v, _)) =
@@ -1163,8 +1171,8 @@
 
 fun serialize seri_defs seri_module validate nsp_conn name_root module =
   let
-(*     val resolver = mk_deresolver module nsp_conn snd validate;  *)
-    val resolver = mk_resolv (mk_resolvtab' nsp_conn validate module);
+    val resolver = mk_deresolver module nsp_conn snd validate;
+(*     val resolver = mk_resolv (mk_resolvtab' nsp_conn validate module);  *)
     fun mk_name prfx name =
       let
         val name_qual = NameSpace.pack (prfx @ [name])
@@ -1177,7 +1185,7 @@
             (mk_name prfx name, mk_contents (prfx @ [name]) modl)
       | seri prfx ds =
           seri_defs (resolver prfx)
-            (map (fn (name, Def def) => (snd (mk_name prfx name), def)) ds)
+            (map (fn (name, Def def) => (fst (mk_name prfx name), def)) ds)
   in
     seri_module (map (resolver []) (Graph.strong_conn module |> List.concat |> rev))
       (("", name_root), (mk_contents [] module))