src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33147 180dc60bd88c
parent 33146 bf852ef586f2
child 33148 0808f7d0d0d7
equal deleted inserted replaced
33146:bf852ef586f2 33147:180dc60bd88c
   436 fun check_expected_modes (options : Predicate_Compile_Aux.options) modes =
   436 fun check_expected_modes (options : Predicate_Compile_Aux.options) modes =
   437   case expected_modes options of
   437   case expected_modes options of
   438     SOME (s, ms) => (case AList.lookup (op =) modes s of
   438     SOME (s, ms) => (case AList.lookup (op =) modes s of
   439       SOME modes =>
   439       SOME modes =>
   440         if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
   440         if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
   441           error ("expected modes were not inferred:"
   441           error ("expected modes were not inferred:\n"
   442             ^ "infered modes for " ^ s ^ ": " ^ commas (map (string_of_smode o snd) modes))
   442           ^ "inferred modes for " ^ s ^ ": "
       
   443           ^ commas (map ((enclose "[" "]") o string_of_smode o snd) modes))
   443         else ()
   444         else ()
   444       | NONE => ())
   445       | NONE => ())
   445   | NONE => ()
   446   | NONE => ()
   446 
   447 
   447 (* importing introduction rules *)
   448 (* importing introduction rules *)
  1163         val (t', nvs') = distinct_v t nvs;
  1164         val (t', nvs') = distinct_v t nvs;
  1164         val (u', nvs'') = distinct_v u nvs';
  1165         val (u', nvs'') = distinct_v u nvs';
  1165       in (t' $ u', nvs'') end
  1166       in (t' $ u', nvs'') end
  1166   | distinct_v x nvs = (x, nvs);
  1167   | distinct_v x nvs = (x, nvs);
  1167 
  1168 
  1168 fun compile_match thy compfuns eqs eqs' out_ts success_t =
       
  1169   let
       
  1170     val eqs'' = maps mk_eq eqs @ eqs'
       
  1171     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
       
  1172     val name = Name.variant names "x";
       
  1173     val name' = Name.variant (name :: names) "y";
       
  1174     val T = mk_tupleT (map fastype_of out_ts);
       
  1175     val U = fastype_of success_t;
       
  1176     val U' = dest_predT compfuns U;
       
  1177     val v = Free (name, T);
       
  1178     val v' = Free (name', T);
       
  1179   in
       
  1180     lambda v (fst (Datatype.make_case
       
  1181       (ProofContext.init thy) DatatypeCase.Quiet [] v
       
  1182       [(mk_tuple out_ts,
       
  1183         if null eqs'' then success_t
       
  1184         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
       
  1185           foldr1 HOLogic.mk_conj eqs'' $ success_t $
       
  1186             mk_bot compfuns U'),
       
  1187        (v', mk_bot compfuns U')]))
       
  1188   end;
       
  1189 
       
  1190 (*FIXME function can be removed*)
       
  1191 fun mk_funcomp f t =
       
  1192   let
       
  1193     val names = Term.add_free_names t [];
       
  1194     val Ts = binder_types (fastype_of t);
       
  1195     val vs = map Free
       
  1196       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
       
  1197   in
       
  1198     fold_rev lambda vs (f (list_comb (t, vs)))
       
  1199   end;
       
  1200 
       
  1201 fun compile_param compilation_modifiers compfuns thy (NONE, t) = t
       
  1202   | compile_param compilation_modifiers compfuns thy (m as SOME (Mode (mode, _, ms)), t) =
       
  1203    let
       
  1204      val (f, args) = strip_comb (Envir.eta_contract t)
       
  1205      val (params, args') = chop (length ms) args
       
  1206      val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
       
  1207      val f' =
       
  1208        case f of
       
  1209          Const (name, T) => Const (#const_name_of compilation_modifiers thy name mode,
       
  1210            #funT_of compilation_modifiers compfuns mode T)
       
  1211        | Free (name, T) => Free (name, #funT_of compilation_modifiers compfuns mode T)
       
  1212        | _ => error ("PredicateCompiler: illegal parameter term")
       
  1213    in
       
  1214      list_comb (f', params' @ args')
       
  1215    end
       
  1216 
       
  1217 fun compile_expr compilation_modifiers compfuns thy ((Mode (mode, _, ms)), t) inargs additional_arguments =
       
  1218   case strip_comb t of
       
  1219     (Const (name, T), params) =>
       
  1220        let
       
  1221          val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
       
  1222            (*val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of*)
       
  1223          val name' = #const_name_of compilation_modifiers thy name mode
       
  1224          val T' = #funT_of compilation_modifiers compfuns mode T
       
  1225        in
       
  1226          (list_comb (Const (name', T'), params' @ inargs @ additional_arguments))
       
  1227        end
       
  1228   | (Free (name, T), params) =>
       
  1229     list_comb (Free (name, #funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
       
  1230 
       
  1231 (** specific rpred functions -- move them to the correct place in this file *)
  1169 (** specific rpred functions -- move them to the correct place in this file *)
  1232 
  1170 
  1233 fun mk_Eval_of depth ((x, T), NONE) names = (x, names)
  1171 fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
  1234   | mk_Eval_of depth ((x, T), SOME mode) names =
  1172   | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
  1235 	let
  1173 	let
  1236     val Ts = binder_types T
  1174     val Ts = binder_types T
  1237     (*val argnames = Name.variant_list names
  1175     (*val argnames = Name.variant_list names
  1238         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
  1176         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
  1239     val args = map Free (argnames ~~ Ts)
  1177     val args = map Free (argnames ~~ Ts)
  1274 							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
  1212 							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
  1275 						in ((tuple inargs, tuple outargs), args) end
  1213 						in ((tuple inargs, tuple outargs), args) end
  1276 			end
  1214 			end
  1277 		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
  1215 		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
  1278     val (inargs, outargs) = pairself flat (split_list inoutargs)
  1216     val (inargs, outargs) = pairself flat (split_list inoutargs)
  1279     val depth_t = case depth of NONE => [] | SOME (polarity, depth_t) => [polarity, depth_t]
  1217 		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ additional_arguments), mk_tuple outargs)
  1280 		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ depth_t), mk_tuple outargs)
       
  1281     val t = fold_rev mk_split_lambda args r
  1218     val t = fold_rev mk_split_lambda args r
  1282   in
  1219   in
  1283     (t, names)
  1220     (t, names)
  1284   end;
  1221   end;
  1285 
  1222 
  1286 fun compile_arg depth thy param_vs iss arg = 
  1223 fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg = 
  1287   let
  1224   let
  1288     val funT_of = case depth of NONE => funT_of | SOME _ => depth_limited_funT_of
       
  1289     fun map_params (t as Free (f, T)) =
  1225     fun map_params (t as Free (f, T)) =
  1290       if member (op =) param_vs f then
  1226       if member (op =) param_vs f then
  1291         case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
  1227         case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
  1292           SOME is => let val T' = funT_of PredicateCompFuns.compfuns ([], is) T
  1228           SOME is =>
  1293             in fst (mk_Eval_of depth ((Free (f, T'), T), SOME is) []) end
  1229             let
       
  1230               val T' = #funT_of compilation_modifiers compfuns ([], is) T
       
  1231             in fst (mk_Eval_of additional_arguments ((Free (f, T'), T), SOME is) []) end
  1294         | NONE => t
  1232         | NONE => t
  1295       else t
  1233       else t
  1296       | map_params t = t
  1234       | map_params t = t
  1297     in map_aterms map_params arg end
  1235     in map_aterms map_params arg end
  1298   
  1236 
       
  1237 fun compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy eqs eqs' out_ts success_t =
       
  1238   let
       
  1239     val eqs'' = maps mk_eq eqs @ eqs'
       
  1240     val eqs'' =
       
  1241       map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs''
       
  1242     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
       
  1243     val name = Name.variant names "x";
       
  1244     val name' = Name.variant (name :: names) "y";
       
  1245     val T = mk_tupleT (map fastype_of out_ts);
       
  1246     val U = fastype_of success_t;
       
  1247     val U' = dest_predT compfuns U;
       
  1248     val v = Free (name, T);
       
  1249     val v' = Free (name', T);
       
  1250   in
       
  1251     lambda v (fst (Datatype.make_case
       
  1252       (ProofContext.init thy) DatatypeCase.Quiet [] v
       
  1253       [(mk_tuple out_ts,
       
  1254         if null eqs'' then success_t
       
  1255         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
       
  1256           foldr1 HOLogic.mk_conj eqs'' $ success_t $
       
  1257             mk_bot compfuns U'),
       
  1258        (v', mk_bot compfuns U')]))
       
  1259   end;
       
  1260 
       
  1261 (*FIXME function can be removed*)
       
  1262 fun mk_funcomp f t =
       
  1263   let
       
  1264     val names = Term.add_free_names t [];
       
  1265     val Ts = binder_types (fastype_of t);
       
  1266     val vs = map Free
       
  1267       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
       
  1268   in
       
  1269     fold_rev lambda vs (f (list_comb (t, vs)))
       
  1270   end;
       
  1271 
       
  1272 fun compile_param compilation_modifiers compfuns thy (NONE, t) = t
       
  1273   | compile_param compilation_modifiers compfuns thy (m as SOME (Mode (mode, _, ms)), t) =
       
  1274    let
       
  1275      val (f, args) = strip_comb (Envir.eta_contract t)
       
  1276      val (params, args') = chop (length ms) args
       
  1277      val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
       
  1278      val f' =
       
  1279        case f of
       
  1280          Const (name, T) => Const (#const_name_of compilation_modifiers thy name mode,
       
  1281            #funT_of compilation_modifiers compfuns mode T)
       
  1282        | Free (name, T) => Free (name, #funT_of compilation_modifiers compfuns mode T)
       
  1283        | _ => error ("PredicateCompiler: illegal parameter term")
       
  1284    in
       
  1285      list_comb (f', params' @ args')
       
  1286    end
       
  1287 
       
  1288 fun compile_expr compilation_modifiers compfuns thy ((Mode (mode, _, ms)), t) inargs additional_arguments =
       
  1289   case strip_comb t of
       
  1290     (Const (name, T), params) =>
       
  1291        let
       
  1292          val params' = map (compile_param compilation_modifiers compfuns thy) (ms ~~ params)
       
  1293            (*val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of*)
       
  1294          val name' = #const_name_of compilation_modifiers thy name mode
       
  1295          val T' = #funT_of compilation_modifiers compfuns mode T
       
  1296        in
       
  1297          (list_comb (Const (name', T'), params' @ inargs @ additional_arguments))
       
  1298        end
       
  1299   | (Free (name, T), params) =>
       
  1300     list_comb (Free (name, #funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
       
  1301 
  1299 fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments (iss, is) inp (ts, moded_ps) =
  1302 fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments (iss, is) inp (ts, moded_ps) =
  1300   let
  1303   let
       
  1304     val compile_match = compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy
  1301     fun check_constrt t (names, eqs) =
  1305     fun check_constrt t (names, eqs) =
  1302       if is_constrt thy t then (t, (names, eqs)) else
  1306       if is_constrt thy t then (t, (names, eqs)) else
  1303         let
  1307         let
  1304           val s = Name.variant names "x"
  1308           val s = Name.variant names "x"
  1305           val v = Free (s, fastype_of t)
  1309           val v = Free (s, fastype_of t)
  1314             val (out_ts'', (names', eqs')) =
  1318             val (out_ts'', (names', eqs')) =
  1315               fold_map check_constrt out_ts' (names, []);
  1319               fold_map check_constrt out_ts' (names, []);
  1316             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
  1320             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
  1317               out_ts'' (names', map (rpair []) vs);
  1321               out_ts'' (names', map (rpair []) vs);
  1318           in
  1322           in
  1319             compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
  1323             compile_match constr_vs (eqs @ eqs') out_ts'''
  1320               (mk_single compfuns (mk_tuple out_ts))
  1324               (mk_single compfuns (mk_tuple out_ts))
  1321           end
  1325           end
  1322       | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
  1326       | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
  1323           let
  1327           let
  1324             val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
  1328             val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
  1330               #transform_additional_arguments compilation_modifiers p additional_arguments
  1334               #transform_additional_arguments compilation_modifiers p additional_arguments
  1331             val (compiled_clause, rest) = case p of
  1335             val (compiled_clause, rest) = case p of
  1332                Prem (us, t) =>
  1336                Prem (us, t) =>
  1333                  let
  1337                  let
  1334                    val (in_ts, out_ts''') = split_smode is us;
  1338                    val (in_ts, out_ts''') = split_smode is us;
  1335                      (* TODO: add test case for compile_arg *)
  1339                    val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
  1336                    (*val in_ts = map (compile_arg depth thy param_vs iss) in_ts*)
  1340                      thy param_vs iss) in_ts
  1337                      (* additional_arguments
       
  1338                    val args = case depth of
       
  1339                      NONE => in_ts
       
  1340                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
       
  1341                    *)
       
  1342                    val u =
  1341                    val u =
  1343                      compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments'
  1342                      compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments'
  1344                    val rest = compile_prems out_ts''' vs' names'' ps
  1343                    val rest = compile_prems out_ts''' vs' names'' ps
  1345                  in
  1344                  in
  1346                    (u, rest)
  1345                    (u, rest)
  1347                  end
  1346                  end
  1348              | Negprem (us, t) =>
  1347              | Negprem (us, t) =>
  1349                  let
  1348                  let
  1350                    val (in_ts, out_ts''') = split_smode is us
  1349                    val (in_ts, out_ts''') = split_smode is us
  1351                      (* additional_arguments
  1350                    val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
  1352                    val args = case depth of
  1351                      thy param_vs iss) in_ts
  1353                      NONE => in_ts
       
  1354                    | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
       
  1355                    *)
       
  1356                    val u = mk_not compfuns
  1352                    val u = mk_not compfuns
  1357                      (compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments')
  1353                      (compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments')
  1358                    val rest = compile_prems out_ts''' vs' names'' ps
  1354                    val rest = compile_prems out_ts''' vs' names'' ps
  1359                  in
  1355                  in
  1360                    (u, rest)
  1356                    (u, rest)
  1361                  end
  1357                  end
  1362              | Sidecond t =>
  1358              | Sidecond t =>
  1363                  let
  1359                  let
       
  1360                    val t = compile_arg compilation_modifiers compfuns additional_arguments
       
  1361                      thy param_vs iss t
  1364                    val rest = compile_prems [] vs' names'' ps;
  1362                    val rest = compile_prems [] vs' names'' ps;
  1365                  in
  1363                  in
  1366                    (mk_if compfuns t, rest)
  1364                    (mk_if compfuns t, rest)
  1367                  end
  1365                  end
  1368              | Generator (v, T) =>
  1366              | Generator (v, T) =>
  1372                    val rest = compile_prems [Free (v, T)]  vs' names'' ps;
  1370                    val rest = compile_prems [Free (v, T)]  vs' names'' ps;
  1373                  in
  1371                  in
  1374                    (u, rest)
  1372                    (u, rest)
  1375                  end
  1373                  end
  1376           in
  1374           in
  1377             compile_match thy compfuns constr_vs' eqs out_ts''
  1375             compile_match constr_vs' eqs out_ts''
  1378               (mk_bind compfuns (compiled_clause, rest))
  1376               (mk_bind compfuns (compiled_clause, rest))
  1379           end
  1377           end
  1380     val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
  1378     val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
  1381   in
  1379   in
  1382     mk_bind compfuns (mk_single compfuns inp, prem_t)
  1380     mk_bind compfuns (mk_single compfuns inp, prem_t)
  1461 	val (args, argnames) = fold_map mk_args (1 upto (length Ts2) ~~ Ts2) []
  1459 	val (args, argnames) = fold_map mk_args (1 upto (length Ts2) ~~ Ts2) []
  1462   val (inargs, outargs) = split_smode is args
  1460   val (inargs, outargs) = split_smode is args
  1463   val param_names' = Name.variant_list (param_names @ argnames)
  1461   val param_names' = Name.variant_list (param_names @ argnames)
  1464     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
  1462     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
  1465   val param_vs = map Free (param_names' ~~ Ts1)
  1463   val param_vs = map Free (param_names' ~~ Ts1)
  1466   val (params', names) = fold_map (mk_Eval_of NONE) ((params ~~ Ts1) ~~ iss) []
  1464   val (params', names) = fold_map (mk_Eval_of []) ((params ~~ Ts1) ~~ iss) []
  1467   val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
  1465   val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
  1468   val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
  1466   val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
  1469   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
  1467   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
  1470   val funargs = params @ inargs
  1468   val funargs = params @ inargs
  1471   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
  1469   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
  1564 						 in (((if null Tins then [] else [xin], if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
  1562 						 in (((if null Tins then [] else [xin], if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
  1565 						 end
  1563 						 end
  1566    	  val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
  1564    	  val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
  1567       val (xinout, xargs) = split_list xinoutargs
  1565       val (xinout, xargs) = split_list xinoutargs
  1568 			val (xins, xouts) = pairself flat (split_list xinout)
  1566 			val (xins, xouts) = pairself flat (split_list xinout)
  1569 			val (xparams', names') = fold_map (mk_Eval_of NONE) ((xparams ~~ Ts1) ~~ iss) names
  1567 			val (xparams', names') = fold_map (mk_Eval_of []) ((xparams ~~ Ts1) ~~ iss) names
  1570       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
  1568       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
  1571         | mk_split_lambda [x] t = lambda x t
  1569         | mk_split_lambda [x] t = lambda x t
  1572         | mk_split_lambda xs t =
  1570         | mk_split_lambda xs t =
  1573         let
  1571         let
  1574           fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
  1572           fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))