src/Tools/code/code_funcgr.ML
changeset 24283 8ca96f4e49cd
parent 24219 e558fe311376
child 24423 ae9cd0e92423
equal deleted inserted replaced
24282:9b64aa297524 24283:8ca96f4e49cd
    10 sig
    10 sig
    11   type T
    11   type T
    12   val timing: bool ref
    12   val timing: bool ref
    13   val funcs: T -> CodeUnit.const -> thm list
    13   val funcs: T -> CodeUnit.const -> thm list
    14   val typ: T -> CodeUnit.const -> typ
    14   val typ: T -> CodeUnit.const -> typ
    15   val deps: T -> CodeUnit.const list -> CodeUnit.const list list
       
    16   val all: T -> CodeUnit.const list
    15   val all: T -> CodeUnit.const list
    17   val pretty: theory -> T -> Pretty.T
    16   val pretty: theory -> T -> Pretty.T
       
    17   val make: theory -> CodeUnit.const list -> T
       
    18   val make_consts: theory -> CodeUnit.const list -> CodeUnit.const list * T
       
    19   val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
       
    20   val eval_term: theory -> (T -> cterm -> 'a) -> cterm -> 'a
       
    21   val intervene: theory -> T -> T
       
    22     (*FIXME drop intervene as soon as possible*)
    18   structure Constgraph : GRAPH
    23   structure Constgraph : GRAPH
    19 end
    24 end
    20 
    25 
    21 signature CODE_FUNCGR_RETRIEVAL =
    26 structure CodeFuncgr : CODE_FUNCGR =
    22 sig
       
    23   type T (* = CODE_FUNCGR.T *)
       
    24   val make: theory -> CodeUnit.const list -> T
       
    25   val make_consts: theory -> CodeUnit.const list -> CodeUnit.const list * T
       
    26   val make_term: theory -> (T -> (thm -> thm) -> cterm -> thm -> 'a) -> cterm -> 'a * T
       
    27     (*FIXME drop make_term as soon as possible*)
       
    28   val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
       
    29   val intervene: theory -> T -> T
       
    30     (*FIXME drop intervene as soon as possible*)
       
    31 end;
       
    32 
       
    33 structure CodeFuncgr = (*signature is added later*)
       
    34 struct
    27 struct
    35 
    28 
    36 (** the graph type **)
    29 (** the graph type **)
    37 
    30 
    38 structure Constgraph = GraphFun (
    31 structure Constgraph = GraphFun (
    45 fun funcs funcgr =
    38 fun funcs funcgr =
    46   these o Option.map snd o try (Constgraph.get_node funcgr);
    39   these o Option.map snd o try (Constgraph.get_node funcgr);
    47 
    40 
    48 fun typ funcgr =
    41 fun typ funcgr =
    49   fst o Constgraph.get_node funcgr;
    42   fst o Constgraph.get_node funcgr;
    50 
       
    51 fun deps funcgr cs =
       
    52   let
       
    53     val conn = Constgraph.strong_conn funcgr;
       
    54     val order = rev conn;
       
    55   in
       
    56     (map o filter) (member (op =) (Constgraph.all_succs funcgr cs)) order
       
    57     |> filter_out null
       
    58   end;
       
    59 
    43 
    60 fun all funcgr = Constgraph.keys funcgr;
    44 fun all funcgr = Constgraph.keys funcgr;
    61 
    45 
    62 fun pretty thy funcgr =
    46 fun pretty thy funcgr =
    63   AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
    47   AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
   206     []
   190     []
   207     |> fold (fold (insert (op =)) o inst) consts
   191     |> fold (fold (insert (op =)) o inst) consts
   208     |> instances_of thy algebra
   192     |> instances_of thy algebra
   209   end;
   193   end;
   210 
   194 
   211 fun ensure_const' rewrites thy algebra funcgr const auxgr =
   195 fun ensure_const' thy algebra funcgr const auxgr =
   212   if can (Constgraph.get_node funcgr) const
   196   if can (Constgraph.get_node funcgr) const
   213     then (NONE, auxgr)
   197     then (NONE, auxgr)
   214   else if can (Constgraph.get_node auxgr) const
   198   else if can (Constgraph.get_node auxgr) const
   215     then (SOME const, auxgr)
   199     then (SOME const, auxgr)
   216   else if is_some (Code.get_datatype_of_constr thy const) then
   200   else if is_some (Code.get_datatype_of_constr thy const) then
   217     auxgr
   201     auxgr
   218     |> Constgraph.new_node (const, [])
   202     |> Constgraph.new_node (const, [])
   219     |> pair (SOME const)
   203     |> pair (SOME const)
   220   else let
   204   else let
   221     val thms = Code.these_funcs thy const
   205     val thms = Code.these_funcs thy const
   222       |> map (CodeUnit.rewrite_func (rewrites thy))
       
   223       |> CodeUnit.norm_args
   206       |> CodeUnit.norm_args
   224       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   207       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   225     val rhs = consts_of (const, thms);
   208     val rhs = consts_of (const, thms);
   226   in
   209   in
   227     auxgr
   210     auxgr
   228     |> Constgraph.new_node (const, thms)
   211     |> Constgraph.new_node (const, thms)
   229     |> fold_map (ensure_const rewrites thy algebra funcgr) rhs
   212     |> fold_map (ensure_const thy algebra funcgr) rhs
   230     |-> (fn rhs' => fold (fn SOME const' => Constgraph.add_edge (const, const')
   213     |-> (fn rhs' => fold (fn SOME const' => Constgraph.add_edge (const, const')
   231                            | NONE => I) rhs')
   214                            | NONE => I) rhs')
   232     |> pair (SOME const)
   215     |> pair (SOME const)
   233   end
   216   end
   234 and ensure_const rewrites thy algebra funcgr const =
   217 and ensure_const thy algebra funcgr const =
   235   let
   218   let
   236     val timeap = if !timing
   219     val timeap = if !timing
   237       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   220       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   238       else I;
   221       else I;
   239   in timeap (ensure_const' rewrites thy algebra funcgr const) end;
   222   in timeap (ensure_const' thy algebra funcgr const) end;
   240 
   223 
   241 fun merge_funcss rewrites thy algebra raw_funcss funcgr =
   224 fun merge_funcss thy algebra raw_funcss funcgr =
   242   let
   225   let
   243     val funcss = raw_funcss
   226     val funcss = raw_funcss
   244       |> resort_funcss thy algebra funcgr
   227       |> resort_funcss thy algebra funcgr
   245       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   228       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   246     fun typ_func const [] = Code.default_typ thy const
   229     fun typ_func const [] = Code.default_typ thy const
   265         val deps = consts_of funcs;
   248         val deps = consts_of funcs;
   266         val insts = instances_of_consts thy algebra funcgr
   249         val insts = instances_of_consts thy algebra funcgr
   267           (fold_consts (insert (op =)) thms []);
   250           (fold_consts (insert (op =)) thms []);
   268       in
   251       in
   269         funcgr
   252         funcgr
   270         |> ensure_consts' rewrites thy algebra insts
   253         |> ensure_consts' thy algebra insts
   271         |> fold (curry Constgraph.add_edge const) deps
   254         |> fold (curry Constgraph.add_edge const) deps
   272         |> fold (curry Constgraph.add_edge const) insts
   255         |> fold (curry Constgraph.add_edge const) insts
   273        end;
   256        end;
   274   in
   257   in
   275     funcgr
   258     funcgr
   276     |> fold add_funcs funcss
   259     |> fold add_funcs funcss
   277     |> fold add_deps funcss
   260     |> fold add_deps funcss
   278   end
   261   end
   279 and ensure_consts' rewrites thy algebra cs funcgr =
   262 and ensure_consts' thy algebra cs funcgr =
   280   let
   263   let
   281     val auxgr = Constgraph.empty
   264     val auxgr = Constgraph.empty
   282       |> fold (snd oo ensure_const rewrites thy algebra funcgr) cs;
   265       |> fold (snd oo ensure_const thy algebra funcgr) cs;
   283   in
   266   in
   284     funcgr
   267     funcgr
   285     |> fold (merge_funcss rewrites thy algebra)
   268     |> fold (merge_funcss thy algebra)
   286          (map (AList.make (Constgraph.get_node auxgr))
   269          (map (AList.make (Constgraph.get_node auxgr))
   287          (rev (Constgraph.strong_conn auxgr)))
   270          (rev (Constgraph.strong_conn auxgr)))
   288   end handle INVALID (cs', msg)
   271   end handle INVALID (cs', msg)
   289     => raise INVALID (fold (insert CodeUnit.eq_const) cs' cs, msg);
   272     => raise INVALID (fold (insert CodeUnit.eq_const) cs' cs, msg);
   290 
   273 
   291 fun ensure_consts rewrites thy consts funcgr =
   274 fun ensure_consts thy consts funcgr =
   292   let
   275   let
   293     val algebra = Code.coregular_algebra thy
   276     val algebra = Code.coregular_algebra thy
   294   in ensure_consts' rewrites thy algebra consts funcgr
   277   in ensure_consts' thy algebra consts funcgr
   295     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   278     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   296     ^ commas (map (CodeUnit.string_of_const thy) cs'))
   279     ^ commas (map (CodeUnit.string_of_const thy) cs'))
   297   end;
   280   end;
   298 
   281 
   299 in
   282 in
   300 
   283 
   301 (** retrieval interfaces **)
   284 (** retrieval interfaces **)
   302 
   285 
   303 val ensure_consts = ensure_consts;
   286 val ensure_consts = ensure_consts;
   304 
   287 
   305 fun check_consts rewrites thy consts funcgr =
   288 fun check_consts thy consts funcgr =
   306   let
   289   let
   307     val algebra = Code.coregular_algebra thy;
   290     val algebra = Code.coregular_algebra thy;
   308     fun try_const const funcgr =
   291     fun try_const const funcgr =
   309       (SOME const, ensure_consts' rewrites thy algebra [const] funcgr)
   292       (SOME const, ensure_consts' thy algebra [const] funcgr)
   310       handle INVALID (cs', msg) => (NONE, funcgr);
   293       handle INVALID (cs', msg) => (NONE, funcgr);
   311     val (consts', funcgr') = fold_map try_const consts funcgr;
   294     val (consts', funcgr') = fold_map try_const consts funcgr;
   312   in (map_filter I consts', funcgr') end;
   295   in (map_filter I consts', funcgr') end;
   313 
   296 
   314 fun ensure_consts_term rewrites thy f ct funcgr =
   297 fun ensure_consts_term_proto thy f ct funcgr =
   315   let
   298   let
   316     fun consts_of thy t =
   299     fun consts_of thy t =
   317       fold_aterms (fn Const c => cons (CodeUnit.const_of_cexpr thy c) | _ => I) t []
   300       fold_aterms (fn Const c => cons (CodeUnit.const_of_cexpr thy c) | _ => I) t []
   318     fun rhs_conv conv thm =
   301     fun rhs_conv conv thm =
   319       let
   302       let
   320         val thm' = (conv o Thm.rhs_of) thm;
   303         val thm' = (conv o Thm.rhs_of) thm;
   321       in Thm.transitive thm thm' end
   304       in Thm.transitive thm thm' end
   322     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   305     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   323     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   306     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   324     val thm1 = Code.preprocess_conv ct
   307     val thm1 = Code.preprocess_conv ct;
   325       |> fold (rhs_conv o MetaSimplifier.rewrite false o single) (rewrites thy);
       
   326     val ct' = Thm.rhs_of thm1;
   308     val ct' = Thm.rhs_of thm1;
   327     val consts = consts_of thy (Thm.term_of ct');
   309     val consts = consts_of thy (Thm.term_of ct');
   328     val funcgr' = ensure_consts rewrites thy consts funcgr;
   310     val funcgr' = ensure_consts thy consts funcgr;
   329     val algebra = Code.coregular_algebra thy;
   311     val algebra = Code.coregular_algebra thy;
   330     val (_, thm2) = Thm.varifyT' [] thm1;
   312     val (_, thm2) = Thm.varifyT' [] thm1;
   331     val thm3 = Thm.reflexive (Thm.rhs_of thm2);
   313     val thm3 = Thm.reflexive (Thm.rhs_of thm2);
   332     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodeUnit.const_of_cexpr thy);
   314     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodeUnit.const_of_cexpr thy);
   333     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   315     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   342     val thm6 = inst thm4;
   324     val thm6 = inst thm4;
   343     val ct'' = Thm.rhs_of thm6;
   325     val ct'' = Thm.rhs_of thm6;
   344     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
   326     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
   345     val drop = drop_classes thy tfrees;
   327     val drop = drop_classes thy tfrees;
   346     val instdefs = instances_of_consts thy algebra funcgr' cs;
   328     val instdefs = instances_of_consts thy algebra funcgr' cs;
   347     val funcgr'' = ensure_consts rewrites thy instdefs funcgr';
   329     val funcgr'' = ensure_consts thy instdefs funcgr';
   348   in (f funcgr'' drop ct'' thm5, funcgr'') end;
   330   in (f funcgr'' drop ct'' thm5, funcgr'') end;
   349 
   331 
   350 fun ensure_consts_eval rewrites thy conv =
   332 fun ensure_consts_eval thy conv =
   351   let
   333   let
   352     fun conv' funcgr drop_classes ct thm1 =
   334     fun conv' funcgr drop_classes ct thm1 =
   353       let
   335       let
   354         val thm2 = conv funcgr ct;
   336         val thm2 = conv funcgr ct;
   355         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   337         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   357       in
   339       in
   358         Thm.transitive thm1 thm23 handle THM _ =>
   340         Thm.transitive thm1 thm23 handle THM _ =>
   359           error ("eval_conv - could not construct proof:\n"
   341           error ("eval_conv - could not construct proof:\n"
   360           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
   342           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
   361       end;
   343       end;
   362   in ensure_consts_term rewrites thy conv' end;
   344   in ensure_consts_term_proto thy conv' end;
       
   345 
       
   346 fun ensure_consts_term thy f =
       
   347   let
       
   348     fun f' funcgr drop_classes ct thm1 = f funcgr ct;
       
   349   in ensure_consts_term_proto thy f' end;
   363 
   350 
   364 end; (*local*)
   351 end; (*local*)
   365 
       
   366 end; (*struct*)
       
   367 
       
   368 functor CodeFuncgrRetrieval (val rewrites: theory -> thm list) : CODE_FUNCGR_RETRIEVAL =
       
   369 struct
       
   370 
       
   371 (** code data **)
       
   372 
       
   373 type T = CodeFuncgr.T;
       
   374 
   352 
   375 structure Funcgr = CodeDataFun
   353 structure Funcgr = CodeDataFun
   376 (struct
   354 (struct
   377   type T = T;
   355   type T = T;
   378   val empty = CodeFuncgr.Constgraph.empty;
   356   val empty = Constgraph.empty;
   379   fun merge _ _ = CodeFuncgr.Constgraph.empty;
   357   fun merge _ _ = Constgraph.empty;
   380   fun purge _ NONE _ = CodeFuncgr.Constgraph.empty
   358   fun purge _ NONE _ = Constgraph.empty
   381     | purge _ (SOME cs) funcgr =
   359     | purge _ (SOME cs) funcgr =
   382         CodeFuncgr.Constgraph.del_nodes ((CodeFuncgr.Constgraph.all_preds funcgr 
   360         Constgraph.del_nodes ((Constgraph.all_preds funcgr 
   383           o filter (can (CodeFuncgr.Constgraph.get_node funcgr))) cs) funcgr;
   361           o filter (can (Constgraph.get_node funcgr))) cs) funcgr;
   384 end);
   362 end);
   385 
   363 
   386 fun make thy =
   364 fun make thy =
   387   Funcgr.change thy o CodeFuncgr.ensure_consts rewrites thy;
   365   Funcgr.change thy o ensure_consts thy;
   388 
   366 
   389 fun make_consts thy =
   367 fun make_consts thy =
   390   Funcgr.change_yield thy o CodeFuncgr.check_consts rewrites thy;
   368   Funcgr.change_yield thy o check_consts thy;
   391 
       
   392 fun make_term thy f =
       
   393   Funcgr.change_yield thy o CodeFuncgr.ensure_consts_term rewrites thy f;
       
   394 
   369 
   395 fun eval_conv thy f =
   370 fun eval_conv thy f =
   396   fst o Funcgr.change_yield thy o CodeFuncgr.ensure_consts_eval rewrites thy f;
   371   fst o Funcgr.change_yield thy o ensure_consts_eval thy f;
       
   372 
       
   373 fun eval_term thy f =
       
   374   fst o Funcgr.change_yield thy o ensure_consts_term thy f;
   397 
   375 
   398 fun intervene thy funcgr = Funcgr.change thy (K funcgr);
   376 fun intervene thy funcgr = Funcgr.change thy (K funcgr);
   399 
   377 
   400 end; (*functor*)
       
   401 
       
   402 structure CodeFuncgr : CODE_FUNCGR =
       
   403 struct
       
   404 
       
   405 open CodeFuncgr;
       
   406 
       
   407 end; (*struct*)
   378 end; (*struct*)