:- module(tor,
	[(tor)/2
	,op(1100,xfy,tor)
	,op(1150,fx,tor)
	,search/1
	,tor_handlers/3
	,tor_before_handlers/3
        ,tor_merge/2
        ,dbs_tree/1
        ,dbs/2
        ,dibs_tree/1
        ,dibs/2
        ,id/1
        ,nbs/2
        ,nbs_tree/1
        ,bab/2
        ,lds/1
        ,dbs/3
        ,iterate/1
	,tor_statistics/1
	,solution_count/2
	,node_count/2
	,failure_count/2
        ,log/1
        ,parallel/1
	]).

/** <module> Tor infrastructure and many handlers.
 
 This module contains the basic Tor infrastructure for hookable disjunction as
 well as the definition of the search strategies.

*/

:- use_module(library(apply)).  
:- use_module(library(lists)).  
:- use_module(library(terms)).  
:- use_module(library(mutable_variables)).
:- use_module(library(clpfd)).
:- use_module(library(unix)).

%-------------------------------------------------------------------------------
% Tor hookable disjunction
%-------------------------------------------------------------------------------

:- meta_predicate tor(0,0).

%% tor(+G1, +G2)
%
% Hookable disjunction. This operator should be used instead of normal
% disjunction.
G1 tor G2 :-
       ( b_getval(left,Left),
         call(Left,G1)
       ; b_getval(right,Right),
         call(Right,G2)
       ).

%-------------------------------------------------------------------------------
% Infrastructure
%-------------------------------------------------------------------------------

:- meta_predicate search(0).

%% search(+Goal)
%
%  New search scope: sets up the default handler for both hooks, that is,
%  call/1. With this default handler, tor/2 corresponds to plain disjunction.
search(Goal) :-
  b_setval(left,call),
  b_setval(right,call),
  reset_pruned,
  call(Goal).

:- meta_predicate tor_handlers(0,1,1).

%% tor_handlers(+Goal,+Left,+Right)
%
% Around advice. This predicate composes the currently installed handlers with
% the new ones provided. Then, it runs the provided goal and finally, it resets
% the installed handlers.
tor_handlers(Goal,Left,Right) :-
  b_getval(left,LeftHandler),
  b_getval(right,RightHandler),
  b_setval(left,compose(LeftHandler,Left)),
  b_setval(right,compose(RightHandler,Right)),
  call(Goal),
  b_setval(left,LeftHandler),
  b_setval(right,RightHandler).


:- meta_predicate compose(1,1,0).

% Conceptually: G1(G2(Goal))
compose(G1,G2,Goal) :- call(G1,call(G2,Goal)).

:- meta_predicate tor_before_handlers(0,0,0).

% tor_before_handlers(+Goal, +Left, +Right)
%
% Before advice: in case the handler only needs to precede the actual branch
% goal by its own goal.
tor_before_handlers(Goal,Left,Right) :-
  tor_handlers(Goal,before(Left),before(Right)).

:- meta_predicate before(0,0).

before(G1,G2) :- G1, G2.

% :- meta_predicate tor_merge(0,0).

%% tor_merge(+Heuristic, +Goal)
%
% Extracts left and right handler definitions from the source code of a
% high-level search heuristic definition and invokes tor_handlers.
tor_merge(Heuristic,Goal) :-
  % For a correct translation, we need to have a head that only contains free variables
  construct_template(Heuristic,FreeHead),
  clause(FreeHead,Body),
  % Do translation
  % BVarPos is a list of positions at which we need we need a mutable variable
  translate(FreeHead,Body, HandlerLeft, HandlerRight, HandlerLeftHeadVars, HandlerRightHeadVars, BVarPos),
  % Assert handlers
  assert_handler(HandlerLeftHeadVars, HandlerLeft, LeftSym),
  assert_handler(HandlerRightHeadVars, HandlerRight, RightSym),
  % Create and initialize mutable variables
  maplist(create_bvar(Heuristic),BVarPos,MutableVariables),
  install_handlers(Heuristic, MutableVariables, BVarPos, LeftSym, RightSym, InstallLeft, InstallRight),
  tor_handlers(Goal,InstallLeft,InstallRight).

