src/Provers/Arith/assoc_fold.ML
changeset 13462 56610e2ba220
parent 12262 11ff5f47df6e
child 13480 bb72bd43c6c3
equal deleted inserted replaced
13461:f93f7d766895 13462:56610e2ba220
     9 *)
     9 *)
    10 
    10 
    11 
    11 
    12 signature ASSOC_FOLD_DATA =
    12 signature ASSOC_FOLD_DATA =
    13 sig
    13 sig
    14   val ss		: simpset	(*basic simpset of object-logtic*)
    14   val ss                : simpset       (*basic simpset of object-logtic*)
    15   val eq_reflection	: thm		(*object-equality to meta-equality*)
    15   val eq_reflection     : thm           (*object-equality to meta-equality*)
    16   val sg_ref 		: Sign.sg_ref	(*the operator's signature*)
    16   val sg_ref            : Sign.sg_ref   (*the operator's signature*)
    17   val T			: typ		(*the operator's numeric type*)
    17   val T                 : typ           (*the operator's numeric type*)
    18   val plus		: term		(*the operator being folded*)
    18   val plus              : term          (*the operator being folded*)
    19   val add_ac		: thm list      (*AC-rewrites for plus*)
    19   val add_ac            : thm list      (*AC-rewrites for plus*)
    20 end;
    20 end;
    21 
    21 
    22 
    22 
    23 functor Assoc_Fold (Data: ASSOC_FOLD_DATA) =
    23 functor Assoc_Fold (Data: ASSOC_FOLD_DATA) =
    24 struct
    24 struct
    25 
    25 
    26  val assoc_ss = Data.ss addsimps Data.add_ac;
    26  val assoc_ss = Data.ss addsimps Data.add_ac;
    27 
    27 
    28  (*prove while suppressing timing information*)
    28  (*prove while suppressing timing information*)
    29  fun prove name ct tacf = 
    29  fun prove name ct tacf =
    30      setmp Library.timing false (prove_goalw_cterm [] ct) tacf
    30      setmp Library.timing false (prove_goalw_cterm [] ct) tacf
    31      handle ERROR =>
    31      handle ERROR =>
    32 	 error(name ^ " simproc:\nfailed to prove " ^ string_of_cterm ct);
    32          error(name ^ " simproc:\nfailed to prove " ^ string_of_cterm ct);
    33                 
    33 
    34  exception Assoc_fail;
    34  exception Assoc_fail;
    35 
    35 
    36  fun mk_sum []  = raise Assoc_fail
    36  fun mk_sum []  = raise Assoc_fail
    37    | mk_sum tms = foldr1 (fn (x,y) => Data.plus $ x $ y) tms;
    37    | mk_sum tms = foldr1 (fn (x,y) => Data.plus $ x $ y) tms;
    38 
    38 
    39  (*Separate the literals from the other terms being combined*)
    39  (*Separate the literals from the other terms being combined*)
    40  fun sift_terms (t, (lits,others)) =
    40  fun sift_terms (t, (lits,others)) =
    41      case t of
    41      case t of
    42 	  Const("Numeral.number_of", _) $ _ =>
    42           Const("Numeral.number_of", _) $ _ =>
    43 	      (t::lits, others)         (*new literal*)
    43               (t::lits, others)         (*new literal*)
    44 	| (f as Const _) $ x $ y =>
    44         | (f as Const _) $ x $ y =>
    45 	      if f = Data.plus 
    45               if f = Data.plus
    46               then sift_terms (x, sift_terms (y, (lits,others)))
    46               then sift_terms (x, sift_terms (y, (lits,others)))
    47 	      else (lits, t::others)    (*arbitrary summand*)
    47               else (lits, t::others)    (*arbitrary summand*)
    48 	| _ => (lits, t::others);
    48         | _ => (lits, t::others);
    49 
    49 
    50  val trace = ref false;
    50  val trace = ref false;
    51 
    51 
    52  (*Make a simproc to combine all literals in a associative nest*)
    52  (*Make a simproc to combine all literals in a associative nest*)
    53  fun proc sg _ lhs =
    53  fun proc sg _ lhs =
    54    let fun show t = string_of_cterm (Thm.cterm_of sg t)
    54    let fun show t = string_of_cterm (Thm.cterm_of sg t)
    55        val _ = if !trace then tracing ("assoc_fold simproc: LHS = " ^ show lhs)
    55        val _ = if !trace then tracing ("assoc_fold simproc: LHS = " ^ show lhs)
    56 	       else ()
    56                else ()
    57        val (lits,others) = sift_terms (lhs, ([],[]))
    57        val (lits,others) = sift_terms (lhs, ([],[]))
    58        val _ = if length lits < 2
    58        val _ = if length lits < 2
    59                then raise Assoc_fail (*we can't reduce the number of terms*)
    59                then raise Assoc_fail (*we can't reduce the number of terms*)
    60                else ()  
    60                else ()
    61        val rhs = Data.plus $ mk_sum lits $ mk_sum others
    61        val rhs = Data.plus $ mk_sum lits $ mk_sum others
    62        val _ = if !trace then tracing ("RHS = " ^ show rhs) else ()
    62        val _ = if !trace then tracing ("RHS = " ^ show rhs) else ()
    63        val th = prove "assoc_fold" 
    63        val th = prove "assoc_fold"
    64 	           (Thm.cterm_of sg (Logic.mk_equals (lhs, rhs)))
    64                    (Thm.cterm_of sg (Logic.mk_equals (lhs, rhs)))
    65 		   (fn _ => [rtac Data.eq_reflection 1,
    65                    (fn _ => [rtac Data.eq_reflection 1,
    66 			     simp_tac assoc_ss 1])
    66                              simp_tac assoc_ss 1])
    67    in Some th end
    67    in Some th end
    68    handle Assoc_fail => None;
    68    handle Assoc_fail => None;
    69  
    69 
    70  val conv = 
    70  val conv =
    71      Simplifier.mk_simproc "assoc_fold"
    71      Simplifier.simproc_i (Sign.deref Data.sg_ref) "assoc_fold"
    72        [Thm.cterm_of (Sign.deref Data.sg_ref)
    72        [Data.plus $ Free ("x", Data.T) $ Free ("y",Data.T)] proc;
    73 	             (Data.plus $ Free("x",Data.T) $ Free("y",Data.T))]
       
    74        proc;
       
    75 
    73 
    76 end;
    74 end;
    77 
    75 
    78 
    76 
    79 (*test data:
    77 (*test data: