# HG changeset patch # User haftmann # Date 1485552423 -3600 # Node ID 9ca021bd718da3280a1b860fd7efbcf70b949b51 # Parent 85b87da722ab6ec6f3f8086c00bddca18f672338 ML antiquotation for generated computations diff -r 85b87da722ab -r 9ca021bd718d src/HOL/ROOT --- a/src/HOL/ROOT Thu Jan 26 16:25:32 2017 +0100 +++ b/src/HOL/ROOT Fri Jan 27 22:27:03 2017 +0100 @@ -536,6 +536,7 @@ "~~/src/HOL/Library/Refute" "~~/src/HOL/Library/Transitive_Closure_Table" Cartouche_Examples + Computations theories Commands Adhoc_Overloading_Examples diff -r 85b87da722ab -r 9ca021bd718d src/HOL/ex/Computations.thy --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/ex/Computations.thy Fri Jan 27 22:27:03 2017 +0100 @@ -0,0 +1,63 @@ +(* Title: HOL/ex/Computations.thy + Author: Florian Haftmann, TU Muenchen +*) + +section \Simple example for computations generated by the code generator\ + +theory Computations + imports Nat Fun_Def Num Code_Numeral +begin + +fun even :: "nat \ bool" + where "even 0 \ True" + | "even (Suc 0) \ False" + | "even (Suc (Suc n)) \ even n" + +fun fib :: "nat \ nat" + where "fib 0 = 0" + | "fib (Suc 0) = Suc 0" + | "fib (Suc (Suc n)) = fib (Suc n) + fib n" + +declare [[ML_source_trace]] + +ML \ +local + +fun int_of_nat @{code "0 :: nat"} = 0 + | int_of_nat (@{code Suc} n) = int_of_nat n + 1; + +in + +val comp_nat = @{computation "0 :: nat" Suc + "plus :: nat \_" "times :: nat \ _" fib :: nat} + (fn post => post o HOLogic.mk_nat o int_of_nat); + +val comp_numeral = @{computation "0 :: nat" "1 :: nat" "2 :: nat" "3 :: nat" :: nat} + (fn post => post o HOLogic.mk_nat o int_of_nat); + +val comp_bool = @{computation True False HOL.conj HOL.disj HOL.implies + HOL.iff even "less_eq :: nat \ _" "less :: nat \ _" "HOL.eq :: nat \ _" :: bool } + (K I) + +end +\ + +declare [[ML_source_trace = false]] + +ML_val \ + comp_nat @{context} @{term "fib (Suc (Suc (Suc 0)) * Suc (Suc (Suc 0))) + Suc 0"} + |> Syntax.string_of_term @{context} + |> writeln +\ + +ML_val \ + comp_bool @{context} @{term "fib (Suc (Suc (Suc 0)) * Suc (Suc (Suc 0))) + Suc 0 < fib (Suc (Suc 0))"} +\ + +ML_val \ + comp_numeral @{context} @{term "Suc 42 + 7"} + |> Syntax.string_of_term @{context} + |> writeln +\ + +end diff -r 85b87da722ab -r 9ca021bd718d src/Tools/Code/code_runtime.ML --- a/src/Tools/Code/code_runtime.ML Thu Jan 26 16:25:32 2017 +0100 +++ b/src/Tools/Code/code_runtime.ML Fri Jan 27 22:27:03 2017 +0100 @@ -28,10 +28,8 @@ -> Proof.context -> term -> 'a Exn.result val dynamic_holds_conv: Proof.context -> conv val static_holds_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv - val experimental_computation: (term -> 'a) cookie - -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a, - terms: term list, T: typ } - -> Proof.context -> term -> 'a (*EXPERIMENTAL!*) + val mount_computation: Proof.context -> (string * typ) list -> typ + -> (term -> 'b) -> ((term -> term) -> 'b -> 'a) -> Proof.context -> term -> 'a val code_reflect: (string * string list option) list -> string list -> string -> string option -> theory -> theory datatype truth = Holds @@ -49,9 +47,9 @@ (* technical prerequisites *) -val this = "Code_Runtime"; -val s_truth = Long_Name.append this "truth"; -val s_Holds = Long_Name.append this "Holds"; +val thisN = "Code_Runtime"; +val truthN = Long_Name.append thisN "truth"; +val HoldsN = Long_Name.append thisN "Holds"; val target = "Eval"; @@ -60,10 +58,10 @@ val _ = Theory.setup (Code_Target.add_derived_target (target, [(Code_ML.target_SML, I)]) #> Code_Target.set_printings (Type_Constructor (@{type_name prop}, - [(target, SOME (0, (K o K o K) (Code_Printer.str s_truth)))])) + [(target, SOME (0, (K o K o K) (Code_Printer.str truthN)))])) #> Code_Target.set_printings (Constant (@{const_name Code_Generator.holds}, - [(target, SOME (Code_Printer.plain_const_syntax s_Holds))])) - #> Code_Target.add_reserved target this + [(target, SOME (Code_Printer.plain_const_syntax HoldsN))])) + #> Code_Target.add_reserved target thisN #> fold (Code_Target.add_reserved target) ["oo", "ooo", "oooo", "upto", "downto", "orf", "andf"]); (*avoid further pervasive infix names*) @@ -153,7 +151,7 @@ fun init _ = empty; ); val put_truth = Truth_Result.put; -val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append this "put_truth"); +val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append thisN "put_truth"); local @@ -231,6 +229,8 @@ in +val covered_constsN = "covered_consts"; + fun of_term_for_typ Ts = let val names = Ts @@ -277,6 +277,8 @@ (* code generation for of the universal morphism *) +val print_const = ML_Syntax.print_pair ML_Syntax.print_string ML_Syntax.print_typ; + fun print_of_term_funs { typ_sign_for, eval_for_const, of_term_for_typ } Ts = let val var_names = map_range (fn n => "t" ^ string_of_int (n + 1)); @@ -305,42 +307,47 @@ |> prefix "fun " end; -fun print_computation_code ctxt compiled_value requested_Ts cTs = - let - val proper_cTs = map_filter I cTs; - val typ_sign_for = typ_signatures proper_cTs; - fun add_typ T Ts = - if member (op =) Ts T - then Ts - else Ts - |> cons T - |> fold (fold add_typ o snd) (typ_sign_for T); - val required_Ts = fold add_typ requested_Ts []; - val of_term_for_typ' = of_term_for_typ required_Ts; - val eval_for_const' = eval_for_const ctxt proper_cTs; - val eval_for_const'' = the_default "_" o Option.map eval_for_const'; - val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs)); - fun mk_abs s = "fn " ^ s ^ " => "; - val eval_abs = space_implode "" - (map (mk_abs o eval_for_const'') cTs); - val of_term_code = print_of_term_funs { - typ_sign_for = typ_sign_for, - eval_for_const = eval_for_const', - of_term_for_typ = of_term_for_typ' } required_Ts; - in - (cat_lines [ - "structure " ^ generated_computationN ^ " =", - "struct", - "", - "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()", - " (" ^ eval_abs, - " " ^ eval_tuple ^ ");", - "", - of_term_code, - "", - "end" - ], map (prefix (generated_computationN ^ ".") o of_term_for_typ') requested_Ts) - end; +fun print_computation_code ctxt compiled_value [] requested_Ts = ("", []) + | print_computation_code ctxt compiled_value cTs requested_Ts = + let + val proper_cTs = map_filter I cTs; + val typ_sign_for = typ_signatures proper_cTs; + fun add_typ T Ts = + if member (op =) Ts T + then Ts + else Ts + |> cons T + |> fold (fold add_typ o snd) (typ_sign_for T); + val required_Ts = fold add_typ requested_Ts []; + val of_term_for_typ' = of_term_for_typ required_Ts; + val eval_for_const' = eval_for_const ctxt proper_cTs; + val eval_for_const'' = the_default "_" o Option.map eval_for_const'; + val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs)); + fun mk_abs s = "fn " ^ s ^ " => "; + val eval_abs = space_implode "" + (map (mk_abs o eval_for_const'') cTs); + val of_term_code = print_of_term_funs { + typ_sign_for = typ_sign_for, + eval_for_const = eval_for_const', + of_term_for_typ = of_term_for_typ' } required_Ts; + val of_term_names = map (Long_Name.append generated_computationN + o of_term_for_typ') requested_Ts; + in + cat_lines [ + "structure " ^ generated_computationN ^ " =", + "struct", + "", + "val " ^ covered_constsN ^ " = " ^ ML_Syntax.print_list print_const proper_cTs ^ ";", + "", + "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()", + " (" ^ eval_abs, + " " ^ eval_tuple ^ ");", + "", + of_term_code, + "", + "end" + ] |> rpair of_term_names + end; fun mount_computation ctxt cTs T raw_computation lift_postproc = Code_Preproc.static_value { ctxt = ctxt, lift_postproc = lift_postproc, consts = [] } @@ -350,48 +357,12 @@ #> check_computation_input ctxt' cTs #> raw_computation); -fun compile_computation cookie ctxt T program evals vs_ty_evals deps = - let - fun the_const (Const cT) = cT - | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t) - val raw_cTs = case evals of - Abs (_, _, t) => (map the_const o snd o strip_comb) t - | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals); - val cTs = fold_rev (fn cT => fn cTs => - (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_cTs []; - fun comp' vs_ty_evals = - let - val (generated_code, compiled_value) = - build_compilation_text ctxt NONE program deps vs_ty_evals; - val (of_term_code, [of_term]) = print_computation_code ctxt compiled_value [T] cTs; - in - (generated_code ^ "\n" ^ of_term_code, - enclose "(" ")" ("fn () => " ^ of_term)) - end; - val compiled_computation = - Exn.release (run_compilation_text cookie ctxt comp' vs_ty_evals []); - in (map_filter I cTs, compiled_computation) end; + +(** variants of universal runtime code generation **) -fun experimental_computation cookie { ctxt, lift_postproc, terms = ts, T } = - let - val _ = if not (monomorphic T) - then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T) - else (); - val cTs = (fold o fold_aterms) - (fn (t as Const (cT as (_, T))) => - if not (monomorphic T) then error ("Polymorphic constant:" ^ Syntax.string_of_term ctxt t) - else insert (op =) cT | _ => I) ts []; - val evals = Abs ("eval", map snd cTs ---> TFree (Name.aT, []), list_comb (Bound 0, map Const cTs)); - val (cTs, raw_computation) = Code_Thingol.dynamic_value ctxt - (K I) (compile_computation cookie ctxt T) evals; - in - mount_computation ctxt cTs T raw_computation lift_postproc - end; +(*FIXME consolidate variants*) - -(** code antiquotation **) - -fun runtime_code ctxt module_name program tycos consts = +fun runtime_code'' ctxt module_name program tycos consts = let val thy = Proof_Context.theory_of ctxt; val (ml_modules, target_names) = @@ -411,63 +382,159 @@ | SOME tyco' => (tyco, tyco')) tycos tycos'; in (ml_code, (tycos_map, consts_map)) end; +fun runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps = + let + val thy = Proof_Context.theory_of ctxt; + fun the_const (Const cT) = cT + | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t) + val raw_computation_cTs = case evals of + Abs (_, _, t) => (map the_const o snd o strip_comb) t + | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals); + val computation_cTs = fold_rev (fn cT => fn cTs => + (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_computation_cTs []; + val consts' = fold (fn NONE => I | SOME (c, _) => insert (op =) c) + computation_cTs named_consts; + val program' = Code_Thingol.consts_program ctxt consts'; + (*FIXME insufficient interfaces require double invocation of code generator*) + val ((ml_modules, compiled_value), deresolve) = + Code_Target.compilation_text' ctxt target some_module_name program' + (map Code_Symbol.Type_Constructor named_tycos @ map Code_Symbol.Constant consts' @ deps) true vs_ty_evals; + (*FIXME constrain signature*) + fun deresolve_const c = case (deresolve o Code_Symbol.Constant) c of + NONE => error ("Constant " ^ (quote o Code.string_of_const thy) c ^ + "\nhas a user-defined serialization") + | SOME c' => c'; + fun deresolve_tyco tyco = case (deresolve o Code_Symbol.Type_Constructor) tyco of + NONE => error ("Type " ^ quote (Proof_Context.markup_type ctxt tyco) ^ + "\nhas a user-defined serialization") + | SOME c' => c'; + val tyco_names = map deresolve_const named_tycos; + val const_names = map deresolve_const named_consts; + val generated_code = space_implode "\n\n" (map snd ml_modules); + val (of_term_code, of_term_names) = + print_computation_code ctxt compiled_value computation_cTs computation_Ts; + val compiled_computation = generated_code ^ "\n" ^ of_term_code; + in + compiled_computation + |> rpair { tyco_map = named_tycos ~~ tyco_names, + const_map = named_consts ~~ const_names, + of_term_map = computation_Ts ~~ of_term_names } + end; + +fun funs_of_maps { tyco_map, const_map, of_term_map } = + { name_for_tyco = the o AList.lookup (op =) tyco_map, + name_for_const = the o AList.lookup (op =) const_map, + of_term_for_typ = the o AList.lookup (op =) of_term_map }; + +fun runtime_code ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps = + runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps + ||> funs_of_maps; + + +(** code and computation antiquotations **) + +val mount_computationN = Long_Name.append thisN "mount_computation"; + local structure Code_Antiq_Data = Proof_Data ( type T = { named_consts: string list, - first_occurrence: bool, - generated_code: { - code: string, - name_for_const: string -> string - } lazy + computation_Ts: typ list, computation_cTs: (string * typ) list, + position_index: int, + generated_code: (string * { + name_for_tyco: string -> string, + name_for_const: string -> string, + of_term_for_typ: typ -> string + }) lazy }; val empty: T = { named_consts = [], - first_occurrence = true, - generated_code = Lazy.value { - code = "", - name_for_const = I - } + computation_Ts = [], computation_cTs = [], + position_index = 0, + generated_code = Lazy.lazy (fn () => raise Fail "empty") }; fun init _ = empty; ); -val is_first_occurrence = #first_occurrence o Code_Antiq_Data.get; +val current_position_index = #position_index o Code_Antiq_Data.get; -fun lazy_code ctxt consts = Lazy.lazy (fn () => +fun register { named_consts, computation_Ts, computation_cTs } ctxt = let - val program = Code_Thingol.consts_program ctxt consts; - val (code, (_, consts_map)) = - runtime_code ctxt Code_Target.generatedN program [] consts - in { code = code, name_for_const = the o AList.lookup (op =) consts_map } end); - -fun register_const const ctxt = - let - val consts = insert (op =) const ((#named_consts o Code_Antiq_Data.get) ctxt); + val data = Code_Antiq_Data.get ctxt; + val named_consts' = union (op =) named_consts (#named_consts data); + val computation_Ts' = union (op =) computation_Ts (#computation_Ts data); + val computation_cTs' = union (op =) computation_cTs (#computation_cTs data); + val position_index' = #position_index data + 1; + fun generated_code' () = + let + val evals = Abs ("eval", map snd computation_cTs' ---> + TFree (Name.aT, []), list_comb (Bound 0, map Const computation_cTs')); + in Code_Thingol.dynamic_value ctxt + (K I) (runtime_code ctxt NONE [] named_consts' computation_Ts') evals + end; in ctxt |> Code_Antiq_Data.put { - named_consts = consts, - first_occurrence = false, - generated_code = lazy_code ctxt consts + named_consts = named_consts', + computation_Ts = computation_Ts', + computation_cTs = computation_cTs', + position_index = position_index', + generated_code = Lazy.lazy generated_code' } end; -fun print_code is_first_occ const ctxt = +fun register_const const = + register { named_consts = [const], + computation_Ts = [], + computation_cTs = [] }; + +fun register_computation cTs T = + register { named_consts = [], + computation_Ts = [T], + computation_cTs = cTs }; + +fun print body_code_for ctxt ctxt' = let - val { code, name_for_const } = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt; - val context_code = if is_first_occ then code else ""; - val body_code = ML_Context.struct_name ctxt ^ "." ^ name_for_const const; + val position_index = current_position_index ctxt; + val (code, name_ofs) = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt'; + val context_code = if position_index = 0 then code else ""; + val body_code = body_code_for name_ofs (ML_Context.struct_name ctxt') position_index; in (context_code, body_code) end; +fun print_code ctxt const = + print (fn { name_for_const, ... } => fn prfx => fn _ => + Long_Name.append prfx (name_for_const const)) ctxt; + +fun print_computation ctxt T = + print (fn { of_term_for_typ, ... } => fn prfx => fn _ => + space_implode " " [ + mount_computationN, + "(Context.proof_of (Context.the_generic_context ()))", + Long_Name.implode [prfx, generated_computationN, covered_constsN], + (ML_Syntax.atomic o ML_Syntax.print_typ) T, + Long_Name.append prfx (of_term_for_typ T) + ]) ctxt; + in fun ml_code_antiq raw_const ctxt = let val thy = Proof_Context.theory_of ctxt; val const = Code.check_const thy raw_const; - val is_first = is_first_occurrence ctxt; - in (print_code is_first const, register_const const ctxt) end; + in (print_code ctxt const, register_const const ctxt) end; + +fun ml_computation_antiq (raw_ts, raw_T) ctxt = + let + val ts = map (Syntax.check_term ctxt) raw_ts; + val T = Syntax.check_typ ctxt raw_T; + val _ = if not (monomorphic T) + then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T) + else (); + val cTs = (fold o fold_aterms) + (fn (t as Const (cT as (_, T))) => + if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t) + else insert (op =) cT | _ => I) ts []; + in (print_computation ctxt T, register_computation cTs T ctxt) end; end; (*local*) @@ -548,7 +615,7 @@ val functions = map (prep_const thy) raw_functions; val consts = constrs @ functions; val program = Code_Thingol.consts_program ctxt consts; - val result = runtime_code ctxt module_name program tycos consts + val result = runtime_code'' ctxt module_name program tycos consts |> (apsnd o apsnd) (chop (length constrs)); in thy @@ -562,7 +629,12 @@ (** Isar setup **) val _ = - Theory.setup (ML_Antiquotation.declaration @{binding code} Args.term (fn _ => ml_code_antiq)); + Theory.setup (ML_Antiquotation.declaration @{binding code} + Args.term (fn _ => ml_code_antiq)); + +val _ = + Theory.setup (ML_Antiquotation.declaration @{binding computation} + (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ) (fn _ => ml_computation_antiq)); local diff -r 85b87da722ab -r 9ca021bd718d src/Tools/Code/code_target.ML --- a/src/Tools/Code/code_target.ML Thu Jan 26 16:25:32 2017 +0100 +++ b/src/Tools/Code/code_target.ML Fri Jan 27 22:27:03 2017 +0100 @@ -31,6 +31,9 @@ val compilation_text: Proof.context -> string -> Code_Thingol.program -> Code_Symbol.T list -> bool -> ((string * class list) list * Code_Thingol.itype) * Code_Thingol.iterm -> (string * string) list * string + val compilation_text': Proof.context -> string -> string option -> Code_Thingol.program + -> Code_Symbol.T list -> bool -> ((string * class list) list * Code_Thingol.itype) * Code_Thingol.iterm + -> ((string * string) list * string) * (Code_Symbol.T -> string option) type serializer type literals = Code_Printer.literals @@ -414,17 +417,20 @@ val (program_code, deresolve) = produce (mounted_serializer program (if all_public then [] else [Code_Symbol.value])); val value_name = the (deresolve Code_Symbol.value); - in (program_code, value_name) end; + in ((program_code, value_name), deresolve) end; -fun compilation_text ctxt target_name program syms = +fun compilation_text' ctxt target_name some_module_name program syms = let val (mounted_serializer, (_, prepared_program)) = - mount_serializer ctxt target_name NONE generatedN [] program syms; + mount_serializer ctxt target_name NONE (the_default generatedN some_module_name) [] program syms; in Code_Preproc.timed_exec "serializing" (fn () => dynamic_compilation_text mounted_serializer prepared_program syms) ctxt end; +fun compilation_text ctxt target_name program syms = + fst oo compilation_text' ctxt target_name NONE program syms + end; (* local *)