% Asserts handler under a unique name
% Mode: + + -
assert_handler(HandlerHeadVars, HandlerBody, Sym) :-
  gensym('handler',Sym),
  Head =.. [Sym|HandlerHeadVars],
  assert((Head :- HandlerBody)).

% Gets argument at given position in head and initializes a new mutable variable with it.
% Designed to be used with maplist/3
create_bvar(Head,Pos,MutableVariable) :-
  Head =.. [_|ArgList],
  nth0(Pos,ArgList,Value),
  new_bvar(Value,MutableVariable).

% Installs both handlers
% Mode: + + + + + - -
install_handlers(Heuristic, MutableVariables, BVarPos, LeftSym, RightSym, InstallLeft, InstallRight) :-
  % Construct the list of arguments for the handler terms that we are going to install
  % Do this by merging mutable variables and normal variables. Handlers must only be partially applied. Tor_handlers does the rest.
  Heuristic =.. [_|AllArgs],
  merge_by_pos(AllArgs,MutableVariables,BVarPos,InstallArgs),
  InstallLeft =.. [LeftSym|InstallArgs],
  InstallRight =.. [RightSym|InstallArgs].

% Creates ListOut by taking the elements of List1, except for the positions mentioned in PosList, where the first element of List2 is used that has not been used so far.
% Counting starts from zero.
% Much more efficient to use this instead of merge_by_pos(List1, List2, PosList, ListOut, []).
merge_by_pos(List1,List2,PosList,ListOut) :-
  merge_by_pos_(List1,List2,PosList,0,ListOut).

merge_by_pos_([],_List2,_PosList,_Pos,[]).
merge_by_pos_([X|Xs],[],_PosList,_Pos,[X|Xs]).
merge_by_pos_([X|Xs],[Y|Ys],PosList,Pos,[Z|Zs]) :-
  Pos1 is Pos + 1,
  ( memberchk(Pos,PosList) ->
    Z = Y,
    merge_by_pos_(Xs,Ys,PosList,Pos1,Zs)
  ; Z = X,
    merge_by_pos_(Xs,[Y|Ys],PosList,Pos1,Zs)
  ).

% Creates a difference list by taking the elements of List1, except for the positions mentioned in PosList, where the first element of List2 is used that has not been used so far.
% Mode: + + + - -
% Counting starts from zero.
% Should only be used if TailOut is not [] (for efficiency).
merge_by_pos(List1,List2,PosList,ListOut,TailOut) :-
  merge_by_pos_(List1,List2,PosList,0,ListOut,TailOut).

merge_by_pos_([],_List2,_PosList,_Pos,X,X).
% Extra case for efficiency: don't bother about updating pos - it still is not terribly efficient since we can only add a single element at a time.
merge_by_pos_([X|Xs],[],_PosList,_Pos,[X|HO],TO) :- !,
    merge_by_pos_(Xs,[],dummy,dummy,HO,TO).
merge_by_pos_([X|Xs],[Y|Ys],PosList,Pos,[Z|HO],TO) :-
  Pos1 is Pos + 1,
  ( memberchk(Pos,PosList) ->
    Z = Y,
    merge_by_pos_(Xs,Ys,PosList,Pos1,HO,TO)
  ; Z = X,
    merge_by_pos_(Xs,[Y|Ys],PosList,Pos1,HO,TO)
  ).

% Derives left and right handler definitions from a high-level heuristic definition.
% This process involves:
% - splitting into left and right handler by finding tor operator
% - replacing regular parameters with mutable variables
% - adding b_get and b_put for those variables
% - eliminating explicit recursive calls 
translate(Head, Body, HandlerLeft, HandlerRight,HandlerLeftHeadVars,HandlerRightHeadVars, DiffPos) :-
  split_handlers(Body, Left, Right),
  % Find out what to b_get: that is, we need to look at all the recursive calls and if a parameter changes between the head and the recursive call, we need to use a mutable variable and a b_get.
  % First find all recursive calls
  functor(Head,HeadName,HeadArity),
  find_goals(HeadName, HeadArity, Body, RecursiveCalls),
  % For each recursive call: check in which argument positions head and call have different variables.
  maplist(differentVariablePosList(Head),RecursiveCalls,ListListDiffPos),
  foldl(union,ListListDiffPos,[],DiffPos),
  % Translate left and right handler seperately.
  translate_handler(Head, Left, DiffPos, HandlerLeft, HandlerLeftHeadVars),
  translate_handler(Head, Right, DiffPos, HandlerRight, HandlerRightHeadVars).

