src/Pure/Tools/codegen_func.ML
changeset 22737 d87ccbcc2702
parent 22705 6199df39688d
child 22804 d3c23b90c6c6
equal deleted inserted replaced
22736:4948e2bd67e5 22737:d87ccbcc2702
     6 *)
     6 *)
     7 
     7 
     8 signature CODEGEN_FUNC =
     8 signature CODEGEN_FUNC =
     9 sig
     9 sig
    10   val assert_rew: thm -> thm
    10   val assert_rew: thm -> thm
    11   val mk_rew: thm -> thm list
    11   val mk_rew: thm -> thm
    12   val assert_func: bool -> thm -> thm option
    12   val mk_func: thm -> thm
    13   val mk_func: bool -> thm -> (CodegenConsts.const * thm) list
    13   val head_func: thm -> CodegenConsts.const * typ
    14   val mk_head: thm -> CodegenConsts.const * thm
    14   val bad_thm: string -> 'a
    15   val dest_func: thm -> (string * typ) * term list
    15   val error_thm: (thm -> thm) -> thm -> thm
    16   val typ_func: thm -> typ
    16   val warning_thm: (thm -> thm) -> thm -> thm option
    17 
    17 
    18   val inst_thm: sort Vartab.table -> thm -> thm
    18   val inst_thm: sort Vartab.table -> thm -> thm
    19   val expand_eta: int -> thm -> thm
    19   val expand_eta: int -> thm -> thm
    20   val rewrite_func: thm list -> thm -> thm
    20   val rewrite_func: thm list -> thm -> thm
    21   val norm_args: thm list -> thm list 
    21   val norm_args: thm list -> thm list 
    23 end;
    23 end;
    24 
    24 
    25 structure CodegenFunc : CODEGEN_FUNC =
    25 structure CodegenFunc : CODEGEN_FUNC =
    26 struct
    26 struct
    27 
    27 
    28 fun lift_thm_thy f thm = f (Thm.theory_of_thm thm) thm;
    28 
    29 
    29 (* auxiliary *)
    30 fun bad_thm msg thm =
    30 
    31   error (msg ^ ": " ^ string_of_thm thm);
    31 exception BAD_THM of string;
       
    32 fun bad_thm msg = raise BAD_THM msg;
       
    33 fun error_thm f thm = f thm handle BAD_THM msg => error msg;
       
    34 fun warning_thm f thm = SOME (f thm) handle BAD_THM msg
       
    35   => (warning msg; NONE);
    32 
    36 
    33 
    37 
    34 (* making rewrite theorems *)
    38 (* making rewrite theorems *)
    35 
    39 
    36 fun assert_rew thm =
    40 fun assert_rew thm =
    37   let
    41   let
    38     val thy = Thm.theory_of_thm thm;
    42     val thy = Thm.theory_of_thm thm;
    39     val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
    43     val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm
       
    44       handle TERM _ => bad_thm ("Not an equation: " ^ Display.string_of_thm thm);
    40     fun vars_of t = fold_aterms
    45     fun vars_of t = fold_aterms
    41      (fn Var (v, _) => insert (op =) v
    46      (fn Var (v, _) => insert (op =) v
    42        | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
    47        | Free _ => bad_thm ("Illegal free variable in rewrite theorem\n"
       
    48            ^ Display.string_of_thm thm)
    43        | _ => I) t [];
    49        | _ => I) t [];
    44     fun tvars_of t = fold_term_types
    50     fun tvars_of t = fold_term_types
    45      (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
    51      (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
    46                           | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
    52                           | TFree _ => bad_thm 
       
    53       ("Illegal free type variable in rewrite theorem\n" ^ Display.string_of_thm thm))) t [];
    47     val lhs_vs = vars_of lhs;
    54     val lhs_vs = vars_of lhs;
    48     val rhs_vs = vars_of rhs;
    55     val rhs_vs = vars_of rhs;
    49     val lhs_tvs = tvars_of lhs;
    56     val lhs_tvs = tvars_of lhs;
    50     val rhs_tvs = tvars_of lhs;
    57     val rhs_tvs = tvars_of lhs;
    51     val _ = if null (subtract (op =) lhs_vs rhs_vs)
    58     val _ = if null (subtract (op =) lhs_vs rhs_vs)
    52       then ()
    59       then ()
    53       else bad_thm "Free variables on right hand side of rewrite theorems" thm
    60       else bad_thm ("Free variables on right hand side of rewrite theorem\n"
       
    61         ^ Display.string_of_thm thm);
    54     val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
    62     val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
    55       then ()
    63       then ()
    56       else bad_thm "Free type variables on right hand side of rewrite theorems" thm
    64       else bad_thm ("Free type variables on right hand side of rewrite theorem\n"
       
    65         ^ Display.string_of_thm thm)
    57   in thm end;
    66   in thm end;
    58 
    67 
    59 fun mk_rew thm =
    68 fun mk_rew thm =
    60   let
    69   let
    61     val thy = Thm.theory_of_thm thm;
    70     val thy = Thm.theory_of_thm thm;
    62     val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
    71     val ctxt = ProofContext.init thy;
    63   in
    72   in
    64     map assert_rew thms
    73     thm
       
    74     |> LocalDefs.meta_rewrite_rule ctxt
       
    75     |> assert_rew
    65   end;
    76   end;
    66 
    77 
    67 
    78 
    68 (* making defining equations *)
    79 (* making defining equations *)
    69 
    80 
    70 val typ_func = lift_thm_thy (fn thy => snd o dest_Const o fst o strip_comb
    81 fun assert_func thm =
    71   o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Thm.plain_prop_of);
    82   let
    72 
    83     val thy = Thm.theory_of_thm thm;
    73 val dest_func = lift_thm_thy (fn thy => apfst dest_Const o strip_comb
    84     val args = (snd o strip_comb o fst o Logic.dest_equals
    74   o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Thm.plain_prop_of
    85       o ObjectLogic.drop_judgment thy o Thm.plain_prop_of) thm;
    75   o Drule.fconv_rule Drule.beta_eta_conversion);
    86     val _ =
    76 
    87       if has_duplicates (op =)
    77 val mk_head = lift_thm_thy (fn thy => fn thm =>
    88         ((fold o fold_aterms) (fn Var (v, _) => cons v
    78   ((CodegenConsts.const_of_cexpr thy o fst o dest_func) thm, thm));
    89           | _ => I
    79 
    90         ) args [])
    80 local
    91       then bad_thm ("Duplicated variables on left hand side of equation\n"
    81 
    92         ^ Display.string_of_thm thm)
    82 exception BAD of string;
    93       else ()
    83 
    94     fun check _ (Abs _) = bad_thm
    84 fun handle_bad strict thm msg =
    95           ("Abstraction on left hand side of equation\n"
    85   if strict then error (msg ^ ": " ^ string_of_thm thm)
    96             ^ Display.string_of_thm thm)
    86     else (warning (msg ^ ": " ^ string_of_thm thm); NONE);
    97       | check 0 (Var _) = ()
    87 
    98       | check _ (Var _) = bad_thm
    88 in
    99           ("Variable with application on left hand side of defining equation\n"
    89 
   100             ^ Display.string_of_thm thm)
    90 fun assert_func strict thm = case try dest_func thm
   101       | check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
    91  of SOME (c_ty as (c, ty), args) => (
   102       | check n (Const (_, ty)) = if n <> (length o fst o strip_type) ty
    92       let
   103           then bad_thm
    93         val thy = Thm.theory_of_thm thm;
   104             ("Partially applied constant on left hand side of equation"
    94         val _ =
   105                ^ Display.string_of_thm thm)
    95           if has_duplicates (op =)
   106           else ();
    96             ((fold o fold_aterms) (fn Var (v, _) => cons v
   107     val _ = map (check 0) args;
    97               | _ => I
   108   in thm end;
    98             ) args [])
   109 
    99           then raise BAD "Repeated variables on left hand side of defining equation"
   110 val mk_func = assert_func o Drule.fconv_rule Drule.beta_eta_conversion o mk_rew;
   100           else ()
   111 
   101         fun check _ (Abs _) = raise BAD
   112 fun head_func thm =
   102               "Abstraction on left hand side of defining equation"
   113   let
   103           | check 0 (Var _) = ()
   114     val thy = Thm.theory_of_thm thm;
   104           | check _ (Var _) = raise BAD
   115     val (Const (c_ty as (_, ty))) = (fst o strip_comb o fst o Logic.dest_equals
   105               "Variable with application on left hand side of defining equation"
   116       o ObjectLogic.drop_judgment thy o Thm.plain_prop_of) thm;
   106           | check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
   117     val const = CodegenConsts.const_of_cexpr thy c_ty;
   107           | check n (Const (_, ty)) = if n <> (length o fst o strip_type) ty
   118   in (const, ty) end;
   108               then raise BAD
       
   109                 ("Partially applied constant on left hand side of defining equation")
       
   110               else ();
       
   111         val _ = map (check 0) args;
       
   112       in SOME thm end handle BAD msg => handle_bad strict thm msg)
       
   113   | NONE => handle_bad strict thm "Not a defining equation";
       
   114 
       
   115 end;
       
   116 
       
   117 fun mk_func strict = map_filter (Option.map mk_head o assert_func strict) o mk_rew;
       
   118 
   119 
   119 
   120 
   120 (* utilities *)
   121 (* utilities *)
   121 
   122 
   122 fun inst_thm tvars' thm =
   123 fun inst_thm tvars' thm =