1:- module(egraph, [add_term//2, union//2, saturate//1, saturate//2,
    2                   extract/1, extract//0, lookup/2]).

E-graph implementation for term rewriting and saturation

This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.

Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the egraph_compile module for full details. The supported rule declarations are:

Main predicates:

*/

   27:- use_module(library(dcg/high_order)).   28:- use_module(library(ordsets)).   29:- use_module(library(rbtrees)).   30:- use_module(library(clpr)).   31
   32:- use_module(egraph/compile).   33
   34cost:attr_unify_hook(_, _) :-
   35   true.
   36const:attr_unify_hook(XConst, Y) :-
   37   (  get_attr(Y, const, YConst)
   38   -> (  XConst =:= YConst
   39      -> true
   40      ;  domain_error(XConst, YConst)
   41      )
   42   ;  var(Y)
   43   -> put_attr(Y, const, XConst)
   44   ;  true
   45   ).
 lookup(+Pair, +SortedPairs) is semidet
Retrieves a value from a sorted list of pairs using standard term comparison. The search is unrolled for performance. Adapted from ord_memberchk/2.
Arguments:
Pair- A Key-Value pair where Key is the target key to find, and Value is unified with the associated value.
SortedPairs- A list of Key-Value pairs sorted by Key.
   55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
   56   !,
   57   compare(R4, Item, X4),
   58   (   R4=(>)
   59   ->  lookup(Item-V, Xs)
   60   ;   R4=(<)
   61   ->  compare(R2, Item, X2),
   62      (   R2=(>)
   63      ->  Item==X3, V = V3
   64      ;   R2=(<)
   65      ->  Item==X1, V = V1
   66      ;   V = V2
   67      )
   68   ;   V = V4
   69   ).
   70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
   71   !,
   72   compare(R2, Item, X2),
   73   (   R2=(>)
   74   ->  lookup(Item-V, Xs)
   75   ;   R2=(<)
   76   ->  Item==X1, V = V1
   77   ;   V = V2
   78   ).
   79lookup(Item-V, [X1-V1]) :-
   80   Item==X1, V = V1.
 add_term(+Term, -Id)// is det
Adds a term to the E-graph, returning its e-class ID. Compound terms are recursively traversed and their arguments are added to the E-graph first. Variables are represented using '$VAR'/1 wrappers.
Arguments:
Term- The term to be added.
Id- The e-class ID representing the added term.
   92add_term(Term, Id), var(Term) ==>
   93   add_node('$VAR'(Term), Id).
   94add_term(Term, Id), is_dict(Term) ==>
   95   {
   96      dict_pairs(Term, Tag, Pairs),
   97      pairs_keys_values(Pairs, Keys, Values)
   98   },
   99   foldl(add_term, Values, Ids),
  100   {
  101      pairs_keys_values(Data, Keys, Ids),
  102      dict_create(Node, Tag, Data)
  103   },
  104   add_node(Node, Id).
  105add_term(Term, Id), compound(Term) ==>
  106   { Term =.. [F | Args] },
  107   foldl(add_term, Args, Ids),
  108   { Node =.. [F | Ids] },
  109   add_node(Node, Id).
  110add_term(Term, Id) ==>
  111   add_node(Term, Id).
  112
  113add_node(Node-Id, In, Out) :-
  114   add_node(Node, Id, In, Out).
  115add_node(Node, Id, In, Out) :-
  116   (  lookup(Node-node(Id, _Cost), In)
  117   -> Out = In
  118   ;  ord_add_element(In, Node-node(Id, 1), Out),
  119      (  number(Node)
  120      -> put_attr(Id, const, Node)
  121      ;  true
  122      )
  123   ).
  124
  125rules([Rule | Rules], Index, Pat-node(Id, Cost), UnifsIn, UnifsOut) -->
  126   call(Rule, Pat, Id, Index, UnifsIn, UnifsTmp),
  127   rules(Rules, Index, Pat-node(Id, Cost), UnifsTmp, UnifsOut).
  128rules([], _, _, Unifs, Unifs) --> [].
  129
  130make_index(In, Index) :-
  131   index_pairs(In, UnsortedPairs),
  132   keysort(UnsortedPairs, IdPairs),
  133   group_pairs_by_key(IdPairs, Groups),
  134   ord_list_to_rbtree(Groups, Index).
  135
  136index_pairs([], []).
  137index_pairs([Node-node(Id, _Cost)|T0], [Id-Node|T1]) :-
  138   index_pairs(T0, T1).
  139
  140match([], _, _, Unifs, Unifs) --> [].
  141match([Node | Rest], Rules, Index, UnifsIn, UnifsOut) -->
  142   rules(Rules, Index, Node, UnifsIn, UnifsTmp),
  143   match(Rest, Rules, Index, UnifsTmp, UnifsOut).
 union(+Id1, +Id2)// is det
Merges two e-classes by unifying their IDs and merging their underlying nodes.
Arguments:
Id1- The first e-class ID.
Id2- The second e-class ID.
  152union(A, A) -->
  153   merge_nodes.
  154
  155merge_nodes(In, Out) :-
  156   sort(In, Sort),
  157   group_pairs_by_key(Sort, Groups),
  158   merge_groups(Groups, Tmp, false, Merged),
  159   (  Merged == true
  160   -> merge_nodes(Tmp, Out)
  161   ;  Out = Sort
  162   ).
  163
  164merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
  165   merge_group(T, H, Node),
  166   (  T == []
  167   -> Tmp = In
  168   ;  Tmp = true
  169   ),
  170   merge_groups(Nodes, Worklist, Tmp, Out).
  171merge_groups([], [], In, In).
  172
  173merge_group([], Node, Node).
  174merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
  175   MinCost is min(Cost, PrevCost),
  176   merge_group(T, node(Id, MinCost), Out).
  177
  178apply_unifs([]).
  179apply_unifs([A=A | L]) :-
  180   apply_unifs(L).
  181
  182rebuild(Matches, Unifs, Out) :-
  183   apply_unifs(Unifs),
  184   merge_nodes(Matches, Out).
 saturate(+Rules)// is det
Applies a list of compiled rewrite rules to the E-graph until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
  193saturate(Rules) -->
  194   saturate(Rules, inf).
 saturate(+Rules, +N)// is det
Applies a list of compiled rewrite rules to the E-graph up to N times or until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
N- The maximum number of iterations (or inf for no limit).
  204saturate(Rules, N, In, Out) :-
  205   (  N > 0
  206   -> make_index(In, Index),
  207      match(In, Rules, Index, Unifs, [], Matches, In),
  208      rebuild(Matches, Unifs, Tmp),
  209      length(In, Len1),
  210      length(Tmp, Len2),
  211      (  Len1 \== Len2
  212      -> (  N == inf
  213         -> N1 = N
  214         ;  N1 is N - 1
  215         ),
  216         saturate(Rules, N1, Tmp, Out)
  217      ;  Out = Tmp
  218      )
  219   ;  Out = In
  220   ).
 extract(+Nodes) is det
 extract// is det
Extracts the optimal term(s) from the E-graph based on term costs.
Arguments:
Nodes- A list of E-graph nodes representing the state.
  229extract(Nodes) :-
  230   extract(Nodes, Nodes).
  231extract(Nodes, Nodes) :-
  232   transpose_pairs(Nodes, Pairs),
  233   maplist([node(Id, Cost)-Node, Id-node(Cost, Node)]>>true, Pairs, IdPairs),
  234   group_pairs_by_key(IdPairs, ClassNodes),
  235   maplist([Id-_Node]>>({Cost >= 0}, put_attr(Id, cost, Cost)), ClassNodes),
  236   maplist(compute_class_cost, ClassNodes, NewClassNodes),
  237   maplist(extract_class, NewClassNodes).
  238
  239extract_class(Id-Nodes) :-
  240   % make sure that costs are instantiated
  241   sort(Nodes, SortedNodes),
  242   member(node(_Cost, Node), SortedNodes),
  243   (  Node = '$VAR'(Var)
  244   -> Id = Var
  245   ;  Id = Node
  246   ),
  247   (  var(Id)
  248   -> del_attr(Id, cost)
  249   ;  true
  250   ).
  251
  252compute_class_cost(Id-Nodes, Id-NewNodes) :-
  253   maplist(compute_node_cost, Nodes, NewNodes, NodeCosts),
  254   NodeCosts = [FirstCost | RestCosts],
  255   foldl([NodeCost, Cost, MinCost]>>
  256         {MinCost = min(NodeCost, Cost)},
  257         RestCosts, FirstCost, ClassCost),
  258   get_attr(Id, cost, ClassCost).
  259compute_node_cost(node(Offset, Node), node(Cost, Node), Cost) :-
  260   (  Node = '$VAR'(_)
  261   -> Cost = Offset
  262   ;  is_dict(Node)
  263   -> dict_pairs(Node, _, Pairs),
  264      pairs_keys_values(Pairs, _, Ids),
  265      foldl([Id, In, Out]>>(
  266         get_attr(Id, cost, IdCost),
  267         {Out = In + IdCost}
  268      ), Ids, 0, CCost),
  269      { Cost = CCost + Offset }
  270   ;  compound(Node)
  271   -> Node =.. [_ | Ids],
  272      foldl([Id, In, Out]>>(
  273         get_attr(Id, cost, IdCost),
  274         {Out = In + IdCost}
  275      ), Ids, 0, CCost),
  276      { Cost = CCost + Offset }
  277   ;  Cost = Offset
  278   )