src/HOL/Matrix/eq_codegen.ML
changeset 15531 08c8dad8e399
parent 15178 5f621aa35c25
child 15570 8d8c70b41bab
equal deleted inserted replaced
15530:6f43714517ee 15531:08c8dad8e399
    10 fun my_mk_meta_eq thm =
    10 fun my_mk_meta_eq thm =
    11   let
    11   let
    12     val (_, eq) = Thm.dest_comb (cprop_of thm);
    12     val (_, eq) = Thm.dest_comb (cprop_of thm);
    13     val (ct, rhs) = Thm.dest_comb eq;
    13     val (ct, rhs) = Thm.dest_comb eq;
    14     val (_, lhs) = Thm.dest_comb ct
    14     val (_, lhs) = Thm.dest_comb ct
    15   in Thm.implies_elim (Drule.instantiate' [Some (ctyp_of_term lhs)]
    15   in Thm.implies_elim (Drule.instantiate' [SOME (ctyp_of_term lhs)]
    16     [Some lhs, Some rhs] eq_reflection) thm
    16     [SOME lhs, SOME rhs] eq_reflection) thm
    17   end; 
    17   end; 
    18 
    18 
    19 structure SimprocsCodegen =
    19 structure SimprocsCodegen =
    20 struct
    20 struct
    21 
    21 
    49         val v2 = variant (v1 :: vs) (mk_decomp_name t)
    49         val v2 = variant (v1 :: vs) (mk_decomp_name t)
    50       in
    50       in
    51         decomp_term_code cn ((v1 :: v2 :: vs,
    51         decomp_term_code cn ((v1 :: v2 :: vs,
    52           bs @ [(Free (s, T), v1)],
    52           bs @ [(Free (s, T), v1)],
    53           ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1,
    53           ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1,
    54             Pretty.str "None", Pretty.brk 1, Pretty.str v]]), (v2, t))
    54             Pretty.str "NONE", Pretty.brk 1, Pretty.str v]]), (v2, t))
    55       end
    55       end
    56     | t $ u =>
    56     | t $ u =>
    57       let
    57       let
    58         val v1 = variant vs (mk_decomp_name t);
    58         val v1 = variant vs (mk_decomp_name t);
    59         val v2 = variant (v1 :: vs) (mk_decomp_name u);
    59         val v2 = variant (v1 :: vs) (mk_decomp_name u);
   141           p1, Pretty.brk 1, p2])
   141           p1, Pretty.brk 1, p2])
   142       end
   142       end
   143   | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) =
   143   | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) =
   144       let
   144       let
   145         val u = Free (s, T);
   145         val u = Free (s, T);
   146         val Some s' = assoc (ts, u);
   146         val SOME s' = assoc (ts, u);
   147         val p = Pretty.str s';
   147         val p = Pretty.str s';
   148         val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs)
   148         val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs)
   149           (if null (typ_tvars T) then vals
   149           (if null (typ_tvars T) then vals
   150            else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t)
   150            else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t)
   151       in (vals',
   151       in (vals',
   152         parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p'])
   152         parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p'])
   153       end
   153       end
   154   | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs))
   154   | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs))
   155   | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of
   155   | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of
   156         None =>
   156         NONE =>
   157           let val Some s = assoc (ts, t)
   157           let val SOME s = assoc (ts, t)
   158           in (if is_Const t andalso not (null (term_tvars t)) then
   158           in (if is_Const t andalso not (null (term_tvars t)) then
   159               vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))]
   159               vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))]
   160             else vals, Pretty.str s)
   160             else vals, Pretty.str s)
   161           end
   161           end
   162       | Some ((_, s), _) => (vals, Pretty.str s));
   162       | SOME ((_, s), _) => (vals, Pretty.str s));
   163 
   163 
   164 fun get_cases sg =
   164 fun get_cases sg =
   165   Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new
   165   Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new
   166     ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop
   166     ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop
   167       (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab))
   167       (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab))
   193 fun lookup sg fs t = apsome snd (Library.find_first
   193 fun lookup sg fs t = apsome snd (Library.find_first
   194   (is_instance sg t o fst) fs);
   194   (is_instance sg t o fst) fs);
   195 *)
   195 *)
   196 
   196 
   197 fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of
   197 fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of
   198     None => (bla := (t ins !bla); None)
   198     NONE => (bla := (t ins !bla); NONE)
   199   | Some (_, x) => Some x);
   199   | SOME (_, x) => SOME x);
   200 
   200 
   201 fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t);
   201 fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t);
   202 
   202 
   203 fun mk_let s i xs ys =
   203 fun mk_let s i xs ys =
   204   Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)),
   204   Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)),
   218 (*       ps: ML code for creating the above two items                        *)
   218 (*       ps: ML code for creating the above two items                        *)
   219 (*****************************************************************************)
   219 (*****************************************************************************)
   220 
   220 
   221 fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) =
   221 fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) =
   222   (case assoc (vals, t) of
   222   (case assoc (vals, t) of
   223     Some ((eq, ct), ps) =>  (* binding already generated *) 
   223     SOME ((eq, ct), ps) =>  (* binding already generated *) 
   224       if mkeq andalso eq="" then
   224       if mkeq andalso eq="" then
   225         let val eq' = variant vs "eq"
   225         let val eq' = variant vs "eq"
   226         in ((eq' :: vs, overwrite (vals,
   226         in ((eq' :: vs, overwrite (vals,
   227           (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct))
   227           (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct))
   228         end
   228         end
   229       else ((vs, vals), (eq, ct))
   229       else ((vs, vals), (eq, ct))
   230   | None => (case assoc (ts, t) of
   230   | NONE => (case assoc (ts, t) of
   231       Some v =>  (* atomic term *)
   231       SOME v =>  (* atomic term *)
   232         let val xs = if not (null (term_tvars t)) andalso is_Const t then
   232         let val xs = if not (null (term_tvars t)) andalso is_Const t then
   233           [mk_val [v] [inst_ty false ty_bs t v]] else []
   233           [mk_val [v] [inst_ty false ty_bs t v]] else []
   234         in
   234         in
   235           if mkeq then
   235           if mkeq then
   236             let val eq = variant vs "eq"
   236             let val eq = variant vs "eq"
   238               [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v))
   238               [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v))
   239             end
   239             end
   240           else ((vs, if null xs then vals else vals @
   240           else ((vs, if null xs then vals else vals @
   241             [(t, (("", v), xs))]), ("", v))
   241             [(t, (("", v), xs))]), ("", v))
   242         end
   242         end
   243     | None =>  (* complex term *)
   243     | NONE =>  (* complex term *)
   244         let val (f as Const (cname, _), us) = strip_comb t
   244         let val (f as Const (cname, _), us) = strip_comb t
   245         in case Symtab.lookup (case_tab, cname) of
   245         in case Symtab.lookup (case_tab, cname) of
   246             Some cases =>  (* case expression *)
   246             SOME cases =>  (* case expression *)
   247               let
   247               let
   248                 val (us', u) = split_last us;
   248                 val (us', u) = split_last us;
   249                 val b = unint sg fs u;
   249                 val b = unint sg fs u;
   250                 val ((vs1, vals1), (eq, ct)) =
   250                 val ((vs1, vals1), (eq, ct)) =
   251                   mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u);
   251                   mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u);
   327                           Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"),
   327                           Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"),
   328                           Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct'))
   328                           Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct'))
   329                   end
   329                   end
   330               end
   330               end
   331           
   331           
   332           | None =>
   332           | NONE =>
   333             let
   333             let
   334               val b = forall (unint sg fs) us;
   334               val b = forall (unint sg fs) us;
   335               val (q, eqs) = foldl_map
   335               val (q, eqs) = foldl_map
   336                 (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us);
   336                 (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us);
   337               val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b
   337               val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b
   343                 (Pretty.str ctf) (map (Pretty.str o snd) eqs)];
   343                 (Pretty.str ctf) (map (Pretty.str o snd) eqs)];
   344               fun combp b = mk_apps "Thm.combination" b
   344               fun combp b = mk_apps "Thm.combination" b
   345                 (Pretty.str eqf) (map (Pretty.str o fst) eqs)
   345                 (Pretty.str eqf) (map (Pretty.str o fst) eqs)
   346             in
   346             in
   347               case (lookup sg fs f, b) of
   347               case (lookup sg fs f, b) of
   348                 (None, true) =>  (* completely uninterpreted *)
   348                 (NONE, true) =>  (* completely uninterpreted *)
   349                   if mkeq then ((ct :: eq :: vs', vals' @
   349                   if mkeq then ((ct :: eq :: vs', vals' @
   350                     [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct))
   350                     [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct))
   351                   else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct))
   351                   else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct))
   352               | (None, false) =>  (* function uninterpreted *)
   352               | (NONE, false) =>  (* function uninterpreted *)
   353                   ((eq :: ct :: vs', vals' @
   353                   ((eq :: ct :: vs', vals' @
   354                      [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct))
   354                      [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct))
   355               | (Some (s, _, _), true) =>  (* arguments uninterpreted *)
   355               | (SOME (s, _, _), true) =>  (* arguments uninterpreted *)
   356                   ((eq :: ct :: vs', vals' @
   356                   ((eq :: ct :: vs', vals' @
   357                      [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1)
   357                      [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1)
   358                        (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct))
   358                        (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct))
   359               | (Some (s, _, _), false) =>  (* function and arguments interpreted *)
   359               | (SOME (s, _, _), false) =>  (* function and arguments interpreted *)
   360                   let val eq' = variant (eq :: ct :: vs') "eq"
   360                   let val eq' = variant (eq :: ct :: vs') "eq"
   361                   in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct),
   361                   in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct),
   362                     [mk_val [ct, eq] (separate (Pretty.brk 1)
   362                     [mk_val [ct, eq] (separate (Pretty.brk 1)
   363                        (Pretty.str s :: map (Pretty.str o snd) eqs)),
   363                        (Pretty.str s :: map (Pretty.str o snd) eqs)),
   364                      mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1,
   364                      mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1,
   371 fun rhs_of thm = snd (Logic.dest_equals (prop_of thm));
   371 fun rhs_of thm = snd (Logic.dest_equals (prop_of thm));
   372 
   372 
   373 fun mk_funs_code sg case_tab fs fs' =
   373 fun mk_funs_code sg case_tab fs fs' =
   374   let
   374   let
   375     val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of
   375     val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of
   376         None => None
   376         NONE => NONE
   377       | Some thms => Some (unsuffix "_case" (Sign.base_name s) ^ ".cases",
   377       | SOME thms => SOME (unsuffix "_case" (Sign.base_name s) ^ ".cases",
   378           map (fn i => Sign.base_name s ^ "_" ^ string_of_int i)
   378           map (fn i => Sign.base_name s ^ "_" ^ string_of_int i)
   379             (1 upto length thms) ~~ thms)))
   379             (1 upto length thms) ~~ thms)))
   380       (foldr add_term_consts (map (prop_of o snd)
   380       (foldr add_term_consts (map (prop_of o snd)
   381         (flat (map (#3 o snd) fs')), []));
   381         (flat (map (#3 o snd) fs')), []));
   382     val case_vals = map (fn (s, cs) => mk_vall (map fst cs)
   382     val case_vals = map (fn (s, cs) => mk_vall (map fst cs)
   421                 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]]
   421                 [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]]
   422           end;
   422           end;
   423 
   423 
   424         val default = if d then
   424         val default = if d then
   425             let
   425             let
   426               val Some s = assoc (thm_bs, f);
   426               val SOME s = assoc (thm_bs, f);
   427               val ct = variant vs' "ct"
   427               val ct = variant vs' "ct"
   428             in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1,
   428             in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1,
   429               Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2
   429               Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2
   430                 (ty_vals @ (if null (term_tvars f) then [] else
   430                 (ty_vals @ (if null (term_tvars f) then [] else
   431                    [mk_val [s] [inst_ty false ty_bs f s]]) @
   431                    [mk_val [s] [inst_ty false ty_bs f s]]) @