src/HOL/Tools/Datatype/datatype_codegen.ML
author haftmann
Mon Sep 28 14:54:15 2009 +0200 (2009-09-28)
changeset 32730 3958b45b29e4
parent 31784 bd3486c57ba3
child 32896 99cd75a18b78
permissions -rw-r--r--
tuned
     1 (*  Title:      HOL/Tools/datatype_codegen.ML
     2     Author:     Stefan Berghofer and Florian Haftmann, TU Muenchen
     3 
     4 Code generator facilities for inductive datatypes.
     5 *)
     6 
     7 signature DATATYPE_CODEGEN =
     8 sig
     9   val find_shortest_path: Datatype.descr -> int -> (string * int) option
    10   val mk_eq_eqns: theory -> string -> (thm * bool) list
    11   val mk_case_cert: theory -> string -> thm
    12   val setup: theory -> theory
    13 end;
    14 
    15 structure DatatypeCodegen : DATATYPE_CODEGEN =
    16 struct
    17 
    18 (** find shortest path to constructor with no recursive arguments **)
    19 
    20 fun find_nonempty (descr: Datatype.descr) is i =
    21   let
    22     val (_, _, constrs) = the (AList.lookup (op =) descr i);
    23     fun arg_nonempty (_, DatatypeAux.DtRec i) = if member (op =) is i
    24           then NONE
    25           else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i)
    26       | arg_nonempty _ = SOME 0;
    27     fun max xs = Library.foldl
    28       (fn (NONE, _) => NONE
    29         | (SOME i, SOME j) => SOME (Int.max (i, j))
    30         | (_, NONE) => NONE) (SOME 0, xs);
    31     val xs = sort (int_ord o pairself snd)
    32       (map_filter (fn (s, dts) => Option.map (pair s)
    33         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
    34   in case xs of [] => NONE | x :: _ => SOME x end;
    35 
    36 fun find_shortest_path descr i = find_nonempty descr [i] i;
    37 
    38 
    39 (** SML code generator **)
    40 
    41 open Codegen;
    42 
    43 (* datatype definition *)
    44 
    45 fun add_dt_defs thy defs dep module (descr: Datatype.descr) sorts gr =
    46   let
    47     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
    48     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
    49       exists (exists DatatypeAux.is_rec_type o snd) cs) descr');
    50 
    51     val (_, (tname, _, _)) :: _ = descr';
    52     val node_id = tname ^ " (type)";
    53     val module' = if_library (thyname_of_type thy tname) module;
    54 
    55     fun mk_dtdef prfx [] gr = ([], gr)
    56       | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr =
    57           let
    58             val tvs = map DatatypeAux.dest_DtTFree dts;
    59             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    60             val ((_, type_id), gr') = mk_type_id module' tname gr;
    61             val (ps, gr'') = gr' |>
    62               fold_map (fn (cname, cargs) =>
    63                 fold_map (invoke_tycodegen thy defs node_id module' false)
    64                   cargs ##>>
    65                 mk_const_id module' cname) cs';
    66             val (rest, gr''') = mk_dtdef "and " xs gr''
    67           in
    68             (Pretty.block (str prfx ::
    69                (if null tvs then [] else
    70                   [mk_tuple (map str tvs), str " "]) @
    71                [str (type_id ^ " ="), Pretty.brk 1] @
    72                List.concat (separate [Pretty.brk 1, str "| "]
    73                  (map (fn (ps', (_, cname)) => [Pretty.block
    74                    (str cname ::
    75                     (if null ps' then [] else
    76                      List.concat ([str " of", Pretty.brk 1] ::
    77                        separate [str " *", Pretty.brk 1]
    78                          (map single ps'))))]) ps))) :: rest, gr''')
    79           end;
    80 
    81     fun mk_constr_term cname Ts T ps =
    82       List.concat (separate [str " $", Pretty.brk 1]
    83         ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
    84           mk_type false (Ts ---> T), str ")"] :: ps));
    85 
    86     fun mk_term_of_def gr prfx [] = []
    87       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
    88           let
    89             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    90             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
    91             val T = Type (tname, dts');
    92             val rest = mk_term_of_def gr "and " xs;
    93             val (eqs, _) = fold_map (fn (cname, Ts) => fn prfx =>
    94               let val args = map (fn i =>
    95                 str ("x" ^ string_of_int i)) (1 upto length Ts)
    96               in (Pretty.blk (4,
    97                 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
    98                  if null Ts then str (snd (get_const_id gr cname))
    99                  else parens (Pretty.block
   100                    [str (snd (get_const_id gr cname)),
   101                     Pretty.brk 1, mk_tuple args]),
   102                  str " =", Pretty.brk 1] @
   103                  mk_constr_term cname Ts T
   104                    (map2 (fn x => fn U => [Pretty.block [mk_term_of gr module' false U,
   105                       Pretty.brk 1, x]]) args Ts)), "  | ")
   106               end) cs' prfx
   107           in eqs @ rest end;
   108 
   109     fun mk_gen_of_def gr prfx [] = []
   110       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
   111           let
   112             val tvs = map DatatypeAux.dest_DtTFree dts;
   113             val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
   114             val T = Type (tname, Us);
   115             val (cs1, cs2) =
   116               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
   117             val SOME (cname, _) = find_shortest_path descr i;
   118 
   119             fun mk_delay p = Pretty.block
   120               [str "fn () =>", Pretty.brk 1, p];
   121 
   122             fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"];
   123 
   124             fun mk_constr s b (cname, dts) =
   125               let
   126                 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
   127                     (DatatypeAux.typ_of_dtyp descr sorts dt))
   128                   [str (if b andalso DatatypeAux.is_rec_type dt then "0"
   129                      else "j")]) dts;
   130                 val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
   131                 val xs = map str
   132                   (DatatypeProp.indexify_names (replicate (length dts) "x"));
   133                 val ts = map str
   134                   (DatatypeProp.indexify_names (replicate (length dts) "t"));
   135                 val (_, id) = get_const_id gr cname
   136               in
   137                 mk_let
   138                   (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
   139                   (mk_tuple
   140                     [case xs of
   141                        _ :: _ :: _ => Pretty.block
   142                          [str id, Pretty.brk 1, mk_tuple xs]
   143                      | _ => mk_app false (str id) xs,
   144                      mk_delay (Pretty.block (mk_constr_term cname Ts T
   145                        (map (single o mk_force) ts)))])
   146               end;
   147 
   148             fun mk_choice [c] = mk_constr "(i-1)" false c
   149               | mk_choice cs = Pretty.block [str "one_of",
   150                   Pretty.brk 1, Pretty.blk (1, str "[" ::
   151                   List.concat (separate [str ",", Pretty.fbrk]
   152                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
   153                   [str "]"]), Pretty.brk 1, str "()"];
   154 
   155             val gs = maps (fn s =>
   156               let val s' = strip_tname s
   157               in [str (s' ^ "G"), str (s' ^ "T")] end) tvs;
   158             val gen_name = "gen_" ^ snd (get_type_id gr tname)
   159 
   160           in
   161             Pretty.blk (4, separate (Pretty.brk 1) 
   162                 (str (prfx ^ gen_name ^
   163                    (if null cs1 then "" else "'")) :: gs @
   164                  (if null cs1 then [] else [str "i"]) @
   165                  [str "j"]) @
   166               [str " =", Pretty.brk 1] @
   167               (if not (null cs1) andalso not (null cs2)
   168                then [str "frequency", Pretty.brk 1,
   169                  Pretty.blk (1, [str "[",
   170                    mk_tuple [str "i", mk_delay (mk_choice cs1)],
   171                    str ",", Pretty.fbrk,
   172                    mk_tuple [str "1", mk_delay (mk_choice cs2)],
   173                    str "]"]), Pretty.brk 1, str "()"]
   174                else if null cs2 then
   175                  [Pretty.block [str "(case", Pretty.brk 1,
   176                    str "i", Pretty.brk 1, str "of",
   177                    Pretty.brk 1, str "0 =>", Pretty.brk 1,
   178                    mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)),
   179                    Pretty.brk 1, str "| _ =>", Pretty.brk 1,
   180                    mk_choice cs1, str ")"]]
   181                else [mk_choice cs2])) ::
   182             (if null cs1 then []
   183              else [Pretty.blk (4, separate (Pretty.brk 1) 
   184                  (str ("and " ^ gen_name) :: gs @ [str "i"]) @
   185                [str " =", Pretty.brk 1] @
   186                separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @
   187                  [str "i", str "i"]))]) @
   188             mk_gen_of_def gr "and " xs
   189           end
   190 
   191   in
   192     (module', (add_edge_acyclic (node_id, dep) gr
   193         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
   194          let
   195            val gr1 = add_edge (node_id, dep)
   196              (new_node (node_id, (NONE, "", "")) gr);
   197            val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ;
   198          in
   199            map_node node_id (K (NONE, module',
   200              string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
   201                [str ";"])) ^ "\n\n" ^
   202              (if "term_of" mem !mode then
   203                 string_of (Pretty.blk (0, separate Pretty.fbrk
   204                   (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
   205               else "") ^
   206              (if "test" mem !mode then
   207                 string_of (Pretty.blk (0, separate Pretty.fbrk
   208                   (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
   209               else ""))) gr2
   210          end)
   211   end;
   212 
   213 
   214 (* case expressions *)
   215 
   216 fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr =
   217   let val i = length constrs
   218   in if length ts <= i then
   219        invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
   220     else
   221       let
   222         val ts1 = Library.take (i, ts);
   223         val t :: ts2 = Library.drop (i, ts);
   224         val names = List.foldr OldTerm.add_term_names
   225           (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
   226         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
   227 
   228         fun pcase [] [] [] gr = ([], gr)
   229           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
   230               let
   231                 val j = length cargs;
   232                 val xs = Name.variant_list names (replicate j "x");
   233                 val Us' = Library.take (j, fst (strip_type U));
   234                 val frees = map Free (xs ~~ Us');
   235                 val (cp, gr0) = invoke_codegen thy defs dep module false
   236                   (list_comb (Const (cname, Us' ---> dT), frees)) gr;
   237                 val t' = Envir.beta_norm (list_comb (t, frees));
   238                 val (p, gr1) = invoke_codegen thy defs dep module false t' gr0;
   239                 val (ps, gr2) = pcase cs ts Us gr1;
   240               in
   241                 ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2)
   242               end;
   243 
   244         val (ps1, gr1) = pcase constrs ts1 Ts gr ;
   245         val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1);
   246         val (p, gr2) = invoke_codegen thy defs dep module false t gr1;
   247         val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2;
   248       in ((if not (null ts2) andalso brack then parens else I)
   249         (Pretty.block (separate (Pretty.brk 1)
   250           (Pretty.block ([str "(case ", p, str " of",
   251              Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3)
   252       end
   253   end;
   254 
   255 
   256 (* constructors *)
   257 
   258 fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr =
   259   let val i = length args
   260   in if i > 1 andalso length ts < i then
   261       invoke_codegen thy defs dep module brack (eta_expand c ts i) gr
   262      else
   263        let
   264          val id = mk_qual_id module (get_const_id gr s);
   265          val (ps, gr') = fold_map
   266            (invoke_codegen thy defs dep module (i = 1)) ts gr;
   267        in (case args of
   268           _ :: _ :: _ => (if brack then parens else I)
   269             (Pretty.block [str id, Pretty.brk 1, mk_tuple ps])
   270         | _ => (mk_app brack (str id) ps), gr')
   271        end
   272   end;
   273 
   274 
   275 (* code generators for terms and types *)
   276 
   277 fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of
   278    (c as Const (s, T), ts) =>
   279      (case Datatype.info_of_case thy s of
   280         SOME {index, descr, ...} =>
   281           if is_some (get_assoc_code thy (s, T)) then NONE else
   282           SOME (pretty_case thy defs dep module brack
   283             (#3 (the (AList.lookup op = descr index))) c ts gr )
   284       | NONE => case (Datatype.info_of_constr thy s, strip_type T) of
   285         (SOME {index, descr, ...}, (_, U as Type (tyname, _))) =>
   286           if is_some (get_assoc_code thy (s, T)) then NONE else
   287           let
   288             val SOME (tyname', _, constrs) = AList.lookup op = descr index;
   289             val SOME args = AList.lookup op = constrs s
   290           in
   291             if tyname <> tyname' then NONE
   292             else SOME (pretty_constr thy defs
   293               dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr)))
   294           end
   295       | _ => NONE)
   296  | _ => NONE);
   297 
   298 fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr =
   299       (case Datatype.get_info thy s of
   300          NONE => NONE
   301        | SOME {descr, sorts, ...} =>
   302            if is_some (get_assoc_type thy s) then NONE else
   303            let
   304              val (ps, gr') = fold_map
   305                (invoke_tycodegen thy defs dep module false) Ts gr;
   306              val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ;
   307              val (tyid, gr''') = mk_type_id module' s gr''
   308            in SOME (Pretty.block ((if null Ts then [] else
   309                [mk_tuple ps, str " "]) @
   310                [str (mk_qual_id module tyid)]), gr''')
   311            end)
   312   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   313 
   314 
   315 (** generic code generator **)
   316 
   317 (* liberal addition of code data for datatypes *)
   318 
   319 fun mk_constr_consts thy vs dtco cos =
   320   let
   321     val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   322     val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
   323   in if is_some (try (Code.constrset_of_consts thy) cs')
   324     then SOME cs
   325     else NONE
   326   end;
   327 
   328 
   329 (* case certificates *)
   330 
   331 fun mk_case_cert thy tyco =
   332   let
   333     val raw_thms =
   334       (#case_rewrites o Datatype.the_info thy) tyco;
   335     val thms as hd_thm :: _ = raw_thms
   336       |> Conjunction.intr_balanced
   337       |> Thm.unvarify
   338       |> Conjunction.elim_balanced (length raw_thms)
   339       |> map Simpdata.mk_meta_eq
   340       |> map Drule.zero_var_indexes
   341     val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
   342       | _ => I) (Thm.prop_of hd_thm) [];
   343     val rhs = hd_thm
   344       |> Thm.prop_of
   345       |> Logic.dest_equals
   346       |> fst
   347       |> Term.strip_comb
   348       |> apsnd (fst o split_last)
   349       |> list_comb;
   350     val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
   351     val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
   352   in
   353     thms
   354     |> Conjunction.intr_balanced
   355     |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
   356     |> Thm.implies_intr asm
   357     |> Thm.generalize ([], params) 0
   358     |> AxClass.unoverload thy
   359     |> Thm.varifyT
   360   end;
   361 
   362 
   363 (* equality *)
   364 
   365 fun mk_eq_eqns thy dtco =
   366   let
   367     val (vs, cos) = Datatype.the_spec thy dtco;
   368     val { descr, index, inject = inject_thms, ... } = Datatype.the_info thy dtco;
   369     val ty = Type (dtco, map TFree vs);
   370     fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
   371       $ t1 $ t2;
   372     fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
   373     fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
   374     val triv_injects = map_filter
   375      (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
   376        | _ => NONE) cos;
   377     fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
   378       trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
   379     val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index);
   380     fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
   381       [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
   382     val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index));
   383     val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
   384     val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
   385       addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
   386       addsimprocs [Datatype.distinct_simproc]);
   387     fun prove prop = SkipProof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
   388       |> Simpdata.mk_eq;
   389   in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
   390 
   391 fun add_equality vs dtcos thy =
   392   let
   393     fun add_def dtco lthy =
   394       let
   395         val ty = Type (dtco, map TFree vs);
   396         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
   397           $ Free ("x", ty) $ Free ("y", ty);
   398         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
   399           (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
   400         val def' = Syntax.check_term lthy def;
   401         val ((_, (_, thm)), lthy') = Specification.definition
   402           (NONE, (Attrib.empty_binding, def')) lthy;
   403         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
   404         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
   405       in (thm', lthy') end;
   406     fun tac thms = Class.intro_classes_tac []
   407       THEN ALLGOALS (ProofContext.fact_tac thms);
   408     fun add_eq_thms dtco thy =
   409       let
   410         val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
   411         val thy_ref = Theory.check_thy thy;
   412         fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco));
   413       in
   414         Code.add_eqnl (const, Lazy.lazy mk_thms) thy
   415       end;
   416   in
   417     thy
   418     |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq])
   419     |> fold_map add_def dtcos
   420     |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
   421          (fn _ => fn def_thms => tac def_thms) def_thms)
   422     |-> (fn def_thms => fold Code.del_eqn def_thms)
   423     |> fold add_eq_thms dtcos
   424   end;
   425 
   426 
   427 (* register a datatype etc. *)
   428 
   429 fun add_all_code config dtcos thy =
   430   let
   431     val (vs :: _, coss) = (split_list o map (Datatype.the_spec thy)) dtcos;
   432     val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
   433     val css = if exists is_none any_css then []
   434       else map_filter I any_css;
   435     val case_rewrites = maps (#case_rewrites o Datatype.the_info thy) dtcos;
   436     val certs = map (mk_case_cert thy) dtcos;
   437   in
   438     if null css then thy
   439     else thy
   440       |> tap (fn _ => DatatypeAux.message config "Registering datatype for code generator ...")
   441       |> fold Code.add_datatype css
   442       |> fold_rev Code.add_default_eqn case_rewrites
   443       |> fold Code.add_case certs
   444       |> add_equality vs dtcos
   445    end;
   446 
   447 
   448 (** theory setup **)
   449 
   450 val setup = 
   451   add_codegen "datatype" datatype_codegen
   452   #> add_tycodegen "datatype" datatype_tycodegen
   453   #> Datatype.interpretation add_all_code
   454 
   455 end;