src/Tools/nbe.ML
changeset 55147 bce3dbc11f95
parent 55043 acefda71629b
child 55150 0940309ed8f1
     1.1 --- a/src/Tools/nbe.ML	Sat Jan 25 23:50:49 2014 +0100
     1.2 +++ b/src/Tools/nbe.ML	Sat Jan 25 23:50:49 2014 +0100
     1.3 @@ -249,11 +249,11 @@
     1.4  
     1.5  val univs_cookie = (Univs.get, put_result, name_put);
     1.6  
     1.7 -val sloppy_name = Long_Name.base_name o Long_Name.qualifier
     1.8 +val sloppy_name = Code_Symbol.base_name;
     1.9  
    1.10 -fun nbe_fun idx_of 0 "" = "nbe_value"
    1.11 -  | nbe_fun idx_of i c = "c_" ^ string_of_int (idx_of c)
    1.12 -      ^ "_" ^ sloppy_name c ^ "_" ^ string_of_int i;
    1.13 +fun nbe_fun idx_of 0 (Code_Symbol.Constant "") = "nbe_value"
    1.14 +  | nbe_fun idx_of i sym = "c_" ^ string_of_int (idx_of sym)
    1.15 +      ^ "_" ^ sloppy_name sym ^ "_" ^ string_of_int i;
    1.16  fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
    1.17  fun nbe_bound v = "v_" ^ v;
    1.18  fun nbe_bound_optional NONE = "_"
    1.19 @@ -291,24 +291,24 @@
    1.20        in (c, (num_args, (dicts, eqns))) end;
    1.21      val eqnss' = map prep_eqns eqnss;
    1.22  
    1.23 -    fun assemble_constapp c dss ts = 
    1.24 +    fun assemble_constapp sym dss ts = 
    1.25        let
    1.26          val ts' = (maps o map) assemble_dict dss @ ts;
    1.27 -      in case AList.lookup (op =) eqnss' c
    1.28 +      in case AList.lookup (op =) eqnss' sym
    1.29         of SOME (num_args, _) => if num_args <= length ts'
    1.30              then let val (ts1, ts2) = chop num_args ts'
    1.31 -            in nbe_apps (nbe_apps_local idx_of 0 c ts1) ts2
    1.32 -            end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 c)) ts'
    1.33 -        | NONE => if member (op =) deps c
    1.34 -            then nbe_apps (nbe_fun idx_of 0 c) ts'
    1.35 -            else nbe_apps_constr idx_of c ts'
    1.36 +            in nbe_apps (nbe_apps_local idx_of 0 sym ts1) ts2
    1.37 +            end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 sym)) ts'
    1.38 +        | NONE => if member (op =) deps sym
    1.39 +            then nbe_apps (nbe_fun idx_of 0 sym) ts'
    1.40 +            else nbe_apps_constr idx_of sym ts'
    1.41        end
    1.42      and assemble_classrels classrels =
    1.43 -      fold_rev (fn classrel => assemble_constapp classrel [] o single) classrels
    1.44 +      fold_rev (fn classrel => assemble_constapp (Code_Symbol.Class_Relation classrel) [] o single) classrels
    1.45      and assemble_dict (Dict (classrels, x)) =
    1.46            assemble_classrels classrels (assemble_plain_dict x)
    1.47      and assemble_plain_dict (Dict_Const (inst, dss)) =
    1.48 -          assemble_constapp inst dss []
    1.49 +          assemble_constapp (Code_Symbol.Class_Instance inst) dss []
    1.50        | assemble_plain_dict (Dict_Var (v, (n, _))) =
    1.51            nbe_dict v n
    1.52  
    1.53 @@ -318,7 +318,7 @@
    1.54            let
    1.55              val (t', ts) = Code_Thingol.unfold_app t
    1.56            in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
    1.57 -        and of_iapp match_cont (IConst { name = c, dicts = dss, ... }) ts = constapp c dss ts
    1.58 +        and of_iapp match_cont (IConst { sym, dicts = dss, ... }) ts = constapp sym dss ts
    1.59            | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
    1.60            | of_iapp match_cont ((v, _) `|=> t) ts =
    1.61                nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts
    1.62 @@ -353,12 +353,12 @@
    1.63          val (args', _) = fold_map subst_vars args samepairs;
    1.64        in (samepairs, args') end;
    1.65  
    1.66 -    fun assemble_eqn c dicts default_args (i, (args, rhs)) =
    1.67 +    fun assemble_eqn sym dicts default_args (i, (args, rhs)) =
    1.68        let
    1.69 -        val match_cont = if c = "" then NONE
    1.70 -          else SOME (nbe_apps_local idx_of (i + 1) c (dicts @ default_args));
    1.71 +        val match_cont = if Code_Symbol.is_value sym then NONE
    1.72 +          else SOME (nbe_apps_local idx_of (i + 1) sym (dicts @ default_args));
    1.73          val assemble_arg = assemble_iterm
    1.74 -          (fn c => fn dss => fn ts => nbe_apps_constr idx_of c ((maps o map) (K "_") dss @ ts)) NONE;
    1.75 +          (fn sym' => fn dss => fn ts => nbe_apps_constr idx_of sym' ((maps o map) (K "_") dss @ ts)) NONE;
    1.76          val assemble_rhs = assemble_iterm assemble_constapp match_cont;
    1.77          val (samepairs, args') = subst_nonlin_vars args;
    1.78          val s_args = map assemble_arg args';
    1.79 @@ -370,17 +370,17 @@
    1.80            | SOME default_rhs =>
    1.81                [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
    1.82                  ([ml_list (rev (dicts @ default_args))], default_rhs)]
    1.83 -      in (nbe_fun idx_of i c, eqns) end;
    1.84 +      in (nbe_fun idx_of i sym, eqns) end;
    1.85  
    1.86 -    fun assemble_eqns (c, (num_args, (dicts, eqns))) =
    1.87 +    fun assemble_eqns (sym, (num_args, (dicts, eqns))) =
    1.88        let
    1.89          val default_args = map nbe_default
    1.90            (Name.invent Name.context "a" (num_args - length dicts));
    1.91 -        val eqns' = map_index (assemble_eqn c dicts default_args) eqns
    1.92 -          @ (if c = "" then [] else [(nbe_fun idx_of (length eqns) c,
    1.93 +        val eqns' = map_index (assemble_eqn sym dicts default_args) eqns
    1.94 +          @ (if Code_Symbol.is_value sym then [] else [(nbe_fun idx_of (length eqns) sym,
    1.95              [([ml_list (rev (dicts @ default_args))],
    1.96 -              nbe_apps_constr idx_of c (dicts @ default_args))])]);
    1.97 -      in (eqns', nbe_abss num_args (nbe_fun idx_of 0 c)) end;
    1.98 +              nbe_apps_constr idx_of sym (dicts @ default_args))])]);
    1.99 +      in (eqns', nbe_abss num_args (nbe_fun idx_of 0 sym)) end;
   1.100  
   1.101      val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
   1.102      val deps_vars = ml_list (map (nbe_fun idx_of 0) deps);
   1.103 @@ -394,9 +394,9 @@
   1.104        let
   1.105          val ctxt = Proof_Context.init_global thy;
   1.106          val (deps, deps_vals) = split_list (map_filter
   1.107 -          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
   1.108 +          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
   1.109          val idx_of = raw_deps
   1.110 -          |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
   1.111 +          |> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep)))
   1.112            |> AList.lookup (op =)
   1.113            |> (fn f => the o f);
   1.114          val s = assemble_eqnss idx_of deps eqnss;
   1.115 @@ -413,45 +413,45 @@
   1.116  
   1.117  (* extract equations from statements *)
   1.118  
   1.119 -fun dummy_const c dss =
   1.120 -  IConst { name = c, typargs = [], dicts = dss,
   1.121 +fun dummy_const sym dss =
   1.122 +  IConst { sym = sym, typargs = [], dicts = dss,
   1.123      dom = [], range = ITyVar "", annotate = false };
   1.124  
   1.125 -fun eqns_of_stmt (_, Code_Thingol.NoStmt _) =
   1.126 +fun eqns_of_stmt (_, Code_Thingol.NoStmt) =
   1.127        []
   1.128 -  | eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
   1.129 +  | eqns_of_stmt (_, Code_Thingol.Fun ((_, []), _)) =
   1.130        []
   1.131 -  | eqns_of_stmt (const, Code_Thingol.Fun (_, (((vs, _), eqns), _))) =
   1.132 -      [(const, (vs, map fst eqns))]
   1.133 +  | eqns_of_stmt (sym_const, Code_Thingol.Fun (((vs, _), eqns), _)) =
   1.134 +      [(sym_const, (vs, map fst eqns))]
   1.135    | eqns_of_stmt (_, Code_Thingol.Datatypecons _) =
   1.136        []
   1.137    | eqns_of_stmt (_, Code_Thingol.Datatype _) =
   1.138        []
   1.139 -  | eqns_of_stmt (class, Code_Thingol.Class (_, (v, (super_classes, classparams)))) =
   1.140 +  | eqns_of_stmt (sym_class, Code_Thingol.Class (v, (classrels, classparams))) =
   1.141        let
   1.142 -        val names = map snd super_classes @ map fst classparams;
   1.143 -        val params = Name.invent Name.context "d" (length names);
   1.144 -        fun mk (k, name) =
   1.145 -          (name, ([(v, [])],
   1.146 -            [([dummy_const class [] `$$ map (IVar o SOME) params],
   1.147 +        val syms = map Code_Symbol.Class_Relation classrels @ map (Code_Symbol.Constant o fst) classparams;
   1.148 +        val params = Name.invent Name.context "d" (length syms);
   1.149 +        fun mk (k, sym) =
   1.150 +          (sym, ([(v, [])],
   1.151 +            [([dummy_const sym_class [] `$$ map (IVar o SOME) params],
   1.152                IVar (SOME (nth params k)))]));
   1.153 -      in map_index mk names end
   1.154 +      in map_index mk syms end
   1.155    | eqns_of_stmt (_, Code_Thingol.Classrel _) =
   1.156        []
   1.157    | eqns_of_stmt (_, Code_Thingol.Classparam _) =
   1.158        []
   1.159 -  | eqns_of_stmt (inst, Code_Thingol.Classinst { class, vs, superinsts, inst_params, ... }) =
   1.160 -      [(inst, (vs, [([], dummy_const class [] `$$
   1.161 -        map (fn (_, (_, (inst, dss))) => dummy_const inst dss) superinsts
   1.162 +  | eqns_of_stmt (sym_inst, Code_Thingol.Classinst { class, tyco, vs, superinsts, inst_params, ... }) =
   1.163 +      [(sym_inst, (vs, [([], dummy_const (Code_Symbol.Type_Class class) [] `$$
   1.164 +        map (fn (class, dss) => dummy_const (Code_Symbol.Class_Instance (tyco, class)) dss) superinsts
   1.165          @ map (IConst o fst o snd o fst) inst_params)]))];
   1.166  
   1.167  
   1.168  (* compile whole programs *)
   1.169  
   1.170  fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) =
   1.171 -  if can (Graph.get_node nbe_program) name
   1.172 +  if can (Code_Symbol.Graph.get_node nbe_program) name
   1.173    then (nbe_program, (maxidx, idx_tab))
   1.174 -  else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
   1.175 +  else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
   1.176      (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
   1.177  
   1.178  fun compile_stmts thy stmts_deps =
   1.179 @@ -468,20 +468,20 @@
   1.180        |> rpair nbe_program;
   1.181    in
   1.182      fold ensure_const_idx refl_deps
   1.183 -    #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   1.184 +    #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
   1.185        #> compile
   1.186 -      #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   1.187 +      #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
   1.188    end;
   1.189  
   1.190  fun compile_program thy program =
   1.191    let
   1.192 -    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
   1.193 +    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
   1.194        then (nbe_program, (maxidx, idx_tab))
   1.195        else (nbe_program, (maxidx, idx_tab))
   1.196 -        |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
   1.197 -          Graph.immediate_succs program name)) names);
   1.198 +        |> compile_stmts thy (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
   1.199 +          Code_Symbol.Graph.immediate_succs program name)) names);
   1.200    in
   1.201 -    fold_rev add_stmts (Graph.strong_conn program)
   1.202 +    fold_rev add_stmts (Code_Symbol.Graph.strong_conn program)
   1.203    end;
   1.204  
   1.205  
   1.206 @@ -493,7 +493,7 @@
   1.207    let 
   1.208      val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   1.209    in
   1.210 -    ("", (vs, [([], t)]))
   1.211 +    (Code_Symbol.value, (vs, [([], t)]))
   1.212      |> singleton (compile_eqnss thy nbe_program deps)
   1.213      |> snd
   1.214      |> (fn t => apps t (rev dict_frees))
   1.215 @@ -502,43 +502,35 @@
   1.216  
   1.217  (* reconstruction *)
   1.218  
   1.219 -fun typ_of_itype program vs (ityco `%% itys) =
   1.220 -      let
   1.221 -        val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco;
   1.222 -      in Type (tyco, map (typ_of_itype program vs) itys) end
   1.223 -  | typ_of_itype program vs (ITyVar v) =
   1.224 -      let
   1.225 -        val sort = (the o AList.lookup (op =) vs) v;
   1.226 -      in TFree ("'" ^ v, sort) end;
   1.227 +fun typ_of_itype vs (tyco `%% itys) =
   1.228 +      Type (tyco, map (typ_of_itype vs) itys)
   1.229 +  | typ_of_itype vs (ITyVar v) =
   1.230 +      TFree ("'" ^ v, (the o AList.lookup (op =) vs) v);
   1.231  
   1.232 -fun term_of_univ thy program idx_tab t =
   1.233 +fun term_of_univ thy idx_tab t =
   1.234    let
   1.235      fun take_until f [] = []
   1.236 -      | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
   1.237 -    fun is_dict (Const (idx, _)) = (case (Graph.get_node program o the o Inttab.lookup idx_tab) idx
   1.238 -         of Code_Thingol.Class _ => true
   1.239 -          | Code_Thingol.Classrel _ => true
   1.240 -          | Code_Thingol.Classinst _ => true
   1.241 -          | _ => false)
   1.242 +      | take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
   1.243 +    fun is_dict (Const (idx, _)) =
   1.244 +          (case Inttab.lookup idx_tab idx of
   1.245 +            SOME (Code_Symbol.Constant _) => false
   1.246 +          | _ => true)
   1.247        | is_dict (DFree _) = true
   1.248        | is_dict _ = false;
   1.249 -    fun const_of_idx idx = (case (Graph.get_node program o the o Inttab.lookup idx_tab) idx
   1.250 -     of Code_Thingol.NoStmt c => c
   1.251 -      | Code_Thingol.Fun (c, _) => c
   1.252 -      | Code_Thingol.Datatypecons (c, _) => c
   1.253 -      | Code_Thingol.Classparam (c, _) => c);
   1.254 +    fun const_of_idx idx =
   1.255 +      case Inttab.lookup idx_tab idx of SOME (Code_Symbol.Constant const) => const;
   1.256      fun of_apps bounds (t, ts) =
   1.257        fold_map (of_univ bounds) ts
   1.258        #>> (fn ts' => list_comb (t, rev ts'))
   1.259      and of_univ bounds (Const (idx, ts)) typidx =
   1.260            let
   1.261              val ts' = take_until is_dict ts;
   1.262 -            val c = const_of_idx idx;
   1.263 +            val const = const_of_idx idx;
   1.264              val T = map_type_tvar (fn ((v, i), _) =>
   1.265                Type_Infer.param typidx (v ^ string_of_int i, []))
   1.266 -                (Sign.the_const_type thy c);
   1.267 +                (Sign.the_const_type thy const);
   1.268              val typidx' = typidx + 1;
   1.269 -          in of_apps bounds (Term.Const (c, T), ts') typidx' end
   1.270 +          in of_apps bounds (Term.Const (const, T), ts') typidx' end
   1.271        | of_univ bounds (BVar (n, ts)) typidx =
   1.272            of_apps bounds (Bound (bounds - n - 1), ts) typidx
   1.273        | of_univ bounds (t as Abs _) typidx =
   1.274 @@ -550,11 +542,11 @@
   1.275  
   1.276  (* evaluation with type reconstruction *)
   1.277  
   1.278 -fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   1.279 +fun eval_term thy (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   1.280    let
   1.281      val ctxt = Syntax.init_pretty_global thy;
   1.282      val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
   1.283 -    val ty' = typ_of_itype program vs0 ty;
   1.284 +    val ty' = typ_of_itype vs0 ty;
   1.285      fun type_infer t =
   1.286        Syntax.check_term (Config.put Type_Infer_Context.const_sorts false ctxt)
   1.287          (Type.constraint ty' t);
   1.288 @@ -563,7 +555,7 @@
   1.289        else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   1.290    in
   1.291      compile_term thy nbe_program deps (vs, t)
   1.292 -    |> term_of_univ thy program idx_tab
   1.293 +    |> term_of_univ thy idx_tab
   1.294      |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   1.295      |> type_infer
   1.296      |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   1.297 @@ -576,8 +568,8 @@
   1.298  
   1.299  structure Nbe_Functions = Code_Data
   1.300  (
   1.301 -  type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   1.302 -  val empty = (Graph.empty, (0, Inttab.empty));
   1.303 +  type T = (Univ option * int) Code_Symbol.Graph.T * (int * Code_Symbol.symbol Inttab.table);
   1.304 +  val empty = (Code_Symbol.Graph.empty, (0, Inttab.empty));
   1.305  );
   1.306  
   1.307  fun compile ignore_cache thy program =
   1.308 @@ -599,26 +591,23 @@
   1.309  
   1.310  val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   1.311    (Thm.add_oracle (@{binding normalization_by_evaluation},
   1.312 -    fn (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
   1.313 -      mk_equals thy ct (eval_term thy program nbe_program_idx_tab vsp_ty_t deps))));
   1.314 +    fn (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
   1.315 +      mk_equals thy ct (eval_term thy nbe_program_idx_tab vsp_ty_t deps))));
   1.316  
   1.317 -fun oracle thy program nbe_program_idx_tab vsp_ty_t deps ct =
   1.318 -  raw_oracle (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct);
   1.319 +fun oracle thy nbe_program_idx_tab vsp_ty_t deps ct =
   1.320 +  raw_oracle (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct);
   1.321  
   1.322 -fun dynamic_conv thy = lift_triv_classes_conv thy (Code_Thingol.dynamic_conv thy
   1.323 -    (K (fn program => oracle thy program (compile false thy program))));
   1.324 +fun dynamic_conv thy = lift_triv_classes_conv thy
   1.325 +  (Code_Thingol.dynamic_conv thy (oracle thy o compile false thy));
   1.326  
   1.327  fun dynamic_value thy = lift_triv_classes_rew thy
   1.328 -  (Code_Thingol.dynamic_value thy I
   1.329 -    (K (fn program => eval_term thy program (compile false thy program))));
   1.330 +  (Code_Thingol.dynamic_value thy I (eval_term thy o compile false thy));
   1.331  
   1.332 -fun static_conv thy consts =
   1.333 -  lift_triv_classes_conv thy (Code_Thingol.static_conv thy consts
   1.334 -    (K (fn program => fn _ => oracle thy program (compile true thy program))));
   1.335 +fun static_conv thy consts = lift_triv_classes_conv thy
   1.336 +  (Code_Thingol.static_conv thy consts (K o oracle thy o compile true thy));
   1.337  
   1.338  fun static_value thy consts = lift_triv_classes_rew thy
   1.339 -  (Code_Thingol.static_value thy I consts
   1.340 -    (K (fn program => fn _ => eval_term thy program (compile true thy program))));
   1.341 +  (Code_Thingol.static_value thy I consts (K o eval_term thy o compile true thy));
   1.342  
   1.343  
   1.344  (** setup **)