src/Tools/nbe.ML
changeset 55147 bce3dbc11f95
parent 55043 acefda71629b
child 55150 0940309ed8f1
--- a/src/Tools/nbe.ML	Sat Jan 25 23:50:49 2014 +0100
+++ b/src/Tools/nbe.ML	Sat Jan 25 23:50:49 2014 +0100
@@ -249,11 +249,11 @@
 
 val univs_cookie = (Univs.get, put_result, name_put);
 
-val sloppy_name = Long_Name.base_name o Long_Name.qualifier
+val sloppy_name = Code_Symbol.base_name;
 
-fun nbe_fun idx_of 0 "" = "nbe_value"
-  | nbe_fun idx_of i c = "c_" ^ string_of_int (idx_of c)
-      ^ "_" ^ sloppy_name c ^ "_" ^ string_of_int i;
+fun nbe_fun idx_of 0 (Code_Symbol.Constant "") = "nbe_value"
+  | nbe_fun idx_of i sym = "c_" ^ string_of_int (idx_of sym)
+      ^ "_" ^ sloppy_name sym ^ "_" ^ string_of_int i;
 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
 fun nbe_bound v = "v_" ^ v;
 fun nbe_bound_optional NONE = "_"
@@ -291,24 +291,24 @@
       in (c, (num_args, (dicts, eqns))) end;
     val eqnss' = map prep_eqns eqnss;
 
-    fun assemble_constapp c dss ts = 
+    fun assemble_constapp sym dss ts = 
       let
         val ts' = (maps o map) assemble_dict dss @ ts;
-      in case AList.lookup (op =) eqnss' c
+      in case AList.lookup (op =) eqnss' sym
        of SOME (num_args, _) => if num_args <= length ts'
             then let val (ts1, ts2) = chop num_args ts'
-            in nbe_apps (nbe_apps_local idx_of 0 c ts1) ts2
-            end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 c)) ts'
-        | NONE => if member (op =) deps c
-            then nbe_apps (nbe_fun idx_of 0 c) ts'
-            else nbe_apps_constr idx_of c ts'
+            in nbe_apps (nbe_apps_local idx_of 0 sym ts1) ts2
+            end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 sym)) ts'
+        | NONE => if member (op =) deps sym
+            then nbe_apps (nbe_fun idx_of 0 sym) ts'
+            else nbe_apps_constr idx_of sym ts'
       end
     and assemble_classrels classrels =
-      fold_rev (fn classrel => assemble_constapp classrel [] o single) classrels
+      fold_rev (fn classrel => assemble_constapp (Code_Symbol.Class_Relation classrel) [] o single) classrels
     and assemble_dict (Dict (classrels, x)) =
           assemble_classrels classrels (assemble_plain_dict x)
     and assemble_plain_dict (Dict_Const (inst, dss)) =
-          assemble_constapp inst dss []
+          assemble_constapp (Code_Symbol.Class_Instance inst) dss []
       | assemble_plain_dict (Dict_Var (v, (n, _))) =
           nbe_dict v n
 
