src/HOL/Tools/Predicate_Compile/code_prolog.ML
changeset 38728 182b180e9804
parent 38727 c7f5f0b7dc7f
child 38729 9c9d14827380
--- a/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Wed Aug 25 16:59:46 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Wed Aug 25 16:59:48 2010 +0200
@@ -56,6 +56,9 @@
 datatype prol_term = Var of string | Cons of string | AppF of string * prol_term list
   | Number of int | ArithOp of arith_op * prol_term list;
 
+fun maybe_AppF (c, []) = Cons c
+  | maybe_AppF (c, xs) = AppF (c, xs)
+
 fun is_Var (Var _) = true
   | is_Var _ = false
 
@@ -136,7 +139,7 @@
         val l' = translate_term ctxt constant_table l
         val r' = translate_term ctxt constant_table r
       in
-        (if is_Var l' andalso is_arith_term r' then ArithEq else Eq) (l', r')
+        (if is_Var l' andalso is_arith_term r' andalso not (is_Var r') then ArithEq else Eq) (l', r')
       end
   | (Const (c, _), args) =>
       Rel (translate_const constant_table c, map (translate_term ctxt constant_table) args)
@@ -204,32 +207,49 @@
   | add_ground_typ (Ground (_, T)) = insert (op =) T
   | add_ground_typ _ = I
 
-fun mk_ground_impl ctxt (Type (Tcon, [])) constant_table =
-  let
-    fun mk_impl (constr_name, T) constant_table =
-      if binder_types T = [] then
+fun mk_relname (Type (Tcon, Targs)) =
+  first_lower (Long_Name.base_name Tcon) ^ space_implode "_" (map mk_relname Targs)
+  | mk_relname _ = raise Fail "unexpected type"
+
+(* This is copied from "pat_completeness.ML" *)
+fun inst_constrs_of thy (T as Type (name, _)) =
+  map (fn (Cn,CT) =>
+    Envir.subst_term_types (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
+    (the (Datatype.get_constrs thy name))
+  | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], [])
+  
+fun mk_ground_impl ctxt (T as Type (Tcon, Targs)) (seen, constant_table) =
+  if member (op =) seen T then ([], (seen, constant_table))
+  else
+    let
+      val rel_name = mk_relname T
+      fun mk_impl (Const (constr_name, T)) (seen, constant_table) =
         let
           val constant_table' = declare_consts [constr_name] constant_table
-          val clause = (("is_" ^ first_lower (Long_Name.base_name Tcon),
-            [Cons (translate_const constant_table' constr_name)]), Conj [])
+          val (rec_clauses, (seen', constant_table'')) =
+            fold_map (mk_ground_impl ctxt) (binder_types T) (seen, constant_table')
+          val vars = map (fn i => Var ("x" ^ string_of_int i)) (1 upto (length (binder_types T)))    
+          fun mk_prem v T = Rel (mk_relname T, [v])
+          val clause =
+            ((rel_name, [maybe_AppF (translate_const constant_table'' constr_name, vars)]),
+             Conj (map2 mk_prem vars (binder_types T)))
         in
-          (clause, constant_table')
+          (clause :: flat rec_clauses, (seen', constant_table''))
         end
-        else raise Fail "constructor with arguments" 
-    val constrs = the (Datatype.get_constrs (ProofContext.theory_of ctxt) Tcon)
-  in fold_map mk_impl constrs constant_table end
-  | mk_ground_impl ctxt (Type (Tcon, _)) constant_table =
-    raise Fail "type constructor with type arguments"
-  
+      val constrs = inst_constrs_of (ProofContext.theory_of ctxt) T
+    in apfst flat (fold_map mk_impl constrs (T :: seen, constant_table)) end
+ | mk_ground_impl ctxt T (seen, constant_table) =
+   raise Fail ("unexpected type :" ^ Syntax.string_of_typ ctxt T)
+
 fun replace_ground (Conj prems) = Conj (map replace_ground prems)
-  | replace_ground (Ground (x, Type (Tcon, []))) =
-    Rel ("is_" ^ first_lower (Long_Name.base_name Tcon), [Var x])  
+  | replace_ground (Ground (x, T)) =
+    Rel (mk_relname T, [Var x])  
   | replace_ground p = p
   
 fun add_ground_predicates ctxt (p, constant_table) =
   let
     val ground_typs = fold (add_ground_typ o snd) p []
-    val (grs, constant_table') = fold_map (mk_ground_impl ctxt) ground_typs constant_table
+    val (grs, (_, constant_table')) = fold_map (mk_ground_impl ctxt) ground_typs ([], constant_table)
     val p' = map (apsnd replace_ground) p
   in
     ((flat grs) @ p', constant_table')
@@ -290,7 +310,7 @@
   Scan.many1 Symbol.is_ascii_digit
 
 val scan_atom =
-  Scan.many1 (fn s => Symbol.is_ascii_lower s orelse Symbol.is_ascii_quasi s)
+  Scan.many1 (fn s => Symbol.is_ascii_lower s orelse Symbol.is_ascii_digit s orelse Symbol.is_ascii_quasi s)
 
 val scan_var =
   Scan.many1