elipl/lib/tilly/bdd/ops.ex
Kacper Marzecki 748f87636a checkpoint
checkpoint

failing test

after fixing tests

checkpoint

checkpoint

checkpoint

re-work

asd

checkpoint

checkpoint

checkpoint

mix proj

checkpoint mix

first parser impl

checkpoint

fix tests

re-org parser

checkpoint strings

fix multiline strings

tuples

checkpoint maps

checkpoint

checkpoint

checkpoint

checkpoint

fix weird eof expression parse error

checkpoint before typing

checkpoint

checpoint

checkpoint

checkpoint

checkpoint ids in primitive types

checkpoint

checkpoint

fix tests

initial annotation

checkpoint

checkpoint

checkpoint

union subtyping

conventions

refactor - split typer

typing tuples

checkpoint test refactor

checkpoint test refactor

parsing atoms

checkpoint atoms

wip lists

checkpoint typing lists

checkopint

checkpoint

wip fixing

correct list typing

map discussion

checkpoint map basic typing

fix tests checkpoint

checkpoint

checkpoint

checkpoint

fix condition typing

fix literal keys in map types

checkpoint union types

checkpoint union type

checkpoint row types discussion & bidirectional typecheck

checkpoint

basic lambdas

checkpoint lambdas typing application

wip function application

checkpoint

checkpoint

checkpoint cduce

checkpoint

checkpoint

checkpoint

checkpoint

checkpoint

checkpoint

checkpoint
2025-06-13 23:48:07 +02:00

348 lines
13 KiB
Elixir

