(* Hindley-Milner Type Inference. http://ian-grant.net/hm *) datatype term = Tyvar of string | Tyapp of string * (term list) fun subs [] term = term | subs ((t1,v1)::ss) (term as Tyvar name) = if name=v1 then t1 else subs ss term | subs _ (term as Tyapp(name,[])) = term | subs l (Tyapp(name,args)) = let fun arglist r [] = rev r | arglist r (h::t) = arglist ((subs l h)::r) t in Tyapp(name, arglist [] args) end fun compose [] s1 = s1 | compose (s::ss) s1 = let fun iter r s [] = rev r | iter r s ((t1,v1)::ss) = iter (((subs [s] t1),v1)::r) s ss in compose ss (s::(iter [] s s1)) end exception Unify of string fun unify t1 t2 = let fun iter r t1 t2 = let fun occurs v (Tyapp(name,[])) = false | occurs v (Tyapp(name,((Tyvar vn)::t))) = if vn=v then true else occurs v (Tyapp(name,t)) | occurs v (Tyapp(name,(s::t))) = occurs v s orelse occurs v (Tyapp(name,t)) | occurs v (Tyvar vn) = vn=v fun unify_args r [] [] = rev r | unify_args r [] _ = raise Unify "Arity" | unify_args r _ [] = raise Unify "Arity" | unify_args r (t1::t1s) (t2::t2s) = unify_args (compose (iter [] (subs r t1) (subs r t2)) r) t1s t2s in case (t1,t2) of (Tyvar v1,Tyvar v2) => if (v1 = v2) then [] else ((t1, v2)::r) | (Tyvar v,Tyapp(_,[])) => ((t2, v)::r) | (Tyapp(_,[]),Tyvar v) => ((t1, v)::r) | (Tyvar v,Tyapp _) => if occurs v t2 then raise Unify "Occurs" else ((t2, v)::r) | (Tyapp _,Tyvar v) => if occurs v t1 then raise Unify "Occurs" else ((t1, v)::r) | (Tyapp(name1,args1),Tyapp(name2,args2)) => if (name1=name2) then unify_args r args1 args2 else raise Unify "Const" end in iter [] t1 t2 end datatype typescheme = Forall of string * typescheme | Type of term fun mem p l = let fun iter [] = false | iter (x::xs) = (x=p) orelse iter xs in iter l end fun fbtyvars f b (Tyvar v) = if (mem v b) then (f,b) else (v::f,b) | fbtyvars f b (Tyapp(name,args)) = let fun iter r [] = r | iter r (t::ts) = let val (f,b) = fbtyvars r b t in iter f ts end in let val fvs = iter f args in (fvs,b) end end fun fbtsvs f b (Forall(v,s)) = fbtsvs f (v::b) s | fbtsvs f b (Type t) = fbtyvars f b t fun tyvars t = let val (f,b) = fbtyvars [] [] t in f@b end fun varno "" = ~1 | varno v = let val vl = explode v val letter = (ord (hd vl)) - (ord #"a") fun primes r [] = r | primes r (h::t) = if h = #"\039" then primes (r+26) t else ~1 in if letter >= 0 andalso letter <= 25 then primes letter (tl vl) else ~1 end fun lastusedtsvar nv sigma = let val vars = let val (f,b) = fbtsvs [] [] sigma in f@b end fun iter r [] = r | iter r (h::t) = let val vn = varno h in if vn > r then iter vn t else iter r t end in (iter nv vars) end fun lastfreetsvar nv sigma = let val (vars,_) = fbtsvs [] [] sigma fun iter r [] = r | iter r (h::t) = let val vn = varno h in if vn > r then iter vn t else iter r t end in (iter nv vars) end fun newvar v = let val nv = v+1 fun prime v 0 = v | prime v n = prime (v^"'") (n-1) val primes = nv div 26 val var = str(chr ((ord #"a") + (nv mod 26))) in (nv,prime var primes) end fun tssubs nv [] sigma = (nv, sigma) | tssubs nv ((tvp as (t,v))::tvs) sigma = let val (fvs,_) = fbtyvars [] [] t fun iter nv rnss (tvp as (t,v)) (ts as (Forall(sv,sts))) = if (sv = v) then (nv, ts) else if mem sv fvs then let val (nv,newv) = newvar nv val (nv,sigma') = iter nv (compose [(Tyvar newv,sv)] rnss) tvp sts in (nv, Forall(newv,sigma')) end else let val (nv,sigma') = iter nv rnss tvp sts in (nv,Forall(sv,sigma')) end | iter nv rnss tvp (Type term) = (nv,(Type (subs [tvp] (subs rnss term)))) val (nv, sigma') = iter nv [] tvp sigma in tssubs nv tvs sigma' end exception Assum of string fun assq p l = let fun iter [] = raise Assum p | iter ((k,v)::xs) = if (k=p) then v else iter xs in iter l end fun fassumvars Gamma = let fun iter f [] = f | iter f ((_,ts)::Gamma') = let val (fvs,_) = fbtsvs f [] ts in iter (f@fvs) Gamma' end in iter [] Gamma end fun assumvars Gamma = let fun iter f [] = f | iter f ((_,ts)::Gamma') = let val (fvs,bvs) = fbtsvs f [] ts in iter (f@fvs@bvs) Gamma' end in iter [] Gamma end fun lastfreeassumvar Gamma = let fun iter r [] = r | iter r ((_,sigma)::Gamma') = iter (lastfreetsvar r sigma) Gamma' in iter ~1 Gamma end fun assumsubs nv S Gamma = let fun iter r nv S [] = (nv, rev r) | iter r nv S ((v,sigma)::Gamma') = let val (nv, sigma') = tssubs nv S sigma in iter ((v,sigma')::r) nv S Gamma' end in iter [] nv S Gamma end fun tsclosure Gamma tau = let val favs = fassumvars Gamma val (ftvs,_) = fbtyvars [] [] tau fun iter bvs [] = Type tau | iter bvs (v::vs) = if (mem v favs) orelse (mem v bvs) then iter bvs vs else Forall(v,iter (v::bvs) vs) in iter [] ftvs end datatype exp = Var of string | Comb of exp * exp | Abs of string * exp | Let of (string * exp) * exp infixr --> fun tau1 --> tau2 = Tyapp("%f",[tau1,tau2]) fun W nv Gamma e = case e of (Var v) => let fun tsinst nv (Type tau) = (nv, tau) | tsinst nv (Forall(alpha,sigma)) = let val (nv, beta) = newvar (lastusedtsvar nv sigma) val (nv, sigma') = (tssubs nv [(Tyvar beta,alpha)] sigma) in tsinst nv sigma' end val (nv, tau) = tsinst nv (assq v Gamma) in (nv, ([], tau)) end | (Comb(e1,e2)) => let val (nv, (S1,tau1)) = W nv Gamma e1 val (nv, S1Gamma) = assumsubs nv S1 Gamma val (nv, (S2,tau2)) = W nv S1Gamma e2 val S2tau1 = subs S2 tau1 val (nv,beta) = newvar nv val V = unify S2tau1 (tau2 --> Tyvar beta) val Vbeta = subs V (Tyvar beta) val VS2S1 = compose V (compose S2 S1) in (nv, (VS2S1, Vbeta)) end | (Abs(v,e)) => let val (nv,beta) = newvar nv val (nv,(S1,tau1)) = W nv ((v,Type (Tyvar beta))::Gamma) e val S1beta = subs S1 (Tyvar beta) in (nv, (S1,(S1beta --> tau1))) end | (Let((v,e1),e2)) => let val (nv, (S1,tau1)) = W nv Gamma e1 val (nv, S1Gamma) = assumsubs nv S1 Gamma val (nv, (S2,tau2)) = W nv ((v,tsclosure S1Gamma tau1)::S1Gamma) e2 val S2S1 = compose S2 S1 in (nv, (S2S1,tau2)) end fun principalts Gamma e = let val (var, (S, tau)) = W (lastfreeassumvar Gamma) Gamma e val (_,SGamma) = assumsubs var S Gamma in tsclosure SGamma tau end fun pptsterm tau = let fun iter prec (Tyvar name) = ""^name | iter prec (Tyapp(name,[])) = name | iter prec (Tyapp("%f",[a1,a2])) = let fun maybebracket s = if prec <= 10 then s else "("^s^")" in maybebracket ((iter 11 a1)^" -> "^(iter 10 a2)) end | iter prec (Tyapp(name,args)) = let fun arglist r [] = r | arglist r (h::t) = arglist (r^(iter 30 h)^(if t=[] then "" else ", ")) t in if (length args) > 1 then (arglist "(" args)^") "^name else (arglist "" args)^" "^name end in iter 10 tau end fun ppterm (Tyvar name) = name | ppterm (Tyapp(name,[])) = name | ppterm (Tyapp(name,args)) = let fun arglist r [] = r | arglist r (h::t) = arglist (r^(ppterm h)^(if t=[] then "" else ",")) t in name^(arglist "(" args)^")" end fun ppsubs s = let fun iter r [] = r^"]" | iter r ((term,var)::t) = iter (r^(ppterm term)^"/"^var^(if t=[] then "" else ",")) t in iter "[" s end fun ppexp e = let fun ppe r e = case e of (Var v) => r^v | (Comb(e1,e2)) => r^"("^(ppe "" e1)^" "^(ppe "" e2)^")" | (Abs(v,e)) => r^"(\\"^v^"."^(ppe "" e)^")" | (Let((v,e1),e2)) => r^"let "^v^"="^(ppe "" e1)^" in "^(ppe "" e2) in ppe "" e end fun ppts sigma = let fun iter r (Forall(sv,sts)) = iter (r^"!"^sv^".") sts | iter r (Type term) = r^(pptsterm term) in iter "" sigma end fun ppassums Gamma = let fun iter r [] = r | iter r ((v,ts)::assums) = iter (r^v^":"^(ppts ts)^(if assums=[] then "" else ",")) assums in iter "" Gamma end (* Examples *) (* Unification *) val x = Tyvar "x" val y = Tyvar "y" val z = Tyvar "z" fun apply s l = Tyapp(s,l) val a = apply "a" [] fun j(x, y, z) = apply "j" [x, y, z] fun f(x, y) = apply "f" [x, y] val t1 = j(x,y,z) val t2 = j(f(y,y), f(z,z), f(a,a)); ppterm t1; ppterm t2; val U = unify t1 t2; ppsubs U; print ((ppterm (subs U t1))^"\n"); print ((ppterm (subs U t2))^"\n"); (* Constructors for types *) fun mk_func name args = Tyapp(name,args) fun mk_nullary name = mk_func name [] fun mk_unary name arg = mk_func name [arg] fun mk_binary name arg1 arg2 = mk_func name [arg1, arg2] fun mk_ternary name arg1 arg2 arg3 = mk_func name [arg1, arg2, arg3] fun pairt t1 t2 = mk_binary "pair" t1 t2 fun listt t = mk_unary "list" t val boolt = mk_nullary "bool" (* Type variables *) val alpha = Tyvar "a" val beta = Tyvar "b" val alpha' = Tyvar "a'" val beta' = Tyvar "b'" (* Type-schemes *) fun mk_tyscheme [] t = Type t | mk_tyscheme ((Tyvar v)::vs) t = Forall (v, mk_tyscheme vs t) | mk_tyscheme _ _ = raise Fail "mk_tyscheme: Invalid type-scheme." (* Now we can construct type-schemes. For example here is a polymorphic function taking pairs of functions and two lists to a list of pairs: *) val dmapts = mk_tyscheme [alpha, alpha', beta, beta'] (pairt (alpha --> alpha') (beta --> beta') --> listt alpha --> listt beta --> pairt (listt alpha') (listt beta')); ppts dmapts; (* Lambda expressions with let bindings *) fun labs (Var v) e = Abs(v,e) | labs _ _ = raise Fail "labs: Invalid argument" fun llet (Var v) e1 e2 = Let((v,e1),e2) | llet _ _ _ = raise Fail "llet: Invalid argument" infix @: fun e1 @: e2 = Comb(e1,e2) fun lambda [] e = e | lambda (x::xs) e = labs x (lambda xs e) fun letbind [] e = e | letbind ((v,e1)::bs) e = llet v e1 (letbind bs e) fun lapply r [] = r | lapply r (e::es) = lapply (r @: e) es (* Variables *) val x = Var "x" val y = Var "y" val z = Var "z" val p = Var "p" val f = Var "f" val m = Var "m" val n = Var "n" val s = Var "s" val i = Var "i" (* Church numerals *) fun num n = let val f = Var "f" val x = Var "x" fun iter r 0 = lambda [f,x] r | iter r n = iter (f @: r) (n-1) in iter x n end (* Now we can construct assumptions and expressions *) (* S ZERO = (λ n f x.n f (f x)) λ f x.x *) val ZERO = num 0 val S = lambda [n,f,x] (n @: f @: (f @: x)); ppts (principalts [] (S @: ZERO)); (* PRED and PRED 6 *) val PAIR = (lambda [x, y, f] (f @: x @: y)) val FST = (lambda [p] (p @: (lambda [x, y] x))) val SND = (lambda [p] (p @: (lambda [x, y] y))) val G = lambda [f,p] (PAIR @: (f @: (FST @: p)) @: (FST @: p)) val PRED = lambda [n] (SND @: (n @: (G @: S) @: (PAIR @: ZERO @: ZERO))) val SUB = lambda [m, n] (n @: PRED @: m); ppts (principalts [] PRED); ppts (principalts [] (PRED @: (num 6))); ppts (principalts [] SUB); (* The definition of PRED from Larry Paulson's lecture notes *) val PREDp = lambda [n,f,x] (SND @: (n @: (G @: f) @: (PAIR @: x @: x))) val SUBp = lambda [m, n] (n @: PREDp @: m); ppts (principalts [] PREDp); ppts (principalts [] (PREDp @: (num 6))); ppexp SUBp; ppts (principalts [] SUBp); (* let i=λx.x in i i *) val polylet = letbind [(i,lambda [x] x)] (i @: i); ppexp polylet; ppts (principalts [] polylet); (* map *) val condts = mk_tyscheme [alpha] (boolt --> alpha --> alpha --> alpha) val fixts = mk_tyscheme [alpha] ((alpha --> alpha) --> alpha) val nullts = mk_tyscheme [alpha] (listt alpha --> boolt) val nilts = mk_tyscheme [alpha] (listt alpha) val consts = mk_tyscheme [alpha] (alpha --> listt alpha --> listt alpha) val hdts = mk_tyscheme [alpha] (listt alpha --> alpha) val tlts = mk_tyscheme [alpha] (listt alpha --> listt alpha) val pairts = mk_tyscheme [alpha, beta] (alpha --> beta --> pairt alpha beta) val fstts = mk_tyscheme [alpha, beta] (pairt alpha beta --> alpha) val sndts = mk_tyscheme [alpha, beta] (pairt alpha beta --> beta) val bool_assums = [("true",Type(boolt)),("false",Type(boolt)),("cond",condts)] val pair_assums = [("pair",pairts),("fst",fstts),("snd",sndts)] val fix_assums = [("fix",fixts)] val list_assums = [("null",nullts),("nil",nilts), ("cons",consts),("hd",hdts),("tl",tlts)] (* let map = (fix (λ map f s. (cond (null s) nil (cons (f (hd s)) (map f (tl s)))))) in map *) val assums = bool_assums@fix_assums@list_assums val map' = Var "map" val fix = Var "fix" val null' = Var "null" val nil' = Var "nil" val cond = Var "cond" val cons = Var "cons" val hd' = Var "hd" val tl' = Var "tl" val mapdef = letbind [(map', (fix @: (lambda [map', f, s] (cond @: (null' @: s) @: nil' @: (cons @: (f @: (hd' @: s)) @: (map' @: f @: (tl' @: s)))))))] map'; ppassums assums; ppexp mapdef; val mapdefts = principalts assums mapdef; ppts mapdefts; (* Mairson's expression in ML let fun pair x y = fn z => z x y val x1 = fn y => pair y y val x2 = fn y => x1 (x1 y) val x3 = fn y => x2 (x2 y) val x4 = fn y => x3 (x3 y) val x5 = fn y => x4 (x4 y) in x5 (fn z => z) end *) val x1 = Var "x1" val x2 = Var "x2" val x3 = Var "x3" val x4 = Var "x4" val x5 = Var "x5" val pair = Var "pair" val mairson = letbind [(pair,lambda [x,y,z] (z @: x @: y)), (x1,lambda [y] (pair @: y @: y)), (x2,lambda [y] (x1 @: (x1 @: y))), (x3,lambda [y] (x2 @: (x2 @: y))), (x4,lambda [y] (x3 @: (x3 @: y))) ] (x4 @: (lambda [x] x)); ppts (principalts [] mairson); (* val mairson = letbind [(pair,lambda [x,y,z] (z @: x @: y)), (x1,lambda [y] (pair @: y @: y)), (x2,lambda [y] (x1 @: (x1 @: y))), (x3,lambda [y] (x2 @: (x2 @: y))), (x4,lambda [y] (x3 @: (x3 @: y))), (x5,lambda [y] (x4 @: (x4 @: y))) ] (x5 @: (lambda [x] x)); principalts [] mairson; *) (* handle expected exceptions *) exception TestFail fun expect_Unify se f = (ignore(f ())) handle Unify s => if se = s then () else raise TestFail (* omega *) val omegaexp = Let(("omega",Abs ("x",Comb(Var "x",Var "x"))),Var "omega") val omegadef = fn () => principalts [] omegaexp; expect_Unify "Occurs" omegadef; (* Y *) val r = Abs("x",Comb(Var "f",Comb(Var "x",Var "x"))) val ydef = fn () => principalts [] (Let(("Y",Abs ("f",Comb(r,r))),Var "Y")); expect_Unify "Occurs" ydef; (* Church numerals *) val nassums = []; val cn_zerodef = Abs("f",Abs("x",Var "x")); ppassums nassums; ppexp cn_zerodef; val cn_zero = principalts [] cn_zerodef; ppts cn_zero; val cn_onedef = Abs("f",Abs("x",Comb(Var "f", Var "x"))); ppassums nassums; ppexp cn_onedef; val cn_one = principalts [] cn_onedef; ppts cn_one; val cn_twodef = Abs("f",Abs("x",Comb(Var "f",Comb(Var "f",Var "x")))); ppassums nassums; ppexp cn_twodef; val cn_two = principalts [] cn_twodef; ppts cn_two; ppts cn_zero; ppts cn_one; ppts cn_two; (* Church numerals in SML *) fn f => fn x => x; fn f => fn x => f x; fn f => fn x => f (f x)