% Returns a list of positions in which different variables were found
differentVariablePosList(Head1,Head2, DifferentPos) :-
  Head1 =.. [_|Args1],
  Head2 =.. [_|Args2],
  differentVariablePosList_(Args1,Args2,0,DifferentPos-[]).

differentVariablePosList_([],[],_,X-X).
differentVariablePosList_([X|Xs], [Y|Ys], Nr, List-Tail) :-
  ( X \== Y ->
     List = [Nr|Tail1]
  ;
    Tail1 = List
  ),
  Pos1 is Nr + 1,
  differentVariablePosList_(Xs, Ys, Pos1, Tail1-Tail).

% Construct template: a predicate with the same name and arity, but containing only free variables
construct_template(Predicate,Template) :-
  functor(Predicate, Name, Arity),
  functor(Template, Name, Arity).

% Does translation for a single handler
% Parameters:
% - Head
% - Body: handler code to transform (does not contain tor disjunctions anymore)
% - MutableVarsPositions: list of positions, starting from zero, that indicate which head variables need to become mutable variables.
translate_handler(Head, Body, MutableVarsPositions, Handler, HandlerHeadVars) :-
  % Construct template for recursive call and replace it by a free variable in Body, resulting in Body2.
  construct_template(Head,TemplateRecursive),
  replace_goal(TemplateRecursive, Body,Free,Body2),  
  % Generalized adding of b_get here:
  % Make a conjunction of b_gets, and one of b_puts
  % Use a pair for the first argument, since maplist/6 is not defined anymore, but maplist/5 is
  Head =.. [_|Args],
  TemplateRecursive =.. [_|Args2],
  maplist(make_bgetput_arg_pos(Args-Args2),MutableVarsPositions,MutableVars,BGetList,BPutList),
  list_to_conj(BGetList,BGetConj),
  list_to_conj(BPutList,BPutConj),
  % Add b_gets and b_puts to handlers.
  % list_to_conj gives true in case of empty list, avoid to insert these
  empty_different(BGetList,Handler,(BGetConj,Body2),Body2),
  empty_different(BPutList,Free,(BPutConj,call(Goal)),call(Goal)),
  % Now create the head variables for the handler by combining free variables that must become mutable variables and regular arguments.
  merge_by_pos(Args,MutableVars,MutableVarsPositions,HandlerHeadVars, [Goal]). % We should be able to use argument from recursive call instead and get the same result.

% Binds Variable to Nonempty if List is non-empty, else binds Variable to Empty.
empty_different(List,Variable,Nonempty,Empty) :-
  (List = [] ->
    Variable = Empty
  ;
    Variable = Nonempty
  ).

% For use with maplist/5
% Gets argument nr. Position from Args and Args2
% Intented "results": MutableVariable, b_get(MutableVariable,Argument) and b_put(MutableVariable,Argument2).
make_bgetput_arg_pos(Args-Args2,Position,MutableVariable,b_get(MutableVariable,Arg),b_put(MutableVariable,Arg2)) :-
  nth0(Position,Args,Arg),
  nth0(Position,Args2,Arg2).

find_goals(PredicateName, Arity, Term, ResultList) :-
  find_goals_(PredicateName, Arity, Term, ResultList-[]).

% Does term contain conjunction, disjunction or tor disjunction?
% If yes, Arg1 and Arg2 are both operands, and Operator is the operator.
has_selected_binary_operator(Term, Operator, Arg1, Arg2) :-
  Term =.. [Operator,Arg1, Arg2],
  memberchk(Operator,[',', ';', 'tor']).

