src/Pure/Tools/codegen_funcgr.ML
changeset 22554 d1499fff65d8
parent 22507 3572bc633d9a
child 22570 f166a5416b3f
equal deleted inserted replaced
22553:b860975e47b4 22554:d1499fff65d8
    78 fun consts_of (const, []) = []
    78 fun consts_of (const, []) = []
    79   | consts_of (const, thms as thm :: _) = 
    79   | consts_of (const, thms as thm :: _) = 
    80       let
    80       let
    81         val thy = Thm.theory_of_thm thm;
    81         val thy = Thm.theory_of_thm thm;
    82         val is_refl = curry CodegenConsts.eq_const const;
    82         val is_refl = curry CodegenConsts.eq_const const;
    83         fun the_const c = case try (CodegenConsts.norm_of_typ thy) c
    83         fun the_const c = case try (CodegenConsts.const_of_cexpr thy) c
    84          of SOME const => if is_refl const then I else insert CodegenConsts.eq_const const
    84          of SOME const => if is_refl const then I else insert CodegenConsts.eq_const const
    85           | NONE => I
    85           | NONE => I
    86       in fold_consts the_const thms [] end;
    86       in fold_consts the_const thms [] end;
    87 
    87 
    88 fun insts_of thy algebra c ty_decl ty =
    88 fun insts_of thy algebra c ty_decl ty =
   145         val tvars = fold match cs Vartab.empty;
   145         val tvars = fold match cs Vartab.empty;
   146       in map (CodegenFunc.inst_thm tvars) thms end;
   146       in map (CodegenFunc.inst_thm tvars) thms end;
   147 
   147 
   148 fun resort_funcss thy algebra funcgr =
   148 fun resort_funcss thy algebra funcgr =
   149   let
   149   let
   150     val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodegenConsts.norm_of_typ thy);
   150     val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodegenConsts.const_of_cexpr thy);
   151     fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
   151     fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
   152       handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
   152       handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
   153                     ^ ",\nfor constant " ^ CodegenConsts.string_of_const thy const
   153                     ^ ",\nfor constant " ^ CodegenConsts.string_of_const thy const
   154                     ^ "\nin defining equations\n"
   154                     ^ "\nin defining equations\n"
   155                     ^ (cat_lines o map string_of_thm) thms)
   155                     ^ (cat_lines o map string_of_thm) thms)
   160             val thms' as thm' :: _ = resort_thms algebra tap_typ thms
   160             val thms' as thm' :: _ = resort_thms algebra tap_typ thms
   161             val ty' = CodegenFunc.typ_func thm';
   161             val ty' = CodegenFunc.typ_func thm';
   162           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   162           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   163     fun resort_recs funcss =
   163     fun resort_recs funcss =
   164       let
   164       let
   165         fun tap_typ c_ty = case try (CodegenConsts.norm_of_typ thy) c_ty
   165         fun tap_typ c_ty = case try (CodegenConsts.const_of_cexpr thy) c_ty
   166          of SOME const => AList.lookup (CodegenConsts.eq_const) funcss const
   166          of SOME const => AList.lookup (CodegenConsts.eq_const) funcss const
   167               |> these
   167               |> these
   168               |> try hd
   168               |> try hd
   169               |> Option.map CodegenFunc.typ_func
   169               |> Option.map CodegenFunc.typ_func
   170           | NONE => NONE;
   170           | NONE => NONE;
   175       let
   175       let
   176         val (unchanged, funcss') = resort_recs funcss;
   176         val (unchanged, funcss') = resort_recs funcss;
   177       in if unchanged then funcss' else resort_rec_until funcss' end;
   177       in if unchanged then funcss' else resort_rec_until funcss' end;
   178   in map resort_dep #> resort_rec_until end;
   178   in map resort_dep #> resort_rec_until end;
   179 
   179 
   180 fun classop_const thy algebra class classop tyco =
       
   181   let
       
   182     val sorts = Sorts.mg_domain algebra tyco [class]
       
   183     val (var, _) = try (AxClass.params_of_class thy) class |> the_default ("'a", []);
       
   184     val vs = Name.names (Name.declare var Name.context) "'a" sorts;
       
   185     val arity_typ = Type (tyco, map TFree vs);
       
   186   in (classop, [arity_typ]) end;
       
   187 
       
   188 fun instances_of thy algebra insts =
   180 fun instances_of thy algebra insts =
   189   let
   181   let
   190     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   182     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   191     fun all_classops tyco class =
   183     fun all_classops tyco class =
   192       try (AxClass.params_of_class thy) class
   184       try (AxClass.params_of_class thy) class
   193       |> Option.map snd
   185       |> Option.map snd
   194       |> these
   186       |> these
   195       |> map (fn (c, _) => classop_const thy algebra class c tyco)
   187       |> map (fn (c, _) => (c, SOME tyco))
   196       |> map (CodegenConsts.norm thy)
       
   197   in
   188   in
   198     Symtab.empty
   189     Symtab.empty
   199     |> fold (fn (tyco, class) =>
   190     |> fold (fn (tyco, class) =>
   200         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   191         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   201     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops tyco)
   192     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops tyco)
   202          (Graph.all_succs thy_classes classes))) tab [])
   193          (Graph.all_succs thy_classes classes))) tab [])
   203   end;
   194   end;
   204 
   195 
   205 fun instances_of_consts thy algebra funcgr consts =
   196 fun instances_of_consts thy algebra funcgr consts =
   206   let
   197   let
   207     fun inst (const as (c, ty)) = case try (CodegenConsts.norm_of_typ thy) const
   198     fun inst (cexpr as (c, ty)) = insts_of thy algebra c
   208      of SOME const => insts_of thy algebra c (fst (Constgraph.get_node funcgr const)) ty
   199       ((fst o Constgraph.get_node funcgr o CodegenConsts.const_of_cexpr thy) cexpr)
   209       | NONE => [];
   200       ty handle CLASS_ERROR => [];
   210   in
   201   in
   211     []
   202     []
   212     |> fold (fold (insert (op =)) o inst) consts
   203     |> fold (fold (insert (op =)) o inst) consts
   213     |> instances_of thy algebra
   204     |> instances_of thy algebra
   214   end;
   205   end;
   246 fun merge_funcss rewrites thy algebra raw_funcss funcgr =
   237 fun merge_funcss rewrites thy algebra raw_funcss funcgr =
   247   let
   238   let
   248     val funcss = raw_funcss
   239     val funcss = raw_funcss
   249       |> resort_funcss thy algebra funcgr
   240       |> resort_funcss thy algebra funcgr
   250       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   241       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   251     fun classop_typ (c, [typarg]) class =
   242     fun typ_func const [] = CodegenData.default_typ thy const
   252       let
   243       | typ_func (_, NONE) (thm :: _) = CodegenFunc.typ_func thm
   253         val ty = Sign.the_const_type thy c;
   244       | typ_func (const as (c, SOME tyco)) (thms as (thm :: _)) =
   254         val inst = case typarg
   245           let
   255          of Type (tyco, _) => classop_const thy algebra class c tyco
   246             val ty = CodegenFunc.typ_func thm;
   256               |> snd
   247             val SOME class = AxClass.class_of_param thy c;
   257               |> the_single
   248             val sorts_decl = Sorts.mg_domain algebra tyco [class];
   258               |> Logic.varifyT
   249             val tys = CodegenConsts.typargs thy (c, ty);
   259           | _ => TVar (("'a", 0), [class]);
   250             val sorts = map (snd o dest_TVar) tys;
   260       in Term.map_type_tvar (K inst) ty end;
   251           in if sorts = sorts_decl then ty
   261     fun default_typ (const as (c, tys)) = case AxClass.class_of_param thy c
   252             else raise INVALID ([const], "Illegal instantation for class operation "
   262          of SOME class => classop_typ const class
   253               ^ CodegenConsts.string_of_const thy const
   263           | NONE => (case CodegenData.tap_typ thy const
   254               ^ "\nin defining equations\n"
   264              of SOME ty => ty
   255               ^ (cat_lines o map string_of_thm) thms)
   265               | NONE => (case CodegenData.get_constr_typ thy const
   256           end;
   266                  of SOME ty => ty
   257     fun add_funcs (const, thms) =
   267                   | NONE => Sign.the_const_type thy c))
   258       Constgraph.new_node (const, (typ_func const thms, thms));
   268     fun typ_func (const as (c, tys)) thms thm =
       
   269       let
       
   270         val ty = CodegenFunc.typ_func thm;
       
   271       in case AxClass.class_of_param thy c
       
   272        of SOME class => (case tys
       
   273            of [Type _] => let val ty_decl = classop_typ const class
       
   274               in if Sign.typ_equiv thy (ty, ty_decl) then ty
       
   275               else raise raise INVALID ([const], "Illegal instantation for class operation "
       
   276                     ^ CodegenConsts.string_of_const thy const
       
   277                     ^ ":\n" ^ CodegenConsts.string_of_typ thy ty_decl
       
   278                     ^ "\nto " ^ CodegenConsts.string_of_typ thy ty
       
   279                     ^ "\nin defining equations\n"
       
   280                     ^ (cat_lines o map string_of_thm) thms)
       
   281               end
       
   282             | _ => ty)
       
   283         | NONE => ty
       
   284       end;
       
   285     fun add_funcs (const, thms as thm :: _) =
       
   286           Constgraph.new_node (const, (typ_func const thms thm, thms))
       
   287       | add_funcs (const, []) =
       
   288           Constgraph.new_node (const, (default_typ const, []));
       
   289     fun add_deps (funcs as (const, thms)) funcgr =
   259     fun add_deps (funcs as (const, thms)) funcgr =
   290       let
   260       let
   291         val deps = consts_of funcs;
   261         val deps = consts_of funcs;
   292         val insts = instances_of_consts thy algebra funcgr
   262         val insts = instances_of_consts thy algebra funcgr
   293           (fold_consts (insert (op =)) thms []);
   263           (fold_consts (insert (op =)) thms []);
   337     val (consts', funcgr') = fold_map try_const consts funcgr;
   307     val (consts', funcgr') = fold_map try_const consts funcgr;
   338   in (map_filter I consts', funcgr') end;
   308   in (map_filter I consts', funcgr') end;
   339 
   309 
   340 fun ensure_consts_term rewrites thy f ct funcgr =
   310 fun ensure_consts_term rewrites thy f ct funcgr =
   341   let
   311   let
       
   312     fun consts_of thy t =
       
   313       fold_aterms (fn Const c => cons (CodegenConsts.const_of_cexpr thy c) | _ => I) t []
   342     fun rhs_conv conv thm =
   314     fun rhs_conv conv thm =
   343       let
   315       let
   344         val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm;
   316         val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm;
   345       in Thm.transitive thm thm' end
   317       in Thm.transitive thm thm' end
   346     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   318     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   347     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   319     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   348     val thm1 = CodegenData.preprocess_cterm ct
   320     val thm1 = CodegenData.preprocess_cterm ct
   349       |> fold (rhs_conv o MetaSimplifier.rewrite false o single) (rewrites thy);
   321       |> fold (rhs_conv o MetaSimplifier.rewrite false o single) (rewrites thy);
   350     val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
   322     val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
   351     val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
   323     val consts = consts_of thy (Thm.term_of ct');
   352     val funcgr' = ensure_consts rewrites thy consts funcgr;
   324     val funcgr' = ensure_consts rewrites thy consts funcgr;
   353     val algebra = CodegenData.coregular_algebra thy;
   325     val algebra = CodegenData.coregular_algebra thy;
   354     val (_, thm2) = Thm.varifyT' [] thm1;
   326     val (_, thm2) = Thm.varifyT' [] thm1;
   355     val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
   327     val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
   356     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodegenConsts.norm_of_typ thy);
   328     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodegenConsts.const_of_cexpr thy);
   357     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   329     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   358     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
   330     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
   359     fun inst thm =
   331     fun inst thm =
   360       let
   332       let
   361         val tvars = Term.add_tvars (Thm.prop_of thm) [];
   333         val tvars = Term.add_tvars (Thm.prop_of thm) [];