src/HOL/Tools/type_lifting.ML
changeset 41371 35d2241c169c
parent 41298 aad679ca38d2
child 41373 48503e4e96b6
equal deleted inserted replaced
41366:ea73e74ec827 41371:35d2241c169c
    15 end;
    15 end;
    16 
    16 
    17 structure Type_Lifting : TYPE_LIFTING =
    17 structure Type_Lifting : TYPE_LIFTING =
    18 struct
    18 struct
    19 
    19 
       
    20 val compN = "comp";
       
    21 val idN = "id";
    20 val compositionalityN = "compositionality";
    22 val compositionalityN = "compositionality";
    21 val identityN = "identity";
    23 val identityN = "identity";
    22 
    24 
    23 (** functorial mappers and their properties **)
    25 (** functorial mappers and their properties **)
    24 
    26 
    25 (* bookkeeping *)
    27 (* bookkeeping *)
    26 
    28 
    27 type entry = { mapper: string, variances: (sort * (bool * bool)) list,
    29 type entry = { mapper: string, variances: (sort * (bool * bool)) list,
    28   compositionality: thm, identity: thm };
    30   comp: thm, id: thm };
    29 
    31 
    30 structure Data = Theory_Data(
    32 structure Data = Theory_Data(
    31   type T = entry Symtab.table
    33   type T = entry Symtab.table
    32   val empty = Symtab.empty
    34   val empty = Symtab.empty
    33   fun merge (xy : T * T) = Symtab.merge (K true) xy
    35   fun merge (xy : T * T) = Symtab.merge (K true) xy
    72   in construct end;
    74   in construct end;
    73 
    75 
    74 
    76 
    75 (* mapper properties *)
    77 (* mapper properties *)
    76 
    78 
    77 fun make_compositionality_prop variances (tyco, mapper) =
    79 fun make_comp_prop ctxt variances (tyco, mapper) =
    78   let
    80   let
    79     fun invents n k nctxt =
    81     val sorts = map fst variances
    80       let
    82     val (((vs3, vs2), vs1), _) = ctxt
    81         val names = Name.invents nctxt n k;
    83       |> Variable.invent_types sorts
    82       in (names, fold Name.declare names nctxt) end;
    84       ||>> Variable.invent_types sorts
    83     val (((vs3, vs2), vs1), _) = Name.context
    85       ||>> Variable.invent_types sorts
    84       |> invents Name.aT (length variances)
    86     val (Ts1, Ts2, Ts3) = (map TFree vs1, map TFree vs2, map TFree vs3);
    85       ||>> invents Name.aT (length variances)
       
    86       ||>> invents Name.aT (length variances);
       
    87     fun mk_Ts vs = map2 (fn v => fn (sort, _) => TFree (v, sort))
       
    88       vs variances;
       
    89     val (Ts1, Ts2, Ts3) = (mk_Ts vs1, mk_Ts vs2, mk_Ts vs3);
       
    90     fun mk_argT ((T, T'), (_, (co, contra))) =
    87     fun mk_argT ((T, T'), (_, (co, contra))) =
    91       (if co then [(T --> T')] else [])
    88       (if co then [(T --> T')] else [])
    92       @ (if contra then [(T' --> T)] else []);
    89       @ (if contra then [(T' --> T)] else []);
    93     val contras = maps (fn (_, (co, contra)) =>
    90     val contras = maps (fn (_, (co, contra)) =>
    94       (if co then [false] else []) @ (if contra then [true] else [])) variances;
    91       (if co then [false] else []) @ (if contra then [true] else [])) variances;
    95     val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
    92     val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
    96     val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
    93     val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
    97     val ((names21, names32), nctxt) = Name.context
    94     fun invents n k nctxt =
       
    95       let
       
    96         val names = Name.invents nctxt n k;
       
    97       in (names, fold Name.declare names nctxt) end;
       
    98     val ((names21, names32), nctxt) = Variable.names_of ctxt
    98       |> invents "f" (length Ts21)
    99       |> invents "f" (length Ts21)
    99       ||>> invents "f" (length Ts32);
   100       ||>> invents "f" (length Ts32);
   100     val T1 = Type (tyco, Ts1);
   101     val T1 = Type (tyco, Ts1);
   101     val T2 = Type (tyco, Ts2);
   102     val T2 = Type (tyco, Ts2);
   102     val T3 = Type (tyco, Ts3);
   103     val T3 = Type (tyco, Ts3);
   103     val x = Free (the_single (Name.invents nctxt (Long_Name.base_name tyco) 1), T3);
       
   104     val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
   104     val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
   105     val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
   105     val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
   106       if not is_contra then
   106       if not is_contra then
   107         Abs ("x", domain_type T32, Free (f21, T21) $ (Free (f32, T32) $ Bound 0))
   107         HOLogic.mk_comp (Free (f21, T21), Free (f32, T32))
   108       else
   108       else
   109         Abs ("x", domain_type T21, Free (f32, T32) $ (Free (f21, T21) $ Bound 0))
   109         HOLogic.mk_comp (Free (f32, T32), Free (f21, T21))
   110       ) contras (args21 ~~ args32)
   110       ) contras (args21 ~~ args32)
   111     fun mk_mapper T T' args = list_comb (Const (mapper,
   111     fun mk_mapper T T' args = list_comb (Const (mapper,
   112       map fastype_of args ---> T --> T'), args);
   112       map fastype_of args ---> T --> T'), args);
   113     val lhs = mk_mapper T2 T1 (map Free args21) $
   113     val lhs = HOLogic.mk_comp (mk_mapper T2 T1 (map Free args21), mk_mapper T3 T2 (map Free args32));
   114       (mk_mapper T3 T2 (map Free args32) $ x);
   114     val rhs = mk_mapper T3 T1 args31;
   115     val rhs = mk_mapper T3 T1 args31 $ x;
   115   in fold_rev Logic.all (map Free (args21 @ args32)) ((HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
   116   in (map Free (args21 @ args32) @ [x], (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
   116 
   117 
   117 fun make_id_prop ctxt variances (tyco, mapper) =
   118 fun make_identity_prop variances (tyco, mapper) =
   118   let
   119   let
   119     val (vs, ctxt') = Variable.invent_types (map fst variances) ctxt;
   120     val vs = Name.invents Name.context Name.aT (length variances);
   120     val Ts = map TFree vs;
   121     val Ts = map2 (fn v => fn (sort, _) => TFree (v, sort)) vs variances;
       
   122     fun bool_num b = if b then 1 else 0;
   121     fun bool_num b = if b then 1 else 0;
   123     fun mk_argT (T, (_, (co, contra))) =
   122     fun mk_argT (T, (_, (co, contra))) =
   124       replicate (bool_num co + bool_num contra) (T --> T)
   123       replicate (bool_num co + bool_num contra) (T --> T)
   125     val Ts' = maps mk_argT (Ts ~~ variances)
   124     val Ts' = maps mk_argT (Ts ~~ variances)
   126     val T = Type (tyco, Ts);
   125     val T = Type (tyco, Ts);
   127     val x = Free (Long_Name.base_name tyco, T);
       
   128     val lhs = list_comb (Const (mapper, Ts' ---> T --> T),
   126     val lhs = list_comb (Const (mapper, Ts' ---> T --> T),
   129       map (fn T => Abs ("x", domain_type T, Bound 0)) Ts') $ x;
   127       map (HOLogic.mk_id o domain_type) Ts');
   130   in (x, (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, x)) end;
   128   in (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, HOLogic.mk_id T) end;
       
   129 
       
   130 val comp_apply = Simpdata.mk_eq @{thm o_apply};
       
   131 val id_def = Simpdata.mk_eq @{thm id_def};
       
   132 
       
   133 fun make_compositionality ctxt thm =
       
   134   let
       
   135     val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
       
   136     val thm'' = @{thm fun_cong} OF [thm'];
       
   137     val thm''' =
       
   138       (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o Conv.rewr_conv) comp_apply thm'';
       
   139   in singleton (Variable.export ctxt' ctxt) thm''' end;
       
   140 
       
   141 fun args_conv k conv =
       
   142   if k <= 0 then Conv.all_conv
       
   143   else Conv.combination_conv (args_conv (k - 1) conv) conv;
       
   144 
       
   145 fun make_identity ctxt variances thm =
       
   146   let
       
   147     val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
       
   148     fun bool_num b = if b then 1 else 0;
       
   149     val num_args = Integer.sum
       
   150       (map (fn (_, (co, contra)) => bool_num co + bool_num contra) variances);
       
   151     val thm'' =
       
   152       (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o args_conv num_args o Conv.rewr_conv) id_def thm';
       
   153   in singleton (Variable.export ctxt' ctxt) thm'' end;
   131 
   154 
   132 
   155 
   133 (* analyzing and registering mappers *)
   156 (* analyzing and registering mappers *)
   134 
   157 
   135 fun consume eq x [] = (false, [])
   158 fun consume eq x [] = (false, [])
   175 fun gen_type_lifting prep_term some_prfx raw_t thy =
   198 fun gen_type_lifting prep_term some_prfx raw_t thy =
   176   let
   199   let
   177     val (mapper, T) = case prep_term thy raw_t
   200     val (mapper, T) = case prep_term thy raw_t
   178      of Const cT => cT
   201      of Const cT => cT
   179       | t => error ("No constant: " ^ Syntax.string_of_term_global thy t);
   202       | t => error ("No constant: " ^ Syntax.string_of_term_global thy t);
   180     val prfx = the_default (Long_Name.base_name mapper) some_prfx;
       
   181     val _ = Type.no_tvars T;
   203     val _ = Type.no_tvars T;
   182     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   204     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   183       | add_tycos _ = I;
   205       | add_tycos _ = I;
   184     val tycos = add_tycos T [];
   206     val tycos = add_tycos T [];
   185     val tyco = if tycos = ["fun"] then "fun"
   207     val tyco = if tycos = ["fun"] then "fun"
   186       else case remove (op =) "fun" tycos
   208       else case remove (op =) "fun" tycos
   187        of [tyco] => tyco
   209        of [tyco] => tyco
   188         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
   210         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
       
   211     val prfx = the_default (Long_Name.base_name tyco) some_prfx;
   189     val variances = analyze_variances thy tyco T;
   212     val variances = analyze_variances thy tyco T;
   190     val compositionality_prop = uncurry (fold_rev Logic.all)
   213     val ctxt = ProofContext.init_global thy;
   191       (make_compositionality_prop variances (tyco, mapper));
   214     val comp_prop = make_comp_prop ctxt variances (tyco, mapper);
   192     val identity_prop = uncurry Logic.all
   215     val id_prop = make_id_prop ctxt variances (tyco, mapper);
   193       (make_identity_prop variances (tyco, mapper));
       
   194     val qualify = Binding.qualify true prfx o Binding.name;
   216     val qualify = Binding.qualify true prfx o Binding.name;
   195     fun after_qed [single_compositionality, single_identity] lthy =
   217     fun after_qed [single_comp, single_id] lthy =
   196       lthy
   218       lthy
   197       |> Local_Theory.note ((qualify compositionalityN, []), single_compositionality)
   219       |> Local_Theory.note ((qualify compN, []), single_comp)
   198       ||>> Local_Theory.note ((qualify identityN, []), single_identity)
   220       ||>> Local_Theory.note ((qualify idN, []), single_id)
   199       |-> (fn ((_, [compositionality]), (_, [identity])) =>
   221       |-> (fn ((_, [comp]), (_, [id])) => fn lthy =>
   200           (Local_Theory.background_theory o Data.map)
   222         lthy
       
   223         |> Local_Theory.note ((qualify compositionalityN, []), [make_compositionality lthy comp])
       
   224         |> snd
       
   225         |> Local_Theory.note ((qualify identityN, []), [make_identity lthy variances id])
       
   226         |> snd
       
   227         |> (Local_Theory.background_theory o Data.map)
   201             (Symtab.update (tyco, { mapper = mapper, variances = variances,
   228             (Symtab.update (tyco, { mapper = mapper, variances = variances,
   202               compositionality = compositionality, identity = identity })));
   229               comp = comp, id = id })));
   203   in
   230   in
   204     thy
   231     thy
   205     |> Named_Target.theory_init
   232     |> Named_Target.theory_init
   206     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [compositionality_prop, identity_prop])
   233     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   207   end
   234   end
   208 
   235 
   209 val type_lifting = gen_type_lifting Sign.cert_term;
   236 val type_lifting = gen_type_lifting Sign.cert_term;
   210 val type_lifting_cmd = gen_type_lifting Syntax.read_term_global;
   237 val type_lifting_cmd = gen_type_lifting Syntax.read_term_global;
   211 
   238