defmodule Tilly.BDD.Ops do
@moduledoc """
Generic BDD algorithms and smart constructors.
These functions operate on BDD node IDs and use an `ops_module`
to dispatch to specific element/leaf operations.
"""
alias Tilly.BDD
alias Tilly.BDD.Node
@doc """
Smart constructor for leaf nodes.
Uses the `ops_module` to test if the `leaf_value` corresponds to
an empty or universal set for that module.
Returns `{new_typing_ctx, node_id}`.
"""
def leaf(typing_ctx, leaf_value, ops_module) do
case apply(ops_module, :test_leaf_value, [leaf_value]) do
:empty ->
{typing_ctx, BDD.false_node_id()}
:full ->
{typing_ctx, BDD.true_node_id()}
:other ->
logical_structure = Node.mk_leaf(leaf_value)
BDD.get_or_intern_node(typing_ctx, logical_structure, ops_module)
end
end
@doc """
Smart constructor for split nodes. Applies simplification rules.
Returns `{new_typing_ctx, node_id}`.
"""
def split(typing_ctx, element, p_id, i_id, n_id, ops_module) do
# Apply simplification rules. Order can be important.
cond do
# If ignore and negative children are False, result is positive child.
BDD.is_false_node?(typing_ctx, i_id) and
BDD.is_false_node?(typing_ctx, n_id) ->
{typing_ctx, p_id}
# If ignore child is True, the whole BDD is True.
BDD.is_true_node?(typing_ctx, i_id) ->
{typing_ctx, BDD.true_node_id()}
# If positive and negative children are the same.
p_id == n_id ->
if p_id == i_id do
# All three children are identical.
{typing_ctx, p_id}
else
# Result is p_id (or n_id) unioned with i_id.
# This creates a potential mutual recursion with union_bdds
# which needs to be handled by the apply_op cache.
union_bdds(typing_ctx, p_id, i_id)
end
# TODO: Add more simplification rules from CDuce bdd.ml `split` as needed.
# e.g. if p=T, i=F, n=T -> True
# e.g. if p=F, i=F, n=T -> not(x) relative to this BDD's element universe (complex)
true ->
# No further simplification rule applied, intern the node.
logical_structure = Node.mk_split(element, p_id, i_id, n_id)
BDD.get_or_intern_node(typing_ctx, logical_structure, ops_module)
end
end
@doc """
Computes the union of two BDDs.
Returns `{new_typing_ctx, result_node_id}`.
"""
def union_bdds(typing_ctx, bdd1_id, bdd2_id) do
apply_op(typing_ctx, :union, bdd1_id, bdd2_id)
end
@doc """
Computes the intersection of two BDDs.
Returns `{new_typing_ctx, result_node_id}`.
"""
def intersection_bdds(typing_ctx, bdd1_id, bdd2_id) do
apply_op(typing_ctx, :intersection, bdd1_id, bdd2_id)
end
@doc """
Computes the negation of a BDD.
Returns `{new_typing_ctx, result_node_id}`.
"""
def negation_bdd(typing_ctx, bdd_id) do
# The second argument to apply_op is nil for unary operations like negation.
apply_op(typing_ctx, :negation, bdd_id, nil)
end
@doc """
Computes the difference of two BDDs (bdd1 - bdd2).
Returns `{new_typing_ctx, result_node_id}`.
Implemented as `bdd1 INTERSECTION (NEGATION bdd2)`.
"""
def difference_bdd(typing_ctx, bdd1_id, bdd2_id) do
{ctx, neg_bdd2_id} = negation_bdd(typing_ctx, bdd2_id)
intersection_bdds(ctx, bdd1_id, neg_bdd2_id)
end
# Internal function to handle actual BDD operations, bypassing cache for direct calls.
defp do_union_bdds(typing_ctx, bdd1_id, bdd2_id) do
# Ensure canonical order for commutative operations if not handled by apply_op key
# For simplicity, apply_op will handle canonical key generation.
# 1. Handle terminal cases
cond do
bdd1_id == bdd2_id -> {typing_ctx, bdd1_id}
BDD.is_true_node?(typing_ctx, bdd1_id) -> {typing_ctx, BDD.true_node_id()}
BDD.is_true_node?(typing_ctx, bdd2_id) -> {typing_ctx, BDD.true_node_id()}
BDD.is_false_node?(typing_ctx, bdd1_id) -> {typing_ctx, bdd2_id}
BDD.is_false_node?(typing_ctx, bdd2_id) -> {typing_ctx, bdd1_id}
true -> perform_union(typing_ctx, bdd1_id, bdd2_id)
end
end
defp perform_union(typing_ctx, bdd1_id, bdd2_id) do
%{structure: s1, ops_module: ops_m1} = BDD.get_node_data(typing_ctx, bdd1_id)
%{structure: s2, ops_module: ops_m2} = BDD.get_node_data(typing_ctx, bdd2_id)
# For now, assume ops_modules must match for simplicity.
# Production systems might need more complex logic or type errors here.
if ops_m1 != ops_m2 do
raise ArgumentError,
"Cannot union BDDs with different ops_modules: #{inspect(ops_m1)} and #{inspect(ops_m2)}"
end
ops_m = ops_m1
case {s1, s2} do
# Both are leaves
{{:leaf, v1}, {:leaf, v2}} ->
new_leaf_val = apply(ops_m, :union_leaves, [typing_ctx, v1, v2])
leaf(typing_ctx, new_leaf_val, ops_m)
# s1 is split, s2 is leaf
{{:split, x1, p1_id, i1_id, n1_id}, {:leaf, _v2}} ->
# CDuce: split x1 p1 (i1 ++ b) n1
{ctx, new_i1_id} = union_bdds(typing_ctx, i1_id, bdd2_id)
split(ctx, x1, p1_id, new_i1_id, n1_id, ops_m)
# s1 is leaf, s2 is split
{{:leaf, _v1}, {:split, x2, p2_id, i2_id, n2_id}} ->
# CDuce: split x2 p2 (i2 ++ a) n2 (symmetric to above)
{ctx, new_i2_id} = union_bdds(typing_ctx, i2_id, bdd1_id)
split(ctx, x2, p2_id, new_i2_id, n2_id, ops_m)
# Both are splits
{{:split, x1, p1_id, i1_id, n1_id}, {:split, x2, p2_id, i2_id, n2_id}} ->
# Compare elements using the ops_module
comp_result = apply(ops_m, :compare_elements, [x1, x2])
cond do
comp_result == :eq ->
# Elements are equal, merge children
{ctx0, new_p_id} = union_bdds(typing_ctx, p1_id, p2_id)
{ctx1, new_i_id} = union_bdds(ctx0, i1_id, i2_id)
{ctx2, new_n_id} = union_bdds(ctx1, n1_id, n2_id)
split(ctx2, x1, new_p_id, new_i_id, new_n_id, ops_m)
comp_result == :lt ->
# x1 < x2
# CDuce: split x1 p1 (i1 ++ b) n1
{ctx, new_i1_id} = union_bdds(typing_ctx, i1_id, bdd2_id)
split(ctx, x1, p1_id, new_i1_id, n1_id, ops_m)
comp_result == :gt ->
# x1 > x2
# CDuce: split x2 p2 (i2 ++ a) n2
{ctx, new_i2_id} = union_bdds(typing_ctx, i2_id, bdd1_id)
split(ctx, x2, p2_id, new_i2_id, n2_id, ops_m)
end
end
end
defp do_intersection_bdds(typing_ctx, bdd1_id, bdd2_id) do
# Canonical order handled by apply_op key generation.
# Fast path for disjoint singleton BDDs
case {BDD.get_node_data(typing_ctx, bdd1_id), BDD.get_node_data(typing_ctx, bdd2_id)} do
{%{structure: {:split, x1, t, f, f}, ops_module: m},
%{structure: {:split, x2, t, f, f}, ops_module: m}}
when x1 != x2 ->
{typing_ctx, BDD.false_node_id()}
_ ->
# 1. Handle terminal cases
cond do
bdd1_id == bdd2_id -> {typing_ctx, bdd1_id}
BDD.is_false_node?(typing_ctx, bdd1_id) -> {typing_ctx, BDD.false_node_id()}
BDD.is_false_node?(typing_ctx, bdd2_id) -> {typing_ctx, BDD.false_node_id()}
BDD.is_true_node?(typing_ctx, bdd1_id) -> {typing_ctx, bdd2_id}
BDD.is_true_node?(typing_ctx, bdd2_id) -> {typing_ctx, bdd1_id}
true -> perform_intersection(typing_ctx, bdd1_id, bdd2_id)
end
end
end
defp perform_intersection(typing_ctx, bdd1_id, bdd2_id) do
%{structure: s1, ops_module: ops_m1} = BDD.get_node_data(typing_ctx, bdd1_id)
%{structure: s2, ops_module: ops_m2} = BDD.get_node_data(typing_ctx, bdd2_id)
if ops_m1 != ops_m2 do
raise ArgumentError,
"Cannot intersect BDDs with different ops_modules: #{inspect(ops_m1)} and #{inspect(ops_m2)}"
end
ops_m = ops_m1
case {s1, s2} do
# Both are leaves
{{:leaf, v1}, {:leaf, v2}} ->
new_leaf_val = apply(ops_m, :intersection_leaves, [typing_ctx, v1, v2])
leaf(typing_ctx, new_leaf_val, ops_m)
# s1 is split, s2 is leaf
{{:split, x1, p1_id, i1_id, n1_id}, {:leaf, _v2}} ->
{ctx0, new_p1_id} = intersection_bdds(typing_ctx, p1_id, bdd2_id)
{ctx1, new_i1_id} = intersection_bdds(ctx0, i1_id, bdd2_id)
{ctx2, new_n1_id} = intersection_bdds(ctx1, n1_id, bdd2_id)
split(ctx2, x1, new_p1_id, new_i1_id, new_n1_id, ops_m)
# s1 is leaf, s2 is split
{{:leaf, _v1}, {:split, x2, p2_id, i2_id, n2_id}} ->
{ctx0, new_p2_id} = intersection_bdds(typing_ctx, bdd1_id, p2_id)
{ctx1, new_i2_id} = intersection_bdds(ctx0, bdd1_id, i2_id)
{ctx2, new_n2_id} = intersection_bdds(ctx1, bdd1_id, n2_id)
split(ctx2, x2, new_p2_id, new_i2_id, new_n2_id, ops_m)
# Both are splits
{{:split, x1, p1_id, i1_id, n1_id}, {:split, x2, p2_id, i2_id, n2_id}} ->
comp_result = apply(ops_m, :compare_elements, [x1, x2])
cond do
comp_result == :eq ->
# CDuce: split x1 ((p1**(p2++i2))++(p2**i1)) (i1**i2) ((n1**(n2++i2))++(n2**i1))
{ctx0, p2_u_i2} = union_bdds(typing_ctx, p2_id, i2_id)
{ctx1, n2_u_i2} = union_bdds(ctx0, n2_id, i2_id)
{ctx2, p1_i_p2ui2} = intersection_bdds(ctx1, p1_id, p2_u_i2)
{ctx3, p2_i_i1} = intersection_bdds(ctx2, p2_id, i1_id)
{ctx4, new_p_id} = union_bdds(ctx3, p1_i_p2ui2, p2_i_i1)
{ctx5, new_i_id} = intersection_bdds(ctx4, i1_id, i2_id)
{ctx6, n1_i_n2ui2} = intersection_bdds(ctx5, n1_id, n2_u_i2)
{ctx7, n2_i_i1} = intersection_bdds(ctx6, n2_id, i1_id)
{ctx8, new_n_id} = union_bdds(ctx7, n1_i_n2ui2, n2_i_i1)
split(ctx8, x1, new_p_id, new_i_id, new_n_id, ops_m)
# x1 < x2
comp_result == :lt ->
# CDuce: split x1 (p1 ** b) (i1 ** b) (n1 ** b) where b is bdd2
{ctx0, new_p1_id} = intersection_bdds(typing_ctx, p1_id, bdd2_id)
{ctx1, new_i1_id} = intersection_bdds(ctx0, i1_id, bdd2_id)
{ctx2, new_n1_id} = intersection_bdds(ctx1, n1_id, bdd2_id)
split(ctx2, x1, new_p1_id, new_i1_id, new_n1_id, ops_m)
# x1 > x2
comp_result == :gt ->
# CDuce: split x2 (a ** p2) (a ** i2) (a ** n2) where a is bdd1
{ctx0, new_p2_id} = intersection_bdds(typing_ctx, bdd1_id, p2_id)
{ctx1, new_i2_id} = intersection_bdds(ctx0, bdd1_id, i2_id)
{ctx2, new_n2_id} = intersection_bdds(ctx1, bdd1_id, n2_id)
split(ctx2, x2, new_p2_id, new_i2_id, new_n2_id, ops_m)
end
end
end
defp do_negation_bdd(typing_ctx, bdd_id) do
# 1. Handle terminal cases
cond do
BDD.is_true_node?(typing_ctx, bdd_id) -> {typing_ctx, BDD.false_node_id()}
BDD.is_false_node?(typing_ctx, bdd_id) -> {typing_ctx, BDD.true_node_id()}
true -> perform_negation(typing_ctx, bdd_id)
end
end
defp perform_negation(typing_ctx, bdd_id) do
%{structure: s, ops_module: ops_m} = BDD.get_node_data(typing_ctx, bdd_id)
case s do
# Leaf
{:leaf, v} ->
neg_leaf_val = apply(ops_m, :negation_leaf, [typing_ctx, v])
leaf(typing_ctx, neg_leaf_val, ops_m)
# Split
{:split, x, p_id, i_id, n_id} ->
# CDuce: ~~i ** split x (~~p) (~~(p++n)) (~~n)
{ctx0, neg_i_id} = negation_bdd(typing_ctx, i_id)
{ctx1, neg_p_id} = negation_bdd(ctx0, p_id)
{ctx2, p_u_n_id} = union_bdds(ctx1, p_id, n_id)
{ctx3, neg_p_u_n_id} = negation_bdd(ctx2, p_u_n_id)
{ctx4, neg_n_id} = negation_bdd(ctx3, n_id)
{ctx5, split_part_id} = split(ctx4, x, neg_p_id, neg_p_u_n_id, neg_n_id, ops_m)
intersection_bdds(ctx5, neg_i_id, split_part_id)
end
end
# --- Caching Wrapper for BDD Operations ---
defp apply_op(typing_ctx, op_key, bdd1_id, bdd2_id) do
cache_key = make_cache_key(op_key, bdd1_id, bdd2_id)
bdd_store = Map.get(typing_ctx, :bdd_store)
case Map.get(bdd_store.ops_cache, cache_key) do
nil ->
# Not in cache, compute it
{new_typing_ctx, result_id} =
case op_key do
:union -> do_union_bdds(typing_ctx, bdd1_id, bdd2_id)
:intersection -> do_intersection_bdds(typing_ctx, bdd1_id, bdd2_id)
# bdd2_id is nil here
:negation -> do_negation_bdd(typing_ctx, bdd1_id)
_ -> raise "Unsupported op_key: #{op_key}"
end
# Store in cache
# IMPORTANT: Use new_typing_ctx (from the operation) to get the potentially updated bdd_store
current_bdd_store_after_op = Map.get(new_typing_ctx, :bdd_store)
new_ops_cache = Map.put(current_bdd_store_after_op.ops_cache, cache_key, result_id)
final_bdd_store_with_cache = %{current_bdd_store_after_op | ops_cache: new_ops_cache}
# And put this updated bdd_store back into new_typing_ctx
final_typing_ctx_with_cache =
Map.put(new_typing_ctx, :bdd_store, final_bdd_store_with_cache)
{final_typing_ctx_with_cache, result_id}
cached_result_id ->
{typing_ctx, cached_result_id}
end
end
defp make_cache_key(:negation, bdd_id, nil), do: {:negation, bdd_id}
defp make_cache_key(op_key, id1, id2) when op_key in [:union, :intersection] do
# Canonical order for commutative binary operations
if id1 <= id2, do: {op_key, id1, id2}, else: {op_key, id2, id1}
end
defp make_cache_key(op_key, id1, id2), do: {op_key, id1, id2}
end