% Does term contain conjunction, disjunction?
% If yes, Arg1 and Arg2 are both operands, and Operator is the operator.
has_selected_binary_operator2(Term, Operator, Arg1, Arg2) :-
  Term =.. [Operator,Arg1, Arg2],
  memberchk(Operator,[',', ';']).

% Does not look in test - needs to be tested before disjunction
find_goals_(PredicateName, Arity, (_Test -> Term1 ; Term2), List-Tail) :- !,
  find_goals_(PredicateName, Arity, Term1, List-Tail1),
  find_goals_(PredicateName, Arity, Term2, Tail1-Tail).
find_goals_(PredicateName, Arity, Term, List-Tail) :-
  has_selected_binary_operator(Term, _Operator, Term1, Term2), !, % Mind the place of the cut
  find_goals_(PredicateName, Arity, Term1, List-Tail1),
  find_goals_(PredicateName, Arity, Term2, Tail1-Tail).
find_goals_(PredicateName, Arity, Term, List-Tail) :- !,
  functor(Template,PredicateName, Arity),
  ( (nonvar(Term), Term = Template) ->
     % Add to difflist
     List = [Template|Tail]
     ;
     List = Tail
  ).

% In a term that is
% - a conjunction of subterms
% - a disjunction of subterms
% - an if containing two subterms (test not investigated further),
% replace goal that matches given template by free variable
% Must be before disjunction
replace_goal(Template,(Test -> Term1 ; Term2),Free, (Test -> Result1 ; Result2)) :- !,
  replace_goal(Template,Term1,Free,Result1),
  replace_goal(Template,Term2,Free,Result2).
replace_goal(Template,Term,Free,Result) :- 
  has_selected_binary_operator(Term,Operator,Term1,Term2), !, % Mind the place of the cut
  replace_goal(Template,Term1,Free,Result1),
  replace_goal(Template,Term2,Free,Result2),
  Result =.. [Operator,Result1,Result2].
replace_goal(Template,Term,Free,Result) :-
  ((nonvar(Term), Template = Term) ->
    % Found match!
    Result = Free
  ;
    Result = Term
  ).

% Must be before disjunction
split_handlers((Term1,Term2),(Result1,Result2),(Result3,Result4)) :- !,
  split_handlers(Term1,Result1,Result3),
  split_handlers(Term2,Result2,Result4).
split_handlers((Test -> Term1 ; Term2), (Test -> Result1 ; Result2), (Test -> Result3 ; Result4)) :- !,
  split_handlers(Term1,Result1,Result3),
  split_handlers(Term2,Result2,Result4).
split_handlers((Term1;Term2),(Result1;Result2),(Result3;Result4)) :- !,
  split_handlers(Term1,Result1,Result3),
  split_handlers(Term2,Result2,Result4).
split_handlers(Term,Result,Result2) :-
  ((nonvar(Term), tor(X,Y) = Term) ->
    % Found match!
    Result = X,
    Result2 = Y
  ;
    Result = Term,
    Result2 = Term
  ).


% Converts a list of conjuncts to a conjunction.
list_to_conj([],true) :- ! .
list_to_conj([X],X) :- ! .
list_to_conj([X1,X2],(X1,X2)) :- ! .
list_to_conj([X|Xs],(X,Ys)) :-
  list_to_conj(Xs,Ys).

%-------------------------------------------------------------------------------
% Search Methods listed in the paper
%-------------------------------------------------------------------------------

%% dbs_tree(+Depth)
%
% Depth bounded search tree. Use with tor_merge.
dbs_tree(D) :-
  D > 0, ND is D - 1,
  (dbs_tree(ND) tor dbs_tree(ND)).

%% dbs(+Depth, +Goal)
%
% Depth bounded search.
dbs(Depth, Goal) :-
  tor_merge(dbs_tree(Depth),Goal).

%% dibs_tree(+Discrepancies)
%
% Discrepancy-bounded search tree. Use with tor_merge.
% Uses prune instead of fail so it can be used to define lds in terms of it.
dibs_tree(D) :-
  (
    ( D > 0, dibs_tree(D)
    tor
      D > 0, ND is D - 1, dibs_tree(ND)
    )
  ;
    prune
  ).

