src/HOL/Matrix_LP/Compute_Oracle/compute.ML
changeset 77869 1156aa9db7f5
parent 77863 760515c45864
child 78795 f7e972d567f3
equal deleted inserted replaced
77868:6ea0030b9ee9 77869:1156aa9db7f5
    15     exception Make of string
    15     exception Make of string
    16     val make : machine -> theory -> thm list -> computer
    16     val make : machine -> theory -> thm list -> computer
    17     val make_with_cache : machine -> theory -> term list -> thm list -> computer
    17     val make_with_cache : machine -> theory -> term list -> thm list -> computer
    18     val theory_of : computer -> theory
    18     val theory_of : computer -> theory
    19     val hyps_of : computer -> term list
    19     val hyps_of : computer -> term list
    20     val shyps_of : computer -> Sortset.T
    20     val shyps_of : computer -> sort list
    21     (* ! *) val update : computer -> thm list -> unit
    21     (* ! *) val update : computer -> thm list -> unit
    22     (* ! *) val update_with_cache : computer -> term list -> thm list -> unit
    22     (* ! *) val update_with_cache : computer -> term list -> thm list -> unit
    23     
    23     
    24     (* ! *) val set_naming : computer -> naming -> unit
    24     (* ! *) val set_naming : computer -> naming -> unit
    25     val naming_of : computer -> naming
    25     val naming_of : computer -> naming
   167 type naming = int -> string
   167 type naming = int -> string
   168 
   168 
   169 fun default_naming i = "v_" ^ string_of_int i
   169 fun default_naming i = "v_" ^ string_of_int i
   170 
   170 
   171 datatype computer = Computer of
   171 datatype computer = Computer of
   172   (theory * Encode.encoding * term list * Sortset.T * prog * unit Unsynchronized.ref * naming)
   172   (theory * Encode.encoding * term list * unit Sorttab.table * prog * unit Unsynchronized.ref * naming)
   173     option Unsynchronized.ref
   173     option Unsynchronized.ref
   174 
   174 
   175 fun theory_of (Computer (Unsynchronized.ref (SOME (thy,_,_,_,_,_,_)))) = thy
   175 fun theory_of (Computer (Unsynchronized.ref (SOME (thy,_,_,_,_,_,_)))) = thy
   176 fun hyps_of (Computer (Unsynchronized.ref (SOME (_,_,hyps,_,_,_,_)))) = hyps
   176 fun hyps_of (Computer (Unsynchronized.ref (SOME (_,_,hyps,_,_,_,_)))) = hyps
   177 fun shyps_of (Computer (Unsynchronized.ref (SOME (_,_,_,shypset,_,_,_)))) = shypset
   177 fun shyps_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = Sorttab.keys (shyptable)
       
   178 fun shyptab_of (Computer (Unsynchronized.ref (SOME (_,_,_,shyptable,_,_,_)))) = shyptable
   178 fun stamp_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,stamp,_)))) = stamp
   179 fun stamp_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,_,stamp,_)))) = stamp
   179 fun prog_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,prog,_,_)))) = prog
   180 fun prog_of (Computer (Unsynchronized.ref (SOME (_,_,_,_,prog,_,_)))) = prog
   180 fun encoding_of (Computer (Unsynchronized.ref (SOME (_,encoding,_,_,_,_,_)))) = encoding
   181 fun encoding_of (Computer (Unsynchronized.ref (SOME (_,encoding,_,_,_,_,_)))) = encoding
   181 fun set_encoding (Computer (r as Unsynchronized.ref (SOME (p1,_,p2,p3,p4,p5,p6)))) encoding' = 
   182 fun set_encoding (Computer (r as Unsynchronized.ref (SOME (p1,_,p2,p3,p4,p5,p6)))) encoding' = 
   182     (r := SOME (p1,encoding',p2,p3,p4,p5,p6))
   183     (r := SOME (p1,encoding',p2,p3,p4,p5,p6))
   185     (r := SOME (p1,p2,p3,p4,p5,p6,naming'))
   186     (r := SOME (p1,p2,p3,p4,p5,p6,naming'))
   186 
   187 
   187 fun ref_of (Computer r) = r
   188 fun ref_of (Computer r) = r
   188 
   189 
   189 
   190 
   190 datatype cthm = ComputeThm of term list * Sortset.T * term
   191 datatype cthm = ComputeThm of term list * sort list * term
   191 
   192 
   192 fun thm2cthm th = 
   193 fun thm2cthm th = 
   193     (if not (null (Thm.tpairs_of th)) then raise Make "theorems may not contain tpairs" else ();
   194     (if not (null (Thm.tpairs_of th)) then raise Make "theorems may not contain tpairs" else ();
   194      ComputeThm (Thm.hyps_of th, Thm.shyps_of th, Thm.prop_of th))
   195      ComputeThm (Thm.hyps_of th, Thm.shyps_of th, Thm.prop_of th))
   195 
   196 
   217                     raise (Make "patterns may not start with a variable")
   218                     raise (Make "patterns may not start with a variable")
   218                   | AbstractMachine.PConst (c, args) =>
   219                   | AbstractMachine.PConst (c, args) =>
   219                     (n, vars, AbstractMachine.PConst (c, args@[pb]))
   220                     (n, vars, AbstractMachine.PConst (c, args@[pb]))
   220             end
   221             end
   221 
   222 
   222         fun thm2rule (encoding, hyptable, shypset) th =
   223         fun thm2rule (encoding, hyptable, shyptable) th =
   223             let
   224             let
   224                 val (ComputeThm (hyps, shyps, prop)) = th
   225                 val (ComputeThm (hyps, shyps, prop)) = th
   225                 val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable
   226                 val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable
   226                 val shypset = Sortset.merge (shyps, shypset)
   227                 val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable
   227                 val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop)
   228                 val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop)
   228                 val (a, b) = Logic.dest_equals prop
   229                 val (a, b) = Logic.dest_equals prop
   229                   handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)")
   230                   handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)")
   230                 val a = Envir.eta_contract a
   231                 val a = Envir.eta_contract a
   231                 val b = Envir.eta_contract b
   232                 val b = Envir.eta_contract b
   267                     AbstractMachine.Abs (rename (level+1) vars m)
   268                     AbstractMachine.Abs (rename (level+1) vars m)
   268                     
   269                     
   269                 fun rename_guard (AbstractMachine.Guard (a,b)) = 
   270                 fun rename_guard (AbstractMachine.Guard (a,b)) = 
   270                     AbstractMachine.Guard (rename 0 vars a, rename 0 vars b)
   271                     AbstractMachine.Guard (rename 0 vars a, rename 0 vars b)
   271             in
   272             in
   272                 ((encoding, hyptable, shypset), (map rename_guard prems, pattern, rename 0 vars right))
   273                 ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right))
   273             end
   274             end
   274 
   275 
   275         val ((encoding, hyptable, shypset), rules) =
   276         val ((encoding, hyptable, shyptable), rules) =
   276           fold_rev (fn th => fn (encoding_hyptable, rules) =>
   277           fold_rev (fn th => fn (encoding_hyptable, rules) =>
   277             let
   278             let
   278               val (encoding_hyptable, rule) = thm2rule encoding_hyptable th
   279               val (encoding_hyptable, rule) = thm2rule encoding_hyptable th
   279             in (encoding_hyptable, rule::rules) end)
   280             in (encoding_hyptable, rule::rules) end)
   280           ths ((encoding, Termtab.empty, Sortset.empty), [])
   281           ths ((encoding, Termtab.empty, Sorttab.empty), [])
   281 
   282 
   282         fun make_cache_pattern t (encoding, cache_patterns) =
   283         fun make_cache_pattern t (encoding, cache_patterns) =
   283             let
   284             let
   284                 val (encoding, a) = remove_types encoding t
   285                 val (encoding, a) = remove_types encoding t
   285                 val (_,_,p) = make_pattern encoding 0 Inttab.empty a
   286                 val (_,_,p) = make_pattern encoding 0 Inttab.empty a
   296               | HASKELL => ProgHaskell (AM_GHC.compile rules)
   297               | HASKELL => ProgHaskell (AM_GHC.compile rules)
   297               | SML => ProgSML (AM_SML.compile rules)
   298               | SML => ProgSML (AM_SML.compile rules)
   298 
   299 
   299         fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
   300         fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
   300 
   301 
   301         val shypset = Sortset.fold (fn s => has_witness s ? Sortset.remove s) shypset shypset
   302         val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable
   302 
   303 
   303     in (thy, encoding, Termtab.keys hyptable, shypset, prog, stamp, default_naming) end
   304     in (thy, encoding, Termtab.keys hyptable, shyptable, prog, stamp, default_naming) end
   304 
   305 
   305 fun make_with_cache machine thy cache_patterns raw_thms =
   306 fun make_with_cache machine thy cache_patterns raw_thms =
   306   Computer (Unsynchronized.ref (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty cache_patterns raw_thms)))
   307   Computer (Unsynchronized.ref (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty cache_patterns raw_thms)))
   307 
   308 
   308 fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms
   309 fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms
   332     fun add hyps tab = fold (fn h => fn tab => Termtab.update (h, ()) tab) hyps tab
   333     fun add hyps tab = fold (fn h => fn tab => Termtab.update (h, ()) tab) hyps tab
   333 in
   334 in
   334     Termtab.keys (add hyps2 (add hyps1 Termtab.empty))
   335     Termtab.keys (add hyps2 (add hyps1 Termtab.empty))
   335 end
   336 end
   336 
   337 
       
   338 fun add_shyps shyps tab = fold (fn h => fn tab => Sorttab.update (h, ()) tab) shyps tab
       
   339 
       
   340 fun merge_shyps shyps1 shyps2 = Sorttab.keys (add_shyps shyps2 (add_shyps shyps1 Sorttab.empty))
       
   341 
   337 val (_, export_oracle) = Context.>>> (Context.map_theory_result
   342 val (_, export_oracle) = Context.>>> (Context.map_theory_result
   338   (Thm.add_oracle (\<^binding>\<open>compute\<close>, fn (thy, hyps, shyps, prop) =>
   343   (Thm.add_oracle (\<^binding>\<open>compute\<close>, fn (thy, hyps, shyps, prop) =>
   339     let
   344     let
   340         fun remove_term t = Sortset.subtract (Sortset.build (Sortset.insert_term t))
   345         val shyptab = add_shyps shyps Sorttab.empty
       
   346         fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab
       
   347         fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab
   341         fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
   348         fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
   342         val shyps = Sortset.fold (fn s => has_witness s ? Sortset.remove s) shyps shyps
   349         val shyptab = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptab))) shyptab
   343         val shyps =
   350         val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (prop::hyps) shyptab)
   344           if Sortset.is_empty shyps then Sortset.empty
       
   345           else fold remove_term (prop::hyps) shyps
       
   346         val _ =
   351         val _ =
   347           if not (Sortset.is_empty shyps) then
   352           if not (null shyps) then
   348             raise Compute ("dangling sort hypotheses: " ^
   353             raise Compute ("dangling sort hypotheses: " ^
   349               commas (map (Syntax.string_of_sort_global thy) (Sortset.dest shyps)))
   354               commas (map (Syntax.string_of_sort_global thy) shyps))
   350           else ()
   355           else ()
   351     in
   356     in
   352         Thm.global_cterm_of thy (fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps prop)
   357         Thm.global_cterm_of thy (fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps prop)
   353     end)));
   358     end)));
   354 
   359 
   372         val (encoding, t) = remove_types (encoding_of computer) t'
   377         val (encoding, t) = remove_types (encoding_of computer) t'
   373         val t = runprog (prog_of computer) t
   378         val t = runprog (prog_of computer) t
   374         val t = infer_types naming encoding ty t
   379         val t = infer_types naming encoding ty t
   375         val eq = Logic.mk_equals (t', t)
   380         val eq = Logic.mk_equals (t', t)
   376     in
   381     in
   377         export_thm thy (hyps_of computer) (shyps_of computer) eq
   382         export_thm thy (hyps_of computer) (Sorttab.keys (shyptab_of computer)) eq
   378     end
   383     end
   379 
   384 
   380 (* --------- Simplify ------------ *)
   385 (* --------- Simplify ------------ *)
   381 
   386 
   382 datatype prem = EqPrem of AbstractMachine.term * AbstractMachine.term * Term.typ * int 
   387 datatype prem = EqPrem of AbstractMachine.term * AbstractMachine.term * Term.typ * int 
   383               | Prem of AbstractMachine.term
   388               | Prem of AbstractMachine.term
   384 datatype theorem = Theorem of theory * unit Unsynchronized.ref * (int * typ) Symtab.table * (AbstractMachine.term option) Inttab.table  
   389 datatype theorem = Theorem of theory * unit Unsynchronized.ref * (int * typ) Symtab.table * (AbstractMachine.term option) Inttab.table  
   385                * prem list * AbstractMachine.term * term list * Sortset.T
   390                * prem list * AbstractMachine.term * term list * sort list
   386 
   391 
   387 
   392 
   388 exception ParamSimplify of computer * theorem
   393 exception ParamSimplify of computer * theorem
   389 
   394 
   390 fun make_theorem computer raw_th vars =
   395 fun make_theorem computer raw_th vars =
   612       | SOME varsubst =>
   617       | SOME varsubst =>
   613         let
   618         let
   614             val th = update_varsubst varsubst th
   619             val th = update_varsubst varsubst th
   615             val th = update_prems (splicein prem_no (prems_of_theorem th') prems) th
   620             val th = update_prems (splicein prem_no (prems_of_theorem th') prems) th
   616             val th = update_hyps (merge_hyps (hyps_of_theorem th) (hyps_of_theorem th')) th
   621             val th = update_hyps (merge_hyps (hyps_of_theorem th) (hyps_of_theorem th')) th
   617             val th = update_shyps (Sortset.merge (shyps_of_theorem th, shyps_of_theorem th')) th
   622             val th = update_shyps (merge_shyps (shyps_of_theorem th) (shyps_of_theorem th')) th
   618         in
   623         in
   619             update_theory thy th
   624             update_theory thy th
   620         end
   625         end
   621 end
   626 end
   622                      
   627                      
   629     fun infer t = infer_types naming encoding \<^typ>\<open>prop\<close> t
   634     fun infer t = infer_types naming encoding \<^typ>\<open>prop\<close> t
   630     fun run t = infer (runprog (prog_of computer) (apply_subst true varsubst t))
   635     fun run t = infer (runprog (prog_of computer) (apply_subst true varsubst t))
   631     fun runprem p = run (prem2term p)
   636     fun runprem p = run (prem2term p)
   632     val prop = Logic.list_implies (map runprem (prems_of_theorem th), run (concl_of_theorem th))
   637     val prop = Logic.list_implies (map runprem (prems_of_theorem th), run (concl_of_theorem th))
   633     val hyps = merge_hyps (hyps_of computer) (hyps_of_theorem th)
   638     val hyps = merge_hyps (hyps_of computer) (hyps_of_theorem th)
   634     val shyps = Sortset.merge (shyps_of_theorem th, shyps_of computer)
   639     val shyps = merge_shyps (shyps_of_theorem th) (Sorttab.keys (shyptab_of computer))
   635 in
   640 in
   636     export_thm (theory_of_theorem th) hyps shyps prop
   641     export_thm (theory_of_theorem th) hyps shyps prop
   637 end
   642 end
   638 
   643 
   639 end
   644 end
       
   645