src/Pure/Tools/codegen_funcgr.ML
changeset 20600 6d75e02ed285
child 20705 da71d46b8b2f
equal deleted inserted replaced
20599:65bd267ae23f 20600:6d75e02ed285
       
     1 (*  Title:      Pure/Tools/codegen_funcgr.ML
       
     2     ID:         $Id$
       
     3     Author:     Florian Haftmann, TU Muenchen
       
     4 
       
     5 Retrieving and structuring code function theorems.
       
     6 *)
       
     7 
       
     8 signature CODEGEN_FUNCGR =
       
     9 sig
       
    10   type T;
       
    11   val mk_funcgr: theory -> CodegenConsts.const list -> (string * typ) list -> T
       
    12   val get_funcs: T -> CodegenConsts.const -> thm list
       
    13   val get_func_typs: T -> (CodegenConsts.const * typ) list
       
    14   val preprocess: theory -> thm list -> thm list
       
    15   val print_codethms: theory -> CodegenConsts.const list -> unit
       
    16 end;
       
    17 
       
    18 structure CodegenFuncgr: CODEGEN_FUNCGR =
       
    19 struct
       
    20 
       
    21 (** code data **)
       
    22 
       
    23 structure Consttab = CodegenConsts.Consttab;
       
    24 structure Constgraph = GraphFun (
       
    25   type key = CodegenConsts.const;
       
    26   val ord = CodegenConsts.const_ord;
       
    27 );
       
    28 
       
    29 type T = (typ * thm list) Constgraph.T;
       
    30 
       
    31 structure Funcgr = CodeDataFun
       
    32 (struct
       
    33   val name = "Pure/codegen_funcgr";
       
    34   type T = T;
       
    35   val empty = Constgraph.empty;
       
    36   fun merge _ _ = Constgraph.empty;
       
    37   fun purge _ _ = Constgraph.empty;
       
    38 end);
       
    39 
       
    40 val _ = Context.add_setup Funcgr.init;
       
    41 
       
    42 
       
    43 (** theorem purification **)
       
    44 
       
    45 fun abs_norm thy thm =
       
    46   let
       
    47     fun expvars t =
       
    48       let
       
    49         val lhs = (fst o Logic.dest_equals) t;
       
    50         val tys = (fst o strip_type o fastype_of) lhs;
       
    51         val used = fold_aterms (fn Var ((v, _), _) => insert (op =) v | _ => I) lhs [];
       
    52         val vs = Name.invent_list used "x" (length tys);
       
    53       in
       
    54         map2 (fn v => fn ty => Var ((v, 0), ty)) vs tys
       
    55       end;
       
    56     fun expand ct thm =
       
    57       Thm.combination thm (Thm.reflexive ct);
       
    58     fun beta_norm thm =
       
    59       thm
       
    60       |> prop_of
       
    61       |> Logic.dest_equals
       
    62       |> fst
       
    63       |> cterm_of thy
       
    64       |> Thm.beta_conversion true
       
    65       |> Thm.symmetric
       
    66       |> (fn thm' => Thm.transitive thm' thm);
       
    67   in
       
    68     thm
       
    69     |> fold (expand o cterm_of thy) ((expvars o prop_of) thm)
       
    70     |> beta_norm
       
    71   end;
       
    72 
       
    73 fun canonical_tvars thy thm =
       
    74   let
       
    75     fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) =
       
    76       if v = v' orelse member (op =) set v then s
       
    77         else let
       
    78           val ty = TVar (v_i, sort)
       
    79         in
       
    80           (maxidx + 1, v :: set,
       
    81             (ctyp_of thy ty, ctyp_of thy (TVar ((v', maxidx), sort))) :: acc)
       
    82         end;
       
    83     fun tvars_of thm = (fold_types o fold_atyps)
       
    84       (fn TVar (v_i as (v, i), sort) => cons (v_i, (CodegenNames.purify_var v, sort))
       
    85         | _ => I) (prop_of thm) [];
       
    86     val maxidx = Thm.maxidx_of thm + 1;
       
    87     val (_, _, inst) = fold mk_inst (tvars_of thm) (maxidx + 1, [], []);
       
    88   in Thm.instantiate (inst, []) thm end;
       
    89 
       
    90 fun canonical_vars thy thm =
       
    91   let
       
    92     fun mk_inst (v_i as (v, i), (v', ty)) (s as (maxidx, set, acc)) =
       
    93       if v = v' orelse member (op =) set v then s
       
    94         else let
       
    95           val t = if i = ~1 then Free (v, ty) else Var (v_i, ty)
       
    96         in
       
    97           (maxidx + 1,  v :: set,
       
    98             (cterm_of thy t, cterm_of thy (Var ((v', maxidx), ty))) :: acc)
       
    99         end;
       
   100     fun vars_of thm = fold_aterms
       
   101       (fn Var (v_i as (v, i), ty) => cons (v_i, (CodegenNames.purify_var v, ty))
       
   102         | _ => I) (prop_of thm) [];
       
   103     val maxidx = Thm.maxidx_of thm + 1;
       
   104     val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []);
       
   105   in Thm.instantiate ([], inst) thm end;
       
   106 
       
   107 fun preprocess thy thms =
       
   108   let
       
   109     fun burrow_thms f [] = []
       
   110       | burrow_thms f thms =
       
   111           thms
       
   112           |> Conjunction.intr_list
       
   113           |> f
       
   114           |> Conjunction.elim_list;
       
   115     fun unvarify thms =
       
   116       #2 (#1 (Variable.import true thms (ProofContext.init thy)));
       
   117   in
       
   118     thms
       
   119     |> CodegenData.preprocess thy
       
   120     |> map (abs_norm thy)
       
   121     |> burrow_thms (
       
   122         canonical_tvars thy
       
   123         #> canonical_vars thy
       
   124         #> Drule.zero_var_indexes
       
   125        )
       
   126   end;
       
   127 
       
   128 fun check_thms c thms =
       
   129   let
       
   130     fun check_head_lhs thm (lhs, rhs) =
       
   131       case strip_comb lhs
       
   132        of (Const (c', _), _) => if c' = c then ()
       
   133            else error ("Illegal function equation for " ^ quote c
       
   134              ^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm)
       
   135         | _ => error ("Illegal function equation: " ^ Display.string_of_thm thm);
       
   136     fun check_vars_lhs thm (lhs, rhs) =
       
   137       if has_duplicates (op =)
       
   138           (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
       
   139       then error ("Repeated variables on left hand side of function equation:"
       
   140         ^ Display.string_of_thm thm)
       
   141       else ();
       
   142     fun check_vars_rhs thm (lhs, rhs) =
       
   143       if null (subtract (op =)
       
   144         (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
       
   145         (fold_aterms (fn Free (v, _) => cons v | _ => I) rhs []))
       
   146       then ()
       
   147       else error ("Free variables on right hand side of function equation:"
       
   148         ^ Display.string_of_thm thm)
       
   149     val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms;
       
   150   in
       
   151     (map2 check_head_lhs thms tts; map2 check_vars_lhs thms tts;
       
   152       map2 check_vars_rhs thms tts; thms)
       
   153   end;
       
   154 
       
   155 
       
   156 
       
   157 (** retrieval **)
       
   158 
       
   159 fun get_funcs funcgr (c_tys as (c, _)) =
       
   160   (check_thms c o these o Option.map snd o try (Constgraph.get_node funcgr)) c_tys;
       
   161 
       
   162 fun get_func_typs funcgr =
       
   163   AList.make (fst o Constgraph.get_node funcgr) (Constgraph.keys funcgr);
       
   164 
       
   165 local
       
   166 
       
   167 fun add_things_of thy f (c, thms) =
       
   168   (fold o fold_aterms)
       
   169      (fn Const c_ty => let
       
   170             val c' = CodegenConsts.norm_of_typ thy c_ty
       
   171           in if CodegenConsts.eq_const (c, c') then I
       
   172           else f (c', c_ty) end
       
   173        | _ => I) (maps (op :: o swap o apfst (snd o strip_comb)
       
   174             o Logic.dest_equals o Drule.plain_prop_of) thms)
       
   175 
       
   176 fun rhs_of thy (c, thms) =
       
   177   Consttab.empty
       
   178   |> add_things_of thy (Consttab.update o rpair () o fst) (c, thms)
       
   179   |> Consttab.keys;
       
   180 
       
   181 fun rhs_of' thy (c, thms) =
       
   182   add_things_of thy (cons o snd) (c, thms) [];
       
   183 
       
   184 fun insts_of thy funcgr (c, ty) =
       
   185   let
       
   186     val tys = Sign.const_typargs thy (c, ty);
       
   187     val c' = CodegenConsts.norm thy (c, tys);
       
   188     val ty_decl = if (is_none o AxClass.class_of_param thy) c
       
   189       then (fst o Constgraph.get_node funcgr) (CodegenConsts.norm thy (c, tys))
       
   190       else CodegenConsts.typ_of_classop thy (c, tys);
       
   191     val tys_decl = Sign.const_typargs thy (c, ty_decl);
       
   192     val pp = Sign.pp thy;
       
   193     val algebra = Sign.classes_of thy;
       
   194     fun classrel (x, _) _ = x;
       
   195     fun constructor tyco xs class =
       
   196       (tyco, class) :: maps (maps fst) xs;
       
   197     fun variable (TVar (_, sort)) = map (pair []) sort
       
   198       | variable (TFree (_, sort)) = map (pair []) sort;
       
   199     fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
       
   200       | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
       
   201       | mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) =
       
   202           if tyco1 <> tyco2 then error "bad instance"
       
   203           else fold2 mk_inst tys1 tys2;
       
   204   in
       
   205     flat (maps (Sorts.of_sort_derivation pp algebra
       
   206       { classrel = classrel, constructor = constructor, variable = variable })
       
   207       (fold2 mk_inst tys tys_decl []))
       
   208   end;
       
   209 
       
   210 fun all_classops thy tyco class =
       
   211   maps (AxClass.params_of thy)
       
   212       (Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class])
       
   213   |> AList.make (fn c => CodegenConsts.typ_of_classop thy (c, [Type (tyco, [])]))
       
   214         (*typ_of_classop is very liberal in its type arguments*)
       
   215   |> map (CodegenConsts.norm_of_typ thy);
       
   216 
       
   217 fun instdefs_of thy insts =
       
   218   let
       
   219     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
       
   220   in
       
   221     Symtab.empty
       
   222     |> fold (fn (tyco, class) =>
       
   223         Symtab.map_default (tyco, []) (insert (op =) class)) insts
       
   224     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops thy tyco)
       
   225          (Graph.all_succs thy_classes classes))) tab [])
       
   226   end;
       
   227 
       
   228 fun insts_of_thms thy funcgr c_thms =
       
   229   let
       
   230     val insts = add_things_of thy (fn (_, c_ty) => fold (insert (op =))
       
   231       (insts_of thy funcgr c_ty)) c_thms [];
       
   232   in instdefs_of thy insts end;
       
   233 
       
   234 fun ensure_const thy funcgr c auxgr =
       
   235   if can (Constgraph.get_node funcgr) c
       
   236     then (NONE, auxgr)
       
   237   else if can (Constgraph.get_node auxgr) c
       
   238     then (SOME c, auxgr)
       
   239   else if is_some (CodegenData.get_datatype_of_constr thy c) then
       
   240     auxgr
       
   241     |> Constgraph.new_node (c, [])
       
   242     |> pair (SOME c)
       
   243   else let
       
   244     val thms = preprocess thy (CodegenData.these_funcs thy c);
       
   245     val rhs = rhs_of thy (c, thms);
       
   246   in
       
   247     auxgr
       
   248     |> Constgraph.new_node (c, thms)
       
   249     |> fold_map (ensure_const thy funcgr) rhs
       
   250     |-> (fn rhs' => fold (fn SOME c' => Constgraph.add_edge (c, c')
       
   251                            | NONE => I) rhs')
       
   252     |> pair (SOME c)
       
   253   end;
       
   254 
       
   255 fun specialize_typs thy funcgr eqss =
       
   256   let
       
   257     fun max k [] = k
       
   258       | max k (l::ls) = max (if k < l then l else k) ls;
       
   259     fun typscheme_of (c, ty) =
       
   260       try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty))
       
   261       |> Option.map fst;
       
   262     fun incr_indices (c, thms) maxidx =
       
   263       let
       
   264         val thms' = map (Thm.incr_indexes maxidx) thms;
       
   265         val maxidx' = Int.max
       
   266           (maxidx, max ~1 (map Thm.maxidx_of thms') + 1);
       
   267       in ((c, thms'), maxidx') end;
       
   268     val tsig = Sign.tsig_of thy;
       
   269     fun unify_const thms (c, ty) (env, maxidx) =
       
   270       case typscheme_of (c, ty)
       
   271        of SOME ty_decl => let
       
   272             val ty_decl' = Logic.incr_tvar maxidx ty_decl;
       
   273             val maxidx' = Int.max (Term.maxidx_of_typ ty_decl' + 1, maxidx);
       
   274           in Type.unify tsig (ty_decl', ty) (env, maxidx')
       
   275           handle TUNIFY => error ("Failed to instantiate\n"
       
   276             ^ (Sign.string_of_typ thy o Envir.norm_type env) ty_decl' ^ "\nto\n"
       
   277             ^ (Sign.string_of_typ thy o Envir.norm_type env) ty ^ ",\n"
       
   278             ^ "in function theorems\n"
       
   279             ^ cat_lines (map string_of_thm thms))
       
   280           end
       
   281         | NONE => (env, maxidx);
       
   282     fun apply_unifier unif (c, []) = (c, [])
       
   283       | apply_unifier unif (c, thms as thm :: _) =
       
   284           let
       
   285             val ty = CodegenData.typ_func thy thm;
       
   286             val ty' = Envir.norm_type unif ty;
       
   287             val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty;
       
   288             val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) =>
       
   289               cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []);
       
   290           in (c, map (Drule.zero_var_indexes o inst) thms) end;
       
   291     val (eqss', maxidx) =
       
   292       fold_map incr_indices eqss 0;
       
   293     val (unif, _) =
       
   294       fold (fn (c, thms) => fold (unify_const thms) (rhs_of' thy (c, thms)))
       
   295         eqss' (Vartab.empty, maxidx);
       
   296     val eqss'' =
       
   297       map (apply_unifier unif) eqss';
       
   298   in eqss'' end;
       
   299 
       
   300 fun merge_eqsyss thy raw_eqss funcgr =
       
   301   let
       
   302     val eqss = specialize_typs thy funcgr raw_eqss;
       
   303     val tys = map (fn (c as (name, _), []) => (case AxClass.class_of_param thy name
       
   304          of SOME class => (case ClassPackage.the_consts_sign thy class of (v, cs) =>
       
   305               (Logic.varifyT o map_type_tfree (fn u as (w, _) =>
       
   306                 if w = v then TFree (v, [class]) else TFree u))
       
   307               ((the o AList.lookup (op =) cs) name))
       
   308           | NONE => Sign.the_const_type thy name)
       
   309                    | (_, eq :: _) => CodegenData.typ_func thy eq) eqss;
       
   310     val rhss = map (rhs_of thy) eqss;
       
   311   in
       
   312     funcgr
       
   313     |> fold2 (fn (c, thms) => fn ty => Constgraph.new_node (c, (ty, thms))) eqss tys
       
   314     |> `(fn funcgr => map (insts_of_thms thy funcgr) eqss)
       
   315     |-> (fn rhs_insts => fold2 (fn (c, _) => fn rhs_inst =>
       
   316           ensure_consts thy rhs_inst #> fold (curry Constgraph.add_edge c) rhs_inst) eqss rhs_insts)
       
   317     |> fold2 (fn (c, _) => fn rhs => fold (curry Constgraph.add_edge c) rhs) eqss rhss
       
   318   end
       
   319 and ensure_consts thy cs funcgr =
       
   320   fold (snd oo ensure_const thy funcgr) cs Constgraph.empty
       
   321   |> (fn auxgr => fold (merge_eqsyss thy)
       
   322        (map (AList.make (Constgraph.get_node auxgr))
       
   323        (rev (Constgraph.strong_conn auxgr))) funcgr);
       
   324 
       
   325 in
       
   326 
       
   327 val ensure_consts = ensure_consts;
       
   328 
       
   329 fun mk_funcgr thy consts cs =
       
   330   Funcgr.change thy (
       
   331     ensure_consts thy consts
       
   332     #> (fn funcgr => ensure_consts thy
       
   333          (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr)
       
   334   );
       
   335 
       
   336 end; (*local*)
       
   337 
       
   338 fun print_funcgr thy funcgr =
       
   339   AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
       
   340   |> (map o apfst) (CodegenConsts.string_of_const thy)
       
   341   |> sort (string_ord o pairself fst)
       
   342   |> map (fn (s, thms) =>
       
   343        (Pretty.block o Pretty.fbreaks) (
       
   344          Pretty.str s
       
   345          :: map Display.pretty_thm thms
       
   346        ))
       
   347   |> Pretty.chunks
       
   348   |> Pretty.writeln;
       
   349 
       
   350 fun print_codethms thy consts =
       
   351   mk_funcgr thy consts [] |> print_funcgr thy;
       
   352 
       
   353 fun print_codethms_e thy cs =
       
   354   print_codethms thy (map (CodegenConsts.read_const thy) cs);
       
   355 
       
   356 
       
   357 (** Isar **)
       
   358 
       
   359 structure P = OuterParse;
       
   360 
       
   361 val print_codethmsK = "print_codethms";
       
   362 
       
   363 val print_codethmsP =
       
   364   OuterSyntax.improper_command print_codethmsK "print code theorems of this theory" OuterKeyword.diag
       
   365     (Scan.option (P.$$$ "(" |-- Scan.repeat P.term --| P.$$$ ")")
       
   366       >> (fn NONE => CodegenData.print_thms
       
   367            | SOME cs => fn thy => print_codethms_e thy cs)
       
   368       >> (fn f => Toplevel.no_timing o Toplevel.unknown_theory
       
   369       o Toplevel.keep (f o Toplevel.theory_of)));
       
   370 
       
   371 val _ = OuterSyntax.add_parsers [print_codethmsP];
       
   372 
       
   373 end; (*struct*)