1:- module(ccp_switches, [ map_sw/3, map_swc/3, map_sum_sw/3, map_sum_sw/4, map_swc/4
2 , sw_mode/2, sw_samples/2, sw_expectations/2, sw_log_prob/3, sw_posteriors/3, marg_log_prob/3
3 , sw_init/3, dirichlet/2
4 ]).
18:- use_module(library(callutils), [(*)/4, const/3]). 19:- use_module(library(data/pair), [fsnd/3, snd/2]). 20:- use_module(library(math), [map_sum/3, map_sum/4]). 21:- use_module(library(plrand), [log_prob_dirichlet/3, log_partition_dirichlet/2]). 22:- use_module(library(prob/tagless),[dirichlet//2]). 23:- use_module(library(lazymath), [add/3, mul/3, max/3, stoch/2]). 24:- use_module(effects, [sample/2]). 25
26dirichlet(As,Ps) :- sample(dirichlet(As),Ps).
27
29:- meta_predicate map_sw(2,?,?), map_swc(2,?,?), map_swc(3,?,?,?), map_sum_sw(2,+,-), map_sum_sw(3,+,+,-). 30
31map_sw(P,X,Y) :- maplist(fsnd(P),X,Y).
32map_swc(P,X,Y) :- map_sw(maplist(P),X,Y).
33map_swc(P,X,Y,Z) :- maplist(fsnd3(maplist(P)),X,Y,Z).
34map_sum_sw(P,X,Sum) :- map_sum(P*snd,X,Sum).
35map_sum_sw(P,X,Y,Sum) :- map_sum(f2sw1(P),X,Y,Sum).
36fsnd3(P,A-X,A-Y,A-Z) :- call(P,X,Y,Z).
37f2sw1(P,SW-X,SW-Y,Z) :- call(P,X,Y,Z).
38
39sw_posteriors(Prior,Eta,Post) :- map_swc(add,Eta,Prior,Post).
40sw_expectations(Alphas,Probs) :- map_sw(stoch,Alphas,Probs).
41sw_samples(Alphas,Probs) :- map_sw(dirichlet,Alphas,Probs).
42sw_log_prob(Alphas,Probs,LP) :- map_sum_sw(log_prob_dirichlet,Alphas,Probs,LP).
43sw_marg_log_prob(Prior,Eta,LP):- map_sum_sw(marg_log_prob,Prior,Eta,LP).
44sw_mode(Alphas,Probs) :- map_sw(stoch*maplist(max(0.0)*add(-1.0)),Alphas,Probs).
45
46marg_log_prob(Prior,Eta,LP) :-
47 maplist(add,Prior,Eta,Post),
48 maplist(log_partition_dirichlet,[Prior,Post],[Bot,Top]),
49 LP is Top - Bot.
66sw_init(\Pred, SW, SW-P) :- !, call(SW,ID,Vals,[]), call(Pred,ID,Vals,P).
67sw_init(Spec, SW, SW-P) :- call(SW,_,Vals,[]), init(Spec, Vals, P).
68
69init(uniform,Vs, Params) :- length(Vs,N), P is 1.0/N, maplist(const(P),Vs,Params).
70init(unit, Vs, Params) :- maplist(const(1.0),Vs,Params).
71init(random, Vs, Params) :- call(dirichlet*init(unit), Vs, Params).
72init(K*Spec, Vs, Params) :- call(maplist(mul(K))*init(Spec), Vs, Params).
73init(S1+S2, Vs, Params) :- init(S1,Vs,P1), init(S2,Vs,P2), maplist(add,P1,P2,Params).
74init(log(Spec), Vs, Params) :- call(maplist(log)*init(Spec), Vs, Params)
Tools for working with lists of switch parameters
Switch parameters are represented as a list of pairs:
Each switch term is associated with a list of numbers, one for each value the switch can take. The meaning of the numbers is context dependent, but is usually either a normalised probability distribution over the values or the parameters for a Dirichlet distribution over switch value distributions. */