pass constructor arity as part of case certficiate
authorhaftmann
Sat, 02 Apr 2022 17:03:35 +0000
changeset 75399 cdf84288d93c
parent 75398 a58718427bff
child 75400 970b9ab6c439
pass constructor arity as part of case certficiate
src/Pure/Isar/code.ML
src/Tools/Code/code_thingol.ML
--- a/src/Pure/Isar/code.ML	Sat Apr 02 17:03:34 2022 +0000
+++ b/src/Pure/Isar/code.ML	Sat Apr 02 17:03:35 2022 +0000
@@ -231,14 +231,14 @@
 
 (* cases *)
 
-type case_schema = int * (int * string option list);
+type case_schema = int * (int * (string * int) option list);
 
 datatype case_spec =
     No_Case
   | Case of {schema: case_schema, tycos: string list, cong: thm}
   | Undefined;
 
-fun associated_datatypes (Case {tycos, schema = (_, (_, raw_cos)), ...}) = (tycos, map_filter I raw_cos)
+fun associated_datatypes (Case {tycos, schema = (_, (_, raw_cos)), ...}) = (tycos, map fst (map_filter I raw_cos))
   | associated_datatypes _ = ([], []);
 
 
@@ -1235,7 +1235,7 @@
                       :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
       );
     fun pretty_case_param NONE = "<ignored>"
-      | pretty_case_param (SOME c) = string_of_const thy c
+      | pretty_case_param (SOME (c, _)) = string_of_const thy c
     fun pretty_case (const, Case {schema = (_, (_, [])), ...}) =
           Pretty.str (string_of_const thy const)
       | pretty_case (const, Case {schema = (_, (_, cos)), ...}) =
@@ -1511,7 +1511,8 @@
         [] => ()
       | cs => error ("Non-constructor(s) in case certificate: " ^ commas_quote cs);
     val tycos = distinct (op =) (map_filter snd cos_with_tycos);
-    val schema = (1 + Int.max (1, length cos), (k, cos));
+    val schema = (1 + Int.max (1, length cos),
+      (k, (map o Option.map) (fn c => (c, args_number thy c)) cos));
     val cong = case_cong thy case_const schema;
   in
     thy
--- a/src/Tools/Code/code_thingol.ML	Sat Apr 02 17:03:34 2022 +0000
+++ b/src/Tools/Code/code_thingol.ML	Sat Apr 02 17:03:35 2022 +0000
@@ -483,7 +483,7 @@
 fun dest_tagged_type (Type ("", [T])) = (true, T)
   | dest_tagged_type T = (false, T);
 
-val untag_term = map_types (snd o dest_tagged_type);
+val fastype_of_tagged_term = fastype_of o map_types (snd o dest_tagged_type);
 
 fun tag_term (proj_sort, _) eqngr =
   let
@@ -722,14 +722,6 @@
   translate_const ctxt algbr eqngr permissive some_thm (c_ty, some_abs)
   ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
   #>> (fn (t, ts) => t `$$ ts)
-and translate_constr ctxt algbr eqngr permissive some_thm ty_case (c, t) =
-  let
-    val n = Code.args_number (Proof_Context.theory_of ctxt) c;
-    val ty = (untag_term #> fastype_of #> binder_types #> take n) t ---> ty_case;
-  in
-    translate_const ctxt algbr eqngr permissive some_thm ((c, ty), NONE)
-    #>> rpair n
-  end
 and translate_case ctxt algbr eqngr permissive some_thm (t_pos, []) (c_ty, ts) =
       let
         fun project_term xs = nth xs t_pos;
@@ -755,7 +747,9 @@
           |> curry (op ~~) case_pats
           |> map_filter (fn (NONE, _) => NONE | (SOME _, x) => SOME x);
         val ty_case = project_term (binder_types (snd c_ty));
-        val constrs = map_filter I case_pats ~~ project_cases ts;
+        val constrs = map_filter I case_pats ~~ project_cases ts
+          |> map (fn ((c, n), t) =>
+            ((c, (take n o binder_types o fastype_of_tagged_term) t ---> ty_case), n));
         fun distill_clauses constrs ts_clause =
           maps (fn ((constr as IConst { dom = tys, ... }, n), t) =>
             map (fn (pat_args, body) => (constr `$$ pat_args, body))
@@ -765,7 +759,8 @@
         translate_const ctxt algbr eqngr permissive some_thm (c_ty, NONE)
         ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
         ##>> translate_typ ctxt algbr eqngr permissive ty_case
-        ##>> fold_map (translate_constr ctxt algbr eqngr permissive some_thm ty_case) constrs
+        ##>> fold_map (fn (c_ty, n) =>
+          translate_const ctxt algbr eqngr permissive some_thm (c_ty, NONE) #>> rpair n) constrs
         #>> (fn (((t_app, ts), ty_case), constrs) =>
             ICase { term = project_term ts, typ = ty_case,
               clauses = (filter_out (is_undefined_clause ctxt) o distill_clauses constrs o project_cases) ts,