325 and of_univ bounds (Const (idx, ts)) typidx = |
325 and of_univ bounds (Const (idx, ts)) typidx = |
326 let |
326 let |
327 val ts' = take_until is_dict ts; |
327 val ts' = take_until is_dict ts; |
328 val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx; |
328 val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx; |
329 val T = Code.default_typ thy c; |
329 val T = Code.default_typ thy c; |
330 val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; |
330 val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T; |
331 val typidx' = typidx + maxidx_of_typ T' + 1; |
331 val typidx' = typidx + maxidx_of_typ T' + 1; |
332 in of_apps bounds (Term.Const (c, T'), ts') typidx' end |
332 in of_apps bounds (Term.Const (c, T'), ts') typidx' end |
333 | of_univ bounds (Free (name, ts)) typidx = |
333 | of_univ bounds (Free (name, ts)) typidx = |
334 of_apps bounds (Term.Free (name, dummyT), ts) typidx |
334 of_apps bounds (Term.Free (name, dummyT), ts) typidx |
335 | of_univ bounds (BVar (name, ts)) typidx = |
335 | of_univ bounds (BVar (name, ts)) typidx = |
371 vs_ty_t |
371 vs_ty_t |
372 |> eval_term gr deps |
372 |> eval_term gr deps |
373 |> term_of_univ thy idx_tab |
373 |> term_of_univ thy idx_tab |
374 end; |
374 end; |
375 |
375 |
|
376 (* trivial type classes *) |
|
377 |
|
378 structure Nbe_Triv_Classes = TheoryDataFun |
|
379 ( |
|
380 type T = class list * (string * string) list; |
|
381 val empty = ([], []); |
|
382 val copy = I; |
|
383 val extend = I; |
|
384 fun merge _ ((classes1, consts1), (classes2, consts2)) = |
|
385 (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2)); |
|
386 ) |
|
387 |
|
388 fun add_triv_classes thy = |
|
389 let |
|
390 val (trivs, _) = Nbe_Triv_Classes.get thy; |
|
391 val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs; |
|
392 fun map_sorts f = (map_types o map_atyps) |
|
393 (fn TVar (v, sort) => TVar (v, f sort) |
|
394 | TFree (v, sort) => TFree (v, f sort)); |
|
395 in map_sorts inters end; |
|
396 |
|
397 fun subst_triv_consts thy = |
|
398 let |
|
399 fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c |
|
400 of SOME c' => Term.Const (c', ty) |
|
401 | NONE => t) |
|
402 | t => t); |
|
403 val (_, consts) = Nbe_Triv_Classes.get thy; |
|
404 val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy); |
|
405 in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end; |
|
406 |
376 (* evaluation with type reconstruction *) |
407 (* evaluation with type reconstruction *) |
377 |
408 |
378 fun eval thy code t vs_ty_t deps = |
409 fun eval thy t code vs_ty_t deps = |
379 let |
410 let |
380 val ty = type_of t; |
411 val ty = type_of t; |
381 fun subst_Frees [] = I |
412 val type_free = AList.lookup (op =) |
382 | subst_Frees inst = |
413 (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])); |
383 Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) |
414 val type_frees = Term.map_aterms |
384 | t => t); |
415 (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t); |
385 val anno_vars = |
416 fun type_infer t = [(t, ty)] |
386 subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) |
417 |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I |
387 #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) |
418 (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE) |
388 fun constrain t = |
419 Name.context 0 NONE |
389 singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t); |
420 |> fst |> the_single |> fst; |
390 fun check_tvars t = if null (Term.term_tvars t) then t else |
421 fun check_tvars t = if null (Term.term_tvars t) then t else |
391 error ("Illegal schematic type variables in normalized term: " |
422 error ("Illegal schematic type variables in normalized term: " |
392 ^ setmp show_types true (Sign.string_of_term thy) t); |
423 ^ setmp show_types true (Sign.string_of_term thy) t); |
393 val string_of_term = setmp show_types true (Sign.string_of_term thy); |
424 val string_of_term = setmp show_types true (Sign.string_of_term thy); |
394 in |
425 in |
395 compile_eval thy code vs_ty_t deps |
426 compile_eval thy code vs_ty_t deps |
396 |> tracing (fn t => "Normalized:\n" ^ string_of_term t) |
427 |> tracing (fn t => "Normalized:\n" ^ string_of_term t) |
397 |> anno_vars |
428 |> subst_triv_consts thy |
|
429 |> type_frees |
398 |> tracing (fn t => "Vars typed:\n" ^ string_of_term t) |
430 |> tracing (fn t => "Vars typed:\n" ^ string_of_term t) |
399 |> constrain |
431 |> type_infer |
400 |> tracing (fn t => "Types inferred:\n" ^ string_of_term t) |
432 |> tracing (fn t => "Types inferred:\n" ^ string_of_term t) |
|
433 |> check_tvars |
401 |> tracing (fn t => "---\n") |
434 |> tracing (fn t => "---\n") |
402 |> check_tvars |
|
403 end; |
435 end; |
404 |
436 |
405 (* evaluation oracle *) |
437 (* evaluation oracle *) |
406 |
438 |
407 exception Norm of CodeThingol.code * term |
439 exception Norm of term * CodeThingol.code |
408 * (CodeThingol.typscheme * CodeThingol.iterm) * string list; |
440 * (CodeThingol.typscheme * CodeThingol.iterm) * string list; |
409 |
441 |
410 fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = |
442 fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) = |
411 Logic.mk_equals (t, eval thy code t vs_ty_t deps); |
443 Logic.mk_equals (t, eval thy t code vs_ty_t deps); |
412 |
444 |
413 fun norm_invoke thy code t vs_ty_t deps = |
445 fun norm_invoke thy t code vs_ty_t deps = |
414 Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps)); |
446 Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps)); |
415 (*FIXME get rid of hardwired theory name*) |
447 (*FIXME get rid of hardwired theory name*) |
416 |
448 |
417 fun norm_conv ct = |
449 fun norm_conv ct = |
418 let |
450 let |
419 val thy = Thm.theory_of_cterm ct; |
451 val thy = Thm.theory_of_cterm ct; |
420 fun conv code vs_ty_t deps ct = |
452 fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps; |
421 let |
453 fun evaluator t = (add_triv_classes thy t, evaluator' t); |
422 val t = Thm.term_of ct; |
454 in CodePackage.evaluate_conv thy evaluator ct end; |
423 in norm_invoke thy code t vs_ty_t deps end; |
455 |
424 in CodePackage.evaluate_conv thy conv ct end; |
456 fun norm_term thy t = |
425 |
457 let |
426 fun norm_term thy = |
458 fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps; |
427 let |
459 fun evaluator t = (add_triv_classes thy t, evaluator' t); |
428 fun invoke code vs_ty_t deps t = |
460 in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end; |
429 eval thy code t vs_ty_t deps; |
|
430 in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end; |
|
431 |
461 |
432 (* evaluation command *) |
462 (* evaluation command *) |
433 |
463 |
434 fun norm_print_term ctxt modes t = |
464 fun norm_print_term ctxt modes t = |
435 let |
465 let |