@@ -318,7 +318,7 @@
           let
             val (t', ts) = Code_Thingol.unfold_app t
           in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
-        and of_iapp match_cont (IConst { name = c, dicts = dss, ... }) ts = constapp c dss ts
+        and of_iapp match_cont (IConst { sym, dicts = dss, ... }) ts = constapp sym dss ts
           | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
           | of_iapp match_cont ((v, _) `|=> t) ts =
               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts
@@ -353,12 +353,12 @@
         val (args', _) = fold_map subst_vars args samepairs;
       in (samepairs, args') end;
 
-    fun assemble_eqn c dicts default_args (i, (args, rhs)) =
+    fun assemble_eqn sym dicts default_args (i, (args, rhs)) =
       let
-        val match_cont = if c = "" then NONE
-          else SOME (nbe_apps_local idx_of (i + 1) c (dicts @ default_args));
+        val match_cont = if Code_Symbol.is_value sym then NONE
+          else SOME (nbe_apps_local idx_of (i + 1) sym (dicts @ default_args));
         val assemble_arg = assemble_iterm
-          (fn c => fn dss => fn ts => nbe_apps_constr idx_of c ((maps o map) (K "_") dss @ ts)) NONE;
+          (fn sym' => fn dss => fn ts => nbe_apps_constr idx_of sym' ((maps o map) (K "_") dss @ ts)) NONE;
         val assemble_rhs = assemble_iterm assemble_constapp match_cont;
         val (samepairs, args') = subst_nonlin_vars args;
         val s_args = map assemble_arg args';
@@ -370,17 +370,17 @@
           | SOME default_rhs =>
               [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
                 ([ml_list (rev (dicts @ default_args))], default_rhs)]
-      in (nbe_fun idx_of i c, eqns) end;
+      in (nbe_fun idx_of i sym, eqns) end;
 
-    fun assemble_eqns (c, (num_args, (dicts, eqns))) =
+    fun assemble_eqns (sym, (num_args, (dicts, eqns))) =
       let
         val default_args = map nbe_default
           (Name.invent Name.context "a" (num_args - length dicts));
-        val eqns' = map_index (assemble_eqn c dicts default_args) eqns
-          @ (if c = "" then [] else [(nbe_fun idx_of (length eqns) c,
+        val eqns' = map_index (assemble_eqn sym dicts default_args) eqns
+          @ (if Code_Symbol.is_value sym then [] else [(nbe_fun idx_of (length eqns) sym,
             [([ml_list (rev (dicts @ default_args))],
-              nbe_apps_constr idx_of c (dicts @ default_args))])]);
-      in (eqns', nbe_abss num_args (nbe_fun idx_of 0 c)) end;
+              nbe_apps_constr idx_of sym (dicts @ default_args))])]);
+      in (eqns', nbe_abss num_args (nbe_fun idx_of 0 sym)) end;
 
     val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
     val deps_vars = ml_list (map (nbe_fun idx_of 0) deps);
@@ -394,9 +394,9 @@
       let
         val ctxt = Proof_Context.init_global thy;
         val (deps, deps_vals) = split_list (map_filter
-          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
+          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
         val idx_of = raw_deps
-          |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
+          |> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep)))
           |> AList.lookup (op =)
           |> (fn f => the o f);
         val s = assemble_eqnss idx_of deps eqnss;
@@ -413,45 +413,45 @@
 
 (* extract equations from statements *)
 
-fun dummy_const c dss =
-  IConst { name = c, typargs = [], dicts = dss,
+fun dummy_const sym dss =
+  IConst { sym = sym, typargs = [], dicts = dss,
     dom = [], range = ITyVar "", annotate = false };
 
-fun eqns_of_stmt (_, Code_Thingol.NoStmt _) =
+fun eqns_of_stmt (_, Code_Thingol.NoStmt) =
       []
-  | eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
+  | eqns_of_stmt (_, Code_Thingol.Fun ((_, []), _)) =
       []
-  | eqns_of_stmt (const, Code_Thingol.Fun (_, (((vs, _), eqns), _))) =
-      [(const, (vs, map fst eqns))]
+  | eqns_of_stmt (sym_const, Code_Thingol.Fun (((vs, _), eqns), _)) =
+      [(sym_const, (vs, map fst eqns))]
   | eqns_of_stmt (_, Code_Thingol.Datatypecons _) =
       []
   | eqns_of_stmt (_, Code_Thingol.Datatype _) =
       []
-  | eqns_of_stmt (class, Code_Thingol.Class (_, (v, (super_classes, classparams)))) =
+  | eqns_of_stmt (sym_class, Code_Thingol.Class (v, (classrels, classparams))) =
       let
-        val names = map snd super_classes @ map fst classparams;
-        val params = Name.invent Name.context "d" (length names);
-        fun mk (k, name) =
-          (name, ([(v, [])],
-            [([dummy_const class [] `$$ map (IVar o SOME) params],
+        val syms = map Code_Symbol.Class_Relation classrels @ map (Code_Symbol.Constant o fst) classparams;
+        val params = Name.invent Name.context "d" (length syms);
+        fun mk (k, sym) =
+          (sym, ([(v, [])],
+            [([dummy_const sym_class [] `$$ map (IVar o SOME) params],
               IVar (SOME (nth params k)))]));
-      in map_index mk names end
+      in map_index mk syms end
   | eqns_of_stmt (_, Code_Thingol.Classrel _) =
       []
   | eqns_of_stmt (_, Code_Thingol.Classparam _) =
       []
-  | eqns_of_stmt (inst, Code_Thingol.Classinst { class, vs, superinsts, inst_params, ... }) =
-      [(inst, (vs, [([], dummy_const class [] `$$
-        map (fn (_, (_, (inst, dss))) => dummy_const inst dss) superinsts
+  | eqns_of_stmt (sym_inst, Code_Thingol.Classinst { class, tyco, vs, superinsts, inst_params, ... }) =
+      [(sym_inst, (vs, [([], dummy_const (Code_Symbol.Type_Class class) [] `$$
+        map (fn (class, dss) => dummy_const (Code_Symbol.Class_Instance (tyco, class)) dss) superinsts
         @ map (IConst o fst o snd o fst) inst_params)]))];
 
 
 (* compile whole programs *)
 
 fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) =
-  if can (Graph.get_node nbe_program) name
+  if can (Code_Symbol.Graph.get_node nbe_program) name
   then (nbe_program, (maxidx, idx_tab))
-  else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
+  else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
     (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
 
 fun compile_stmts thy stmts_deps =
@@ -468,20 +468,20 @@
       |> rpair nbe_program;
   in
     fold ensure_const_idx refl_deps
-    #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
+    #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
       #> compile
-      #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
+      #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
   end;
 
 fun compile_program thy program =
   let
-    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
+    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
       then (nbe_program, (maxidx, idx_tab))
       else (nbe_program, (maxidx, idx_tab))
-        |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
-          Graph.immediate_succs program name)) names);
+        |> compile_stmts thy (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
+          Code_Symbol.Graph.immediate_succs program name)) names);
   in
-    fold_rev add_stmts (Graph.strong_conn program)
+    fold_rev add_stmts (Code_Symbol.Graph.strong_conn program)
   end;
 
 
@@ -493,7 +493,7 @@
   let 
     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   in
-    ("", (vs, [([], t)]))
+    (Code_Symbol.value, (vs, [([], t)]))
     |> singleton (compile_eqnss thy nbe_program deps)
     |> snd
     |> (fn t => apps t (rev dict_frees))
@@ -502,43 +502,35 @@
 
 (* reconstruction *)
 
-fun typ_of_itype program vs (ityco `%% itys) =
-      let
-        val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco;
-      in Type (tyco, map (typ_of_itype program vs) itys) end
-  | typ_of_itype program vs (ITyVar v) =
-      let
-        val sort = (the o AList.lookup (op =) vs) v;
-      in TFree ("'" ^ v, sort) end;
+fun typ_of_itype vs (tyco `%% itys) =
+      Type (tyco, map (typ_of_itype vs) itys)
+  | typ_of_itype vs (ITyVar v) =
+      TFree ("'" ^ v, (the o AList.lookup (op =) vs) v);
 
-fun term_of_univ thy program idx_tab t =
+fun term_of_univ thy idx_tab t =
   let
     fun take_until f [] = []
-      | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
-    fun is_dict (Const (idx, _)) = (case (Graph.get_node program o the o Inttab.lookup idx_tab) idx
-         of Code_Thingol.Class _ => true
-          | Code_Thingol.Classrel _ => true
-          | Code_Thingol.Classinst _ => true
-          | _ => false)
+      | take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
+    fun is_dict (Const (idx, _)) =
+          (case Inttab.lookup idx_tab idx of
+            SOME (Code_Symbol.Constant _) => false
+          | _ => true)
       | is_dict (DFree _) = true
       | is_dict _ = false;
-    fun const_of_idx idx = (case (Graph.get_node program o the o Inttab.lookup idx_tab) idx
-     of Code_Thingol.NoStmt c => c
-      | Code_Thingol.Fun (c, _) => c
-      | Code_Thingol.Datatypecons (c, _) => c
-      | Code_Thingol.Classparam (c, _) => c);
+    fun const_of_idx idx =
+      case Inttab.lookup idx_tab idx of SOME (Code_Symbol.Constant const) => const;
     fun of_apps bounds (t, ts) =
       fold_map (of_univ bounds) ts
       #>> (fn ts' => list_comb (t, rev ts'))
     and of_univ bounds (Const (idx, ts)) typidx =
           let
             val ts' = take_until is_dict ts;
-            val c = const_of_idx idx;
+            val const = const_of_idx idx;
             val T = map_type_tvar (fn ((v, i), _) =>
               Type_Infer.param typidx (v ^ string_of_int i, []))
-                (Sign.the_const_type thy c);
+                (Sign.the_const_type thy const);
             val typidx' = typidx + 1;
-          in of_apps bounds (Term.Const (c, T), ts') typidx' end
+          in of_apps bounds (Term.Const (const, T), ts') typidx' end
       | of_univ bounds (BVar (n, ts)) typidx =
           of_apps bounds (Bound (bounds - n - 1), ts) typidx
       | of_univ bounds (t as Abs _) typidx =
@@ -550,11 +542,11 @@
 
 (* evaluation with type reconstruction *)
 
-fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
+fun eval_term thy (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   let
     val ctxt = Syntax.init_pretty_global thy;
     val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
-    val ty' = typ_of_itype program vs0 ty;
+    val ty' = typ_of_itype vs0 ty;
     fun type_infer t =
       Syntax.check_term (Config.put Type_Infer_Context.const_sorts false ctxt)
         (Type.constraint ty' t);
@@ -563,7 +555,7 @@
       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   in
     compile_term thy nbe_program deps (vs, t)
-    |> term_of_univ thy program idx_tab
+    |> term_of_univ thy idx_tab
     |> traced (fn t => "Normalized:\n" ^ string_of_term t)
     |> type_infer
     |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
@@ -576,8 +568,8 @@
 
 structure Nbe_Functions = Code_Data
 (
-  type T = (Univ option * int) Graph.T * (int * string Inttab.table);
-  val empty = (Graph.empty, (0, Inttab.empty));
+  type T = (Univ option * int) Code_Symbol.Graph.T * (int * Code_Symbol.symbol Inttab.table);
+  val empty = (Code_Symbol.Graph.empty, (0, Inttab.empty));
 );
 
 fun compile ignore_cache thy program =
@@ -599,26 +591,23 @@
 
 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   (Thm.add_oracle (@{binding normalization_by_evaluation},
-    fn (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
-      mk_equals thy ct (eval_term thy program nbe_program_idx_tab vsp_ty_t deps))));
+    fn (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
+      mk_equals thy ct (eval_term thy nbe_program_idx_tab vsp_ty_t deps))));
 
-fun oracle thy program nbe_program_idx_tab vsp_ty_t deps ct =
-  raw_oracle (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct);
+fun oracle thy nbe_program_idx_tab vsp_ty_t deps ct =
+  raw_oracle (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct);
 
-fun dynamic_conv thy = lift_triv_classes_conv thy (Code_Thingol.dynamic_conv thy
-    (K (fn program => oracle thy program (compile false thy program))));
+fun dynamic_conv thy = lift_triv_classes_conv thy
+  (Code_Thingol.dynamic_conv thy (oracle thy o compile false thy));
 
 fun dynamic_value thy = lift_triv_classes_rew thy
-  (Code_Thingol.dynamic_value thy I
-    (K (fn program => eval_term thy program (compile false thy program))));
+  (Code_Thingol.dynamic_value thy I (eval_term thy o compile false thy));
 
-fun static_conv thy consts =
-  lift_triv_classes_conv thy (Code_Thingol.static_conv thy consts
-    (K (fn program => fn _ => oracle thy program (compile true thy program))));
+fun static_conv thy consts = lift_triv_classes_conv thy
+  (Code_Thingol.static_conv thy consts (K o oracle thy o compile true thy));
 
 fun static_value thy consts = lift_triv_classes_rew thy
-  (Code_Thingol.static_value thy I consts
-    (K (fn program => fn _ => eval_term thy program (compile true thy program))));
+  (Code_Thingol.static_value thy I consts (K o eval_term thy o compile true thy));
 
 
 (** setup **)