src/HOL/Tools/functor.ML
author wenzelm
Tue Sep 26 20:54:40 2017 +0200 (23 months ago)
changeset 66695 91500c024c7f
parent 63568 e63c8f2fbd28
child 67149 e61557884799
permissions -rw-r--r--
tuned;
blanchet@55467
     1
(*  Title:      HOL/Tools/functor.ML
haftmann@40582
     2
    Author:     Florian Haftmann, TU Muenchen
haftmann@40582
     3
haftmann@40968
     4
Functorial structure of types.
haftmann@40582
     5
*)
haftmann@40582
     6
blanchet@55467
     7
signature FUNCTOR =
haftmann@40582
     8
sig
haftmann@41388
     9
  val find_atomic: Proof.context -> typ -> (typ * (bool * bool)) list
haftmann@41388
    10
  val construct_mapper: Proof.context -> (string * bool -> term)
haftmann@40582
    11
    -> bool -> typ -> typ -> term
wenzelm@60488
    12
  val functor_: string option -> term -> local_theory -> Proof.state
blanchet@55467
    13
  val functor_cmd: string option -> string -> Proof.context -> Proof.state
haftmann@40582
    14
  type entry
haftmann@41390
    15
  val entries: Proof.context -> entry list Symtab.table
haftmann@40582
    16
end;
haftmann@40582
    17
blanchet@55467
    18
structure Functor : FUNCTOR =
haftmann@40582
    19
struct
haftmann@40582
    20
haftmann@41395
    21
(* bookkeeping *)
haftmann@41395
    22
haftmann@41371
    23
val compN = "comp";
haftmann@41371
    24
val idN = "id";
haftmann@40611
    25
val compositionalityN = "compositionality";
haftmann@40594
    26
val identityN = "identity";
haftmann@40594
    27
haftmann@41387
    28
type entry = { mapper: term, variances: (sort * (bool * bool)) list,
haftmann@41371
    29
  comp: thm, id: thm };
haftmann@40582
    30
wenzelm@41472
    31
structure Data = Generic_Data
wenzelm@41472
    32
(
haftmann@41390
    33
  type T = entry list Symtab.table
haftmann@40582
    34
  val empty = Symtab.empty
haftmann@40582
    35
  val extend = I
wenzelm@41472
    36
  fun merge data = Symtab.merge (K true) data
haftmann@40582
    37
);
haftmann@40582
    38
haftmann@41388
    39
val entries = Data.get o Context.Proof;
haftmann@40582
    40
haftmann@40582
    41
haftmann@40582
    42
(* type analysis *)
haftmann@40582
    43
wenzelm@59838
    44
fun term_with_typ ctxt T t =
wenzelm@59838
    45
  Envir.subst_term_types
wenzelm@59838
    46
    (Sign.typ_match (Proof_Context.theory_of ctxt) (fastype_of t, T) Vartab.empty) t;
haftmann@41389
    47
haftmann@41388
    48
fun find_atomic ctxt T =
haftmann@40582
    49
  let
haftmann@41390
    50
    val variances_of = Option.map #variances o try hd o Symtab.lookup_list (entries ctxt);
haftmann@40582
    51
    fun add_variance is_contra T =
haftmann@40582
    52
      AList.map_default (op =) (T, (false, false))
haftmann@40582
    53
        ((if is_contra then apsnd else apfst) (K true));
haftmann@40582
    54
    fun analyze' is_contra (_, (co, contra)) T =
haftmann@40582
    55
      (if co then analyze is_contra T else I)
haftmann@40582
    56
      #> (if contra then analyze (not is_contra) T else I)