%% dibs(+Discrepancies, +Goal)
%
% Discrepancy-bounded search
dibs(Discrepancies, Goal) :-
  tor_merge(dibs_tree(Discrepancies),Goal).

%% id(Goal)
%
% Iterative deepening.
id(Goal) :-
  new_nbvar(not_pruned,PVar),
  id_loop(Goal,0,PVar).

id_loop(Goal,Depth,PVar) :-
  nb_put(PVar,not_pruned),
  ( tor_merge(id_tree(Depth,PVar),Goal)
  ;
    nb_get(PVar,Value),
    Value == pruned,
    NDepth is Depth + 1,
    id_loop(Goal,NDepth,PVar)
  ).

id_tree(Depth,PruneVar) :-
  ( Depth > 0 ->
    NDepth is Depth - 1
  ;
    nb_put(PruneVar,pruned), false
  ),
  ( id_tree(NDepth, PruneVar)
  tor
    id_tree(NDepth, PruneVar)
  ).

%% nbs(+NumberOfNodes, +Goal)
%
% Node-bounded search
nbs(Nodes,Goal) :-
  new_nbvar(Nodes,NodesVar),
  catch(
    tor_merge(nbs_tree(NodesVar),Goal),
    out_of_nodes(NodesVar),
    fail
  ).

%% nbs_tree(+NodesVar)
%
% Node-bounded search tree. Use with tor_merge.
% Throws out_of_nodes exception.
nbs_tree(Var) :-
  nb_get(Var,N),
  ( N > 0 ->
    N1 is N - 1, nb_put(Var, N1), (nbs_tree(Var) tor nbs_tree(Var))
  ;
    throw(out_of_nodes(Var))
  ).

%% bab(+Objective, +Goal)
%
% Branch-and-bound
bab(Objective,Goal) :-
  fd_inf(Objective,Inf),
  LowerBound is Inf - 1,
  new_nbvar(LowerBound,BestVar),
  Current = inf,
  tor_merge(bab_tree(Objective,BestVar,Current),Goal),
  nb_put(BestVar,Objective).

bab_tree(Objective, BestVar, Current) :-
  nb_get(BestVar, Best),
  ( Best \= inf, (Current == inf ; Best > Current ) ->
    Objective #> Best,
    NCurrent = Best
  ;
    NCurrent = Current
  ),
  ( bab_tree(Objective, BestVar, NCurrent)
  tor
    bab_tree(Objective, BestVar, NCurrent)
  ).

:- meta_predicate lds(0).

%% lds(+Goal)
%
% Limited discrepancy search
lds(Goal) :-
  iterate(flip(dibs,Goal)).

% Level is a number, so should not be module sensitive.
% The Method is called with the Goal as extra argument, so meta argument specifier must a 1.
:- meta_predicate dbs(+,1,0).

%% dbs(+Level, +Method, +Goal)
%
% Variant on depth-bounded search. When the depth bound is reached, it does
% not prune the remaining subtree, but activates the search method Method.
dbs(Level, Method, Goal) :-
  new_bvar(yes(Level),Var),
  tor_handlers(Goal,dbs_handler(Var,Method)
                   ,dbs_handler(Var,Method)).

dbs_handler(Var,Method,Goal) :-
  b_get(Var,MDepth),
  dbs_handler_(MDepth,Var,Method,Goal).

dbs_handler_(yes(Depth),Var,Method,Goal) :-
  ( Depth > 1 ->
      NDepth is Depth - 1,
      b_put(Var,yes(NDepth)),
      call(Goal)
  ;
      b_put(Var,no),
      call(Method,Goal)
  ).
dbs_handler_(no,_,_,Goal) :-
  call(Goal).

%-------------------------------------------------------------------------------
% Iteration patterns and pruning
%-------------------------------------------------------------------------------

prune :-
  set_pruned(true),
  fail.

reset_pruned :-
  set_pruned(false).

is_pruned :-
  get_pruned(true).

get_pruned(Flag) :-
  nb_getval(pruned,Flag). 

set_pruned(Flag) :-
  nb_setval(pruned,Flag).

scope_pruned(Goal) :-
  get_pruned(OldFlag),
  ( reset_pruned,
    call(Goal)
  ;
    set_pruned(OldFlag),
    fail
  ).

