src/HOL/Tools/Predicate_Compile/core_data.ML
changeset 55543 f0ef75c6c0d8
parent 55471 198498f861ee
child 58823 513268cb2178
equal deleted inserted replaced
55537:6ec3c2c38650 55543:f0ef75c6c0d8
    16     elim : thm,
    16     elim : thm,
    17     neg_intro : thm option
    17     neg_intro : thm option
    18   };
    18   };
    19 
    19 
    20   datatype pred_data = PredData of {
    20   datatype pred_data = PredData of {
       
    21     pos : Position.T,
    21     intros : (string option * thm) list,
    22     intros : (string option * thm) list,
    22     elim : thm option,
    23     elim : thm option,
    23     preprocessed : bool,
    24     preprocessed : bool,
    24     function_names : (compilation * (mode * string) list) list,
    25     function_names : (compilation * (mode * string) list) list,
    25     predfun_data : (mode * predfun_data) list,
    26     predfun_data : (mode * predfun_data) list,
    98 
    99 
    99 fun mk_predfun_data (definition, ((intro, elim), neg_intro)) =
   100 fun mk_predfun_data (definition, ((intro, elim), neg_intro)) =
   100   PredfunData {definition = definition, intro = intro, elim = elim, neg_intro = neg_intro}
   101   PredfunData {definition = definition, intro = intro, elim = elim, neg_intro = neg_intro}
   101 
   102 
   102 datatype pred_data = PredData of {
   103 datatype pred_data = PredData of {
       
   104   pos: Position.T,
   103   intros : (string option * thm) list,
   105   intros : (string option * thm) list,
   104   elim : thm option,
   106   elim : thm option,
   105   preprocessed : bool,
   107   preprocessed : bool,
   106   function_names : (compilation * (mode * string) list) list,
   108   function_names : (compilation * (mode * string) list) list,
   107   predfun_data : (mode * predfun_data) list,
   109   predfun_data : (mode * predfun_data) list,
   108   needs_random : mode list
   110   needs_random : mode list
   109 };
   111 };
   110 
   112 
   111 fun rep_pred_data (PredData data) = data;
   113 fun rep_pred_data (PredData data) = data;
   112 
   114 val pos_of = #pos o rep_pred_data;
   113 fun mk_pred_data (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))) =
   115 
   114   PredData {intros = intros, elim = elim, preprocessed = preprocessed,
   116 fun mk_pred_data
       
   117     (pos, (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random)))) =
       
   118   PredData {pos = pos, intros = intros, elim = elim, preprocessed = preprocessed,
   115     function_names = function_names, predfun_data = predfun_data, needs_random = needs_random}
   119     function_names = function_names, predfun_data = predfun_data, needs_random = needs_random}
   116 
   120 
   117 fun map_pred_data f (PredData {intros, elim, preprocessed, function_names, predfun_data, needs_random}) =
   121 fun map_pred_data f
   118   mk_pred_data (f (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))))
   122     (PredData {pos, intros, elim, preprocessed, function_names, predfun_data, needs_random}) =
       
   123   mk_pred_data
       
   124     (f (pos, (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random)))))
   119 
   125 
   120 fun eq_option eq (NONE, NONE) = true
   126 fun eq_option eq (NONE, NONE) = true
   121   | eq_option eq (SOME x, SOME y) = eq (x, y)
   127   | eq_option eq (SOME x, SOME y) = eq (x, y)
   122   | eq_option eq _ = false
   128   | eq_option eq _ = false
   123 
   129 
   128 structure PredData = Theory_Data
   134 structure PredData = Theory_Data
   129 (
   135 (
   130   type T = pred_data Graph.T;
   136   type T = pred_data Graph.T;
   131   val empty = Graph.empty;
   137   val empty = Graph.empty;
   132   val extend = I;
   138   val extend = I;
   133   val merge = Graph.merge eq_pred_data;
   139   val merge =
       
   140     Graph.join (fn key => fn (x, y) =>
       
   141       if eq_pred_data (x, y)
       
   142       then raise Graph.SAME
       
   143       else
       
   144         error ("Duplicate predicate declarations for " ^ quote key ^
       
   145           Position.here (pos_of x) ^ Position.here (pos_of y)));
   134 );
   146 );
   135 
   147 
   136 
   148 
   137 (* queries *)
   149 (* queries *)
   138 
   150 
   258 
   270 
   259 fun fetch_pred_data ctxt name =
   271 fun fetch_pred_data ctxt name =
   260   (case try (Inductive.the_inductive ctxt) name of
   272   (case try (Inductive.the_inductive ctxt) name of
   261     SOME (info as (_, result)) => 
   273     SOME (info as (_, result)) => 
   262       let
   274       let
       
   275         val thy = Proof_Context.theory_of ctxt
       
   276 
       
   277         val pos = Position.thread_data ()
   263         fun is_intro_of intro =
   278         fun is_intro_of intro =
   264           let
   279           let
   265             val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
   280             val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
   266           in (fst (dest_Const const) = name) end;
   281           in (fst (dest_Const const) = name) end;
   267         val thy = Proof_Context.theory_of ctxt
       
   268         val intros = map (preprocess_intro thy) (filter is_intro_of (#intrs result))
   282         val intros = map (preprocess_intro thy) (filter is_intro_of (#intrs result))
   269         val index = find_index (fn s => s = name) (#names (fst info))
   283         val index = find_index (fn s => s = name) (#names (fst info))
   270         val pre_elim = nth (#elims result) index
   284         val pre_elim = nth (#elims result) index
   271         val pred = nth (#preds result) index
   285         val pred = nth (#preds result) index
   272         val elim_t = mk_casesrule ctxt pred intros
   286         val elim_t = mk_casesrule ctxt pred intros
   273         val nparams = length (Inductive.params_of (#raw_induct result))
   287         val nparams = length (Inductive.params_of (#raw_induct result))
   274         val elim = prove_casesrule ctxt (pred, (pre_elim, nparams)) elim_t
   288         val elim = prove_casesrule ctxt (pred, (pre_elim, nparams)) elim_t
   275       in
   289       in
   276         mk_pred_data (((map (pair NONE) intros, SOME elim), true), no_compilation)
   290         mk_pred_data (pos, (((map (pair NONE) intros, SOME elim), true), no_compilation))
   277       end
   291       end
   278   | NONE => error ("No such predicate: " ^ quote name))
   292   | NONE => error ("No such predicate: " ^ quote name))
   279 
   293 
   280 fun add_predfun_data name mode data =
   294 fun add_predfun_data name mode data =
   281   let
   295   let
   282     val add = (apsnd o apsnd o apfst) (cons (mode, mk_predfun_data data))
   296     val add = (apsnd o apsnd o apsnd o apfst) (cons (mode, mk_predfun_data data))
   283   in PredData.map (Graph.map_node name (map_pred_data add)) end
   297   in PredData.map (Graph.map_node name (map_pred_data add)) end
   284 
   298 
   285 fun is_inductive_predicate ctxt name =
   299 fun is_inductive_predicate ctxt name =
   286   is_some (try (Inductive.the_inductive ctxt) name)
   300   is_some (try (Inductive.the_inductive ctxt) name)
   287 
   301 
   303     val (name, _) = dest_Const (fst (strip_intro_concl thm))
   317     val (name, _) = dest_Const (fst (strip_intro_concl thm))
   304     fun cons_intro gr =
   318     fun cons_intro gr =
   305       (case try (Graph.get_node gr) name of
   319       (case try (Graph.get_node gr) name of
   306         SOME _ =>
   320         SOME _ =>
   307           Graph.map_node name (map_pred_data
   321           Graph.map_node name (map_pred_data
   308             (apfst (apfst (apfst (fn intros => intros @ [(opt_case_name, thm)]))))) gr
   322             (apsnd (apfst (apfst (apfst (fn intros => intros @ [(opt_case_name, thm)])))))) gr
   309       | NONE =>
   323       | NONE =>
   310           Graph.new_node
   324           Graph.new_node
   311             (name, mk_pred_data ((([(opt_case_name, thm)], NONE), false), no_compilation)) gr)
   325             (name,
       
   326               mk_pred_data (Position.thread_data (),
       
   327                 (((([(opt_case_name, thm)], NONE), false), no_compilation)))) gr)
   312   in PredData.map cons_intro thy end
   328   in PredData.map cons_intro thy end
   313 
   329 
   314 fun set_elim thm =
   330 fun set_elim thm =
   315   let
   331   let
   316     val (name, _) =
   332     val (name, _) =
   317       dest_Const (fst (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
   333       dest_Const (fst (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
   318   in PredData.map (Graph.map_node name (map_pred_data (apfst (apfst (apsnd (K (SOME thm))))))) end
   334   in
       
   335     PredData.map (Graph.map_node name (map_pred_data (apsnd (apfst (apfst (apsnd (K (SOME thm))))))))
       
   336   end
   319 
   337 
   320 fun register_predicate (constname, intros, elim) thy =
   338 fun register_predicate (constname, intros, elim) thy =
   321   let
   339   let
   322     val named_intros = map (pair NONE) intros
   340     val named_intros = map (pair NONE) intros
   323   in
   341   in
   324     if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
   342     if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
   325       PredData.map
   343       PredData.map
   326         (Graph.new_node (constname,
   344         (Graph.new_node (constname,
   327           mk_pred_data (((named_intros, SOME elim), false), no_compilation))) thy
   345           mk_pred_data (Position.thread_data (),
       
   346             (((named_intros, SOME elim), false), no_compilation)))) thy
   328     else thy
   347     else thy
   329   end
   348   end
   330 
   349 
   331 fun register_intros (constname, pre_intros) thy =
   350 fun register_intros (constname, pre_intros) thy =
   332   let
   351   let
   343       (mk_casesrule (Proof_Context.init_global thy) pred pre_intros)
   362       (mk_casesrule (Proof_Context.init_global thy) pred pre_intros)
   344   in register_predicate (constname, pre_intros, pre_elim) thy end
   363   in register_predicate (constname, pre_intros, pre_elim) thy end
   345 
   364 
   346 fun defined_function_of compilation pred =
   365 fun defined_function_of compilation pred =
   347   let
   366   let
   348     val set = (apsnd o apfst) (cons (compilation, []))
   367     val set = (apsnd o apsnd o apfst) (cons (compilation, []))
   349   in
   368   in
   350     PredData.map (Graph.map_node pred (map_pred_data set))
   369     PredData.map (Graph.map_node pred (map_pred_data set))
   351   end
   370   end
   352 
   371 
   353 fun set_function_name compilation pred mode name =
   372 fun set_function_name compilation pred mode name =
   354   let
   373   let
   355     val set = (apsnd o apfst)
   374     val set = (apsnd o apsnd o apfst)
   356       (AList.map_default (op =) (compilation, [(mode, name)]) (cons (mode, name)))
   375       (AList.map_default (op =) (compilation, [(mode, name)]) (cons (mode, name)))
   357   in
   376   in
   358     PredData.map (Graph.map_node pred (map_pred_data set))
   377     PredData.map (Graph.map_node pred (map_pred_data set))
   359   end
   378   end
   360 
   379 
   361 fun set_needs_random name modes =
   380 fun set_needs_random name modes =
   362   let
   381   let
   363     val set = (apsnd o apsnd o apsnd) (K modes)
   382     val set = (apsnd o apsnd o apsnd o apsnd) (K modes)
   364   in
   383   in
   365     PredData.map (Graph.map_node name (map_pred_data set))
   384     PredData.map (Graph.map_node name (map_pred_data set))
   366   end  
   385   end  
   367 
   386 
   368 fun extend' value_of edges_of key (G, visited) =
   387 fun extend' value_of edges_of key (G, visited) =
   387   in
   406   in
   388     PredData.map (fold (extend (fetch_pred_data ctxt) (depending_preds_of ctxt)) names) thy
   407     PredData.map (fold (extend (fetch_pred_data ctxt) (depending_preds_of ctxt)) names) thy
   389   end
   408   end
   390 
   409 
   391 fun preprocess_intros name thy =
   410 fun preprocess_intros name thy =
   392   PredData.map (Graph.map_node name (map_pred_data (apfst (fn (rules, preprocessed) =>
   411   PredData.map (Graph.map_node name (map_pred_data (apsnd (apfst (fn (rules, preprocessed) =>
   393     if preprocessed then (rules, preprocessed)
   412     if preprocessed then (rules, preprocessed)
   394     else
   413     else
   395       let
   414       let
   396         val (named_intros, SOME elim) = rules
   415         val (named_intros, SOME elim) = rules
   397         val named_intros' = map (apsnd (preprocess_intro thy)) named_intros
   416         val named_intros' = map (apsnd (preprocess_intro thy)) named_intros
   399         val ctxt = Proof_Context.init_global thy
   418         val ctxt = Proof_Context.init_global thy
   400         val elim_t = mk_casesrule ctxt pred (map snd named_intros')
   419         val elim_t = mk_casesrule ctxt pred (map snd named_intros')
   401         val elim' = prove_casesrule ctxt (pred, (elim, 0)) elim_t
   420         val elim' = prove_casesrule ctxt (pred, (elim, 0)) elim_t
   402       in
   421       in
   403         ((named_intros', SOME elim'), true)
   422         ((named_intros', SOME elim'), true)
   404       end))))
   423       end)))))
   405     thy  
   424     thy  
   406 
   425 
   407 
   426 
   408 (* registration of alternative function names *)
   427 (* registration of alternative function names *)
   409 
   428 
   420 
   439 
   421 fun alternative_compilation_of ctxt pred_name mode =
   440 fun alternative_compilation_of ctxt pred_name mode =
   422   AList.lookup eq_mode
   441   AList.lookup eq_mode
   423     (Symtab.lookup_list (Alt_Compilations_Data.get (Proof_Context.theory_of ctxt)) pred_name) mode
   442     (Symtab.lookup_list (Alt_Compilations_Data.get (Proof_Context.theory_of ctxt)) pred_name) mode
   424 
   443 
   425 fun force_modes_and_compilations pred_name compilations =
   444 fun force_modes_and_compilations pred_name compilations thy =
   426   let
   445   let
   427     (* thm refl is a dummy thm *)
   446     (* thm refl is a dummy thm *)
   428     val modes = map fst compilations
   447     val modes = map fst compilations
   429     val (needs_random, non_random_modes) = pairself (map fst)
   448     val (needs_random, non_random_modes) = pairself (map fst)
   430       (List.partition (fn (_, (_, random)) => random) compilations)
   449       (List.partition (fn (_, (_, random)) => random) compilations)
   433     val dummy_function_names =
   452     val dummy_function_names =
   434       map (rpair all_dummys) Predicate_Compile_Aux.random_compilations @
   453       map (rpair all_dummys) Predicate_Compile_Aux.random_compilations @
   435       map (rpair non_random_dummys) Predicate_Compile_Aux.non_random_compilations
   454       map (rpair non_random_dummys) Predicate_Compile_Aux.non_random_compilations
   436     val alt_compilations = map (apsnd fst) compilations
   455     val alt_compilations = map (apsnd fst) compilations
   437   in
   456   in
       
   457     thy |>
   438     PredData.map
   458     PredData.map
   439       (Graph.new_node
   459       (Graph.new_node
   440         (pred_name,
   460         (pred_name,
   441           mk_pred_data ((([], SOME @{thm refl}), true), (dummy_function_names, ([], needs_random)))))
   461           mk_pred_data
   442     #> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations))
   462             (Position.thread_data (),
       
   463               ((([], SOME @{thm refl}), true), (dummy_function_names, ([], needs_random))))))
       
   464     |> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations))
   443   end
   465   end
   444 
   466 
   445 fun functional_compilation fun_name mode compfuns T =
   467 fun functional_compilation fun_name mode compfuns T =
   446   let
   468   let
   447     val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) mode (binder_types T)
   469     val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) mode (binder_types T)