haftmann@40582
    57
    and analyze is_contra (T as Type (tyco, Ts)) = (case variances_of tyco
haftmann@40582
    58
          of NONE => add_variance is_contra T
haftmann@40582
    59
           | SOME variances => fold2 (analyze' is_contra) variances Ts)
haftmann@40582
    60
      | analyze is_contra T = add_variance is_contra T;
haftmann@40582
    61
  in analyze false T [] end;
haftmann@40582
    62
haftmann@41388
    63
fun construct_mapper ctxt atomic =
haftmann@40582
    64
  let
haftmann@41390
    65
    val lookup = hd o Symtab.lookup_list (entries ctxt);
haftmann@40582
    66
    fun constructs is_contra (_, (co, contra)) T T' =
haftmann@40582
    67
      (if co then [construct is_contra T T'] else [])
haftmann@40582
    68
      @ (if contra then [construct (not is_contra) T T'] else [])
haftmann@40582
    69
    and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
haftmann@40582
    70
          let
haftmann@41388
    71
            val { mapper = raw_mapper, variances, ... } = lookup tyco;
haftmann@40582
    72
            val args = maps (fn (arg_pattern, (T, T')) =>
haftmann@40582
    73
              constructs is_contra arg_pattern T T')
haftmann@40582
    74
                (variances ~~ (Ts ~~ Ts'));
haftmann@40582
    75
            val (U, U') = if is_contra then (T', T) else (T, T');
haftmann@41388
    76
            val mapper = term_with_typ ctxt (map fastype_of args ---> U --> U') raw_mapper;
haftmann@41388
    77
          in list_comb (mapper, args) end
haftmann@40582
    78
      | construct is_contra (TFree (v, _)) (TFree _) = atomic (v, is_contra);
haftmann@40582
    79
  in construct end;
haftmann@40582
    80
haftmann@40582
    81
haftmann@40582
    82
(* mapper properties *)
haftmann@40582
    83
wenzelm@51717
    84
val compositionality_ss =
wenzelm@51717
    85
  simpset_of (put_simpset HOL_basic_ss @{context} addsimps [Simpdata.mk_eq @{thm comp_def}]);
haftmann@41387
    86
haftmann@41371
    87
fun make_comp_prop ctxt variances (tyco, mapper) =
haftmann@40582
    88
  let
haftmann@41371
    89
    val sorts = map fst variances
haftmann@41371
    90
    val (((vs3, vs2), vs1), _) = ctxt
haftmann@41371
    91
      |> Variable.invent_types sorts
haftmann@41371
    92
      ||>> Variable.invent_types sorts
haftmann@41371
    93
      ||>> Variable.invent_types sorts
haftmann@41371
    94
    val (Ts1, Ts2, Ts3) = (map TFree vs1, map TFree vs2, map TFree vs3);
haftmann@40582
    95
    fun mk_argT ((T, T'), (_, (co, contra))) =
haftmann@40582
    96
      (if co then [(T --> T')] else [])
haftmann@40582
    97
      @ (if contra then [(T' --> T)] else []);
haftmann@40582
    98
    val contras = maps (fn (_, (co, contra)) =>
haftmann@40582
    99
      (if co then [false] else []) @ (if contra then [true] else [])) variances;
haftmann@40582
   100
    val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
haftmann@40582
   101
    val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
haftmann@41371
   102
    fun invents n k nctxt =
haftmann@41371
   103
      let
wenzelm@43329
   104
        val names = Name.invent nctxt n k;
haftmann@41371
   105
      in (names, fold Name.declare names nctxt) end;
haftmann@41371
   106
    val ((names21, names32), nctxt) = Variable.names_of ctxt
haftmann@40582
   107
      |> invents "f" (length Ts21)
haftmann@40582
   108
      ||>> invents "f" (length Ts32);
haftmann@40582
   109
    val T1 = Type (tyco, Ts1);
haftmann@40582
   110
    val T2 = Type (tyco, Ts2);
haftmann@40582
   111
    val T3 = Type (tyco, Ts3);
haftmann@40582
   112
    val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
haftmann@40582
   113
    val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
haftmann@40582
   114
      if not is_contra then
haftmann@41371
   115
        HOLogic.mk_comp (Free (f21, T21), Free (f32, T32))
haftmann@40582
   116
      else
haftmann@41371
   117
        HOLogic.mk_comp (Free (f32, T32), Free (f21, T21))
haftmann@40582
   118
      ) contras (args21 ~~ args32)
haftmann@41395
   119
    fun mk_mapper T T' args = list_comb
haftmann@41395
   120
      (term_with_typ ctxt (map fastype_of args ---> T --> T') mapper, args);
haftmann@41387
   121
    val mapper21 = mk_mapper T2 T1 (map Free args21);
haftmann@41387
   122
    val mapper32 = mk_mapper T3 T2 (map Free args32);
haftmann@41387
   123
    val mapper31 = mk_mapper T3 T1 args31;
haftmann@41395
   124
    val eq1 = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
haftmann@41395
   125
      (HOLogic.mk_comp (mapper21, mapper32), mapper31);
wenzelm@43329
   126
    val x = Free (the_single (Name.invent nctxt (Long_Name.base_name tyco) 1), T3)
haftmann@41395
   127
    val eq2 = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
haftmann@41395
   128
      (mapper21 $ (mapper32 $ x), mapper31 $ x);
haftmann@41387
   129
    val comp_prop = fold_rev Logic.all (map Free (args21 @ args32)) eq1;
haftmann@41387
   130
    val compositionality_prop = fold_rev Logic.all (map Free (args21 @ args32) @ [x]) eq2;
wenzelm@61841
   131
    fun prove_compositionality ctxt comp_thm =
wenzelm@61841
   132
      Goal.prove_sorry ctxt [] [] compositionality_prop
wenzelm@61841
   133
        (K (ALLGOALS (Method.insert_tac ctxt [@{thm fun_cong} OF [comp_thm]]
wenzelm@61841
   134
          THEN' Simplifier.asm_lr_simp_tac (put_simpset compositionality_ss ctxt)
wenzelm@61841
   135
          THEN_ALL_NEW (Goal.assume_rule_tac ctxt))));
haftmann@41387
   136
  in (comp_prop, prove_compositionality) end;
haftmann@41387
   137
wenzelm@51717
   138
val identity_ss =
wenzelm@51717
   139
  simpset_of (put_simpset HOL_basic_ss @{context} addsimps [Simpdata.mk_eq @{thm id_def}]);
haftmann@40582
   140
haftmann@41371
   141
fun make_id_prop ctxt variances (tyco, mapper) =
haftmann@40582
   142
  let
haftmann@46810
   143
    val (vs, _) = Variable.invent_types (map fst variances) ctxt;
haftmann@41371
   144
    val Ts = map TFree vs;
haftmann@40582
   145
    fun bool_num b = if b then 1 else 0;
haftmann@40582
   146
    fun mk_argT (T, (_, (co, contra))) =
haftmann@41387
   147
      replicate (bool_num co + bool_num contra) T
haftmann@41387
   148
    val arg_Ts = maps mk_argT (Ts ~~ variances)
haftmann@40582
   149
    val T = Type (tyco, Ts);
haftmann@41387
   150
    val head = term_with_typ ctxt (map (fn T => T --> T) arg_Ts ---> T --> T) mapper;
haftmann@41387
   151
    val lhs1 = list_comb (head, map (HOLogic.id_const) arg_Ts);
haftmann@41387
   152
    val lhs2 = list_comb (head, map (fn arg_T => Abs ("x", arg_T, Bound 0)) arg_Ts);
haftmann@41387
   153
    val rhs = HOLogic.id_const T;
wenzelm@59058
   154
    val (id_prop, identity_prop) =
wenzelm@59058
   155
      apply2 (HOLogic.mk_Trueprop o HOLogic.mk_eq o rpair rhs) (lhs1, lhs2);
wenzelm@61841
   156
    fun prove_identity ctxt id_thm =
wenzelm@61841
   157
      Goal.prove_sorry ctxt [] [] identity_prop
wenzelm@61841
   158
        (K (ALLGOALS (Method.insert_tac ctxt [id_thm] THEN'
wenzelm@61841
   159
          Simplifier.asm_lr_simp_tac (put_simpset identity_ss ctxt))));
haftmann@41387
   160
  in (id_prop, prove_identity) end;
haftmann@40582
   161
haftmann@40582
   162
haftmann@40597
   163
(* analyzing and registering mappers *)
haftmann@40582
   164
blanchet@55467
   165
fun consume _ _ [] = (false, [])
haftmann@40594
   166
  | consume eq x (ys as z :: zs) = if eq (x, z) then (true, zs) else (false, ys);
haftmann@40594
   167
haftmann@40587
   168
fun split_mapper_typ "fun" T =
haftmann@40587
   169
      let
haftmann@40587
   170
        val (Ts', T') = strip_type T;
haftmann@40587
   171
        val (Ts'', T'') = split_last Ts';
haftmann@40587
   172
        val (Ts''', T''') = split_last Ts'';
haftmann@40587
   173
      in (Ts''', T''', T'' --> T') end
haftmann@46810
   174
  | split_mapper_typ _ T =
haftmann@40587
   175
      let
haftmann@40587
   176
        val (Ts', T') = strip_type T;
haftmann@40587
   177
        val (Ts'', T'') = split_last Ts';
haftmann@40587
   178
      in (Ts'', T'', T') end;
haftmann@40587
   179
haftmann@46852
   180
fun analyze_mapper ctxt input_mapper =
haftmann@46852
   181
  let
haftmann@46852
   182
    val T = fastype_of input_mapper;
haftmann@46852
   183
    val _ = Type.no_tvars T;
haftmann@46852
   184
    val _ =
haftmann@46852
   185
      if null (subtract (op =) (Term.add_tfreesT T []) (Term.add_tfrees input_mapper []))
haftmann@46852
   186
      then ()
haftmann@46852
   187
      else error ("Illegal additional type variable(s) in term: " ^ Syntax.string_of_term ctxt input_mapper);
haftmann@46852
   188
    val _ =
haftmann@46852
   189
      if null (Term.add_vars (singleton
haftmann@46852
   190
        (Variable.export_terms (Variable.auto_fixes input_mapper ctxt) ctxt) input_mapper) [])
haftmann@46852
   191
      then ()
haftmann@46852
   192
      else error ("Illegal locally free variable(s) in term: "
wenzelm@63568
   193
        ^ Syntax.string_of_term ctxt input_mapper);
haftmann@46852
   194
    val mapper = singleton (Variable.polymorphic ctxt) input_mapper;
haftmann@46852
   195
    val _ =
haftmann@46852
   196
      if null (Term.add_tfreesT (fastype_of mapper) []) then ()
haftmann@46852
   197
      else error ("Illegal locally fixed type variable(s) in type: " ^ Syntax.string_of_typ ctxt T);
haftmann@46852
   198
    fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
haftmann@46852
   199
      | add_tycos _ = I;
haftmann@46852
   200
    val tycos = add_tycos T [];
haftmann@46852
   201
    val tyco = if tycos = ["fun"] then "fun"
haftmann@46852
   202
      else case remove (op =) "fun" tycos
haftmann@46852
   203
       of [tyco] => tyco
haftmann@46852
   204
        | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ ctxt T);
haftmann@46852
   205
  in (mapper, T, tyco) end;
haftmann@46852
   206
haftmann@41390
   207
fun analyze_variances ctxt tyco T =
haftmann@40587
   208
  let
haftmann@41390
   209
    fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ ctxt T);
haftmann@40587
   210
    val (Ts, T1, T2) = split_mapper_typ tyco T
haftmann@40587
   211
      handle List.Empty => bad_typ ();
wenzelm@59058
   212
    val _ =
wenzelm@59058
   213
      apply2 ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
wenzelm@59058
   214
        handle TYPE _ => bad_typ ();
wenzelm@59058
   215
    val (vs1, vs2) =
wenzelm@59058
   216
      apply2 (map dest_TFree o snd o dest_Type) (T1, T2)
wenzelm@59058
   217
        handle TYPE _ => bad_typ ();
haftmann@40587
   218
    val _ = if has_duplicates (eq_fst (op =)) (vs1 @ vs2)
haftmann@40587
   219
      then bad_typ () else ();
haftmann@46810
   220
    fun check_variance_pair (var1 as (_, sort1), var2 as (_, sort2)) =
haftmann@40594
   221
      let
haftmann@40594
   222
        val coT = TFree var1 --> TFree var2;
haftmann@40594
   223
        val contraT = TFree var2 --> TFree var1;
wenzelm@42361
   224
        val sort = Sign.inter_sort (Proof_Context.theory_of ctxt) (sort1, sort2);
haftmann@40594
   225
      in
haftmann@40594
   226
        consume (op =) coT
haftmann@40594
   227
        ##>> consume (op =) contraT
haftmann@40594
   228
        #>> pair sort
haftmann@40594
   229
      end;
haftmann@40594
   230
    val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
haftmann@40594
   231
    val _ = if null left_variances then () else bad_typ ();
haftmann@40594
   232
  in variances end;
haftmann@40587
   233
blanchet@55467
   234
fun gen_functor prep_term some_prfx raw_mapper lthy =
haftmann@40583
   235
  let
haftmann@46852
   236
    val (mapper, T, tyco) = analyze_mapper lthy (prep_term lthy raw_mapper);
haftmann@41371
   237
    val prfx = the_default (Long_Name.base_name tyco) some_prfx;
haftmann@41390
   238
    val variances = analyze_variances lthy tyco T;
haftmann@41390
   239
    val (comp_prop, prove_compositionality) = make_comp_prop lthy variances (tyco, mapper);
haftmann@41390
   240
    val (id_prop, prove_identity) = make_id_prop lthy variances (tyco, mapper);
haftmann@40856
   241
    val qualify = Binding.qualify true prfx o Binding.name;
haftmann@41389
   242
    fun mapper_declaration comp_thm id_thm phi context =
haftmann@41389
   243
      let
wenzelm@42388
   244
        val typ_instance = Sign.typ_instance (Context.theory_of context);
haftmann@41389
   245
        val mapper' = Morphism.term phi mapper;
wenzelm@59058
   246
        val T_T' = apply2 fastype_of (mapper, mapper');
haftmann@46852
   247
        val vars = Term.add_vars mapper' [];
wenzelm@42388
   248
      in
haftmann@46852
   249
        if null vars andalso typ_instance T_T' andalso typ_instance (swap T_T')
haftmann@41390
   250
        then (Data.map o Symtab.cons_list) (tyco,
haftmann@41389
   251
          { mapper = mapper', variances = variances,
haftmann@41390
   252
            comp = Morphism.thm phi comp_thm, id = Morphism.thm phi id_thm }) context
haftmann@41389
   253
        else context
haftmann@41389
   254
      end;
haftmann@41387
   255
    fun after_qed [single_comp_thm, single_id_thm] lthy =
haftmann@40587
   256
      lthy
haftmann@41387
   257
      |> Local_Theory.note ((qualify compN, []), single_comp_thm)
haftmann@41387
   258
      ||>> Local_Theory.note ((qualify idN, []), single_id_thm)
haftmann@41387
   259
      |-> (fn ((_, [comp_thm]), (_, [id_thm])) => fn lthy =>
haftmann@41371
   260
        lthy
haftmann@41388
   261
        |> Local_Theory.note ((qualify compositionalityN, []),
haftmann@41388
   262
            [prove_compositionality lthy comp_thm])
haftmann@41371
   263
        |> snd
haftmann@41388
   264
        |> Local_Theory.note ((qualify identityN, []),
haftmann@41388
   265
            [prove_identity lthy id_thm])
haftmann@41388
   266
        |> snd
wenzelm@45291
   267
        |> Local_Theory.declaration {syntax = false, pervasive = false}
wenzelm@45291
   268
          (mapper_declaration comp_thm id_thm))
haftmann@40583
   269
  in
haftmann@41390
   270
    lthy
haftmann@41371
   271
    |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
haftmann@40583
   272
  end
haftmann@40583
   273
wenzelm@60488
   274
val functor_ = gen_functor Syntax.check_term;
blanchet@55467
   275
val functor_cmd = gen_functor Syntax.read_term;
haftmann@40583
   276
wenzelm@46961
   277
val _ =
wenzelm@59936
   278
  Outer_Syntax.local_theory_to_proof @{command_keyword functor}
wenzelm@46961
   279
    "register operations managing the functorial structure of a type"
wenzelm@63120
   280
    (Scan.option (Parse.name --| @{keyword ":"}) -- Parse.term >> uncurry functor_cmd);
haftmann@40583
   281
haftmann@40582
   282
end;