src/HOL/Tools/type_lifting.ML
changeset 40968 a6fcd305f7dc
parent 40857 b3489aa6b63f
child 41298 aad679ca38d2
equal deleted inserted replaced
40965:54b6c9e1c157 40968:a6fcd305f7dc
       
     1 (*  Title:      HOL/Tools/type_lifting.ML
       
     2     Author:     Florian Haftmann, TU Muenchen
       
     3 
       
     4 Functorial structure of types.
       
     5 *)
       
     6 
       
     7 signature TYPE_LIFTING =
       
     8 sig
       
     9   val find_atomic: theory -> typ -> (typ * (bool * bool)) list
       
    10   val construct_mapper: theory -> (string * bool -> term)
       
    11     -> bool -> typ -> typ -> term
       
    12   val type_lifting: string option -> term -> theory -> Proof.state
       
    13   type entry
       
    14   val entries: theory -> entry Symtab.table
       
    15 end;
       
    16 
       
    17 structure Type_Lifting : TYPE_LIFTING =
       
    18 struct
       
    19 
       
    20 val compositionalityN = "compositionality";
       
    21 val identityN = "identity";
       
    22 
       
    23 (** functorial mappers and their properties **)
       
    24 
       
    25 (* bookkeeping *)
       
    26 
       
    27 type entry = { mapper: string, variances: (sort * (bool * bool)) list,
       
    28   compositionality: thm, identity: thm };
       
    29 
       
    30 structure Data = Theory_Data(
       
    31   type T = entry Symtab.table
       
    32   val empty = Symtab.empty
       
    33   fun merge (xy : T * T) = Symtab.merge (K true) xy
       
    34   val extend = I
       
    35 );
       
    36 
       
    37 val entries = Data.get;
       
    38 
       
    39 
       
    40 (* type analysis *)
       
    41 
       
    42 fun find_atomic thy T =
       
    43   let
       
    44     val variances_of = Option.map #variances o Symtab.lookup (Data.get thy);
       
    45     fun add_variance is_contra T =
       
    46       AList.map_default (op =) (T, (false, false))
       
    47         ((if is_contra then apsnd else apfst) (K true));
       
    48     fun analyze' is_contra (_, (co, contra)) T =
       
    49       (if co then analyze is_contra T else I)
       
    50       #> (if contra then analyze (not is_contra) T else I)
       
    51     and analyze is_contra (T as Type (tyco, Ts)) = (case variances_of tyco
       
    52           of NONE => add_variance is_contra T
       
    53            | SOME variances => fold2 (analyze' is_contra) variances Ts)
       
    54       | analyze is_contra T = add_variance is_contra T;
       
    55   in analyze false T [] end;
       
    56 
       
    57 fun construct_mapper thy atomic =
       
    58   let
       
    59     val lookup = the o Symtab.lookup (Data.get thy);
       
    60     fun constructs is_contra (_, (co, contra)) T T' =
       
    61       (if co then [construct is_contra T T'] else [])
       
    62       @ (if contra then [construct (not is_contra) T T'] else [])
       
    63     and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
       
    64           let
       
    65             val { mapper, variances, ... } = lookup tyco;
       
    66             val args = maps (fn (arg_pattern, (T, T')) =>
       
    67               constructs is_contra arg_pattern T T')
       
    68                 (variances ~~ (Ts ~~ Ts'));
       
    69             val (U, U') = if is_contra then (T', T) else (T, T');
       
    70           in list_comb (Const (mapper, map fastype_of args ---> U --> U'), args) end
       
    71       | construct is_contra (TFree (v, _)) (TFree _) = atomic (v, is_contra);
       
    72   in construct end;
       
    73 
       
    74 
       
    75 (* mapper properties *)
       
    76 
       
    77 fun make_compositionality_prop variances (tyco, mapper) =
       
    78   let
       
    79     fun invents n k nctxt =
       
    80       let
       
    81         val names = Name.invents nctxt n k;
       
    82       in (names, fold Name.declare names nctxt) end;
       
    83     val (((vs1, vs2), vs3), _) = Name.context
       
    84       |> invents Name.aT (length variances)
       
    85       ||>> invents Name.aT (length variances)
       
    86       ||>> invents Name.aT (length variances);
       
    87     fun mk_Ts vs = map2 (fn v => fn (sort, _) => TFree (v, sort))
       
    88       vs variances;
       
    89     val (Ts1, Ts2, Ts3) = (mk_Ts vs1, mk_Ts vs2, mk_Ts vs3);
       
    90     fun mk_argT ((T, T'), (_, (co, contra))) =
       
    91       (if co then [(T --> T')] else [])
       
    92       @ (if contra then [(T' --> T)] else []);
       
    93     val contras = maps (fn (_, (co, contra)) =>
       
    94       (if co then [false] else []) @ (if contra then [true] else [])) variances;
       
    95     val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
       
    96     val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
       
    97     val ((names21, names32), nctxt) = Name.context
       
    98       |> invents "f" (length Ts21)
       
    99       ||>> invents "f" (length Ts32);
       
   100     val T1 = Type (tyco, Ts1);
       
   101     val T2 = Type (tyco, Ts2);
       
   102     val T3 = Type (tyco, Ts3);
       
   103     val x = Free (the_single (Name.invents nctxt (Long_Name.base_name tyco) 1), T3);
       
   104     val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
       
   105     val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
       
   106       if not is_contra then
       
   107         Abs ("x", domain_type T32, Free (f21, T21) $ (Free (f32, T32) $ Bound 0))
       
   108       else
       
   109         Abs ("x", domain_type T21, Free (f32, T32) $ (Free (f21, T21) $ Bound 0))
       
   110       ) contras (args21 ~~ args32)
       
   111     fun mk_mapper T T' args = list_comb (Const (mapper,
       
   112       map fastype_of args ---> T --> T'), args);
       
   113     val lhs = mk_mapper T2 T1 (map Free args21) $
       
   114       (mk_mapper T3 T2 (map Free args32) $ x);
       
   115     val rhs = mk_mapper T3 T1 args31 $ x;
       
   116   in (map Free (args21 @ args32) @ [x], (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
       
   117 
       
   118 fun make_identity_prop variances (tyco, mapper) =
       
   119   let
       
   120     val vs = Name.invents Name.context Name.aT (length variances);
       
   121     val Ts = map2 (fn v => fn (sort, _) => TFree (v, sort)) vs variances;
       
   122     fun bool_num b = if b then 1 else 0;
       
   123     fun mk_argT (T, (_, (co, contra))) =
       
   124       replicate (bool_num co + bool_num contra) (T --> T)
       
   125     val Ts' = maps mk_argT (Ts ~~ variances)
       
   126     val T = Type (tyco, Ts);
       
   127     val x = Free (Long_Name.base_name tyco, T);
       
   128     val lhs = list_comb (Const (mapper, Ts' ---> T --> T),
       
   129       map (fn T => Abs ("x", domain_type T, Bound 0)) Ts') $ x;
       
   130   in (x, (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, x)) end;
       
   131 
       
   132 
       
   133 (* analyzing and registering mappers *)
       
   134 
       
   135 fun consume eq x [] = (false, [])
       
   136   | consume eq x (ys as z :: zs) = if eq (x, z) then (true, zs) else (false, ys);
       
   137 
       
   138 fun split_mapper_typ "fun" T =
       
   139       let
       
   140         val (Ts', T') = strip_type T;
       
   141         val (Ts'', T'') = split_last Ts';
       
   142         val (Ts''', T''') = split_last Ts'';
       
   143       in (Ts''', T''', T'' --> T') end
       
   144   | split_mapper_typ tyco T =
       
   145       let
       
   146         val (Ts', T') = strip_type T;
       
   147         val (Ts'', T'') = split_last Ts';
       
   148       in (Ts'', T'', T') end;
       
   149 
       
   150 fun analyze_variances thy tyco T =
       
   151   let
       
   152     fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ_global thy T);
       
   153     val (Ts, T1, T2) = split_mapper_typ tyco T
       
   154       handle List.Empty => bad_typ ();
       
   155     val _ = pairself
       
   156       ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
       
   157     val (vs1, vs2) = pairself (map dest_TFree o snd o dest_Type) (T1, T2)
       
   158       handle TYPE _ => bad_typ ();
       
   159     val _ = if has_duplicates (eq_fst (op =)) (vs1 @ vs2)
       
   160       then bad_typ () else ();
       
   161     fun check_variance_pair (var1 as (v1, sort1), var2 as (v2, sort2)) =
       
   162       let
       
   163         val coT = TFree var1 --> TFree var2;
       
   164         val contraT = TFree var2 --> TFree var1;
       
   165         val sort = Sign.inter_sort thy (sort1, sort2);
       
   166       in
       
   167         consume (op =) coT
       
   168         ##>> consume (op =) contraT
       
   169         #>> pair sort
       
   170       end;
       
   171     val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
       
   172     val _ = if null left_variances then () else bad_typ ();
       
   173   in variances end;
       
   174 
       
   175 fun gen_type_lifting prep_term some_prfx raw_t thy =
       
   176   let
       
   177     val (mapper, T) = case prep_term thy raw_t
       
   178      of Const cT => cT
       
   179       | t => error ("No constant: " ^ Syntax.string_of_term_global thy t);
       
   180     val prfx = the_default (Long_Name.base_name mapper) some_prfx;
       
   181     val _ = Type.no_tvars T;
       
   182     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
       
   183       | add_tycos _ = I;
       
   184     val tycos = add_tycos T [];
       
   185     val tyco = if tycos = ["fun"] then "fun"
       
   186       else case remove (op =) "fun" tycos
       
   187        of [tyco] => tyco
       
   188         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
       
   189     val variances = analyze_variances thy tyco T;
       
   190     val compositionality_prop = uncurry (fold_rev Logic.all)
       
   191       (make_compositionality_prop variances (tyco, mapper));
       
   192     val identity_prop = uncurry Logic.all
       
   193       (make_identity_prop variances (tyco, mapper));
       
   194     val qualify = Binding.qualify true prfx o Binding.name;
       
   195     fun after_qed [single_compositionality, single_identity] lthy =
       
   196       lthy
       
   197       |> Local_Theory.note ((qualify compositionalityN, []), single_compositionality)
       
   198       ||>> Local_Theory.note ((qualify identityN, []), single_identity)
       
   199       |-> (fn ((_, [compositionality]), (_, [identity])) =>
       
   200           (Local_Theory.background_theory o Data.map)
       
   201             (Symtab.update (tyco, { mapper = mapper, variances = variances,
       
   202               compositionality = compositionality, identity = identity })));
       
   203   in
       
   204     thy
       
   205     |> Named_Target.theory_init
       
   206     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [compositionality_prop, identity_prop])
       
   207   end
       
   208 
       
   209 val type_lifting = gen_type_lifting Sign.cert_term;
       
   210 val type_lifting_cmd = gen_type_lifting Syntax.read_term_global;
       
   211 
       
   212 val _ =
       
   213   Outer_Syntax.command "type_lifting" "register operations managing the functorial structure of a type" Keyword.thy_goal
       
   214     (Scan.option (Parse.name --| Parse.$$$ ":") -- Parse.term
       
   215       >> (fn (prfx, t) => Toplevel.print o (Toplevel.theory_to_proof (type_lifting_cmd prfx t))));
       
   216 
       
   217 end;