pruned_union(true,_,true).
pruned_union(false,true,true).
pruned_union(false,false,false).

:- meta_predicate iterate(0).

%% iterate(+PGoal)
%
% Factors out the common iteration part of iterative deepening and limited
% discrepancy search.
iterate(PGoal) :-
  scope_pruned(
    iterate_loop(0,PGoal)).

:- meta_predicate iterate_loop(+,1).

iterate_loop(N,PGoal) :-
  ( 
    call(PGoal,N)
  ; 
    is_pruned,
    reset_pruned,
    M is N + 1,
    iterate_loop(M,PGoal) 
  ).

% The third argument is a number, which is not module-sensitive.
% Using 0 instead of + therefore is wrong.
:- meta_predicate flip(0,0,+).

flip(BaseStrategy, Goal, Number) :-
  call(BaseStrategy,Number, Goal).

%-------------------------------------------------------------------------------
% Statistics and visualisation
%-------------------------------------------------------------------------------

:- meta_predicate tor_statistics(0).

%% tor_statistics(+Goal)
%
% Prints statistics about the search:
%  * number of solutions
%  * number of nodes processed
%  * number of failures
tor_statistics(Goal) :-
  new_nbvar(0,SolutionVar),
  new_nbvar(0,NodeVar),
  new_nbvar(0,FailureVar),
  Vars  = [SolutionVar,NodeVar,FailureVar],
  Names = ['solutions','nodes','failures'],
  ( solution_count(SolutionVar,node_count(NodeVar,failure_count(FailureVar,Goal))),
    maplist(nb_report,Vars,Names)
  ;
    maplist(nb_report,Vars,Names)
  ).

nb_report(Var,Name) :-
  nb_get(Var,Value),
  format('% Number of ~w: ~`.t ~d~34|~n',[Name,Value]).

:- meta_predicate solution_count(+,0).

%% solution_count(+SolutionVar, +Goal)
%
% Count solutions using provided nonbacktrackable variable
solution_count(SolutionVar,Goal) :-
  call(Goal),
  nb_inc(SolutionVar).

:- meta_predicate node_count(+,0).

% node_count(+NodeVar, +Goal)
%
% Count number of nodes processed using provided nonbacktrackable variable
node_count(NodeVar,Goal) :-
  tor_before_handlers(Goal,nb_inc(NodeVar),nb_inc(NodeVar)).

nb_inc(Var) :-
  nb_get(Var,Value),
  NValue is Value + 1,
  nb_put(Var,NValue).

:- meta_predicate failure_count(+,0).

% failure_count(+FailureVar, +Goal)
%
% Count number of failures using provided nonbacktrackable variable
failure_count(FailureVar,Goal) :-
  tor_handlers(Goal,failure_handler(FailureVar),failure_handler(FailureVar)).

failure_handler(Var,Goal) :-
  ( call(Goal) *->
      true
  ;
      nb_inc(Var),
      fail
  ).

%% log(+Goal)
%
% Emits a textual representation of the search tree.
% This log can be turned into a PDF image using the provided tool.
log(Goal) :-
  tor_merge(log_tree, Goal),
  writeln(solution).

log_tree :-
  ( ( writeln(left)
    tor
      writeln(right)
    ),
    log_tree
  ;
    writeln(false),
    false
  ).

%-------------------------------------------------------------------------------
% Parallel infrastructure
%-------------------------------------------------------------------------------

%% parallel(+Goal)
%
% Parallel search
parallel(Goal) :-
        open('num_threads', write, Stream1, [lock(exclusive)]),
        format(Stream1, "0.\n", []),
        close(Stream1),
        general_tor_hook(Goal, tor_fork, tor_fork).

tor_fork(Goal) :-
        wait_for_slot,
        fork(PID),
        (   PID == child -> Goal
        ;   false
        ).

