defmodule Tilly.BDD.Ops do @moduledoc """ Generic BDD algorithms and smart constructors. These functions operate on BDD node IDs and use an `ops_module` to dispatch to specific element/leaf operations. """ alias Tilly.BDD alias Tilly.BDD.Node @doc """ Smart constructor for leaf nodes. Uses the `ops_module` to test if the `leaf_value` corresponds to an empty or universal set for that module. Returns `{new_typing_ctx, node_id}`. """ def leaf(typing_ctx, leaf_value, ops_module) do case apply(ops_module, :test_leaf_value, [leaf_value]) do :empty -> {typing_ctx, BDD.false_node_id()} :full -> {typing_ctx, BDD.true_node_id()} :other -> logical_structure = Node.mk_leaf(leaf_value) BDD.get_or_intern_node(typing_ctx, logical_structure, ops_module) end end @doc """ Smart constructor for split nodes. Applies simplification rules. Returns `{new_typing_ctx, node_id}`. """ def split(typing_ctx, element, p_id, i_id, n_id, ops_module) do # Apply simplification rules. Order can be important. cond do # If ignore and negative children are False, result is positive child. BDD.is_false_node?(typing_ctx, i_id) and BDD.is_false_node?(typing_ctx, n_id) -> {typing_ctx, p_id} # If ignore child is True, the whole BDD is True. BDD.is_true_node?(typing_ctx, i_id) -> {typing_ctx, BDD.true_node_id()} # If positive and negative children are the same. p_id == n_id -> if p_id == i_id do # All three children are identical. {typing_ctx, p_id} else # Result is p_id (or n_id) unioned with i_id. # This creates a potential mutual recursion with union_bdds # which needs to be handled by the apply_op cache. union_bdds(typing_ctx, p_id, i_id) end # TODO: Add more simplification rules from CDuce bdd.ml `split` as needed. # e.g. if p=T, i=F, n=T -> True # e.g. if p=F, i=F, n=T -> not(x) relative to this BDD's element universe (complex) true -> # No further simplification rule applied, intern the node. logical_structure = Node.mk_split(element, p_id, i_id, n_id) BDD.get_or_intern_node(typing_ctx, logical_structure, ops_module) end end @doc """ Computes the union of two BDDs. Returns `{new_typing_ctx, result_node_id}`. """ def union_bdds(typing_ctx, bdd1_id, bdd2_id) do apply_op(typing_ctx, :union, bdd1_id, bdd2_id) end @doc """ Computes the intersection of two BDDs. Returns `{new_typing_ctx, result_node_id}`. """ def intersection_bdds(typing_ctx, bdd1_id, bdd2_id) do apply_op(typing_ctx, :intersection, bdd1_id, bdd2_id) end @doc """ Computes the negation of a BDD. Returns `{new_typing_ctx, result_node_id}`. """ def negation_bdd(typing_ctx, bdd_id) do # The second argument to apply_op is nil for unary operations like negation. apply_op(typing_ctx, :negation, bdd_id, nil) end @doc """ Computes the difference of two BDDs (bdd1 - bdd2). Returns `{new_typing_ctx, result_node_id}`. Implemented as `bdd1 INTERSECTION (NEGATION bdd2)`. """ def difference_bdd(typing_ctx, bdd1_id, bdd2_id) do {ctx, neg_bdd2_id} = negation_bdd(typing_ctx, bdd2_id) intersection_bdds(ctx, bdd1_id, neg_bdd2_id) end # Internal function to handle actual BDD operations, bypassing cache for direct calls. defp do_union_bdds(typing_ctx, bdd1_id, bdd2_id) do # Ensure canonical order for commutative operations if not handled by apply_op key # For simplicity, apply_op will handle canonical key generation. # 1. Handle terminal cases cond do bdd1_id == bdd2_id -> {typing_ctx, bdd1_id} BDD.is_true_node?(typing_ctx, bdd1_id) -> {typing_ctx, BDD.true_node_id()} BDD.is_true_node?(typing_ctx, bdd2_id) -> {typing_ctx, BDD.true_node_id()} BDD.is_false_node?(typing_ctx, bdd1_id) -> {typing_ctx, bdd2_id} BDD.is_false_node?(typing_ctx, bdd2_id) -> {typing_ctx, bdd1_id} true -> perform_union(typing_ctx, bdd1_id, bdd2_id) end end defp perform_union(typing_ctx, bdd1_id, bdd2_id) do %{structure: s1, ops_module: ops_m1} = BDD.get_node_data(typing_ctx, bdd1_id) %{structure: s2, ops_module: ops_m2} = BDD.get_node_data(typing_ctx, bdd2_id) # For now, assume ops_modules must match for simplicity. # Production systems might need more complex logic or type errors here. if ops_m1 != ops_m2 do raise ArgumentError, "Cannot union BDDs with different ops_modules: #{inspect(ops_m1)} and #{inspect(ops_m2)}" end ops_m = ops_m1 case {s1, s2} do # Both are leaves {{:leaf, v1}, {:leaf, v2}} -> new_leaf_val = apply(ops_m, :union_leaves, [typing_ctx, v1, v2]) leaf(typing_ctx, new_leaf_val, ops_m) # s1 is split, s2 is leaf {{:split, x1, p1_id, i1_id, n1_id}, {:leaf, _v2}} -> # CDuce: split x1 p1 (i1 ++ b) n1 {ctx, new_i1_id} = union_bdds(typing_ctx, i1_id, bdd2_id) split(ctx, x1, p1_id, new_i1_id, n1_id, ops_m) # s1 is leaf, s2 is split {{:leaf, _v1}, {:split, x2, p2_id, i2_id, n2_id}} -> # CDuce: split x2 p2 (i2 ++ a) n2 (symmetric to above) {ctx, new_i2_id} = union_bdds(typing_ctx, i2_id, bdd1_id) split(ctx, x2, p2_id, new_i2_id, n2_id, ops_m) # Both are splits {{:split, x1, p1_id, i1_id, n1_id}, {:split, x2, p2_id, i2_id, n2_id}} -> # Compare elements using the ops_module comp_result = apply(ops_m, :compare_elements, [x1, x2]) cond do comp_result == :eq -> # Elements are equal, merge children {ctx0, new_p_id} = union_bdds(typing_ctx, p1_id, p2_id) {ctx1, new_i_id} = union_bdds(ctx0, i1_id, i2_id) {ctx2, new_n_id} = union_bdds(ctx1, n1_id, n2_id) split(ctx2, x1, new_p_id, new_i_id, new_n_id, ops_m) comp_result == :lt -> # x1 < x2 # CDuce: split x1 p1 (i1 ++ b) n1 {ctx, new_i1_id} = union_bdds(typing_ctx, i1_id, bdd2_id) split(ctx, x1, p1_id, new_i1_id, n1_id, ops_m) comp_result == :gt -> # x1 > x2 # CDuce: split x2 p2 (i2 ++ a) n2 {ctx, new_i2_id} = union_bdds(typing_ctx, i2_id, bdd1_id) split(ctx, x2, p2_id, new_i2_id, n2_id, ops_m) end end end defp do_intersection_bdds(typing_ctx, bdd1_id, bdd2_id) do # Canonical order handled by apply_op key generation. # Fast path for disjoint singleton BDDs case {BDD.get_node_data(typing_ctx, bdd1_id), BDD.get_node_data(typing_ctx, bdd2_id)} do {%{structure: {:split, x1, t, f, f}, ops_module: m}, %{structure: {:split, x2, t, f, f}, ops_module: m}} when x1 != x2 -> {typing_ctx, BDD.false_node_id()} _ -> # 1. Handle terminal cases cond do bdd1_id == bdd2_id -> {typing_ctx, bdd1_id} BDD.is_false_node?(typing_ctx, bdd1_id) -> {typing_ctx, BDD.false_node_id()} BDD.is_false_node?(typing_ctx, bdd2_id) -> {typing_ctx, BDD.false_node_id()} BDD.is_true_node?(typing_ctx, bdd1_id) -> {typing_ctx, bdd2_id} BDD.is_true_node?(typing_ctx, bdd2_id) -> {typing_ctx, bdd1_id} true -> perform_intersection(typing_ctx, bdd1_id, bdd2_id) end end end defp perform_intersection(typing_ctx, bdd1_id, bdd2_id) do %{structure: s1, ops_module: ops_m1} = BDD.get_node_data(typing_ctx, bdd1_id) %{structure: s2, ops_module: ops_m2} = BDD.get_node_data(typing_ctx, bdd2_id) if ops_m1 != ops_m2 do raise ArgumentError, "Cannot intersect BDDs with different ops_modules: #{inspect(ops_m1)} and #{inspect(ops_m2)}" end ops_m = ops_m1 case {s1, s2} do # Both are leaves {{:leaf, v1}, {:leaf, v2}} -> new_leaf_val = apply(ops_m, :intersection_leaves, [typing_ctx, v1, v2]) leaf(typing_ctx, new_leaf_val, ops_m) # s1 is split, s2 is leaf {{:split, x1, p1_id, i1_id, n1_id}, {:leaf, _v2}} -> {ctx0, new_p1_id} = intersection_bdds(typing_ctx, p1_id, bdd2_id) {ctx1, new_i1_id} = intersection_bdds(ctx0, i1_id, bdd2_id) {ctx2, new_n1_id} = intersection_bdds(ctx1, n1_id, bdd2_id) split(ctx2, x1, new_p1_id, new_i1_id, new_n1_id, ops_m) # s1 is leaf, s2 is split {{:leaf, _v1}, {:split, x2, p2_id, i2_id, n2_id}} -> {ctx0, new_p2_id} = intersection_bdds(typing_ctx, bdd1_id, p2_id) {ctx1, new_i2_id} = intersection_bdds(ctx0, bdd1_id, i2_id) {ctx2, new_n2_id} = intersection_bdds(ctx1, bdd1_id, n2_id) split(ctx2, x2, new_p2_id, new_i2_id, new_n2_id, ops_m) # Both are splits {{:split, x1, p1_id, i1_id, n1_id}, {:split, x2, p2_id, i2_id, n2_id}} -> comp_result = apply(ops_m, :compare_elements, [x1, x2]) cond do comp_result == :eq -> # CDuce: split x1 ((p1**(p2++i2))++(p2**i1)) (i1**i2) ((n1**(n2++i2))++(n2**i1)) {ctx0, p2_u_i2} = union_bdds(typing_ctx, p2_id, i2_id) {ctx1, n2_u_i2} = union_bdds(ctx0, n2_id, i2_id) {ctx2, p1_i_p2ui2} = intersection_bdds(ctx1, p1_id, p2_u_i2) {ctx3, p2_i_i1} = intersection_bdds(ctx2, p2_id, i1_id) {ctx4, new_p_id} = union_bdds(ctx3, p1_i_p2ui2, p2_i_i1) {ctx5, new_i_id} = intersection_bdds(ctx4, i1_id, i2_id) {ctx6, n1_i_n2ui2} = intersection_bdds(ctx5, n1_id, n2_u_i2) {ctx7, n2_i_i1} = intersection_bdds(ctx6, n2_id, i1_id) {ctx8, new_n_id} = union_bdds(ctx7, n1_i_n2ui2, n2_i_i1) split(ctx8, x1, new_p_id, new_i_id, new_n_id, ops_m) # x1 < x2 comp_result == :lt -> # CDuce: split x1 (p1 ** b) (i1 ** b) (n1 ** b) where b is bdd2 {ctx0, new_p1_id} = intersection_bdds(typing_ctx, p1_id, bdd2_id) {ctx1, new_i1_id} = intersection_bdds(ctx0, i1_id, bdd2_id) {ctx2, new_n1_id} = intersection_bdds(ctx1, n1_id, bdd2_id) split(ctx2, x1, new_p1_id, new_i1_id, new_n1_id, ops_m) # x1 > x2 comp_result == :gt -> # CDuce: split x2 (a ** p2) (a ** i2) (a ** n2) where a is bdd1 {ctx0, new_p2_id} = intersection_bdds(typing_ctx, bdd1_id, p2_id) {ctx1, new_i2_id} = intersection_bdds(ctx0, bdd1_id, i2_id) {ctx2, new_n2_id} = intersection_bdds(ctx1, bdd1_id, n2_id) split(ctx2, x2, new_p2_id, new_i2_id, new_n2_id, ops_m) end end end defp do_negation_bdd(typing_ctx, bdd_id) do # 1. Handle terminal cases cond do BDD.is_true_node?(typing_ctx, bdd_id) -> {typing_ctx, BDD.false_node_id()} BDD.is_false_node?(typing_ctx, bdd_id) -> {typing_ctx, BDD.true_node_id()} true -> perform_negation(typing_ctx, bdd_id) end end defp perform_negation(typing_ctx, bdd_id) do %{structure: s, ops_module: ops_m} = BDD.get_node_data(typing_ctx, bdd_id) case s do # Leaf {:leaf, v} -> neg_leaf_val = apply(ops_m, :negation_leaf, [typing_ctx, v]) leaf(typing_ctx, neg_leaf_val, ops_m) # Split {:split, x, p_id, i_id, n_id} -> # CDuce: ~~i ** split x (~~p) (~~(p++n)) (~~n) {ctx0, neg_i_id} = negation_bdd(typing_ctx, i_id) {ctx1, neg_p_id} = negation_bdd(ctx0, p_id) {ctx2, p_u_n_id} = union_bdds(ctx1, p_id, n_id) {ctx3, neg_p_u_n_id} = negation_bdd(ctx2, p_u_n_id) {ctx4, neg_n_id} = negation_bdd(ctx3, n_id) {ctx5, split_part_id} = split(ctx4, x, neg_p_id, neg_p_u_n_id, neg_n_id, ops_m) intersection_bdds(ctx5, neg_i_id, split_part_id) end end # --- Caching Wrapper for BDD Operations --- defp apply_op(typing_ctx, op_key, bdd1_id, bdd2_id) do cache_key = make_cache_key(op_key, bdd1_id, bdd2_id) bdd_store = Map.get(typing_ctx, :bdd_store) case Map.get(bdd_store.ops_cache, cache_key) do nil -> # Not in cache, compute it {new_typing_ctx, result_id} = case op_key do :union -> do_union_bdds(typing_ctx, bdd1_id, bdd2_id) :intersection -> do_intersection_bdds(typing_ctx, bdd1_id, bdd2_id) # bdd2_id is nil here :negation -> do_negation_bdd(typing_ctx, bdd1_id) _ -> raise "Unsupported op_key: #{op_key}" end # Store in cache # IMPORTANT: Use new_typing_ctx (from the operation) to get the potentially updated bdd_store current_bdd_store_after_op = Map.get(new_typing_ctx, :bdd_store) new_ops_cache = Map.put(current_bdd_store_after_op.ops_cache, cache_key, result_id) final_bdd_store_with_cache = %{current_bdd_store_after_op | ops_cache: new_ops_cache} # And put this updated bdd_store back into new_typing_ctx final_typing_ctx_with_cache = Map.put(new_typing_ctx, :bdd_store, final_bdd_store_with_cache) {final_typing_ctx_with_cache, result_id} cached_result_id -> {typing_ctx, cached_result_id} end end defp make_cache_key(:negation, bdd_id, nil), do: {:negation, bdd_id} defp make_cache_key(op_key, id1, id2) when op_key in [:union, :intersection] do # Canonical order for commutative binary operations if id1 <= id2, do: {op_key, id1, id2}, else: {op_key, id2, id1} end defp make_cache_key(op_key, id1, id2), do: {op_key, id1, id2} end