src/HOL/Tools/type_lifting.ML
changeset 41390 207ee8f8a19c
parent 41389 d06a6d15a958
child 41395 cf5ab80b6717
equal deleted inserted replaced
41389:d06a6d15a958 41390:207ee8f8a19c
     7 signature TYPE_LIFTING =
     7 signature TYPE_LIFTING =
     8 sig
     8 sig
     9   val find_atomic: Proof.context -> typ -> (typ * (bool * bool)) list
     9   val find_atomic: Proof.context -> typ -> (typ * (bool * bool)) list
    10   val construct_mapper: Proof.context -> (string * bool -> term)
    10   val construct_mapper: Proof.context -> (string * bool -> term)
    11     -> bool -> typ -> typ -> term
    11     -> bool -> typ -> typ -> term
    12   val type_lifting: string option -> term -> theory -> Proof.state
    12   val type_lifting: string option -> term -> local_theory -> Proof.state
    13   type entry
    13   type entry
    14   val entries: Proof.context -> entry Symtab.table
    14   val entries: Proof.context -> entry list Symtab.table
    15 end;
    15 end;
    16 
    16 
    17 structure Type_Lifting : TYPE_LIFTING =
    17 structure Type_Lifting : TYPE_LIFTING =
    18 struct
    18 struct
    19 
    19 
    28 
    28 
    29 type entry = { mapper: term, variances: (sort * (bool * bool)) list,
    29 type entry = { mapper: term, variances: (sort * (bool * bool)) list,
    30   comp: thm, id: thm };
    30   comp: thm, id: thm };
    31 
    31 
    32 structure Data = Generic_Data(
    32 structure Data = Generic_Data(
    33   type T = entry Symtab.table
    33   type T = entry list Symtab.table
    34   val empty = Symtab.empty
    34   val empty = Symtab.empty
    35   fun merge (xy : T * T) = Symtab.merge (K true) xy
    35   fun merge (xy : T * T) = Symtab.merge (K true) xy
    36   val extend = I
    36   val extend = I
    37 );
    37 );
    38 
    38 
    44 fun term_with_typ ctxt T t = Envir.subst_term_types
    44 fun term_with_typ ctxt T t = Envir.subst_term_types
    45   (Type.typ_match (ProofContext.tsig_of ctxt) (fastype_of t, T) Vartab.empty) t;
    45   (Type.typ_match (ProofContext.tsig_of ctxt) (fastype_of t, T) Vartab.empty) t;
    46 
    46 
    47 fun find_atomic ctxt T =
    47 fun find_atomic ctxt T =
    48   let
    48   let
    49     val variances_of = Option.map #variances o Symtab.lookup (entries ctxt);
    49     val variances_of = Option.map #variances o try hd o Symtab.lookup_list (entries ctxt);
    50     fun add_variance is_contra T =
    50     fun add_variance is_contra T =
    51       AList.map_default (op =) (T, (false, false))
    51       AList.map_default (op =) (T, (false, false))
    52         ((if is_contra then apsnd else apfst) (K true));
    52         ((if is_contra then apsnd else apfst) (K true));
    53     fun analyze' is_contra (_, (co, contra)) T =
    53     fun analyze' is_contra (_, (co, contra)) T =
    54       (if co then analyze is_contra T else I)
    54       (if co then analyze is_contra T else I)
    59       | analyze is_contra T = add_variance is_contra T;
    59       | analyze is_contra T = add_variance is_contra T;
    60   in analyze false T [] end;
    60   in analyze false T [] end;
    61 
    61 
    62 fun construct_mapper ctxt atomic =
    62 fun construct_mapper ctxt atomic =
    63   let
    63   let
    64     val lookup = the o Symtab.lookup (entries ctxt);
    64     val lookup = hd o Symtab.lookup_list (entries ctxt);
    65     fun constructs is_contra (_, (co, contra)) T T' =
    65     fun constructs is_contra (_, (co, contra)) T T' =
    66       (if co then [construct is_contra T T'] else [])
    66       (if co then [construct is_contra T T'] else [])
    67       @ (if contra then [construct (not is_contra) T T'] else [])
    67       @ (if contra then [construct (not is_contra) T T'] else [])
    68     and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
    68     and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
    69           let
    69           let
   165       let
   165       let
   166         val (Ts', T') = strip_type T;
   166         val (Ts', T') = strip_type T;
   167         val (Ts'', T'') = split_last Ts';
   167         val (Ts'', T'') = split_last Ts';
   168       in (Ts'', T'', T') end;
   168       in (Ts'', T'', T') end;
   169 
   169 
   170 fun analyze_variances thy tyco T =
   170 fun analyze_variances ctxt tyco T =
   171   let
   171   let
   172     fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ_global thy T);
   172     fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ ctxt T);
   173     val (Ts, T1, T2) = split_mapper_typ tyco T
   173     val (Ts, T1, T2) = split_mapper_typ tyco T
   174       handle List.Empty => bad_typ ();
   174       handle List.Empty => bad_typ ();
   175     val _ = pairself
   175     val _ = pairself
   176       ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
   176       ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
   177       handle TYPE _ => bad_typ ();
   177       handle TYPE _ => bad_typ ();
   181       then bad_typ () else ();
   181       then bad_typ () else ();
   182     fun check_variance_pair (var1 as (v1, sort1), var2 as (v2, sort2)) =
   182     fun check_variance_pair (var1 as (v1, sort1), var2 as (v2, sort2)) =
   183       let
   183       let
   184         val coT = TFree var1 --> TFree var2;
   184         val coT = TFree var1 --> TFree var2;
   185         val contraT = TFree var2 --> TFree var1;
   185         val contraT = TFree var2 --> TFree var1;
   186         val sort = Sign.inter_sort thy (sort1, sort2);
   186         val sort = Sign.inter_sort (ProofContext.theory_of ctxt) (sort1, sort2);
   187       in
   187       in
   188         consume (op =) coT
   188         consume (op =) coT
   189         ##>> consume (op =) contraT
   189         ##>> consume (op =) contraT
   190         #>> pair sort
   190         #>> pair sort
   191       end;
   191       end;
   192     val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
   192     val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
   193     val _ = if null left_variances then () else bad_typ ();
   193     val _ = if null left_variances then () else bad_typ ();
   194   in variances end;
   194   in variances end;
   195 
   195 
   196 fun gen_type_lifting prep_term some_prfx raw_mapper thy =
   196 fun gen_type_lifting prep_term some_prfx raw_mapper lthy =
   197   let
   197   let
   198     val input_mapper = prep_term thy raw_mapper;
   198     val input_mapper = prep_term lthy raw_mapper;
   199     val T = fastype_of input_mapper;
   199     val T = fastype_of input_mapper;
   200     val _ = Type.no_tvars T;
   200     val _ = Type.no_tvars T;
   201     val mapper = singleton (Variable.polymorphic (ProofContext.init_global thy)) input_mapper;
   201     val mapper = singleton (Variable.polymorphic lthy) input_mapper;
       
   202     val _ = if null (Term.add_tfreesT (fastype_of mapper) []) then ()
       
   203       else error ("Illegal locally fixed variables in type: " ^ Syntax.string_of_typ lthy T);
   202     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   204     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   203       | add_tycos _ = I;
   205       | add_tycos _ = I;
   204     val tycos = add_tycos T [];
   206     val tycos = add_tycos T [];
   205     val tyco = if tycos = ["fun"] then "fun"
   207     val tyco = if tycos = ["fun"] then "fun"
   206       else case remove (op =) "fun" tycos
   208       else case remove (op =) "fun" tycos
   207        of [tyco] => tyco
   209        of [tyco] => tyco
   208         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
   210         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ lthy T);
   209     val prfx = the_default (Long_Name.base_name tyco) some_prfx;
   211     val prfx = the_default (Long_Name.base_name tyco) some_prfx;
   210     val variances = analyze_variances thy tyco T;
   212     val variances = analyze_variances lthy tyco T;
   211     val ctxt = ProofContext.init_global thy;
   213     val (comp_prop, prove_compositionality) = make_comp_prop lthy variances (tyco, mapper);
   212     val (comp_prop, prove_compositionality) = make_comp_prop ctxt variances (tyco, mapper);
   214     val (id_prop, prove_identity) = make_id_prop lthy variances (tyco, mapper);
   213     val (id_prop, prove_identity) = make_id_prop ctxt variances (tyco, mapper);
       
   214     val qualify = Binding.qualify true prfx o Binding.name;
   215     val qualify = Binding.qualify true prfx o Binding.name;
   215     fun mapper_declaration comp_thm id_thm phi context =
   216     fun mapper_declaration comp_thm id_thm phi context =
   216       let
   217       let
   217         val typ_instance = Type.typ_instance (ProofContext.tsig_of (Context.proof_of context));
   218         val typ_instance = Type.typ_instance (ProofContext.tsig_of (Context.proof_of context));
   218         val mapper' = Morphism.term phi mapper;
   219         val mapper' = Morphism.term phi mapper;
   219         val T_T' = pairself fastype_of (mapper, mapper');
   220         val T_T' = pairself fastype_of (mapper, mapper');
   220       in if typ_instance T_T' andalso typ_instance (swap T_T')
   221       in if typ_instance T_T' andalso typ_instance (swap T_T')
   221         then Data.map (Symtab.update (tyco,
   222         then (Data.map o Symtab.cons_list) (tyco,
   222           { mapper = mapper', variances = variances,
   223           { mapper = mapper', variances = variances,
   223             comp = Morphism.thm phi comp_thm, id = Morphism.thm phi id_thm })) context
   224             comp = Morphism.thm phi comp_thm, id = Morphism.thm phi id_thm }) context
   224         else context
   225         else context
   225       end;
   226       end;
   226     fun after_qed [single_comp_thm, single_id_thm] lthy =
   227     fun after_qed [single_comp_thm, single_id_thm] lthy =
   227       lthy
   228       lthy
   228       |> Local_Theory.note ((qualify compN, []), single_comp_thm)
   229       |> Local_Theory.note ((qualify compN, []), single_comp_thm)
   235         |> Local_Theory.note ((qualify identityN, []),
   236         |> Local_Theory.note ((qualify identityN, []),
   236             [prove_identity lthy id_thm])
   237             [prove_identity lthy id_thm])
   237         |> snd
   238         |> snd
   238         |> Local_Theory.declaration false (mapper_declaration comp_thm id_thm))
   239         |> Local_Theory.declaration false (mapper_declaration comp_thm id_thm))
   239   in
   240   in
   240     thy
   241     lthy
   241     |> Named_Target.theory_init
       
   242     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   242     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   243   end
   243   end
   244 
   244 
   245 val type_lifting = gen_type_lifting Sign.cert_term;
   245 val type_lifting = gen_type_lifting Syntax.check_term;
   246 val type_lifting_cmd = gen_type_lifting Syntax.read_term_global;
   246 val type_lifting_cmd = gen_type_lifting Syntax.read_term;
   247 
   247 
   248 val _ =
   248 val _ = Outer_Syntax.local_theory_to_proof "type_lifting"
   249   Outer_Syntax.command "type_lifting" "register operations managing the functorial structure of a type" Keyword.thy_goal
   249   "register operations managing the functorial structure of a type"
   250     (Scan.option (Parse.name --| Parse.$$$ ":") -- Parse.term
   250   Keyword.thy_goal (Scan.option (Parse.name --| Parse.$$$ ":") -- Parse.term
   251       >> (fn (prfx, t) => Toplevel.print o (Toplevel.theory_to_proof (type_lifting_cmd prfx t))));
   251     >> (fn (prfx, t) => type_lifting_cmd prfx t));
   252 
   252 
   253 end;
   253 end;