wait_for_slot :-
        open(mylock, write, Lock, [lock(exclusive)]),
        repeat,
           catch(open('num_threads', read, Stream, []),
                 E,
                 (   writeln(E), false)),
           read(Stream, Num),
           close(Stream),
           integer(Num),
           format("num: ~w\n", [Num]),
           (   Num > 5 -> sleep(0.5), false
           ;   true
           ),
           !,
        Num1 is Num + 1,
        open('num_threads', write, Stream1, [lock(exclusive)]),
        format(Stream1, "~w.\n", [Num1]),
        close(Stream1),
        close(Lock).

general_tor_hook(Goal,Left,Right) :-
  b_getval(left,LeftHook),
  b_getval(right,RightHook),
  b_setval(left,compose(LeftHook,Left)),
  b_setval(right,compose(RightHook,Right)),
  call(Goal),
  b_setval(left,LeftHook),
  b_setval(right,RightHook).

%-------------------------------------------------------------------------------
% tor/1 declaration
%-------------------------------------------------------------------------------
% The declaration `:- tor Pred/Arity.' replaces the implicit disjunctions
% between the clauses of Pred/Arity by explicit calls to tor/2 disjunction.
%
%   TODO
%     - perform substitations to simplify the unifier
%     - eliminate unifications with singleton variables from the unifier
%

:- multifile user:term_expansion/2.
:- dynamic '$tor_predicate'/3.
:- dynamic '$tor_clause'/5.

tor_expansion((:- tor F/A), File, []) :-
  assertz('$tor_predicate'(F,A,File)).
tor_expansion(Head,File,[]) :-
  functor(Head,F,A),
  '$tor_predicate'(F,A,File),
  assertz('$tor_clause'(F,A,File,Head,true)).
tor_expansion((Head :- Body),File,[]) :-
  functor(Head,F,A),
  '$tor_predicate'(F,A,File),
  assertz('$tor_clause'(F,A,File,Head,Body)).
tor_expansion(end_of_file,File,Clauses) :-
  findall(Clause,(retract('$tor_predicate'(F,A,File)), merge_tor_clauses(F,A,File,Clause)),Clauses).

merge_tor_clauses(F,A,File,Head :- Body) :-
  findall(Term-TermBody,retract('$tor_clause'(F,A,File,Term,TermBody)),TermBodyPairs),
  merge_tor_head(TermBodyPairs,Head),
  reverse(TermBodyPairs,RTermBodyPairs),
  merge_tor_bodies(RTermBodyPairs,Head,Body).

merge_tor_head([Term-_|Terms],Head) :-
  merge_tor_head_(Terms,Term,Head).

merge_tor_head_([],Head,Head).
merge_tor_head_([Term-_|Terms],Acc,Head) :-
  term_subsumer(Term,Acc,NAcc),
  merge_tor_head_(Terms,NAcc,Head).

merge_tor_bodies([Term-TermBody|Terms],Head,Body) :-
  head_matcher(Head,Term,Matcher),
  optimize_conjunction(Matcher,TermBody,Goal),
  merge_tor_bodies_(Terms,Head,Goal,Body).

merge_tor_bodies_([],_Head,Body,Body).
merge_tor_bodies_([Term-TermBody|Terms],Head,Acc,Body) :-
  head_matcher(Head,Term,Matcher),
  optimize_conjunction(Matcher,TermBody,Goal),
  merge_tor_bodies_(Terms,Head,(Goal tor Acc),Body).

head_matcher(Head,Term,Matcher) :-
  unifiable(Head,Term,Unifier),
  head_matcher(Unifier,Matcher).

head_matcher([],true).
head_matcher([G],Body) :- !,
  Body = G.
head_matcher([G|Gs],(G,Matcher)) :-
  head_matcher(Gs,Matcher).

optimize_conjunction(G1,G2,NG) :-
  ( G1 == true ->
     NG = G2
  ; G1 == false ->
     NG = false
  ; G2 == true ->
     NG = G1
  ;
     NG = (G1,G2)
  ).

%-------------------------------------------------------------------------------
% tor/1 declaration
%-------------------------------------------------------------------------------
%
% Save term_expansion for the end of the file.
%
%  (thanks to Jan Wielemaker)

user:term_expansion(TermIn, TermOut) :-
	\+ current_prolog_flag(xref, true),
	prolog_load_context(source, File),	
	tor_expansion(TermIn, File, TermOut).
