1:- module(plflow, [ops_body/4, sub/3, stoch/2, log_prob_dir/3, log_part_dir/2, mean_log_dir/2]). 2 3:- use_module(library(math), [stoch/3]). 4:- use_module(library(plrand), []). 5:- use_module(library(autodiff2), [esc/3]). 6 7ops_body(_, _, Ops, Body) :- foldl(op_goal, Ops, Body, true). 8op_goal(op(OpCode, Ins, Outs), (X,Y), Y) :- op_goal(OpCode, Ins, Outs, X). 9 10:- multifile op_goal/4. 11op_goal(add, [X,Y], [Z], Z is X + Y). 12op_goal(sub, [X,Y], [Z], Z is Y - X). 13op_goal(mul, [X,Y], [Z], Z is X * Y). 14op_goal(div, [X,Y], [Z], Z is X / Y). 15op_goal(pow, [X,Y], [Z], Z is Y**X). 16op_goal(max, [X,Y], [Z], Z is max(X,Y)). 17op_goal(exp, [X], [Z], Z is exp(X)). 18op_goal(log, [X], [Z], Z is log(X)). 19op_goal(chi, [X,Y,Z], [I], (X>Y -> I=Z; X<Y -> I=0.0; I is Z/2.0)). 20op_goal(add_log, [M,S], [Z], Z is log(S) + M). 21op_goal(exp_sub, [M,Z], [S], S is exp(Z-M)). 22op_goal(sum_list, Xs, [Z], sum_list(Xs,Z)). 23op_goal(max_list, Xs, [Z], max_list(Xs,Z)). 24op_goal(stoch, Xs, Ys, math:stoch(Xs,Ys,_)). 25op_goal(log_prob_dirichlet(As), Ps, [LP], plrand:log_prob_dirichlet(As,Ps,LP)). 26op_goal(log_partition_dirichlet, As, [LZ], plrand:log_partition_dirichlet(As,LZ)). 27op_goal(mean_log_dirichlet, As, Psi, plrand:mean_log_dirichlet(As,Psi)). 28 29log_prob_dir(As, Ps, LP) :- esc(log_prob_dirichlet(As), Ps, [LP]). 30log_part_dir(As, LZ) :- esc(log_partition_dirichlet, As, [LZ]). 31mean_log_dir(As, Psi) :- same_length(As, Psi), esc(mean_log_dirichlet, As, Psi). 32stoch(Xs,Ys) :- same_length(Xs,Ys), esc(stoch,Xs,Ys). 33sub(X,Y,Z) :- esc(sub, [X,Y], [Z])