1:- module(ccp_learn, [converge/5, learn/4]).
6:- use_module(library(data/pair), [snd/2]). 7:- use_module(library(callutils), [(*)/4, true2/2]). 8:- use_module(library(plrand), [log_partition_dirichlet/2]). 9:- use_module(library(autodiff2), [esc/3, add/3, mul/3, pow/3, max/3, gather_ops/3]). 10:- use_module(library(clambda), [clambda/2, run_lambda_compiler/1]). 11:- use_module(library(plflow), [ops_body/4, sub/3, stoch/2, mean_log_dir/2, log_part_dir/2, log_prob_dir/3]). 12:- use_module(graph, [graph_counts/6]). 13:- use_module(switches, [map_sw/3, map_swc/3, map_swc/4, map_sum_sw/3]). 14
15
16mul_add(K,X,Y,Z) :- mul(K,Y,KY), add(X,KY,Z).
17map_sum_(P,X,Sum) :- maplist(P,X,Z), esc(sum_list,Z,[Sum]).
18map_sum_(P,X,Y,Sum) :- maplist(P,X,Y,Z), esc(sum_list,Z,[Sum]).
19map_sum_sw_(P,X,Sum) :- map_sum_(P*snd,X,Sum).
20map_sum_sw_(P,X,Y,Sum) :- map_sum_(f2sw1(P),X,Y,Sum).
21f2sw1(P,SW-X,SW-Y,Z) :- call(P,X,Y,Z).
33learn(Method, StatsMethod, Graph, Step) :-
34 learn(Method, StatsMethod, 1.0, Graph, Obj, P1, P2),
35 maplist(term_variables, [P1,P2], [Ins,Outs]),
36 gather_ops(Ins, [Obj|Outs], Ops), length(Ops, NumOps),
37 debug(learn(setup), 'Compiled ~d operations.', [NumOps]),
38 ops_body(Ins, [Obj|Outs], Ops, Body),
39 clambda(lambda([Obj,P1,P2], Body), Step).
40
41learn(ml, Stats, ITemp, Graph, LL, P1, P2) :-
42 once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
43 map_swc(pow(ITemp), P1, PP),
44 map_sw(stoch, Eta, P2).
45
46learn(map(Prior), Stats, ITemp, Graph, Obj, P1, P2) :-
47 once(graph_counts(Stats, lin, Graph, PP, Eta, LL)),
48 map_sum_sw_(log_prob_dir, Prior, P1, LP0),
49 map_swc(add, Eta, Prior, Post),
50 map_sw(stoch*maplist(max(0.0)*add(-1.0)), Post, P2), 51 call(mul_add(ITemp, LL), LP0, Obj),
52 map_swc(pow(ITemp), P1, PP).
53
54learn(vb(Prior), Stats, ITemp, Graph, Obj, A1, A2) :-
55 maplist(map_swc(true2,Prior), [A1,Pi]), 56 map_swc(mul_add(ITemp,1.0-ITemp), Prior, EffPrior),
57 map_sum_sw(log_partition_dirichlet, Prior, LogZPrior),
58 vb_helper(ITemp, LogZPrior, EffPrior, A1, Pi, Div),
59 once(graph_counts(Stats, log, Graph, Pi, Eta, LL)),
60 map_swc(mul_add(ITemp), EffPrior, Eta, A2),
61 sub(Div, LL, Obj).
62
63vb_helper(ITemp, LogZPrior, EffPrior, A, Pi, Div) :-
64 map_sw(mean_log_dir, A, PsiA),
65 map_swc(sub, EffPrior, A, Delta),
66 map_swc(mul(ITemp), PsiA, Pi),
67 map_sum_sw_(log_part_dir, A, LogZA),
68 map_sum_sw_(map_sum_(mul), PsiA, Delta, Diff),
69 call(sub(LogZA)*mul_add(ITemp,Diff), LogZPrior, Div).
78:- meta_predicate converge(+,1,-,+,-). 79converge(Test, Setup, [X0|History], S0, SFinal) :-
80 debug(learn(setup), 'converge: Setting up...',[]),
81 run_lambda_compiler((
82 time(call(Setup, Step)),
83 call(Step, X0, S0, S1),
84 time(converge_x(Test, Step, X0, History, S1, SFinal)))).
85
86converge_x(Test, Step, X0, [X1|History], S1, SFinal) :-
87 call(Step, X1, S1, S2),
88 ( converged(Test, X0, X1) -> History=[], SFinal=S2
89 ; converge_x(Test, Step, X1, History, S2, SFinal)
90 ).
91
92converged(abs(Eps), X1, X2) :- abs(X1-X2) =< Eps.
93converged(rel(Del), X1, X2) :- abs((X1-X2)/(X1+X2)) =< Del
Expectation-maximisation, variational Bayes and deterministic annealing.
*/