Skip to content

Instantly share code, notes, and snippets.

@pkhuong
Last active November 8, 2024 15:56
Show Gist options
  • Select an option

  • Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.

Select an option

Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.

Revisions

  1. pkhuong revised this gist Aug 2, 2023. 2 changed files with 102 additions and 48 deletions.
    75 changes: 51 additions & 24 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -53,15 +53,17 @@ to handle aggregation queries like the following

    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ...
    >>> def sum_odd_or_even_skus(mod_two):
    ... @map_reduce.over(id_skus, Sum())
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ... @map_reduce.over(sku_costs, Min(0))
    ... def min_cost(sku_cost):
    ... if sku_cost[0] == sku:
    ... return Min(sku_cost[1])
    ... return Sum(min_cost)
    ... return count_if_mod_two
    ...

    with linear scaling in the length of `id_skus` and `sku_costs`, and
    @@ -71,10 +73,10 @@ At a small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20
    0.0002319812774658203
    0.0007169246673583984
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50
    0.0002617835998535156
    0.0002627372741699219

    As we increase the scale by 1000x for both input lists, the runtime
    scales (sub) linearly for the first query, and is unchanged for the
    @@ -84,28 +86,28 @@ second:
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20000
    0.10173487663269043
    0.09455370903015137
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50000
    0.0002841949462890625
    0.00025773048400878906

    This still pretty much holds up when we multiply by another factor of 100:

    >>> id_skus = id_skus * 100
    >>> sku_costs = sku_costs * 100
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    2000000
    6.7300660610198975
    6.946590185165405
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    5000000
    0.0002579689025878906
    0.00025200843811035156

    The magic behind the curtains is memoisation (unsurprisingly), but
    a special implementation that can share work for similar closures:
    the memoisation key consists of the function *without closed over
    bindings* and the call arguments, while the memoised value is a
    data structure from the tuple of closed over values to the
    `map_reduce` output.
    `map_reduce_over` output.

    This concrete representation of the function as a branching program
    is the core of Yannakakis's algorithm: we'll iterate over each
    @@ -522,7 +524,7 @@ results for identical keys together.

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding results, for all non-zero results.
    `value` and the corresponding result, for all non-zero results.

    This essentially turns `function()` into a branching program on
    `values`.
    @@ -609,9 +611,10 @@ level: it's essential that merge operators be commutative and
    associative, so isolating the merge logic in dedicated classes
    makes sense to me.

    This file defines a single mergeable value type, `Sum`, but we
    could have different ones, e.g., hyperloglog unique counts, or
    streaming statistical moments.
    This file defines a two trivial mergeable value type, `Sum` and
    `Min`, but we could have different ones, e.g., hyperloglog unique
    counts, or streaming statistical moments... or even a list of row
    ids.


    class Sum:
    @@ -659,7 +662,11 @@ variables). That's actually reasonable because we don't expect too
    many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
    that are also exponential in the number of dimensions... but with
    a base of log(n) instead of 2).


    A real implementation could maybe save work by memoising merged
    results for internal subtrees... it's tempting to broadcast
    wildcard values to the individual keyed entries, but I think that
    might explode the time complexity of our pre-computation phase.

    class NestedDictLevel:
    """One level in a nested dictionary index. We may have a value
    @@ -805,6 +812,16 @@ as an instance of bottom-up dynamic programming.
    return dst


    def _extractValues(accumulator):
    if accumulator is None:
    return None
    if isinstance(accumulator, tuple):
    return tuple(item.value for item in accumulator)
    if isinstance(accumulator, list):
    return list(item.value for item in accumulator)

    return accumulator.value

    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    @@ -832,7 +849,7 @@ as an instance of bottom-up dynamic programming.
    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable, initialValue):
    def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True):
    """Returns the result of merging `map(function, inputIterable)`
    into `initialValue`.

    @@ -861,24 +878,24 @@ as an instance of bottom-up dynamic programming.
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data, Sum()).value
    >>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data, Sum()).value
    >>> map_reduce(count_eql(2), data)
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0))
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ... return map_reduce(count_if_mod_two, id_skus, Sum())
    >>> sum_odd_or_even_skus(0)
    20
    >>> sum_odd_or_even_skus(1)
    @@ -892,8 +909,18 @@ as an instance of bottom-up dynamic programming.
    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    AGGREGATE_CACHE[code, inputIterable].visit(closure, lambda result: _merge(initialValue, result))
    return initialValue

    acc = [initialValue]
    def visitor(result):
    acc[0] = _merge(acc[0], result)

    AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor)
    return _extractValues(acc[0]) if extractResult else acc[0]


    map_reduce.over = \
    lambda inputIterable, initialValue=None, *, extractResult=True: \
    lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult)


    if __name__ == "__main__":
    75 changes: 51 additions & 24 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -51,15 +51,17 @@
    ##
    ## >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    ## >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    ## >>> def sku_min_cost(sku):
    ## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ## ...
    ## >>> def sum_odd_or_even_skus(mod_two):
    ## ... @map_reduce.over(id_skus, Sum())
    ## ... def count_if_mod_two(id_sku):
    ## ... id, sku = id_sku
    ## ... if id % 2 == mod_two:
    ## ... return Sum(sku_min_cost(sku))
    ## ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ## ... @map_reduce.over(sku_costs, Min(0))
    ## ... def min_cost(sku_cost):
    ## ... if sku_cost[0] == sku:
    ## ... return Min(sku_cost[1])
    ## ... return Sum(min_cost)
    ## ... return count_if_mod_two
    ## ...
    ##
    ## with linear scaling in the length of `id_skus` and `sku_costs`, and
    @@ -69,10 +71,10 @@
    ##
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 20
    ## 0.0002319812774658203
    ## 0.0007169246673583984
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 50
    ## 0.0002617835998535156
    ## 0.0002627372741699219
    ##
    ## As we increase the scale by 1000x for both input lists, the runtime
    ## scales (sub) linearly for the first query, and is unchanged for the
    @@ -82,28 +84,28 @@
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 20000
    ## 0.10173487663269043
    ## 0.09455370903015137
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 50000
    ## 0.0002841949462890625
    ## 0.00025773048400878906
    ##
    ## This still pretty much holds up when we multiply by another factor of 100:
    ##
    ## >>> id_skus = id_skus * 100
    ## >>> sku_costs = sku_costs * 100
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 2000000
    ## 6.7300660610198975
    ## 6.946590185165405
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 5000000
    ## 0.0002579689025878906
    ## 0.00025200843811035156
    ##
    ## The magic behind the curtains is memoisation (unsurprisingly), but
    ## a special implementation that can share work for similar closures:
    ## the memoisation key consists of the function *without closed over
    ## bindings* and the call arguments, while the memoised value is a
    ## data structure from the tuple of closed over values to the
    ## `map_reduce` output.
    ## `map_reduce_over` output.
    ##
    ## This concrete representation of the function as a branching program
    ## is the core of Yannakakis's algorithm: we'll iterate over each
    @@ -518,7 +520,7 @@ def enumerate_opaque_values(function, values):

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding results, for all non-zero results.
    `value` and the corresponding result, for all non-zero results.

    This essentially turns `function()` into a branching program on
    `values`.
    @@ -605,9 +607,10 @@ def enumerate_supporting_values(function, args):
    ## associative, so isolating the merge logic in dedicated classes
    ## makes sense to me.
    ##
    ## This file defines a single mergeable value type, `Sum`, but we
    ## could have different ones, e.g., hyperloglog unique counts, or
    ## streaming statistical moments.
    ## This file defines a two trivial mergeable value type, `Sum` and
    ## `Min`, but we could have different ones, e.g., hyperloglog unique
    ## counts, or streaming statistical moments... or even a list of row
    ## ids.


    class Sum:
    @@ -655,7 +658,11 @@ def merge(self, other):
    ## many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
    ## that are also exponential in the number of dimensions... but with
    ## a base of log(n) instead of 2).

    ##
    ## A real implementation could maybe save work by memoising merged
    ## results for internal subtrees... it's tempting to broadcast
    ## wildcard values to the individual keyed entries, but I think that
    ## might explode the time complexity of our pre-computation phase.

    class NestedDictLevel:
    """One level in a nested dictionary index. We may have a value
    @@ -799,6 +806,16 @@ def _merge(dst, update):
    return dst


    def _extractValues(accumulator):
    if accumulator is None:
    return None
    if isinstance(accumulator, tuple):
    return tuple(item.value for item in accumulator)
    if isinstance(accumulator, list):
    return list(item.value for item in accumulator)

    return accumulator.value

    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    @@ -826,7 +843,7 @@ def _precompute_map_reduce(function, depth, inputIterable):
    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable, initialValue):
    def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True):
    """Returns the result of merging `map(function, inputIterable)`
    into `initialValue`.

    @@ -855,24 +872,24 @@ def map_reduce(function, inputIterable, initialValue):
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data, Sum()).value
    >>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data, Sum()).value
    >>> map_reduce(count_eql(2), data)
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0))
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ... return map_reduce(count_if_mod_two, id_skus, Sum())
    >>> sum_odd_or_even_skus(0)
    20
    >>> sum_odd_or_even_skus(1)
    @@ -886,8 +903,18 @@ def map_reduce(function, inputIterable, initialValue):
    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    AGGREGATE_CACHE[code, inputIterable].visit(closure, lambda result: _merge(initialValue, result))
    return initialValue

    acc = [initialValue]
    def visitor(result):
    acc[0] = _merge(acc[0], result)

    AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor)
    return _extractValues(acc[0]) if extractResult else acc[0]


    map_reduce.over = \
    lambda inputIterable, initialValue=None, *, extractResult=True: \
    lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult)


    if __name__ == "__main__":
  2. pkhuong revised this gist Jul 25, 2023. No changes.
  3. pkhuong revised this gist Jul 25, 2023. 2 changed files with 226 additions and 214 deletions.
    220 changes: 113 additions & 107 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -54,51 +54,51 @@ to handle aggregation queries like the following
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ...
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ...

    with linear scaling in the length of `id_skus` and `sku_costs`, and
    caching for similar queries.

    At a small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20
    0.0012328624725341797
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    0.0002319812774658203
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50
    0.00023674964904785156
    0.0002617835998535156

    As we increase the scale by 1000x in both tuples, the runtime
    scales (sub) linearly for the first query, and is unchanged
    for the second:
    As we increase the scale by 1000x for both input lists, the runtime
    scales (sub) linearly for the first query, and is unchanged for the
    second:

    >>> id_skus = id_skus * 1000
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20000
    0.101287841796875
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    0.10173487663269043
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50000
    0.0002338886260986328
    0.0002841949462890625

    This still pretty much holds up when we multiply by another factor of 100:

    >>> id_skus = id_skus * 100
    >>> sku_costs = sku_costs * 100
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    2000000
    6.881294012069702
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    6.7300660610198975
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    5000000
    0.00024199485778808594
    0.0002579689025878906

    The magic behind the curtains is memoisation (unsurprisingly), but
    a special implementation that can share work for similar closures:
    @@ -117,7 +117,7 @@ ordering here, hence the group structure).

    The output data structure controls the join we can implement. We
    show how a simple stack of nested key-value mappings handles
    equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    equijoins, but a [k-d range tree](https://dl.acm.org/doi/10.1145/356789.356797)
    would handle inequalities, i.e., "theta" joins (the rest of the
    machinery already works in terms of less than/greater than
    constraints).
    @@ -241,7 +241,7 @@ constraints on all `OpaqueValue`s in our result data structure).
    Currently, we only support nested dictionaries, so each
    `OpaqueValue`s must be either fully unconstrained (wildcard that
    matches any value), or constrained to be exactly equal to a value.
    There's no reason we can't use k-d trees though, and it's not
    There's no reason we can't use k-d range trees though, and it's not
    harder to track a pair of bounds (lower and upper) than a set of
    inequalities, so we'll handle the general ordered `OpaqueValue`
    case.
    @@ -300,9 +300,8 @@ ground data structure to represent finitely supported functions.
    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name, index=0):
    def __init__(self, name):
    self.name = name
    self.index = index
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True
    @@ -440,7 +439,7 @@ ground data structure to represent finitely supported functions.
    return self.__cmp__(other) != 0

    # No other comparator because we don't index ranges
    # (no k-d tree).
    # (no range tree).

    # def __lt__(self, other):
    # return self.__cmp__(other) < 0
    @@ -523,15 +522,14 @@ results for identical keys together.

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding return value, for all non-zero
    values.
    `value` and the corresponding results, for all non-zero results.

    This essentially turns `function()` into a *not necessarily
    ordered* branching program on `values`.
    This essentially turns `function()` into a branching program on
    `values`.

    >>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [(((0, 0),), 1), (((0, 1), (1, 2)), 2)]
    [((0, None), 1), ((1, 2), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    @@ -540,7 +538,6 @@ results for identical keys together.
    value.reset()

    stackIndex = 0
    constraints = []

    def handle(value, other):
    nonlocal stackIndex
    @@ -554,7 +551,6 @@ results for identical keys together.
    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    constraints.append((value, other))
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    @@ -571,8 +567,7 @@ results for identical keys together.
    assert (
    value.definite() or value.indefinite()
    ), f"partially constrained {value} temporarily unsupported"
    assert value.indefinite() or any(key is value for key, _ in constraints)
    yield (tuple((key.index, other) for key, other in constraints),
    yield (tuple(key.value() if key.definite() else None for key in values),
    result)

    # Drop everything that was fully explored, then move the next
    @@ -596,10 +591,10 @@ results for identical keys together.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name, index) for index, name in enumerate(names)]
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
    @@ -655,77 +650,86 @@ In a real implementation, this data structure would host most of
    the complexity: it's the closest thing we have to indexes.

    For now, support equality *with ground value* as our only
    constraint. This means we can dispatch on the closed over values
    by order of appearance, with either a wildcard value (matches
    everything) or a hash map. At each internal level, the value is
    another `NestedDictLevel`. At the leaf, the value is the
    precomputed value.
    constraint. This means the only two cases we must look for at
    each level are an exact match, or a wildcard match.

    We do have to check for both cases at each level, so the worst-case
    complexity for lookups is exponential in the depth (number of join
    variables). That's actually reasonable because we don't expect too
    many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
    that are also exponential in the number of dimensions... but with
    a base of log(n) instead of 2).


    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything (leaf node), or a key-value dict for a specific
    index in the tuple key.
    """One level in a nested dictionary index. We may have a value
    for everything (leaf node), or a key-value dict *and a wildcard
    entry* for a specific index in the tuple key.
    """

    def __init__(self, indexValues, depth=0):
    assert depth <= len(indexValues)
    self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
    def __init__(self, depth):
    self.depth = depth
    self.value = None
    self.wildcard = None
    self.dict = dict()

    def get(self, keys, default):
    """Gets the value for `keys` in this level, or `default` if None."""
    assert self.value is not None or self.keyIndex is not None

    def visit(self, keys, visitor):
    """Passes the values for `keys` to `visitor`."""
    if self.value is not None:
    return self.value
    next = self.dict.get(keys[self.keyIndex], None)
    if next is None:
    return default
    return next.get(keys, default)
    assert self.wildcard is None and not self.dict
    visitor(self.value)

    if self.depth >= len(keys):
    return

    if self.wildcard is not None:
    self.wildcard.visit(keys, visitor)

    def set(self, indexValues, mergeFunction, depth=0):
    next = self.dict.get(keys[self.depth], None)
    if next is not None:
    next.visit(keys, visitor)

    def set(self, keys, mergeFunction, depth=0):
    """Sets the value for `keys` in this level."""
    assert depth <= len(indexValues)
    if depth == len(indexValues): # Leaf
    assert depth <= len(keys)
    if depth == len(keys): # Leaf
    self.value = mergeFunction(self.value)
    return

    index, value = indexValues[depth]
    assert self.keyIndex == index
    if value not in self.dict:
    self.dict[value] = NestedDictLevel(indexValues, depth + 1)
    self.dict[value].set(indexValues, mergeFunction, depth + 1)
    assert self.depth == depth
    key = keys[depth]

    if key is None:
    if self.wildcard is None:
    self.wildcard = NestedDictLevel(depth + 1)
    dst = self.wildcard
    else:
    dst = self.dict.get(key)
    if dst is None:
    dst = NestedDictLevel(depth + 1)
    self.dict[key] = dst
    dst.set(keys, mergeFunction, depth + 1)


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each `NestedDictLevel` handles a different level. We
    don't test each index in the tuple in fixed order to avoid the
    combinatorial explosion that can happen when change ordering
    (e.g., conversion from BDD to OBDD).
    a value. Each `NestedDictLevel` handles a different level.
    """

    def __init__(self, length):
    self.top = None
    self.top = NestedDictLevel(0)
    self.length = length

    def get(self, keys, default=None):
    def visit(self, keys, visitor):
    """Gets the value associated with `keys`, or `default` if None."""
    assert len(keys) == self.length
    assert all(key is not None for key in keys)
    self.top.visit(keys, visitor)

    if self.top is None:
    return default
    return self.top.get(keys, default)

    def set(self, indexKeyValues, mergeFn):
    def set(self, keys, mergeFn):
    """Sets the value associated with `((index, key), ...)`."""
    assert all(0 <= index < self.length for index, _ in indexKeyValues)

    if self.top is None:
    self.top = NestedDictLevel(indexKeyValues)
    self.top.set(indexKeyValues, mergeFn)
    assert len(keys) == self.length
    self.top.set(keys, mergeFn)


    <details><summary><h2>Identity key-value maps</h2></summary>
    @@ -788,6 +792,18 @@ This last `map_reduce` definition ties everything together, and
    I think is really the general heart of Yannakakis's algorithm
    as an instance of bottom-up dynamic programming.

    def _merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst


    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    @@ -799,44 +815,32 @@ as an instance of bottom-up dynamic programming.

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.get((0,))
    >>> nd.get((1,)).value
    >>> nd.visit((0,), lambda sum: print(sum.value))
    >>> nd.visit((1,), lambda sum: print(sum.value))
    1
    >>> nd.get((2,)).value
    >>> nd.visit((2,), lambda sum: print(sum.value))
    2
    >>> nd.get((4,)).value
    >>> nd.visit((4,), lambda sum: print(sum.value))
    2
    """

    def merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst

    cache = NestedDict(depth)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
    cache.set(indexKeyValues, lambda old: merge(old, result))
    cache.set(indexKeyValues, lambda old: _merge(old, result))
    return cache


    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    """Returns the result of merging `map(function, inputIterable)`.
    def map_reduce(function, inputIterable, initialValue):
    """Returns the result of merging `map(function, inputIterable)`
    into `initialValue`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaced, `map_reduce` runs in time
    Assuming `function` is well-behaved, `map_reduce` runs in time
    linear wrt `len(inputIterable)`. It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.
    @@ -857,28 +861,29 @@ as an instance of bottom-up dynamic programming.
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data).value
    >>> map_reduce(count_eql(4), data, Sum()).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data).value
    >>> map_reduce(count_eql(2), data, Sum()).value
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    >>> sum_odd_or_even_skus(0).value
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    >>> sum_odd_or_even_skus(0)
    20
    >>> sum_odd_or_even_skus(1).value
    >>> sum_odd_or_even_skus(1)
    50

    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    @@ -887,7 +892,8 @@ as an instance of bottom-up dynamic programming.
    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    return AGGREGATE_CACHE[code, inputIterable].get(closure, None)
    AGGREGATE_CACHE[code, inputIterable].visit(closure, lambda result: _merge(initialValue, result))
    return initialValue


    if __name__ == "__main__":
    @@ -989,7 +995,7 @@ realistic queries.
    At a higher level, we could support comparison joins (e.g., less
    than, greater than or equal, in range) if only we represented the
    branching programs with a data structure that supported these
    queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    queries. A [range tree](https://dl.acm.org/doi/10.1145/356789.356797) would
    let us handle these "theta" joins, for tbe low low cost of a
    polylogarithmic multiplicative factor in space and time.

    220 changes: 113 additions & 107 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -52,51 +52,51 @@
    ## >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    ## >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    ## >>> def sku_min_cost(sku):
    ## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    ## ...
    ## >>> def sum_odd_or_even_skus(mod_two):
    ## ... def count_if_mod_two(id_sku):
    ## ... id, sku = id_sku
    ## ... if id % 2 == mod_two:
    ## ... return Sum(sku_min_cost(sku))
    ## ... return map_reduce(count_if_mod_two, id_skus)
    ## ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    ## ...
    ##
    ## with linear scaling in the length of `id_skus` and `sku_costs`, and
    ## caching for similar queries.
    ##
    ## At a small scale, everything's fast.
    ##
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 20
    ## 0.0012328624725341797
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 0.0002319812774658203
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 50
    ## 0.00023674964904785156
    ## 0.0002617835998535156
    ##
    ## As we increase the scale by 1000x in both tuples, the runtime
    ## scales (sub) linearly for the first query, and is unchanged
    ## for the second:
    ## As we increase the scale by 1000x for both input lists, the runtime
    ## scales (sub) linearly for the first query, and is unchanged for the
    ## second:
    ##
    ## >>> id_skus = id_skus * 1000
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 20000
    ## 0.101287841796875
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 0.10173487663269043
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 50000
    ## 0.0002338886260986328
    ## 0.0002841949462890625
    ##
    ## This still pretty much holds up when we multiply by another factor of 100:
    ##
    ## >>> id_skus = id_skus * 100
    ## >>> sku_costs = sku_costs * 100
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    ## 2000000
    ## 6.881294012069702
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 6.7300660610198975
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    ## 5000000
    ## 0.00024199485778808594
    ## 0.0002579689025878906
    ##
    ## The magic behind the curtains is memoisation (unsurprisingly), but
    ## a special implementation that can share work for similar closures:
    @@ -115,7 +115,7 @@
    ##
    ## The output data structure controls the join we can implement. We
    ## show how a simple stack of nested key-value mappings handles
    ## equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    ## equijoins, but a [k-d range tree](https://dl.acm.org/doi/10.1145/356789.356797)
    ## would handle inequalities, i.e., "theta" joins (the rest of the
    ## machinery already works in terms of less than/greater than
    ## constraints).
    @@ -237,7 +237,7 @@ def rebind(values):
    ## Currently, we only support nested dictionaries, so each
    ## `OpaqueValue`s must be either fully unconstrained (wildcard that
    ## matches any value), or constrained to be exactly equal to a value.
    ## There's no reason we can't use k-d trees though, and it's not
    ## There's no reason we can't use k-d range trees though, and it's not
    ## harder to track a pair of bounds (lower and upper) than a set of
    ## inequalities, so we'll handle the general ordered `OpaqueValue`
    ## case.
    @@ -296,9 +296,8 @@ class OpaqueValue:
    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name, index=0):
    def __init__(self, name):
    self.name = name
    self.index = index
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True
    @@ -436,7 +435,7 @@ def __ne__(self, other):
    return self.__cmp__(other) != 0

    # No other comparator because we don't index ranges
    # (no k-d tree).
    # (no range tree).

    # def __lt__(self, other):
    # return self.__cmp__(other) < 0
    @@ -519,15 +518,14 @@ def enumerate_opaque_values(function, values):

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding return value, for all non-zero
    values.
    `value` and the corresponding results, for all non-zero results.

    This essentially turns `function()` into a *not necessarily
    ordered* branching program on `values`.
    This essentially turns `function()` into a branching program on
    `values`.

    >>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [(((0, 0),), 1), (((0, 1), (1, 2)), 2)]
    [((0, None), 1), ((1, 2), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    @@ -536,7 +534,6 @@ def enumerate_opaque_values(function, values):
    value.reset()

    stackIndex = 0
    constraints = []

    def handle(value, other):
    nonlocal stackIndex
    @@ -550,7 +547,6 @@ def handle(value, other):
    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    constraints.append((value, other))
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    @@ -567,8 +563,7 @@ def handle(value, other):
    assert (
    value.definite() or value.indefinite()
    ), f"partially constrained {value} temporarily unsupported"
    assert value.indefinite() or any(key is value for key, _ in constraints)
    yield (tuple((key.index, other) for key, other in constraints),
    yield (tuple(key.value() if key.definite() else None for key in values),
    result)

    # Drop everything that was fully explored, then move the next
    @@ -592,10 +587,10 @@ def enumerate_supporting_values(function, args):

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name, index) for index, name in enumerate(names)]
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
    @@ -651,77 +646,86 @@ def merge(self, other):
    ## the complexity: it's the closest thing we have to indexes.
    ##
    ## For now, support equality *with ground value* as our only
    ## constraint. This means we can dispatch on the closed over values
    ## by order of appearance, with either a wildcard value (matches
    ## everything) or a hash map. At each internal level, the value is
    ## another `NestedDictLevel`. At the leaf, the value is the
    ## precomputed value.
    ## constraint. This means the only two cases we must look for at
    ## each level are an exact match, or a wildcard match.
    ##
    ## We do have to check for both cases at each level, so the worst-case
    ## complexity for lookups is exponential in the depth (number of join
    ## variables). That's actually reasonable because we don't expect too
    ## many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
    ## that are also exponential in the number of dimensions... but with
    ## a base of log(n) instead of 2).


    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything (leaf node), or a key-value dict for a specific
    index in the tuple key.
    """One level in a nested dictionary index. We may have a value
    for everything (leaf node), or a key-value dict *and a wildcard
    entry* for a specific index in the tuple key.
    """

    def __init__(self, indexValues, depth=0):
    assert depth <= len(indexValues)
    self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
    def __init__(self, depth):
    self.depth = depth
    self.value = None
    self.wildcard = None
    self.dict = dict()

    def get(self, keys, default):
    """Gets the value for `keys` in this level, or `default` if None."""
    assert self.value is not None or self.keyIndex is not None

    def visit(self, keys, visitor):
    """Passes the values for `keys` to `visitor`."""
    if self.value is not None:
    return self.value
    next = self.dict.get(keys[self.keyIndex], None)
    if next is None:
    return default
    return next.get(keys, default)
    assert self.wildcard is None and not self.dict
    visitor(self.value)

    if self.depth >= len(keys):
    return

    if self.wildcard is not None:
    self.wildcard.visit(keys, visitor)

    def set(self, indexValues, mergeFunction, depth=0):
    next = self.dict.get(keys[self.depth], None)
    if next is not None:
    next.visit(keys, visitor)

    def set(self, keys, mergeFunction, depth=0):
    """Sets the value for `keys` in this level."""
    assert depth <= len(indexValues)
    if depth == len(indexValues): # Leaf
    assert depth <= len(keys)
    if depth == len(keys): # Leaf
    self.value = mergeFunction(self.value)
    return

    index, value = indexValues[depth]
    assert self.keyIndex == index
    if value not in self.dict:
    self.dict[value] = NestedDictLevel(indexValues, depth + 1)
    self.dict[value].set(indexValues, mergeFunction, depth + 1)
    assert self.depth == depth
    key = keys[depth]

    if key is None:
    if self.wildcard is None:
    self.wildcard = NestedDictLevel(depth + 1)
    dst = self.wildcard
    else:
    dst = self.dict.get(key)
    if dst is None:
    dst = NestedDictLevel(depth + 1)
    self.dict[key] = dst
    dst.set(keys, mergeFunction, depth + 1)


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each `NestedDictLevel` handles a different level. We
    don't test each index in the tuple in fixed order to avoid the
    combinatorial explosion that can happen when change ordering
    (e.g., conversion from BDD to OBDD).
    a value. Each `NestedDictLevel` handles a different level.
    """

    def __init__(self, length):
    self.top = None
    self.top = NestedDictLevel(0)
    self.length = length

    def get(self, keys, default=None):
    def visit(self, keys, visitor):
    """Gets the value associated with `keys`, or `default` if None."""
    assert len(keys) == self.length
    assert all(key is not None for key in keys)
    self.top.visit(keys, visitor)

    if self.top is None:
    return default
    return self.top.get(keys, default)

    def set(self, indexKeyValues, mergeFn):
    def set(self, keys, mergeFn):
    """Sets the value associated with `((index, key), ...)`."""
    assert all(0 <= index < self.length for index, _ in indexKeyValues)

    if self.top is None:
    self.top = NestedDictLevel(indexKeyValues)
    self.top.set(indexKeyValues, mergeFn)
    assert len(keys) == self.length
    self.top.set(keys, mergeFn)


    ##{<h2>Identity key-value maps</h2>
    @@ -782,6 +786,18 @@ def __setitem__(self, keys, value):
    ## I think is really the general heart of Yannakakis's algorithm
    ## as an instance of bottom-up dynamic programming.

    def _merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst


    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    @@ -793,44 +809,32 @@ def _precompute_map_reduce(function, depth, inputIterable):

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.get((0,))
    >>> nd.get((1,)).value
    >>> nd.visit((0,), lambda sum: print(sum.value))
    >>> nd.visit((1,), lambda sum: print(sum.value))
    1
    >>> nd.get((2,)).value
    >>> nd.visit((2,), lambda sum: print(sum.value))
    2
    >>> nd.get((4,)).value
    >>> nd.visit((4,), lambda sum: print(sum.value))
    2
    """

    def merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst

    cache = NestedDict(depth)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
    cache.set(indexKeyValues, lambda old: merge(old, result))
    cache.set(indexKeyValues, lambda old: _merge(old, result))
    return cache


    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    """Returns the result of merging `map(function, inputIterable)`.
    def map_reduce(function, inputIterable, initialValue):
    """Returns the result of merging `map(function, inputIterable)`
    into `initialValue`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaced, `map_reduce` runs in time
    Assuming `function` is well-behaved, `map_reduce` runs in time
    linear wrt `len(inputIterable)`. It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.
    @@ -851,28 +855,29 @@ def map_reduce(function, inputIterable):
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data).value
    >>> map_reduce(count_eql(4), data, Sum()).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data).value
    >>> map_reduce(count_eql(2), data, Sum()).value
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)).value
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    >>> sum_odd_or_even_skus(0).value
    ... return map_reduce(count_if_mod_two, id_skus, Sum()).value
    >>> sum_odd_or_even_skus(0)
    20
    >>> sum_odd_or_even_skus(1).value
    >>> sum_odd_or_even_skus(1)
    50

    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    @@ -881,7 +886,8 @@ def map_reduce(function, inputIterable):
    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    return AGGREGATE_CACHE[code, inputIterable].get(closure, None)
    AGGREGATE_CACHE[code, inputIterable].visit(closure, lambda result: _merge(initialValue, result))
    return initialValue


    if __name__ == "__main__":
    @@ -983,7 +989,7 @@ def map_reduce(function, inputIterable):
    ## At a higher level, we could support comparison joins (e.g., less
    ## than, greater than or equal, in range) if only we represented the
    ## branching programs with a data structure that supported these
    ## queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    ## queries. A [range tree](https://dl.acm.org/doi/10.1145/356789.356797) would
    ## let us handle these "theta" joins, for tbe low low cost of a
    ## polylogarithmic multiplicative factor in space and time.
    ##
  4. pkhuong revised this gist Jul 25, 2023. 2 changed files with 14 additions and 14 deletions.
    14 changes: 7 additions & 7 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -71,10 +71,10 @@ At a small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20
    0.001024007797241211
    0.0012328624725341797
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50
    4.982948303222656e-05
    0.00023674964904785156

    As we increase the scale by 1000x in both tuples, the runtime
    scales (sub) linearly for the first query, and is unchanged
    @@ -84,21 +84,21 @@ for the second:
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20000
    0.09638118743896484
    0.101287841796875
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50000
    8.797645568847656e-05
    0.0002338886260986328

    This still kind of holds when we multiply by another factor of 100:
    This still pretty much holds up when we multiply by another factor of 100:

    >>> id_skus = id_skus * 100
    >>> sku_costs = sku_costs * 100
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    2000000
    5.9715189933776855
    6.881294012069702
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    5000000
    0.00021195411682128906
    0.00024199485778808594

    The magic behind the curtains is memoisation (unsurprisingly), but
    a special implementation that can share work for similar closures:
    14 changes: 7 additions & 7 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -69,10 +69,10 @@
    ##
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20
    ## 0.001024007797241211
    ## 0.0012328624725341797
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50
    ## 4.982948303222656e-05
    ## 0.00023674964904785156
    ##
    ## As we increase the scale by 1000x in both tuples, the runtime
    ## scales (sub) linearly for the first query, and is unchanged
    @@ -82,21 +82,21 @@
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20000
    ## 0.09638118743896484
    ## 0.101287841796875
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50000
    ## 8.797645568847656e-05
    ## 0.0002338886260986328
    ##
    ## This still kind of holds when we multiply by another factor of 100:
    ## This still pretty much holds up when we multiply by another factor of 100:
    ##
    ## >>> id_skus = id_skus * 100
    ## >>> sku_costs = sku_costs * 100
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 2000000
    ## 5.9715189933776855
    ## 6.881294012069702
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 5000000
    ## 0.00021195411682128906
    ## 0.00024199485778808594
    ##
    ## The magic behind the curtains is memoisation (unsurprisingly), but
    ## a special implementation that can share work for similar closures:
  5. pkhuong revised this gist Jul 25, 2023. 2 changed files with 2 additions and 2 deletions.
    2 changes: 1 addition & 1 deletion yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -439,7 +439,7 @@ ground data structure to represent finitely supported functions.
    def __ne__(self, other):
    return self.__cmp__(other) != 0

    # No other comparator because we don't have index ranges
    # No other comparator because we don't index ranges
    # (no k-d tree).

    # def __lt__(self, other):
    2 changes: 1 addition & 1 deletion yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -435,7 +435,7 @@ def __eq__(self, other):
    def __ne__(self, other):
    return self.__cmp__(other) != 0

    # No other comparator because we don't have index ranges
    # No other comparator because we don't index ranges
    # (no k-d tree).

    # def __lt__(self, other):
  6. pkhuong revised this gist Jul 25, 2023. 2 changed files with 204 additions and 174 deletions.
    189 changes: 102 additions & 87 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -280,27 +280,29 @@ ground data structure to represent finitely supported functions.
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> x > 4
    False
    >>> x < 4
    False
    >>> x == 4
    True
    >>> x < 10
    True
    >>> x >= 10
    False
    >>> x > -10
    True
    >>> x <= -10
    False
    >>> ### Not supported by our index data structure (yet)
    >>> # >>> x > 4
    >>> # False
    >>> # >>> x < 4
    >>> # False
    >>> # >>> x == 4
    >>> # True
    >>> # >>> x < 10
    >>> # True
    >>> # >>> x >= 10
    >>> # False
    >>> # >>> x > -10
    >>> # True
    >>> # >>> x <= -10
    >>> # False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    def __init__(self, name, index=0):
    self.name = name
    self.index = index
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True
    @@ -437,17 +439,20 @@ ground data structure to represent finitely supported functions.
    def __ne__(self, other):
    return self.__cmp__(other) != 0

    def __lt__(self, other):
    return self.__cmp__(other) < 0
    # No other comparator because we don't have index ranges
    # (no k-d tree).

    # def __lt__(self, other):
    # return self.__cmp__(other) < 0

    def __le__(self, other):
    return self.__cmp__(other) <= 0
    # def __le__(self, other):
    # return self.__cmp__(other) <= 0

    def __gt__(self, other):
    return self.__cmp__(other) > 0
    # def __gt__(self, other):
    # return self.__cmp__(other) > 0

    def __ge__(self, other):
    return self.__cmp__(other) >= 0
    # def __ge__(self, other):
    # return self.__cmp__(other) >= 0


    ## Depth-first exploration of a function call's support
    @@ -516,16 +521,17 @@ results for identical keys together.
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in `values`,
    and yields a pair of tuple of ground values for each `value` and the
    corresponding return value, for all non-zero values.
    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding return value, for all non-zero
    values.

    This essentially turns `function()` into a branching program on
    `values`.
    This essentially turns `function()` into a *not necessarily
    ordered* branching program on `values`.

    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [((0, None), 1), ((1, 2), 2)]
    [(((0, 0),), 1), (((0, 1), (1, 2)), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    @@ -534,6 +540,7 @@ results for identical keys together.
    value.reset()

    stackIndex = 0
    constraints = []

    def handle(value, other):
    nonlocal stackIndex
    @@ -547,6 +554,7 @@ results for identical keys together.
    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    constraints.append((value, other))
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    @@ -559,13 +567,13 @@ results for identical keys together.
    OpaqueValue.CMP_HANDLER = handle
    result = function()
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert (
    value.definite() or value.indefinite()
    ), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)
    ), f"partially constrained {value} temporarily unsupported"
    assert value.indefinite() or any(key is value for key, _ in constraints)
    yield (tuple((key.index, other) for key, other in constraints),
    result)

    # Drop everything that was fully explored, then move the next
    # top of stack to the next option.
    @@ -588,10 +596,10 @@ results for identical keys together.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    [(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name) for name in names]
    values = [OpaqueValue(name, index) for index, name in enumerate(names)]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
    @@ -656,61 +664,68 @@ precomputed value.

    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    value for everything (leaf node), or a key-value dict for a specific
    index in the tuple key.
    """

    def __init__(self):
    def __init__(self, indexValues, depth=0):
    assert depth <= len(indexValues)
    self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
    self.value = None
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(
    key, self.wildcard if self.wildcard is not None else default
    )
    def get(self, keys, default):
    """Gets the value for `keys` in this level, or `default` if None."""
    assert self.value is not None or self.keyIndex is not None

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    represents a wildcard value.
    """
    if key is not None:
    self.dict[key] = value
    else:
    self.wildcard = value
    if self.value is not None:
    return self.value
    next = self.dict.get(keys[self.keyIndex], None)
    if next is None:
    return default
    return next.get(keys, default)

    def set(self, indexValues, mergeFunction, depth=0):
    """Sets the value for `keys` in this level."""
    assert depth <= len(indexValues)
    if depth == len(indexValues): # Leaf
    self.value = mergeFunction(self.value)
    return

    index, value = indexValues[depth]
    assert self.keyIndex == index
    if value not in self.dict:
    self.dict[value] = NestedDictLevel(indexValues, depth + 1)
    self.dict[value].set(indexValues, mergeFunction, depth + 1)


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    a value. Each `NestedDictLevel` handles a different level. We
    don't test each index in the tuple in fixed order to avoid the
    combinatorial explosion that can happen when change ordering
    (e.g., conversion from BDD to OBDD).
    """

    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    self.depth = depth

    def get(self, key, default=None):
    """Gets the value associated with `key`, or `default` if None."""
    assert len(key) == self.depth
    current = self.top
    for levelKey in key:
    current = current.get(levelKey, None)
    if current is None:
    return default
    return current

    def set(self, key, value):
    """Sets the value associated with `key`."""
    assert len(key) == self.depth
    current = self.top
    for idx, levelKey in enumerate(key):
    if idx == len(key) - 1:
    current.set(levelKey, value)
    else:
    next = current.get(levelKey, None)
    if next is None:
    next = NestedDictLevel()
    current.set(levelKey, next)
    def __init__(self, length):
    self.top = None
    self.length = length

    def get(self, keys, default=None):
    """Gets the value associated with `keys`, or `default` if None."""
    assert len(keys) == self.length

    if self.top is None:
    return default
    return self.top.get(keys, default)

    def set(self, indexKeyValues, mergeFn):
    """Sets the value associated with `((index, key), ...)`."""
    assert all(0 <= index < self.length for index, _ in indexKeyValues)

    if self.top is None:
    self.top = NestedDictLevel(indexKeyValues)
    self.top.set(indexKeyValues, mergeFn)


    <details><summary><h2>Identity key-value maps</h2></summary>
    @@ -753,7 +768,7 @@ supporting values, and merge the result associated with identical
    supporting values.

    Again, we only support ground equality constraints (see assertion
    on L558), i.e., only equijoins. There's nothing that stops a more
    on L568), i.e., only equijoins. There's nothing that stops a more
    sophisticated implementation from using range trees to support
    inequality or range joins.

    @@ -794,20 +809,20 @@ as an instance of bottom-up dynamic programming.
    """

    def merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst

    cache = NestedDict(depth)
    for key, result in enumerate_supporting_values(function, inputIterable):
    prev = cache.get(key, None)
    if prev is None:
    cache.set(key, result)
    else:
    merge(prev, result)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
    cache.set(indexKeyValues, lambda old: merge(old, result))
    return cache


    189 changes: 102 additions & 87 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -276,27 +276,29 @@ class OpaqueValue:
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> x > 4
    False
    >>> x < 4
    False
    >>> x == 4
    True
    >>> x < 10
    True
    >>> x >= 10
    False
    >>> x > -10
    True
    >>> x <= -10
    False
    >>> ### Not supported by our index data structure (yet)
    >>> # >>> x > 4
    >>> # False
    >>> # >>> x < 4
    >>> # False
    >>> # >>> x == 4
    >>> # True
    >>> # >>> x < 10
    >>> # True
    >>> # >>> x >= 10
    >>> # False
    >>> # >>> x > -10
    >>> # True
    >>> # >>> x <= -10
    >>> # False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    def __init__(self, name, index=0):
    self.name = name
    self.index = index
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True
    @@ -433,17 +435,20 @@ def __eq__(self, other):
    def __ne__(self, other):
    return self.__cmp__(other) != 0

    def __lt__(self, other):
    return self.__cmp__(other) < 0
    # No other comparator because we don't have index ranges
    # (no k-d tree).

    # def __lt__(self, other):
    # return self.__cmp__(other) < 0

    def __le__(self, other):
    return self.__cmp__(other) <= 0
    # def __le__(self, other):
    # return self.__cmp__(other) <= 0

    def __gt__(self, other):
    return self.__cmp__(other) > 0
    # def __gt__(self, other):
    # return self.__cmp__(other) > 0

    def __ge__(self, other):
    return self.__cmp__(other) >= 0
    # def __ge__(self, other):
    # return self.__cmp__(other) >= 0


    ## ## Depth-first exploration of a function call's support
    @@ -512,16 +517,17 @@ def enumerate_opaque_values(function, values):
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in `values`,
    and yields a pair of tuple of ground values for each `value` and the
    corresponding return value, for all non-zero values.
    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding return value, for all non-zero
    values.

    This essentially turns `function()` into a branching program on
    `values`.
    This essentially turns `function()` into a *not necessarily
    ordered* branching program on `values`.

    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [((0, None), 1), ((1, 2), 2)]
    [(((0, 0),), 1), (((0, 1), (1, 2)), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    @@ -530,6 +536,7 @@ def enumerate_opaque_values(function, values):
    value.reset()

    stackIndex = 0
    constraints = []

    def handle(value, other):
    nonlocal stackIndex
    @@ -543,6 +550,7 @@ def handle(value, other):
    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    constraints.append((value, other))
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    @@ -555,13 +563,13 @@ def handle(value, other):
    OpaqueValue.CMP_HANDLER = handle
    result = function()
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert (
    value.definite() or value.indefinite()
    ), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)
    ), f"partially constrained {value} temporarily unsupported"
    assert value.indefinite() or any(key is value for key, _ in constraints)
    yield (tuple((key.index, other) for key, other in constraints),
    result)

    # Drop everything that was fully explored, then move the next
    # top of stack to the next option.
    @@ -584,10 +592,10 @@ def enumerate_supporting_values(function, args):

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    [(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name) for name in names]
    values = [OpaqueValue(name, index) for index, name in enumerate(names)]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
    @@ -652,61 +660,68 @@ def merge(self, other):

    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    value for everything (leaf node), or a key-value dict for a specific
    index in the tuple key.
    """

    def __init__(self):
    def __init__(self, indexValues, depth=0):
    assert depth <= len(indexValues)
    self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
    self.value = None
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(
    key, self.wildcard if self.wildcard is not None else default
    )
    def get(self, keys, default):
    """Gets the value for `keys` in this level, or `default` if None."""
    assert self.value is not None or self.keyIndex is not None

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    represents a wildcard value.
    """
    if key is not None:
    self.dict[key] = value
    else:
    self.wildcard = value
    if self.value is not None:
    return self.value
    next = self.dict.get(keys[self.keyIndex], None)
    if next is None:
    return default
    return next.get(keys, default)

    def set(self, indexValues, mergeFunction, depth=0):
    """Sets the value for `keys` in this level."""
    assert depth <= len(indexValues)
    if depth == len(indexValues): # Leaf
    self.value = mergeFunction(self.value)
    return

    index, value = indexValues[depth]
    assert self.keyIndex == index
    if value not in self.dict:
    self.dict[value] = NestedDictLevel(indexValues, depth + 1)
    self.dict[value].set(indexValues, mergeFunction, depth + 1)


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    a value. Each `NestedDictLevel` handles a different level. We
    don't test each index in the tuple in fixed order to avoid the
    combinatorial explosion that can happen when change ordering
    (e.g., conversion from BDD to OBDD).
    """

    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    self.depth = depth

    def get(self, key, default=None):
    """Gets the value associated with `key`, or `default` if None."""
    assert len(key) == self.depth
    current = self.top
    for levelKey in key:
    current = current.get(levelKey, None)
    if current is None:
    return default
    return current

    def set(self, key, value):
    """Sets the value associated with `key`."""
    assert len(key) == self.depth
    current = self.top
    for idx, levelKey in enumerate(key):
    if idx == len(key) - 1:
    current.set(levelKey, value)
    else:
    next = current.get(levelKey, None)
    if next is None:
    next = NestedDictLevel()
    current.set(levelKey, next)
    def __init__(self, length):
    self.top = None
    self.length = length

    def get(self, keys, default=None):
    """Gets the value associated with `keys`, or `default` if None."""
    assert len(keys) == self.length

    if self.top is None:
    return default
    return self.top.get(keys, default)

    def set(self, indexKeyValues, mergeFn):
    """Sets the value associated with `((index, key), ...)`."""
    assert all(0 <= index < self.length for index, _ in indexKeyValues)

    if self.top is None:
    self.top = NestedDictLevel(indexKeyValues)
    self.top.set(indexKeyValues, mergeFn)


    ##{<h2>Identity key-value maps</h2>
    @@ -747,7 +762,7 @@ def __setitem__(self, keys, value):
    ## supporting values.
    ##
    ## Again, we only support ground equality constraints (see assertion
    ## on L558), i.e., only equijoins. There's nothing that stops a more
    ## on L568), i.e., only equijoins. There's nothing that stops a more
    ## sophisticated implementation from using range trees to support
    ## inequality or range joins.
    ##
    @@ -788,20 +803,20 @@ def _precompute_map_reduce(function, depth, inputIterable):
    """

    def merge(dst, update):
    if dst is None:
    return update

    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)
    return dst

    cache = NestedDict(depth)
    for key, result in enumerate_supporting_values(function, inputIterable):
    prev = cache.get(key, None)
    if prev is None:
    cache.set(key, result)
    else:
    merge(prev, result)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
    cache.set(indexKeyValues, lambda old: merge(old, result))
    return cache


  7. pkhuong revised this gist Jul 25, 2023. 2 changed files with 26 additions and 0 deletions.
    13 changes: 13 additions & 0 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -899,6 +899,19 @@ just CQ, and that's actually a liability: this hack will blindly
    try to convert any function to a branching program, instead of
    giving up noisily when the function is too complex.

    The other difference from classical CQs is that we focus on
    aggregates. That's because aggregates are the more general form:
    if we just want to avoid useless work while enumerating all join
    rows, we only need a boolean aggregate that tells us whether the
    join will yield at least one row. We could also special case types
    for which merges don't save space (e.g., set of row ids), and
    instead enumerate values by walking the branching program tree.

    The aggregate viewpoint also works for
    [fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf):
    that extension ends up counting the number of output values up to a
    certain key.

    I guess, in a way, we just showed a trivial way to decorrelate
    queries with a hypertree-width of 1. We just have to be OK with
    building one index for each loop in the nest... but it should be
    13 changes: 13 additions & 0 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -893,6 +893,19 @@ def map_reduce(function, inputIterable):
    ## try to convert any function to a branching program, instead of
    ## giving up noisily when the function is too complex.
    ##
    ## The other difference from classical CQs is that we focus on
    ## aggregates. That's because aggregates are the more general form:
    ## if we just want to avoid useless work while enumerating all join
    ## rows, we only need a boolean aggregate that tells us whether the
    ## join will yield at least one row. We could also special case types
    ## for which merges don't save space (e.g., set of row ids), and
    ## instead enumerate values by walking the branching program tree.
    ##
    ## The aggregate viewpoint also works for
    ## [fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf):
    ## that extension ends up counting the number of output values up to a
    ## certain key.
    ##
    ## I guess, in a way, we just showed a trivial way to decorrelate
    ## queries with a hypertree-width of 1. We just have to be OK with
    ## building one index for each loop in the nest... but it should be
  8. pkhuong revised this gist Jul 25, 2023. 2 changed files with 2 additions and 2 deletions.
    2 changes: 1 addition & 1 deletion yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    <details><summary><h2>license, imports</h2></summary>
    <details><summary>license, imports</summary>

    # Yannakakis.py by Paul Khuong
    #
    2 changes: 1 addition & 1 deletion yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    ##{<h2>license, imports</h2>
    ##{license, imports
    # Yannakakis.py by Paul Khuong
    #
    # To the extent possible under law, the person who associated CC0 with
  9. pkhuong revised this gist Jul 25, 2023. 2 changed files with 422 additions and 142 deletions.
    285 changes: 214 additions & 71 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -1,20 +1,37 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    <details><summary><h2>license, imports</h2></summary>

    # Yannakakis.py by Paul Khuong
    #
    # To the extent possible under law, the person who associated CC0 with
    # Yannakakis.py has waived all copyright and related or neighboring rights
    # to Yannakakis.py.
    #
    # You should have received a copy of the CC0 legalcode along with this
    # work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.

    import collections.abc

    # Linear-time analytics queries in raw Python

    </details>

    # Linear-time analytical queries in plain Python

    <small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small>

    This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
    lets us implement linear-time (wrt the number of input data rows)
    analytical queries in regular programming languages, without
    offloading joins to a specialised database query language, thus
    avoiding the associated impedance mismatch. There are restrictions
    on the queries we can express -- Yannakakis's algorithm relies on a
    hypertree width of 1, and associated decomposition -- but that's
    kind of reasonable: a fractional hypertree width > 1 would mean tere
    are databases for which the intermediate results could grow much
    larger than the input database (superlinearly), from the [AGM bound](https://arxiv.org/abs/1711.03860).
    Structured programs also naturally yield a hypertree decomposition,
    hypertree width of 1, and on having hypertree decomposition as a witness
    -- but that's kind of reasonable: a fractional hypertree width > 1
    would mean there are databases for which the intermediate results
    could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860).
    The hypertree decomposition witness isn't a burden either:
    structured programs naturally yield a hypertree decomposition,
    unlike more declarative logic programs that tend to hide the
    structure implicit in the programmer's thinking.

    @@ -26,80 +43,84 @@ positive, and the only negative argument is the current data row.
    We also assume that these functions are always used in a map/reduce
    pattern, and thus we only memoise the result of
    `map_reduce(function, input)`, with a group-structured reduction:
    it must be associative and commutative, and there must be a zero
    (neutral) value.
    the reduce function must be associative and commutative, and there
    must be a zero (neutral) value.

    With these constraints, we can express joins in natural Python
    without incurring the poor runtime scaling of the nested loops we
    actually wrote. This Python file describes the building blocks
    to handle aggregation queries like the following

    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ...
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    ...

    with linear scaling in the length of `id_skus` and `sku_costs`, and
    caching for similar queries.

    At small scale, everything's fast.
    At a small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20
    0.00020003318786621094
    0.001024007797241211
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50
    0.0001990795135498047
    4.982948303222656e-05

    As we increase the scale by 1000x in both tuples, the runtime
    scales roughly linearly for the first query, and is unchanged
    scales (sub) linearly for the first query, and is unchanged
    for the second:

    >>> id_skus = id_skus * 1000
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20000
    0.19550323486328125
    0.09638118743896484
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50000
    0.00029087066650390625
    8.797645568847656e-05

    This still kind of holds when we multiply by another factor of 10
    (albeit with additional slowdowns due to the memory footprint):
    This still kind of holds when we multiply by another factor of 100:

    >>> id_skus = id_skus * 10
    >>> sku_costs = sku_costs * 10
    >>> id_skus = id_skus * 100
    >>> sku_costs = sku_costs * 100
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    200000
    11.275727272033691
    2000000
    5.9715189933776855
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    500000
    0.0015010833740234375
    5000000
    0.00021195411682128906

    The magic behind the curtains is memoisation (unsurprisingly), but
    a special implementation that can share work for similar closures:
    the memoisation key consists of the function *without closed over
    bindings* and the input, and the memoised value is a data structure
    from the tuple of closed over values to the `map_reduce` output.

    This data structure is where Yannakakis's algorithm comes in: we'll
    iterate over each datum in the input, run the function on it with
    logical variables instead of the closed over values, and generate a
    mapping from closed over values to result for all non-zero results.
    We'll then merge the mappings for all input data together (there is
    no natural ordering here, hence the group structure).
    bindings* and the call arguments, while the memoised value is a
    data structure from the tuple of closed over values to the
    `map_reduce` output.

    This concrete representation of the function as a branching program
    is the core of Yannakakis's algorithm: we'll iterate over each
    datum in the input, run the function on it with logical variables
    instead of the closed over values, and generate a mapping from
    closed over values to result for all non-zero results. We'll then
    merge the mappings for all input data together (there is no natural
    ordering here, hence the group structure).

    The output data structure controls the join we can implement. We
    show how a simple stack of nested key-value mappings handles
    equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    would handle inequalities (the rest of the machinery already works
    .in terms of less than/greater than constraints).
    would handle inequalities, i.e., "theta" joins (the rest of the
    machinery already works in terms of less than/greater than
    constraints).

    As long as we explore a bounded number of paths for each datum and
    a bounded number of `function, input` cache keys, we'll spend a
    @@ -108,9 +129,9 @@ total. The magic of Yannakakis's algorithm is that this works even
    when there are nested `map_reduce` calls, which would naïvely
    result in polynomial time (degree equal to the nesting depth).



    ## Memoising through a Python function's closed over values
    <details><summary><h2>Memoising through a Python function's closed over values</h2></summary>


    Even if functions were hashable and comparable for extensional
    equality, directly using closures as memoisation keys in calls like
    @@ -129,6 +150,7 @@ complicated cases; this is a hack, after all. In short, it only
    handles closing over immutable atomic values like integers or
    strings, but not, e.g., functions (yet), or mutable bindings.


    def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    @@ -164,12 +186,16 @@ strings, but not, e.g., functions (yet), or mutable bindings.
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None: # Toplevel function
    if function.__closure__ is None: # Toplevel function
    assert names == ()

    def rebind(values):
    if len(values) != 0:
    raise RuntimeError(f"Values must be empty for toplevel function. values={values}")
    raise RuntimeError(
    f"Values must be empty for toplevel function. values={values}"
    )
    return function

    return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    @@ -178,18 +204,26 @@ strings, but not, e.g., functions (yet), or mutable bindings.
    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
    if len(values) != len(names):
    raise RuntimeError(f"Values must match names. names={names} values={values}")
    raise RuntimeError(
    f"Values must match names. names={names} values={values}"
    )
    return function.__class__(
    code,
    function.__globals__,
    function.__name__,
    function.__defaults__,
    tuple(cell.__class__(value)
    for cell, value in zip(function.__closure__, values)))
    tuple(
    cell.__class__(value)
    for cell, value in zip(function.__closure__, values)
    ),
    )

    return code, closure, rebind, names



    </details>

    ## Logical variables for closed-over values

    @@ -228,6 +262,7 @@ search space.
    N.B., the set of constraints we can handle is determined by the
    ground data structure to represent finitely supported functions.


    class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    @@ -263,7 +298,7 @@ ground data structure to represent finitely supported functions.

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    self.name = name
    # Upper and lower bounds
    @@ -286,7 +321,12 @@ ground data structure to represent finitely supported functions.

    def definite(self):
    """Returns whether is `OpaqueValue` is constrained to an exact value."""
    return self.lower == self.upper and self.lower is not None and not self.lowerExcl and not self.upperExcl
    return (
    self.lower == self.upper
    and self.lower is not None
    and not self.lowerExcl
    and not self.upperExcl
    )

    def value(self):
    """Returns the exact value for this `OpaqueValue`, assuming there is one."""
    @@ -346,17 +386,22 @@ ground data structure to represent finitely supported functions.
    that value. Otherwise, we ask `CMP_HANDLER` what value to return
    and update the bound accordingly.
    """
    if isinstance(other, OpaqueValue) and \
    self.definite() and not other.definite():
    if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
    # If we have a definite value and `other` is an indefinite
    # `OpaqueValue`, flip the comparison order to let the `other`
    # argument be a ground value.
    return -other.__cmp__(self.value())

    assert not isinstance(other, OpaqueValue) or other.definite()
    if isinstance(other, OpaqueValue) and not other.definite():
    raise RuntimeError(
    f"OpaqueValue may only be compared with ground values. self={self} other={other}"
    )
    if isinstance(other, OpaqueValue):
    other = other.value() # Make sure `other` is a ground value
    assert other is not None # We use `None` internally, and it doesn't compare well

    if other is None:
    # We use `None` internally, and it doesn't compare well
    raise RuntimeError("OpaqueValue may not be compared with None")

    compatible, order = self._contains(other, False)
    if order is not None:
    @@ -404,7 +449,6 @@ ground data structure to represent finitely supported functions.
    def __ge__(self, other):
    return self.__cmp__(other) >= 0



    ## Depth-first exploration of a function call's support

    @@ -429,6 +473,7 @@ with a non-recursive depth-first traversal.
    We then do the same for each datum in our input sequence, and merge
    results for identical keys together.


    def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.
    @@ -489,6 +534,7 @@ results for identical keys together.
    value.reset()

    stackIndex = 0

    def handle(value, other):
    nonlocal stackIndex
    if len(explorationStack) == stackIndex:
    @@ -515,7 +561,9 @@ results for identical keys together.
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert value.definite() or value.indefinite(), f"{value} temporarily unsupported"
    assert (
    value.definite() or value.indefinite()
    ), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)

    @@ -546,9 +594,7 @@ results for identical keys together.
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg),
    values)

    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)


    ## Type driven merges
    @@ -564,8 +610,10 @@ This file defines a single mergeable value type, `Sum`, but we
    could have different ones, e.g., hyperloglog unique counts, or
    streaming statistical moments.


    class Sum:
    """A counter for summed values."""

    def __init__(self, value=0):
    self.value = value

    @@ -576,6 +624,7 @@ streaming statistical moments.

    class Min:
    """A running `min` value tracker."""

    def __init__(self, value=None):
    self.value = value

    @@ -585,6 +634,7 @@ streaming statistical moments.
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)


    ## Nested dictionary with wildcard

    @@ -603,17 +653,21 @@ everything) or a hash map. At each internal level, the value is
    another `NestedDictLevel`. At the leaf, the value is the
    precomputed value.


    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    """

    def __init__(self):
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(key, self.wildcard if self.wildcard is not None else default)
    return self.dict.get(
    key, self.wildcard if self.wildcard is not None else default
    )

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    @@ -629,6 +683,7 @@ precomputed value.
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    """

    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    @@ -658,6 +713,33 @@ precomputed value.
    current.set(levelKey, next)


    <details><summary><h2>Identity key-value maps</h2></summary>

    class IdMap:
    def __init__(self):
    self.entries = dict() # tuple of id -> (key, value)
    # the value's first element keeps the ids stable.

    def get(self, keys, default=None):
    ids = tuple(id(key) for key in keys)
    return self.entries.get(ids, (None, default))[1]

    def __contains__(self, keys):
    ids = tuple(id(key) for key in keys)
    return ids in self.entries

    def __getitem__(self, keys):
    ids = tuple(id(key) for key in keys)
    return self.entries[ids][1]

    def __setitem__(self, keys, value):
    ids = tuple(id(key) for key in keys)
    self.entries[ids] = (keys, value)



    </details>

    ## Cached `map_reduce`

    @@ -671,7 +753,7 @@ supporting values, and merge the result associated with identical
    supporting values.

    Again, we only support ground equality constraints (see assertion
    on L453), i.e., only equijoins. There's nothing that stops a more
    on L558), i.e., only equijoins. There's nothing that stops a more
    sophisticated implementation from using range trees to support
    inequality or range joins.

    @@ -691,6 +773,7 @@ This last `map_reduce` definition ties everything together, and
    I think is really the general heart of Yannakakis's algorithm
    as an instance of bottom-up dynamic programming.


    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    @@ -709,6 +792,7 @@ as an instance of bottom-up dynamic programming.
    >>> nd.get((4,)).value
    2
    """

    def merge(dst, update):
    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    @@ -727,7 +811,7 @@ as an instance of bottom-up dynamic programming.
    return cache


    AGGREGATE_CACHE = dict() # Map from function, input sequence -> NestedDict
    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    @@ -766,8 +850,8 @@ as an instance of bottom-up dynamic programming.
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    @@ -785,23 +869,73 @@ as an instance of bottom-up dynamic programming.
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
    AGGREGATE_CACHE[(code, inputIterable)] = \
    _precompute_map_reduce(function, len(closure), inputIterable)
    return AGGREGATE_CACHE[(code, inputIterable)].get(closure, None)

    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    return AGGREGATE_CACHE[code, inputIterable].get(closure, None)


    if __name__ == "__main__":
    import doctest

    doctest.testmod()



    ## Is this actually a DB post?

    Although the intro name-dropped Yannakakis, the presentation here
    has a very programming language / logic programming flavour. I
    think the logic programming point of view, where we run a program
    backwards with logical variables, is much clearer than the specific
    case of conjunctive equijoin queries in the usual presentation of
    Yannakakis's algorithm. In particular, I think there's a clear
    path to handle range or comparison joins: it's all about having an
    index data structure to handle range queries.

    It should be clear how to write conjunctive queries as Python
    functions, given a hypertree decomposition. The reverse is much
    more complex, if only because Python is much more powerful than
    just CQ, and that's actually a liability: this hack will blindly
    try to convert any function to a branching program, instead of
    giving up noisily when the function is too complex.

    I guess, in a way, we just showed a trivial way to decorrelate
    queries with a hypertree-width of 1. We just have to be OK with
    building one index for each loop in the nest... but it should be
    possible to pattern match on pre-defined indexes and avoid obvious
    redundancy.

    ## Extensions and future work

    The closure hack (`extract_function_state`) should really be
    extended to cover local functions. This is mostly a question of
    going deeply into values that are mapped to functions, and of
    maintaining an id-keyed map from cell to `OpaqueValue`.
    ### Use a dedicated DSL

    First, the whole idea of introspecting closures to stub in logical
    variable is a terrible hack (looks cool though ;). A real
    production implementation should apply CPS partial evaluation to a
    purely functional programming language, then bulk reverse-evaluate
    with a SIMD implementation of the logical program.

    There'll be restrictions on the output traces, but that's OK: a
    different prototype makes me believe the restrictions correspond to
    deterministic logspace (L), and it makes sense to restrict our
    analyses to L. Just like grammars are easier to work with when
    restricted to LL(1), DSLs that only capture L tend to be easier to
    analyse and optimise... and L is reasonably larger (a
    polynomial-time algorithm that's not in L would be a *huge*
    result).

    ### Handle local functions

    While we sin with the closure hack (`extract_function_state`) it
    should really be extended to cover local functions. This is mostly
    a question of going deeply into values that are mapped to
    functions, and of maintaining an id-keyed map from cell to
    `OpaqueValue`.

    We could also add support for partial application objects, which
    may be easier for multiprocessing.

    ### Parallelism

    There is currently no support for parallelism, only caching. It
    should be easy to handle the return values (`NestedDict`s and
    @@ -822,19 +956,28 @@ than cores.
    That being said, the complexity is probably worth the speed up on
    realistic queries.

    ### Theta joins

    At a higher level, we could support comparison joins (e.g., less
    than, greater than or equal, in range) if only we represented the
    branching programs with a data structure that supported these
    queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    let us handle these "theta" joins, for tbe low low cost of a
    polylogarithmic multiplicative factor in space and time.

    ### Self-adjusting computation

    Finally, we could update the indexed branching programs
    incrementally after small changes to the input data. This might
    sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
    but I think viewing each `_precompute_map_reduce` call as a purely
    functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).

    I guess, in a way, this code shows how we can simply decorrelate
    nested loops: we just have to be OK with building one index for
    each loop in the nest.
    Once we add logic to recycle previously constructed indexes, it
    will probably make sense to allow an initial filtering step before
    map/reduce, with a cache key on the filter function (with closed
    over values and all). We can often implement the filtering more
    efficiently than we can run functions backward, and we'll also
    observe that slightly different filter functions often result
    in not too dissimilar filtered sets. Factoring out this filtering
    can thus enable more reuse of partial precomputed results.
    279 changes: 208 additions & 71 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -1,20 +1,35 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    ##{<h2>license, imports</h2>
    # Yannakakis.py by Paul Khuong
    #
    # To the extent possible under law, the person who associated CC0 with
    # Yannakakis.py has waived all copyright and related or neighboring rights
    # to Yannakakis.py.
    #
    # You should have received a copy of the CC0 legalcode along with this
    # work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.

    import collections.abc

    ## # Linear-time analytics queries in raw Python
    ##}

    ## # Linear-time analytical queries in plain Python
    ##
    ## <small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small>
    ##
    ## This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
    ## lets us implement linear-time (wrt the number of input data rows)
    ## analytical queries in regular programming languages, without
    ## offloading joins to a specialised database query language, thus
    ## avoiding the associated impedance mismatch. There are restrictions
    ## on the queries we can express -- Yannakakis's algorithm relies on a
    ## hypertree width of 1, and associated decomposition -- but that's
    ## kind of reasonable: a fractional hypertree width > 1 would mean tere
    ## are databases for which the intermediate results could grow much
    ## larger than the input database (superlinearly), from the [AGM bound](https://arxiv.org/abs/1711.03860).
    ## Structured programs also naturally yield a hypertree decomposition,
    ## hypertree width of 1, and on having hypertree decomposition as a witness
    ## -- but that's kind of reasonable: a fractional hypertree width > 1
    ## would mean there are databases for which the intermediate results
    ## could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860).
    ## The hypertree decomposition witness isn't a burden either:
    ## structured programs naturally yield a hypertree decomposition,
    ## unlike more declarative logic programs that tend to hide the
    ## structure implicit in the programmer's thinking.
    ##
    @@ -26,80 +41,84 @@
    ## We also assume that these functions are always used in a map/reduce
    ## pattern, and thus we only memoise the result of
    ## `map_reduce(function, input)`, with a group-structured reduction:
    ## it must be associative and commutative, and there must be a zero
    ## (neutral) value.
    ## the reduce function must be associative and commutative, and there
    ## must be a zero (neutral) value.
    ##
    ## With these constraints, we can express joins in natural Python
    ## without incurring the poor runtime scaling of the nested loops we
    ## actually wrote. This Python file describes the building blocks
    ## to handle aggregation queries like the following
    ##
    ## >>> id_skus = ((1, 2), (2, 2), (1, 3))
    ## >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    ## >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    ## >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    ## >>> def sku_min_cost(sku):
    ## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ## ...
    ## >>> def sum_odd_or_even_skus(mod_two):
    ## ... def count_if_mod_two(id_sku):
    ## ... id, sku = id_sku
    ## ... if id % 2 == mod_two:
    ## ... return Sum(sku_min_cost(sku))
    ## ... return map_reduce(count_if_mod_two, id_skus)
    ## ...
    ##
    ## with linear scaling in the length of `id_skus` and `sku_costs`, and
    ## caching for similar queries.
    ##
    ## At small scale, everything's fast.
    ## At a small scale, everything's fast.
    ##
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20
    ## 0.00020003318786621094
    ## 0.001024007797241211
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50
    ## 0.0001990795135498047
    ## 4.982948303222656e-05
    ##
    ## As we increase the scale by 1000x in both tuples, the runtime
    ## scales roughly linearly for the first query, and is unchanged
    ## scales (sub) linearly for the first query, and is unchanged
    ## for the second:
    ##
    ## >>> id_skus = id_skus * 1000
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20000
    ## 0.19550323486328125
    ## 0.09638118743896484
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50000
    ## 0.00029087066650390625
    ## 8.797645568847656e-05
    ##
    ## This still kind of holds when we multiply by another factor of 10
    ## (albeit with additional slowdowns due to the memory footprint):
    ## This still kind of holds when we multiply by another factor of 100:
    ##
    ## >>> id_skus = id_skus * 10
    ## >>> sku_costs = sku_costs * 10
    ## >>> id_skus = id_skus * 100
    ## >>> sku_costs = sku_costs * 100
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 200000
    ## 11.275727272033691
    ## 2000000
    ## 5.9715189933776855
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 500000
    ## 0.0015010833740234375
    ## 5000000
    ## 0.00021195411682128906
    ##
    ## The magic behind the curtains is memoisation (unsurprisingly), but
    ## a special implementation that can share work for similar closures:
    ## the memoisation key consists of the function *without closed over
    ## bindings* and the input, and the memoised value is a data structure
    ## from the tuple of closed over values to the `map_reduce` output.
    ##
    ## This data structure is where Yannakakis's algorithm comes in: we'll
    ## iterate over each datum in the input, run the function on it with
    ## logical variables instead of the closed over values, and generate a
    ## mapping from closed over values to result for all non-zero results.
    ## We'll then merge the mappings for all input data together (there is
    ## no natural ordering here, hence the group structure).
    ## bindings* and the call arguments, while the memoised value is a
    ## data structure from the tuple of closed over values to the
    ## `map_reduce` output.
    ##
    ## This concrete representation of the function as a branching program
    ## is the core of Yannakakis's algorithm: we'll iterate over each
    ## datum in the input, run the function on it with logical variables
    ## instead of the closed over values, and generate a mapping from
    ## closed over values to result for all non-zero results. We'll then
    ## merge the mappings for all input data together (there is no natural
    ## ordering here, hence the group structure).
    ##
    ## The output data structure controls the join we can implement. We
    ## show how a simple stack of nested key-value mappings handles
    ## equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    ## would handle inequalities (the rest of the machinery already works
    ##.in terms of less than/greater than constraints).
    ## would handle inequalities, i.e., "theta" joins (the rest of the
    ## machinery already works in terms of less than/greater than
    ## constraints).
    ##
    ## As long as we explore a bounded number of paths for each datum and
    ## a bounded number of `function, input` cache keys, we'll spend a
    @@ -108,9 +127,8 @@
    ## when there are nested `map_reduce` calls, which would naïvely
    ## result in polynomial time (degree equal to the nesting depth).



    ## ## Memoising through a Python function's closed over values
    ##{<h2>Memoising through a Python function's closed over values</h2>
    ##
    ## Even if functions were hashable and comparable for extensional
    ## equality, directly using closures as memoisation keys in calls like
    @@ -129,6 +147,7 @@
    ## handles closing over immutable atomic values like integers or
    ## strings, but not, e.g., functions (yet), or mutable bindings.


    def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    @@ -164,12 +183,16 @@ def extract_function_state(function):
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None: # Toplevel function
    if function.__closure__ is None: # Toplevel function
    assert names == ()

    def rebind(values):
    if len(values) != 0:
    raise RuntimeError(f"Values must be empty for toplevel function. values={values}")
    raise RuntimeError(
    f"Values must be empty for toplevel function. values={values}"
    )
    return function

    return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    @@ -178,18 +201,25 @@ def rebind(values):
    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
    if len(values) != len(names):
    raise RuntimeError(f"Values must match names. names={names} values={values}")
    raise RuntimeError(
    f"Values must match names. names={names} values={values}"
    )
    return function.__class__(
    code,
    function.__globals__,
    function.__name__,
    function.__defaults__,
    tuple(cell.__class__(value)
    for cell, value in zip(function.__closure__, values)))
    tuple(
    cell.__class__(value)
    for cell, value in zip(function.__closure__, values)
    ),
    )

    return code, closure, rebind, names


    ##}


    ## ## Logical variables for closed-over values
    ##
    @@ -228,6 +258,7 @@ def rebind(values):
    ## N.B., the set of constraints we can handle is determined by the
    ## ground data structure to represent finitely supported functions.


    class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    @@ -263,7 +294,7 @@ class OpaqueValue:

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    self.name = name
    # Upper and lower bounds
    @@ -286,7 +317,12 @@ def indefinite(self):

    def definite(self):
    """Returns whether is `OpaqueValue` is constrained to an exact value."""
    return self.lower == self.upper and self.lower is not None and not self.lowerExcl and not self.upperExcl
    return (
    self.lower == self.upper
    and self.lower is not None
    and not self.lowerExcl
    and not self.upperExcl
    )

    def value(self):
    """Returns the exact value for this `OpaqueValue`, assuming there is one."""
    @@ -346,17 +382,22 @@ def __cmp__(self, other):
    that value. Otherwise, we ask `CMP_HANDLER` what value to return
    and update the bound accordingly.
    """
    if isinstance(other, OpaqueValue) and \
    self.definite() and not other.definite():
    if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
    # If we have a definite value and `other` is an indefinite
    # `OpaqueValue`, flip the comparison order to let the `other`
    # argument be a ground value.
    return -other.__cmp__(self.value())

    assert not isinstance(other, OpaqueValue) or other.definite()
    if isinstance(other, OpaqueValue) and not other.definite():
    raise RuntimeError(
    f"OpaqueValue may only be compared with ground values. self={self} other={other}"
    )
    if isinstance(other, OpaqueValue):
    other = other.value() # Make sure `other` is a ground value
    assert other is not None # We use `None` internally, and it doesn't compare well

    if other is None:
    # We use `None` internally, and it doesn't compare well
    raise RuntimeError("OpaqueValue may not be compared with None")

    compatible, order = self._contains(other, False)
    if order is not None:
    @@ -404,7 +445,6 @@ def __gt__(self, other):
    def __ge__(self, other):
    return self.__cmp__(other) >= 0



    ## ## Depth-first exploration of a function call's support
    ##
    @@ -429,6 +469,7 @@ def __ge__(self, other):
    ## We then do the same for each datum in our input sequence, and merge
    ## results for identical keys together.


    def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.
    @@ -489,6 +530,7 @@ def enumerate_opaque_values(function, values):
    value.reset()

    stackIndex = 0

    def handle(value, other):
    nonlocal stackIndex
    if len(explorationStack) == stackIndex:
    @@ -515,7 +557,9 @@ def handle(value, other):
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert value.definite() or value.indefinite(), f"{value} temporarily unsupported"
    assert (
    value.definite() or value.indefinite()
    ), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)

    @@ -546,9 +590,7 @@ def enumerate_supporting_values(function, args):
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg),
    values)

    yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)


    ## ## Type driven merges
    @@ -564,8 +606,10 @@ def enumerate_supporting_values(function, args):
    ## could have different ones, e.g., hyperloglog unique counts, or
    ## streaming statistical moments.


    class Sum:
    """A counter for summed values."""

    def __init__(self, value=0):
    self.value = value

    @@ -576,6 +620,7 @@ def merge(self, other):

    class Min:
    """A running `min` value tracker."""

    def __init__(self, value=None):
    self.value = value

    @@ -585,6 +630,7 @@ def merge(self, other):
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)


    ## ## Nested dictionary with wildcard
    ##
    @@ -603,17 +649,21 @@ def merge(self, other):
    ## another `NestedDictLevel`. At the leaf, the value is the
    ## precomputed value.


    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    """

    def __init__(self):
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(key, self.wildcard if self.wildcard is not None else default)
    return self.dict.get(
    key, self.wildcard if self.wildcard is not None else default
    )

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    @@ -629,6 +679,7 @@ class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    """

    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    @@ -658,6 +709,31 @@ def set(self, key, value):
    current.set(levelKey, next)


    ##{<h2>Identity key-value maps</h2>
    class IdMap:
    def __init__(self):
    self.entries = dict() # tuple of id -> (key, value)
    # the value's first element keeps the ids stable.

    def get(self, keys, default=None):
    ids = tuple(id(key) for key in keys)
    return self.entries.get(ids, (None, default))[1]

    def __contains__(self, keys):
    ids = tuple(id(key) for key in keys)
    return ids in self.entries

    def __getitem__(self, keys):
    ids = tuple(id(key) for key in keys)
    return self.entries[ids][1]

    def __setitem__(self, keys, value):
    ids = tuple(id(key) for key in keys)
    self.entries[ids] = (keys, value)


    ##}


    ## ## Cached `map_reduce`
    ##
    @@ -671,7 +747,7 @@ def set(self, key, value):
    ## supporting values.
    ##
    ## Again, we only support ground equality constraints (see assertion
    ## on L453), i.e., only equijoins. There's nothing that stops a more
    ## on L558), i.e., only equijoins. There's nothing that stops a more
    ## sophisticated implementation from using range trees to support
    ## inequality or range joins.
    ##
    @@ -691,6 +767,7 @@ def set(self, key, value):
    ## I think is really the general heart of Yannakakis's algorithm
    ## as an instance of bottom-up dynamic programming.


    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    @@ -709,6 +786,7 @@ def _precompute_map_reduce(function, depth, inputIterable):
    >>> nd.get((4,)).value
    2
    """

    def merge(dst, update):
    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    @@ -727,7 +805,7 @@ def merge(dst, update):
    return cache


    AGGREGATE_CACHE = dict() # Map from function, input sequence -> NestedDict
    AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    @@ -766,8 +844,8 @@ def map_reduce(function, inputIterable):
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    @@ -785,23 +863,73 @@ def map_reduce(function, inputIterable):
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
    AGGREGATE_CACHE[(code, inputIterable)] = \
    _precompute_map_reduce(function, len(closure), inputIterable)
    return AGGREGATE_CACHE[(code, inputIterable)].get(closure, None)

    AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
    function, len(closure), inputIterable
    )
    return AGGREGATE_CACHE[code, inputIterable].get(closure, None)


    if __name__ == "__main__":
    import doctest

    doctest.testmod()



    ## ## Is this actually a DB post?
    ##
    ## Although the intro name-dropped Yannakakis, the presentation here
    ## has a very programming language / logic programming flavour. I
    ## think the logic programming point of view, where we run a program
    ## backwards with logical variables, is much clearer than the specific
    ## case of conjunctive equijoin queries in the usual presentation of
    ## Yannakakis's algorithm. In particular, I think there's a clear
    ## path to handle range or comparison joins: it's all about having an
    ## index data structure to handle range queries.
    ##
    ## It should be clear how to write conjunctive queries as Python
    ## functions, given a hypertree decomposition. The reverse is much
    ## more complex, if only because Python is much more powerful than
    ## just CQ, and that's actually a liability: this hack will blindly
    ## try to convert any function to a branching program, instead of
    ## giving up noisily when the function is too complex.
    ##
    ## I guess, in a way, we just showed a trivial way to decorrelate
    ## queries with a hypertree-width of 1. We just have to be OK with
    ## building one index for each loop in the nest... but it should be
    ## possible to pattern match on pre-defined indexes and avoid obvious
    ## redundancy.
    ##
    ## ## Extensions and future work
    ##
    ## The closure hack (`extract_function_state`) should really be
    ## extended to cover local functions. This is mostly a question of
    ## going deeply into values that are mapped to functions, and of
    ## maintaining an id-keyed map from cell to `OpaqueValue`.
    ## ### Use a dedicated DSL
    ##
    ## First, the whole idea of introspecting closures to stub in logical
    ## variable is a terrible hack (looks cool though ;). A real
    ## production implementation should apply CPS partial evaluation to a
    ## purely functional programming language, then bulk reverse-evaluate
    ## with a SIMD implementation of the logical program.
    ##
    ## There'll be restrictions on the output traces, but that's OK: a
    ## different prototype makes me believe the restrictions correspond to
    ## deterministic logspace (L), and it makes sense to restrict our
    ## analyses to L. Just like grammars are easier to work with when
    ## restricted to LL(1), DSLs that only capture L tend to be easier to
    ## analyse and optimise... and L is reasonably larger (a
    ## polynomial-time algorithm that's not in L would be a *huge*
    ## result).
    ##
    ## ### Handle local functions
    ##
    ## While we sin with the closure hack (`extract_function_state`) it
    ## should really be extended to cover local functions. This is mostly
    ## a question of going deeply into values that are mapped to
    ## functions, and of maintaining an id-keyed map from cell to
    ## `OpaqueValue`.
    ##
    ## We could also add support for partial application objects, which
    ## may be easier for multiprocessing.
    ##
    ## ### Parallelism
    ##
    ## There is currently no support for parallelism, only caching. It
    ## should be easy to handle the return values (`NestedDict`s and
    @@ -822,19 +950,28 @@ def map_reduce(function, inputIterable):
    ## That being said, the complexity is probably worth the speed up on
    ## realistic queries.
    ##
    ## ### Theta joins
    ##
    ## At a higher level, we could support comparison joins (e.g., less
    ## than, greater than or equal, in range) if only we represented the
    ## branching programs with a data structure that supported these
    ## queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    ## let us handle these "theta" joins, for tbe low low cost of a
    ## polylogarithmic multiplicative factor in space and time.
    ##
    ## ### Self-adjusting computation
    ##
    ## Finally, we could update the indexed branching programs
    ## incrementally after small changes to the input data. This might
    ## sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
    ## but I think viewing each `_precompute_map_reduce` call as a purely
    ## functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).
    ##
    ## I guess, in a way, this code shows how we can simply decorrelate
    ## nested loops: we just have to be OK with building one index for
    ## each loop in the nest.
    ## Once we add logic to recycle previously constructed indexes, it
    ## will probably make sense to allow an initial filtering step before
    ## map/reduce, with a cache key on the filter function (with closed
    ## over values and all). We can often implement the filtering more
    ## efficiently than we can run functions backward, and we'll also
    ## observe that slightly different filter functions often result
    ## in not too dissimilar filtered sets. Factoring out this filtering
    ## can thus enable more reuse of partial precomputed results.
  10. pkhuong revised this gist Jul 24, 2023. 2 changed files with 15 additions and 23 deletions.
    33 changes: 11 additions & 22 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|^\x20{4}\x0c|\n<br>\n|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    import collections.abc

    @@ -49,6 +49,7 @@ with linear scaling in the length of `id_skus` and `sku_costs`, and
    caching for similar queries.

    At small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20
    0.00020003318786621094
    @@ -59,6 +60,7 @@ At small scale, everything's fast.
    As we increase the scale by 1000x in both tuples, the runtime
    scales roughly linearly for the first query, and is unchanged
    for the second:

    >>> id_skus = id_skus * 1000
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    @@ -70,6 +72,7 @@ for the second:

    This still kind of holds when we multiply by another factor of 10
    (albeit with additional slowdowns due to the memory footprint):

    >>> id_skus = id_skus * 10
    >>> sku_costs = sku_costs * 10
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    @@ -106,9 +109,7 @@ when there are nested `map_reduce` calls, which would naïvely
    result in polynomial time (degree equal to the nesting depth).



    <br>


    ## Memoising through a Python function's closed over values

    Even if functions were hashable and comparable for extensional
    @@ -189,9 +190,7 @@ strings, but not, e.g., functions (yet), or mutable bindings.
    return code, closure, rebind, names



    <br>


    ## Logical variables for closed-over values

    We wish to enumerate the support of a function call (parameterised
    @@ -406,9 +405,7 @@ ground data structure to represent finitely supported functions.
    return self.__cmp__(other) >= 0



    <br>


    ## Depth-first exploration of a function call's support

    We assume a `None` result represents a zero value wrt the aggregate
    @@ -553,9 +550,7 @@ results for identical keys together.
    values)



    <br>


    ## Type driven merges

    The interesting part of map/reduce is the reduction step. While
    @@ -590,9 +585,7 @@ streaming statistical moments.
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)

    <br>


    ## Nested dictionary with wildcard

    There's a direct relationship between the data structure we use to
    @@ -665,9 +658,7 @@ precomputed value.
    current.set(levelKey, next)



    <br>


    ## Cached `map_reduce`

    As mentioned earlier, we assume `reduce` is determined implicitly
    @@ -804,9 +795,7 @@ as an instance of bottom-up dynamic programming.
    doctest.testmod()



    <br>


    ## Extensions and future work

    The closure hack (`extract_function_state`) should really be
    5 changes: 4 additions & 1 deletion yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|^\x20{4}\x0c|\n<br>\n|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|

    import collections.abc

    @@ -49,6 +49,7 @@
    ## caching for similar queries.
    ##
    ## At small scale, everything's fast.
    ##
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20
    ## 0.00020003318786621094
    @@ -59,6 +60,7 @@
    ## As we increase the scale by 1000x in both tuples, the runtime
    ## scales roughly linearly for the first query, and is unchanged
    ## for the second:
    ##
    ## >>> id_skus = id_skus * 1000
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    @@ -70,6 +72,7 @@
    ##
    ## This still kind of holds when we multiply by another factor of 10
    ## (albeit with additional slowdowns due to the memory footprint):
    ##
    ## >>> id_skus = id_skus * 10
    ## >>> sku_costs = sku_costs * 10
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
  11. pkhuong revised this gist Jul 24, 2023. 2 changed files with 107 additions and 3 deletions.
    54 changes: 53 additions & 1 deletion yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -29,7 +29,59 @@ pattern, and thus we only memoise the result of
    it must be associative and commutative, and there must be a zero
    (neutral) value.

    The memoisation key consists of the function *without closed over
    With these constraints, we can express joins in natural Python
    without incurring the poor runtime scaling of the nested loops we
    actually wrote. This Python file describes the building blocks
    to handle aggregation queries like the following

    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)

    with linear scaling in the length of `id_skus` and `sku_costs`, and
    caching for similar queries.

    At small scale, everything's fast.
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20
    0.00020003318786621094
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50
    0.0001990795135498047

    As we increase the scale by 1000x in both tuples, the runtime
    scales roughly linearly for the first query, and is unchanged
    for the second:
    >>> id_skus = id_skus * 1000
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    20000
    0.19550323486328125
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    50000
    0.00029087066650390625

    This still kind of holds when we multiply by another factor of 10
    (albeit with additional slowdowns due to the memory footprint):
    >>> id_skus = id_skus * 10
    >>> sku_costs = sku_costs * 10
    >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    200000
    11.275727272033691
    >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    500000
    0.0015010833740234375

    The magic behind the curtains is memoisation (unsurprisingly), but
    a special implementation that can share work for similar closures:
    the memoisation key consists of the function *without closed over
    bindings* and the input, and the memoised value is a data structure
    from the tuple of closed over values to the `map_reduce` output.

    56 changes: 54 additions & 2 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -28,8 +28,60 @@
    ## `map_reduce(function, input)`, with a group-structured reduction:
    ## it must be associative and commutative, and there must be a zero
    ## (neutral) value.
    ##
    ## The memoisation key consists of the function *without closed over
    ##
    ## With these constraints, we can express joins in natural Python
    ## without incurring the poor runtime scaling of the nested loops we
    ## actually wrote. This Python file describes the building blocks
    ## to handle aggregation queries like the following
    ##
    ## >>> id_skus = ((1, 2), (2, 2), (1, 3))
    ## >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    ## >>> def sku_min_cost(sku):
    ## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    ## >>> def sum_odd_or_even_skus(mod_two):
    ## ... def count_if_mod_two(id_sku):
    ## ... id, sku = id_sku
    ## ... if id % 2 == mod_two:
    ## ... return Sum(sku_min_cost(sku))
    ## ... return map_reduce(count_if_mod_two, id_skus)
    ##
    ## with linear scaling in the length of `id_skus` and `sku_costs`, and
    ## caching for similar queries.
    ##
    ## At small scale, everything's fast.
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20
    ## 0.00020003318786621094
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50
    ## 0.0001990795135498047
    ##
    ## As we increase the scale by 1000x in both tuples, the runtime
    ## scales roughly linearly for the first query, and is unchanged
    ## for the second:
    ## >>> id_skus = id_skus * 1000
    ## >>> sku_costs = sku_costs * 1000
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 20000
    ## 0.19550323486328125
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 50000
    ## 0.00029087066650390625
    ##
    ## This still kind of holds when we multiply by another factor of 10
    ## (albeit with additional slowdowns due to the memory footprint):
    ## >>> id_skus = id_skus * 10
    ## >>> sku_costs = sku_costs * 10
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
    ## 200000
    ## 11.275727272033691
    ## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
    ## 500000
    ## 0.0015010833740234375
    ##
    ## The magic behind the curtains is memoisation (unsurprisingly), but
    ## a special implementation that can share work for similar closures:
    ## the memoisation key consists of the function *without closed over
    ## bindings* and the input, and the memoised value is a data structure
    ## from the tuple of closed over values to the `map_reduce` output.
    ##
  12. pkhuong revised this gist Jul 24, 2023. 2 changed files with 23 additions and 9 deletions.
    30 changes: 22 additions & 8 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|<br>|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|^\x20{4}\x0c|\n<br>\n|

    import collections.abc

    @@ -54,7 +54,9 @@ when there are nested `map_reduce` calls, which would naïvely
    result in polynomial time (degree equal to the nesting depth).


    <br>

    <br>

    ## Memoising through a Python function's closed over values

    Even if functions were hashable and comparable for extensional
    @@ -135,7 +137,9 @@ strings, but not, e.g., functions (yet), or mutable bindings.
    return code, closure, rebind, names


    <br>

    <br>

    ## Logical variables for closed-over values

    We wish to enumerate the support of a function call (parameterised
    @@ -350,7 +354,9 @@ ground data structure to represent finitely supported functions.
    return self.__cmp__(other) >= 0


    <br>

    <br>

    ## Depth-first exploration of a function call's support

    We assume a `None` result represents a zero value wrt the aggregate
    @@ -495,7 +501,9 @@ results for identical keys together.
    values)


    <br>

    <br>

    ## Type driven merges

    The interesting part of map/reduce is the reduction step. While
    @@ -530,7 +538,9 @@ streaming statistical moments.
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)
    <br>

    <br>

    ## Nested dictionary with wildcard

    There's a direct relationship between the data structure we use to
    @@ -603,7 +613,9 @@ precomputed value.
    current.set(levelKey, next)


    <br>

    <br>

    ## Cached `map_reduce`

    As mentioned earlier, we assume `reduce` is determined implicitly
    @@ -740,7 +752,9 @@ as an instance of bottom-up dynamic programming.
    doctest.testmod()


    <br>

    <br>

    ## Extensions and future work

    The closure hack (`extract_function_state`) should really be
    2 changes: 1 addition & 1 deletion yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|<br>|
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|^\x20{4}\x0c|\n<br>\n|

    import collections.abc

  13. pkhuong created this gist Jul 24, 2023.
    785 changes: 785 additions & 0 deletions yannakakis.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,785 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|<br>|

    import collections.abc

    # Linear-time analytics queries in raw Python

    This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
    lets us implement linear-time (wrt the number of input data rows)
    analytical queries in regular programming languages, without
    offloading joins to a specialised database query language, thus
    avoiding the associated impedance mismatch. There are restrictions
    on the queries we can express -- Yannakakis's algorithm relies on a
    hypertree width of 1, and associated decomposition -- but that's
    kind of reasonable: a fractional hypertree width > 1 would mean tere
    are databases for which the intermediate results could grow much
    larger than the input database (superlinearly), from the [AGM bound](https://arxiv.org/abs/1711.03860).
    Structured programs also naturally yield a hypertree decomposition,
    unlike more declarative logic programs that tend to hide the
    structure implicit in the programmer's thinking.

    The key is to mark function arguments as either negative (true
    inputs) or positive (possible interesting input values derived from
    negative arguments). In this hack, closed over values are
    positive, and the only negative argument is the current data row.

    We also assume that these functions are always used in a map/reduce
    pattern, and thus we only memoise the result of
    `map_reduce(function, input)`, with a group-structured reduction:
    it must be associative and commutative, and there must be a zero
    (neutral) value.

    The memoisation key consists of the function *without closed over
    bindings* and the input, and the memoised value is a data structure
    from the tuple of closed over values to the `map_reduce` output.

    This data structure is where Yannakakis's algorithm comes in: we'll
    iterate over each datum in the input, run the function on it with
    logical variables instead of the closed over values, and generate a
    mapping from closed over values to result for all non-zero results.
    We'll then merge the mappings for all input data together (there is
    no natural ordering here, hence the group structure).

    The output data structure controls the join we can implement. We
    show how a simple stack of nested key-value mappings handles
    equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    would handle inequalities (the rest of the machinery already works
    .in terms of less than/greater than constraints).

    As long as we explore a bounded number of paths for each datum and
    a bounded number of `function, input` cache keys, we'll spend a
    bounded amount of time on each input datum, and thus linear time
    total. The magic of Yannakakis's algorithm is that this works even
    when there are nested `map_reduce` calls, which would naïvely
    result in polynomial time (degree equal to the nesting depth).


    <br>
    ## Memoising through a Python function's closed over values

    Even if functions were hashable and comparable for extensional
    equality, directly using closures as memoisation keys in calls like
    `map_reduce(function, input)` would result in superlinear runtime
    for nested `map_reduce` calls.

    This `extract_function_state` accepts a function (with or without
    closed over state), and returns four values:
    1. The underlying code object
    2. The tuple of closed over values (current value for mutable cells)
    3. A function to rebind the closure with new closed over values
    4. The name of the closed over bindings

    The third return value, the `rebind` function, silently fails on
    complicated cases; this is a hack, after all. In short, it only
    handles closing over immutable atomic values like integers or
    strings, but not, e.g., functions (yet), or mutable bindings.

    def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    to return a fresh closure with a different tuple of closed over
    values, and a closure of closed over names

    >>> def test(x): return lambda y: x == y
    >>> extract_function_state(test)[1]
    ()
    >>> extract_function_state(test)[3]
    ()
    >>> fun = test(4)
    >>> extract_function_state(fun)[1]
    (4,)
    >>> extract_function_state(fun)[3]
    ('x',)
    >>>
    >>> fun(4)
    True
    >>> fun(5)
    False
    >>> rebind = extract_function_state(fun)[2]
    >>> rebound_4, rebound_5 = rebind([4]), rebind([5])
    >>> rebound_4(4)
    True
    >>> rebound_4(5)
    False
    >>> rebound_5(4)
    False
    >>> rebound_5(5)
    True
    """
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None: # Toplevel function
    assert names == ()
    def rebind(values):
    if len(values) != 0:
    raise RuntimeError(f"Values must be empty for toplevel function. values={values}")
    return function
    return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    assert len(names) == len(closure), (closures, names)

    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
    if len(values) != len(names):
    raise RuntimeError(f"Values must match names. names={names} values={values}")
    return function.__class__(
    code,
    function.__globals__,
    function.__name__,
    function.__defaults__,
    tuple(cell.__class__(value)
    for cell, value in zip(function.__closure__, values)))

    return code, closure, rebind, names


    <br>
    ## Logical variables for closed-over values

    We wish to enumerate the support of a function call (parameterised
    over closed over values), and the associated result value. We'll
    do that by rebinding the closure to point at instances of
    `OpaqueValue` and enumerating all the possible constraints on these
    `OpaqueValue`s. These `OpaqueValue`s work like logical variables
    that let us run a function in reverse: when get a non-zero
    (non-None) return value, we look at the accumulated constraint set
    on the opaque values and use them to update the data representation
    of the function's result (we assume that we can represent the
    constraints on all `OpaqueValue`s in our result data structure).

    Currently, we only support nested dictionaries, so each
    `OpaqueValue`s must be either fully unconstrained (wildcard that
    matches any value), or constrained to be exactly equal to a value.
    There's no reason we can't use k-d trees though, and it's not
    harder to track a pair of bounds (lower and upper) than a set of
    inequalities, so we'll handle the general ordered `OpaqueValue`
    case.

    In the input program (the query), we assume closed over values are
    only used for comparisons (equality, inequality, relational
    operators, or conversion to bool, i.e., non-zero testing). Knowing
    the result of each (non-redundant) comparison tightens the range
    of potential values for the `OpaqueValue`... eventually down to a
    single point value that our hash-based indexes can handle.

    Of course, if a comparison isn't redundant, there are multiple
    feasible results, so we need an external oracle to pick one. An
    external caller is responsible for injecting its logic as
    `OpaqueValue.CMP_HANDLER`, and driving the exploration of the
    search space.

    N.B., the set of constraints we can handle is determined by the
    ground data structure to represent finitely supported functions.

    class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    potentially exclusive.

    `OpaqueValue`s are only used in queries for comparisons with
    ground values. All comparisons are turned into three-way
    `__cmp__` calls; non-redundant `__cmp__` calls (which could return
    more than one value) are resolved by calling `CMP_HANDLER`
    and tightening the bound in response.

    >>> x = OpaqueValue("x")
    >>> x == True
    True
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> x > 4
    False
    >>> x < 4
    False
    >>> x == 4
    True
    >>> x < 10
    True
    >>> x >= 10
    False
    >>> x > -10
    True
    >>> x <= -10
    False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    self.name = name
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True

    def __str__(self):
    left = "(" if self.lowerExcl else "["
    right = ")" if self.upperExcl else "]"
    return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"

    def reset(self):
    """Clears all accumulated constraints on this `OpaqueValue`."""
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True

    def indefinite(self):
    """Returns whether this `OpaqueValue` is still unconstrained."""
    return self.lower == None and self.upper == None

    def definite(self):
    """Returns whether is `OpaqueValue` is constrained to an exact value."""
    return self.lower == self.upper and self.lower is not None and not self.lowerExcl and not self.upperExcl

    def value(self):
    """Returns the exact value for this `OpaqueValue`, assuming there is one."""
    return self.lower

    def _contains(self, value, strictly):
    """Returns whether this `OpaqueValue`'s range includes `value`
    (maybe `strictly` inside the range).

    The first value is the containment truth value, and the second
    is the forced `__cmp__` value, if the `value` is *not*
    (strictly) contained in the range.
    """
    if self.lower is not None:
    if self.lower > value:
    return False, 1
    if self.lower == value and (strictly or self.lowerExcl):
    return False, 1
    if self.upper is not None:
    if self.upper < value:
    return False, -1
    if self.upper == value and (strictly or self.upperExcl):
    return False, -1
    return True, 0 if self.definite() else None

    def contains(self, value, strictly=False):
    """Returns whether `values` is `strictly` contained in the
    `OpaqueValue`'s range.
    """
    try:
    return self._contains(value, strictly)[0]
    except TypeError:
    return False

    def potential_mask(self, other):
    """Returns the set of potential `__cmp__` values for `other`
    that are compatible with the current range: bit 0 is set if
    `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
    is possible.
    """
    if not self.contains(other):
    return 0
    if self.contains(other, strictly=True):
    return 7
    if self.definite() and self.value == other:
    return 2
    # We have a non-strict inclusion, and inequality.
    if self.lower == other and self.lowerExcl:
    return 6
    assert self.upper == other and self.upperExcl
    return 3

    def __cmp__(self, other):
    """Three-way comparison between this `OpaqueValue` and `other`.

    When the result is known from the current bound, we just return
    that value. Otherwise, we ask `CMP_HANDLER` what value to return
    and update the bound accordingly.
    """
    if isinstance(other, OpaqueValue) and \
    self.definite() and not other.definite():
    # If we have a definite value and `other` is an indefinite
    # `OpaqueValue`, flip the comparison order to let the `other`
    # argument be a ground value.
    return -other.__cmp__(self.value())

    assert not isinstance(other, OpaqueValue) or other.definite()
    if isinstance(other, OpaqueValue):
    other = other.value() # Make sure `other` is a ground value
    assert other is not None # We use `None` internally, and it doesn't compare well

    compatible, order = self._contains(other, False)
    if order is not None:
    return order
    order = OpaqueValue.CMP_HANDLER(self, other)
    if order < 0:
    self._add_bound(upper=other, upperExcl=True)
    elif order == 0:
    self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
    else:
    self._add_bound(lower=other, lowerExcl=True)
    return order

    def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
    """Updates the internal range for this new bound."""
    assert lower is None or self.contains(lower, strictly=lowerExcl)
    assert upper is None or self.contains(upper, strictly=upperExcl)
    if lower is not None:
    self.lower = lower
    self.lowerExcl = lowerExcl

    assert upper is None or self.contains(upper, strictly=upperExcl)
    if upper is not None:
    self.upper = upper
    self.upperExcl = upperExcl

    def __bool__(self):
    return self != 0

    def __eq__(self, other):
    return self.__cmp__(other) == 0

    def __ne__(self, other):
    return self.__cmp__(other) != 0

    def __lt__(self, other):
    return self.__cmp__(other) < 0

    def __le__(self, other):
    return self.__cmp__(other) <= 0

    def __gt__(self, other):
    return self.__cmp__(other) > 0

    def __ge__(self, other):
    return self.__cmp__(other) >= 0


    <br>
    ## Depth-first exploration of a function call's support

    We assume a `None` result represents a zero value wrt the aggregate
    merging function (e.g., 0 for a sum). For convenience, we also treat
    tuples and lists of `None`s identically.

    We simply maintain a stack of `CMP_HANDLER` calls, where each entry
    in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER`
    results still to explore (-1, 0, or 1). This stack is filled on demand,
    and `CMP_HANDLER` returns the first result allowed by the bitset.

    Once we have a result, we tweak the stack to force depth-first
    exploration of a different part of the solution space: we drop
    the first bit in the bitset of results to explore, and drop the
    entry wholesale if the bitset is now empty (all zero). When
    this tweaking leaves an empty stack, we're done.

    This ends up enumerating all the paths through the function call
    with a non-recursive depth-first traversal.

    We then do the same for each datum in our input sequence, and merge
    results for identical keys together.

    def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.

    >>> is_zero_result(None)
    True
    >>> is_zero_result(False)
    False
    >>> is_zero_result(True)
    False
    >>> is_zero_result(0)
    False
    >>> is_zero_result(-10)
    False
    >>> is_zero_result(1.5)
    False
    >>> is_zero_result("")
    False
    >>> is_zero_result("asd")
    False
    >>> is_zero_result((None, None))
    True
    >>> is_zero_result((None, 1))
    False
    >>> is_zero_result([])
    True
    >>> is_zero_result([None])
    True
    >>> is_zero_result([None, (None, None)])
    False
    """
    if value is None:
    return True
    if isinstance(value, (tuple, list)):
    return all(item is None for item in value)
    return False


    def enumerate_opaque_values(function, values):
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in `values`,
    and yields a pair of tuple of ground values for each `value` and the
    corresponding return value, for all non-zero values.

    This essentially turns `function()` into a branching program on
    `values`.

    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [((0, None), 1), ((1, 2), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    while True:
    for value in values:
    value.reset()

    stackIndex = 0
    def handle(value, other):
    nonlocal stackIndex
    if len(explorationStack) == stackIndex:
    explorationStack.append((value, value.potential_mask(other)))

    expectedValue, mask = explorationStack[stackIndex]
    assert value is expectedValue
    assert mask != 0

    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    else:
    assert False, f"bad mask {mask}"

    stackIndex += 1
    return ret

    OpaqueValue.CMP_HANDLER = handle
    result = function()
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert value.definite() or value.indefinite(), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)

    # Drop everything that was fully explored, then move the next
    # top of stack to the next option.
    while explorationStack:
    value, mask = explorationStack[-1]
    assert 0 <= mask < 8

    mask &= mask - 1 # Drop first bit
    if mask != 0:
    explorationStack[-1] = (value, mask)
    break
    explorationStack.pop()
    if not explorationStack:
    break


    def enumerate_supporting_values(function, args):
    """Lists the bag of mapping from closed over values to non-zero result,
    for all calls `function(arg) for args in args`.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg),
    values)


    <br>
    ## Type driven merges

    The interesting part of map/reduce is the reduction step. While
    some like to use first-class functions to describe reduction, in my
    opinion, it often makes more sense to define reduction at the type
    level: it's essential that merge operators be commutative and
    associative, so isolating the merge logic in dedicated classes
    makes sense to me.

    This file defines a single mergeable value type, `Sum`, but we
    could have different ones, e.g., hyperloglog unique counts, or
    streaming statistical moments.

    class Sum:
    """A counter for summed values."""
    def __init__(self, value=0):
    self.value = value

    def merge(self, other):
    assert isinstance(other, Sum)
    self.value += other.value


    class Min:
    """A running `min` value tracker."""
    def __init__(self, value=None):
    self.value = value

    def merge(self, other):
    assert isinstance(other, Min)
    if not self.value:
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)
    <br>
    ## Nested dictionary with wildcard

    There's a direct relationship between the data structure we use to
    represent the result of function calls as branching functions, and
    the constraints we can support on closed over values for non-zero
    results.

    In a real implementation, this data structure would host most of
    the complexity: it's the closest thing we have to indexes.

    For now, support equality *with ground value* as our only
    constraint. This means we can dispatch on the closed over values
    by order of appearance, with either a wildcard value (matches
    everything) or a hash map. At each internal level, the value is
    another `NestedDictLevel`. At the leaf, the value is the
    precomputed value.

    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    """
    def __init__(self):
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(key, self.wildcard if self.wildcard is not None else default)

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    represents a wildcard value.
    """
    if key is not None:
    self.dict[key] = value
    else:
    self.wildcard = value


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    """
    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    self.depth = depth

    def get(self, key, default=None):
    """Gets the value associated with `key`, or `default` if None."""
    assert len(key) == self.depth
    current = self.top
    for levelKey in key:
    current = current.get(levelKey, None)
    if current is None:
    return default
    return current

    def set(self, key, value):
    """Sets the value associated with `key`."""
    assert len(key) == self.depth
    current = self.top
    for idx, levelKey in enumerate(key):
    if idx == len(key) - 1:
    current.set(levelKey, value)
    else:
    next = current.get(levelKey, None)
    if next is None:
    next = NestedDictLevel()
    current.set(levelKey, next)


    <br>
    ## Cached `map_reduce`

    As mentioned earlier, we assume `reduce` is determined implicitly
    by the reduced values' type. We also have
    `enumerate_supporting_values` to find all the closed over values
    that yield a non-zero result, for all values in a sequence.

    We can thus accept a function and an input sequence, find the
    supporting values, and merge the result associated with identical
    supporting values.

    Again, we only support ground equality constraints (see assertion
    on L453), i.e., only equijoins. There's nothing that stops a more
    sophisticated implementation from using range trees to support
    inequality or range joins.

    We'll cache the precomputed values by code object (i.e., function
    without closed over values) and input sequence. If we don't have a
    precomputed value, we'll use `enumerate_supporting_values` to run
    the function backward for each input datum from the sequence, and
    accumulate the results in a `NestedDict`. Working backward to find
    closure values that yield a non-zero result (for each input datum)
    lets us precompute a branching program that directly yields the
    result. We represent these branching programs explicitly, so we
    can also directly update a branching program for the result of
    merging all the values returned by mapping over the input sequence,
    for a given closure.

    This last `map_reduce` definition ties everything together, and
    I think is really the general heart of Yannakakis's algorithm
    as an instance of bottom-up dynamic programming.

    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    representation for `reduce(map(function, inputIterable))`, where
    the reduction step simply calls `merge` on the return values
    (tuples are merged elementwise), and the `NestedDict` keys
    represent closed over values.

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.get((0,))
    >>> nd.get((1,)).value
    1
    >>> nd.get((2,)).value
    2
    >>> nd.get((4,)).value
    2
    """
    def merge(dst, update):
    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)

    cache = NestedDict(depth)
    for key, result in enumerate_supporting_values(function, inputIterable):
    prev = cache.get(key, None)
    if prev is None:
    cache.set(key, result)
    else:
    merge(prev, result)
    return cache


    AGGREGATE_CACHE = dict() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    """Returns the result of merging `map(function, inputIterable)`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaced, `map_reduce` runs in time
    linear wrt `len(inputIterable)`. It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.

    These complexity guarantees let us nest `map_reduce` with
    different closed over values, and still guarantee a linear-time
    total complexity.

    This wrapper ties together all the components

    >>> INVOCATION_COUNTER = 0
    >>> data = (1, 2, 2, 4, 2, 4)
    >>> def count_eql(needle):
    ... def count(x):
    ... global INVOCATION_COUNTER
    ... INVOCATION_COUNTER += 1
    ... return Sum(x) if x == needle else None
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data).value
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    >>> sum_odd_or_even_skus(0).value
    20
    >>> sum_odd_or_even_skus(1).value
    50
    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
    AGGREGATE_CACHE[(code, inputIterable)] = \
    _precompute_map_reduce(function, len(closure), inputIterable)
    return AGGREGATE_CACHE[(code, inputIterable)].get(closure, None)


    if __name__ == "__main__":
    import doctest
    doctest.testmod()


    <br>
    ## Extensions and future work

    The closure hack (`extract_function_state`) should really be
    extended to cover local functions. This is mostly a question of
    going deeply into values that are mapped to functions, and of
    maintaining an id-keyed map from cell to `OpaqueValue`.

    There is currently no support for parallelism, only caching. It
    should be easy to handle the return values (`NestedDict`s and
    aggregate classes like `Sum` or `Min`). Distributing the work in
    `_precompute_map_reduce` to merge locally is also not hard.

    The main issue with parallelism is that we can't pass functions
    as work units, so we'd have to stick to the `fork` process pool.

    There's also no support for moving (child) work forward when
    blocked waiting on a future. We'd have to spawn workers on the fly
    to oversubscribe when workers are blocked on a result (spawning on
    demand is already a given for `fork` workers), and to implement our
    own concurrency control to avoid wasted work, and probably internal
    throttling to avoid thrashing when we'd have more active threads
    than cores.

    That being said, the complexity is probably worth the speed up on
    realistic queries.

    At a higher level, we could support comparison joins (e.g., less
    than, greater than or equal, in range) if only we represented the
    branching programs with a data structure that supported these
    queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    let us handle these "theta" joins, for tbe low low cost of a
    polylogarithmic multiplicative factor in space and time.

    Finally, we could update the indexed branching programs
    incrementally after small changes to the input data. This might
    sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
    but I think viewing each `_precompute_map_reduce` call as a purely
    functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).

    I guess, in a way, this code shows how we can simply decorrelate
    nested loops: we just have to be OK with building one index for
    each loop in the nest.
    785 changes: 785 additions & 0 deletions yannakakis.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,785 @@
    #!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|<br>|

    import collections.abc

    ## # Linear-time analytics queries in raw Python
    ##
    ## This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
    ## lets us implement linear-time (wrt the number of input data rows)
    ## analytical queries in regular programming languages, without
    ## offloading joins to a specialised database query language, thus
    ## avoiding the associated impedance mismatch. There are restrictions
    ## on the queries we can express -- Yannakakis's algorithm relies on a
    ## hypertree width of 1, and associated decomposition -- but that's
    ## kind of reasonable: a fractional hypertree width > 1 would mean tere
    ## are databases for which the intermediate results could grow much
    ## larger than the input database (superlinearly), from the [AGM bound](https://arxiv.org/abs/1711.03860).
    ## Structured programs also naturally yield a hypertree decomposition,
    ## unlike more declarative logic programs that tend to hide the
    ## structure implicit in the programmer's thinking.
    ##
    ## The key is to mark function arguments as either negative (true
    ## inputs) or positive (possible interesting input values derived from
    ## negative arguments). In this hack, closed over values are
    ## positive, and the only negative argument is the current data row.
    ##
    ## We also assume that these functions are always used in a map/reduce
    ## pattern, and thus we only memoise the result of
    ## `map_reduce(function, input)`, with a group-structured reduction:
    ## it must be associative and commutative, and there must be a zero
    ## (neutral) value.
    ##
    ## The memoisation key consists of the function *without closed over
    ## bindings* and the input, and the memoised value is a data structure
    ## from the tuple of closed over values to the `map_reduce` output.
    ##
    ## This data structure is where Yannakakis's algorithm comes in: we'll
    ## iterate over each datum in the input, run the function on it with
    ## logical variables instead of the closed over values, and generate a
    ## mapping from closed over values to result for all non-zero results.
    ## We'll then merge the mappings for all input data together (there is
    ## no natural ordering here, hence the group structure).
    ##
    ## The output data structure controls the join we can implement. We
    ## show how a simple stack of nested key-value mappings handles
    ## equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
    ## would handle inequalities (the rest of the machinery already works
    ##.in terms of less than/greater than constraints).
    ##
    ## As long as we explore a bounded number of paths for each datum and
    ## a bounded number of `function, input` cache keys, we'll spend a
    ## bounded amount of time on each input datum, and thus linear time
    ## total. The magic of Yannakakis's algorithm is that this works even
    ## when there are nested `map_reduce` calls, which would naïvely
    ## result in polynomial time (degree equal to the nesting depth).



    ## ## Memoising through a Python function's closed over values
    ##
    ## Even if functions were hashable and comparable for extensional
    ## equality, directly using closures as memoisation keys in calls like
    ## `map_reduce(function, input)` would result in superlinear runtime
    ## for nested `map_reduce` calls.
    ##
    ## This `extract_function_state` accepts a function (with or without
    ## closed over state), and returns four values:
    ## 1. The underlying code object
    ## 2. The tuple of closed over values (current value for mutable cells)
    ## 3. A function to rebind the closure with new closed over values
    ## 4. The name of the closed over bindings
    ##
    ## The third return value, the `rebind` function, silently fails on
    ## complicated cases; this is a hack, after all. In short, it only
    ## handles closing over immutable atomic values like integers or
    ## strings, but not, e.g., functions (yet), or mutable bindings.

    def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    to return a fresh closure with a different tuple of closed over
    values, and a closure of closed over names

    >>> def test(x): return lambda y: x == y
    >>> extract_function_state(test)[1]
    ()
    >>> extract_function_state(test)[3]
    ()
    >>> fun = test(4)
    >>> extract_function_state(fun)[1]
    (4,)
    >>> extract_function_state(fun)[3]
    ('x',)
    >>>
    >>> fun(4)
    True
    >>> fun(5)
    False
    >>> rebind = extract_function_state(fun)[2]
    >>> rebound_4, rebound_5 = rebind([4]), rebind([5])
    >>> rebound_4(4)
    True
    >>> rebound_4(5)
    False
    >>> rebound_5(4)
    False
    >>> rebound_5(5)
    True
    """
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None: # Toplevel function
    assert names == ()
    def rebind(values):
    if len(values) != 0:
    raise RuntimeError(f"Values must be empty for toplevel function. values={values}")
    return function
    return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    assert len(names) == len(closure), (closures, names)

    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
    if len(values) != len(names):
    raise RuntimeError(f"Values must match names. names={names} values={values}")
    return function.__class__(
    code,
    function.__globals__,
    function.__name__,
    function.__defaults__,
    tuple(cell.__class__(value)
    for cell, value in zip(function.__closure__, values)))

    return code, closure, rebind, names



    ## ## Logical variables for closed-over values
    ##
    ## We wish to enumerate the support of a function call (parameterised
    ## over closed over values), and the associated result value. We'll
    ## do that by rebinding the closure to point at instances of
    ## `OpaqueValue` and enumerating all the possible constraints on these
    ## `OpaqueValue`s. These `OpaqueValue`s work like logical variables
    ## that let us run a function in reverse: when get a non-zero
    ## (non-None) return value, we look at the accumulated constraint set
    ## on the opaque values and use them to update the data representation
    ## of the function's result (we assume that we can represent the
    ## constraints on all `OpaqueValue`s in our result data structure).
    ##
    ## Currently, we only support nested dictionaries, so each
    ## `OpaqueValue`s must be either fully unconstrained (wildcard that
    ## matches any value), or constrained to be exactly equal to a value.
    ## There's no reason we can't use k-d trees though, and it's not
    ## harder to track a pair of bounds (lower and upper) than a set of
    ## inequalities, so we'll handle the general ordered `OpaqueValue`
    ## case.
    ##
    ## In the input program (the query), we assume closed over values are
    ## only used for comparisons (equality, inequality, relational
    ## operators, or conversion to bool, i.e., non-zero testing). Knowing
    ## the result of each (non-redundant) comparison tightens the range
    ## of potential values for the `OpaqueValue`... eventually down to a
    ## single point value that our hash-based indexes can handle.
    ##
    ## Of course, if a comparison isn't redundant, there are multiple
    ## feasible results, so we need an external oracle to pick one. An
    ## external caller is responsible for injecting its logic as
    ## `OpaqueValue.CMP_HANDLER`, and driving the exploration of the
    ## search space.
    ##
    ## N.B., the set of constraints we can handle is determined by the
    ## ground data structure to represent finitely supported functions.

    class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    potentially exclusive.

    `OpaqueValue`s are only used in queries for comparisons with
    ground values. All comparisons are turned into three-way
    `__cmp__` calls; non-redundant `__cmp__` calls (which could return
    more than one value) are resolved by calling `CMP_HANDLER`
    and tightening the bound in response.

    >>> x = OpaqueValue("x")
    >>> x == True
    True
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> x > 4
    False
    >>> x < 4
    False
    >>> x == 4
    True
    >>> x < 10
    True
    >>> x >= 10
    False
    >>> x > -10
    True
    >>> x <= -10
    False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name):
    self.name = name
    # Upper and lower bounds
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True

    def __str__(self):
    left = "(" if self.lowerExcl else "["
    right = ")" if self.upperExcl else "]"
    return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"

    def reset(self):
    """Clears all accumulated constraints on this `OpaqueValue`."""
    self.lower = self.upper = None
    self.lowerExcl = self.upperExcl = True

    def indefinite(self):
    """Returns whether this `OpaqueValue` is still unconstrained."""
    return self.lower == None and self.upper == None

    def definite(self):
    """Returns whether is `OpaqueValue` is constrained to an exact value."""
    return self.lower == self.upper and self.lower is not None and not self.lowerExcl and not self.upperExcl

    def value(self):
    """Returns the exact value for this `OpaqueValue`, assuming there is one."""
    return self.lower

    def _contains(self, value, strictly):
    """Returns whether this `OpaqueValue`'s range includes `value`
    (maybe `strictly` inside the range).

    The first value is the containment truth value, and the second
    is the forced `__cmp__` value, if the `value` is *not*
    (strictly) contained in the range.
    """
    if self.lower is not None:
    if self.lower > value:
    return False, 1
    if self.lower == value and (strictly or self.lowerExcl):
    return False, 1
    if self.upper is not None:
    if self.upper < value:
    return False, -1
    if self.upper == value and (strictly or self.upperExcl):
    return False, -1
    return True, 0 if self.definite() else None

    def contains(self, value, strictly=False):
    """Returns whether `values` is `strictly` contained in the
    `OpaqueValue`'s range.
    """
    try:
    return self._contains(value, strictly)[0]
    except TypeError:
    return False

    def potential_mask(self, other):
    """Returns the set of potential `__cmp__` values for `other`
    that are compatible with the current range: bit 0 is set if
    `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
    is possible.
    """
    if not self.contains(other):
    return 0
    if self.contains(other, strictly=True):
    return 7
    if self.definite() and self.value == other:
    return 2
    # We have a non-strict inclusion, and inequality.
    if self.lower == other and self.lowerExcl:
    return 6
    assert self.upper == other and self.upperExcl
    return 3

    def __cmp__(self, other):
    """Three-way comparison between this `OpaqueValue` and `other`.

    When the result is known from the current bound, we just return
    that value. Otherwise, we ask `CMP_HANDLER` what value to return
    and update the bound accordingly.
    """
    if isinstance(other, OpaqueValue) and \
    self.definite() and not other.definite():
    # If we have a definite value and `other` is an indefinite
    # `OpaqueValue`, flip the comparison order to let the `other`
    # argument be a ground value.
    return -other.__cmp__(self.value())

    assert not isinstance(other, OpaqueValue) or other.definite()
    if isinstance(other, OpaqueValue):
    other = other.value() # Make sure `other` is a ground value
    assert other is not None # We use `None` internally, and it doesn't compare well

    compatible, order = self._contains(other, False)
    if order is not None:
    return order
    order = OpaqueValue.CMP_HANDLER(self, other)
    if order < 0:
    self._add_bound(upper=other, upperExcl=True)
    elif order == 0:
    self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
    else:
    self._add_bound(lower=other, lowerExcl=True)
    return order

    def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
    """Updates the internal range for this new bound."""
    assert lower is None or self.contains(lower, strictly=lowerExcl)
    assert upper is None or self.contains(upper, strictly=upperExcl)
    if lower is not None:
    self.lower = lower
    self.lowerExcl = lowerExcl

    assert upper is None or self.contains(upper, strictly=upperExcl)
    if upper is not None:
    self.upper = upper
    self.upperExcl = upperExcl

    def __bool__(self):
    return self != 0

    def __eq__(self, other):
    return self.__cmp__(other) == 0

    def __ne__(self, other):
    return self.__cmp__(other) != 0

    def __lt__(self, other):
    return self.__cmp__(other) < 0

    def __le__(self, other):
    return self.__cmp__(other) <= 0

    def __gt__(self, other):
    return self.__cmp__(other) > 0

    def __ge__(self, other):
    return self.__cmp__(other) >= 0



    ## ## Depth-first exploration of a function call's support
    ##
    ## We assume a `None` result represents a zero value wrt the aggregate
    ## merging function (e.g., 0 for a sum). For convenience, we also treat
    ## tuples and lists of `None`s identically.
    ##
    ## We simply maintain a stack of `CMP_HANDLER` calls, where each entry
    ## in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER`
    ## results still to explore (-1, 0, or 1). This stack is filled on demand,
    ## and `CMP_HANDLER` returns the first result allowed by the bitset.
    ##
    ## Once we have a result, we tweak the stack to force depth-first
    ## exploration of a different part of the solution space: we drop
    ## the first bit in the bitset of results to explore, and drop the
    ## entry wholesale if the bitset is now empty (all zero). When
    ## this tweaking leaves an empty stack, we're done.
    ##
    ## This ends up enumerating all the paths through the function call
    ## with a non-recursive depth-first traversal.
    ##
    ## We then do the same for each datum in our input sequence, and merge
    ## results for identical keys together.

    def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.

    >>> is_zero_result(None)
    True
    >>> is_zero_result(False)
    False
    >>> is_zero_result(True)
    False
    >>> is_zero_result(0)
    False
    >>> is_zero_result(-10)
    False
    >>> is_zero_result(1.5)
    False
    >>> is_zero_result("")
    False
    >>> is_zero_result("asd")
    False
    >>> is_zero_result((None, None))
    True
    >>> is_zero_result((None, 1))
    False
    >>> is_zero_result([])
    True
    >>> is_zero_result([None])
    True
    >>> is_zero_result([None, (None, None)])
    False
    """
    if value is None:
    return True
    if isinstance(value, (tuple, list)):
    return all(item is None for item in value)
    return False


    def enumerate_opaque_values(function, values):
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in `values`,
    and yields a pair of tuple of ground values for each `value` and the
    corresponding return value, for all non-zero values.

    This essentially turns `function()` into a branching program on
    `values`.

    >>> x, y = OpaqueValue("x"), OpaqueValue("y")
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [((0, None), 1), ((1, 2), 2)]

    """
    explorationStack = [] # List of (value, bitmaskOfCmp)
    while True:
    for value in values:
    value.reset()

    stackIndex = 0
    def handle(value, other):
    nonlocal stackIndex
    if len(explorationStack) == stackIndex:
    explorationStack.append((value, value.potential_mask(other)))

    expectedValue, mask = explorationStack[stackIndex]
    assert value is expectedValue
    assert mask != 0

    if (mask & 1) != 0:
    ret = -1
    elif (mask & 2) != 0:
    ret = 0
    elif (mask & 4) != 0:
    ret = 1
    else:
    assert False, f"bad mask {mask}"

    stackIndex += 1
    return ret

    OpaqueValue.CMP_HANDLER = handle
    result = function()
    if not is_zero_result(result):
    keys = []
    for value in values:
    assert value.definite() or value.indefinite(), f"{value} temporarily unsupported"
    keys.append(value.value() if value.definite() else None)
    yield (tuple(keys), result)

    # Drop everything that was fully explored, then move the next
    # top of stack to the next option.
    while explorationStack:
    value, mask = explorationStack[-1]
    assert 0 <= mask < 8

    mask &= mask - 1 # Drop first bit
    if mask != 0:
    explorationStack[-1] = (value, mask)
    break
    explorationStack.pop()
    if not explorationStack:
    break


    def enumerate_supporting_values(function, args):
    """Lists the bag of mapping from closed over values to non-zero result,
    for all calls `function(arg) for args in args`.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name) for name in names]
    reboundFunction = rebind(values)
    for arg in args:
    yield from enumerate_opaque_values(lambda: reboundFunction(arg),
    values)



    ## ## Type driven merges
    ##
    ## The interesting part of map/reduce is the reduction step. While
    ## some like to use first-class functions to describe reduction, in my
    ## opinion, it often makes more sense to define reduction at the type
    ## level: it's essential that merge operators be commutative and
    ## associative, so isolating the merge logic in dedicated classes
    ## makes sense to me.
    ##
    ## This file defines a single mergeable value type, `Sum`, but we
    ## could have different ones, e.g., hyperloglog unique counts, or
    ## streaming statistical moments.

    class Sum:
    """A counter for summed values."""
    def __init__(self, value=0):
    self.value = value

    def merge(self, other):
    assert isinstance(other, Sum)
    self.value += other.value


    class Min:
    """A running `min` value tracker."""
    def __init__(self, value=None):
    self.value = value

    def merge(self, other):
    assert isinstance(other, Min)
    if not self.value:
    self.value = other.value
    elif other.value:
    self.value = min(self.value, other.value)

    ## ## Nested dictionary with wildcard
    ##
    ## There's a direct relationship between the data structure we use to
    ## represent the result of function calls as branching functions, and
    ## the constraints we can support on closed over values for non-zero
    ## results.
    ##
    ## In a real implementation, this data structure would host most of
    ## the complexity: it's the closest thing we have to indexes.
    ##
    ## For now, support equality *with ground value* as our only
    ## constraint. This means we can dispatch on the closed over values
    ## by order of appearance, with either a wildcard value (matches
    ## everything) or a hash map. At each internal level, the value is
    ## another `NestedDictLevel`. At the leaf, the value is the
    ## precomputed value.

    class NestedDictLevel:
    """One level in a nested dictionary index. We may either have a
    value for everything, or a key-value dict.
    """
    def __init__(self):
    self.dict = dict()
    self.wildcard = None

    def get(self, key, default):
    """Gets the value for `key` in this level, or `default` if None."""
    return self.dict.get(key, self.wildcard if self.wildcard is not None else default)

    def set(self, key, value):
    """Sets the value for `key` in this level. A `key` of `None`
    represents a wildcard value.
    """
    if key is not None:
    self.dict[key] = value
    else:
    self.wildcard = value


    class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value. Each position in the tuple is handled by a `NestedDictLevel`.
    """
    def __init__(self, depth):
    assert depth >= 1
    self.top = NestedDictLevel()
    self.depth = depth

    def get(self, key, default=None):
    """Gets the value associated with `key`, or `default` if None."""
    assert len(key) == self.depth
    current = self.top
    for levelKey in key:
    current = current.get(levelKey, None)
    if current is None:
    return default
    return current

    def set(self, key, value):
    """Sets the value associated with `key`."""
    assert len(key) == self.depth
    current = self.top
    for idx, levelKey in enumerate(key):
    if idx == len(key) - 1:
    current.set(levelKey, value)
    else:
    next = current.get(levelKey, None)
    if next is None:
    next = NestedDictLevel()
    current.set(levelKey, next)



    ## ## Cached `map_reduce`
    ##
    ## As mentioned earlier, we assume `reduce` is determined implicitly
    ## by the reduced values' type. We also have
    ## `enumerate_supporting_values` to find all the closed over values
    ## that yield a non-zero result, for all values in a sequence.
    ##
    ## We can thus accept a function and an input sequence, find the
    ## supporting values, and merge the result associated with identical
    ## supporting values.
    ##
    ## Again, we only support ground equality constraints (see assertion
    ## on L453), i.e., only equijoins. There's nothing that stops a more
    ## sophisticated implementation from using range trees to support
    ## inequality or range joins.
    ##
    ## We'll cache the precomputed values by code object (i.e., function
    ## without closed over values) and input sequence. If we don't have a
    ## precomputed value, we'll use `enumerate_supporting_values` to run
    ## the function backward for each input datum from the sequence, and
    ## accumulate the results in a `NestedDict`. Working backward to find
    ## closure values that yield a non-zero result (for each input datum)
    ## lets us precompute a branching program that directly yields the
    ## result. We represent these branching programs explicitly, so we
    ## can also directly update a branching program for the result of
    ## merging all the values returned by mapping over the input sequence,
    ## for a given closure.
    ##
    ## This last `map_reduce` definition ties everything together, and
    ## I think is really the general heart of Yannakakis's algorithm
    ## as an instance of bottom-up dynamic programming.

    def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    representation for `reduce(map(function, inputIterable))`, where
    the reduction step simply calls `merge` on the return values
    (tuples are merged elementwise), and the `NestedDict` keys
    represent closed over values.

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.get((0,))
    >>> nd.get((1,)).value
    1
    >>> nd.get((2,)).value
    2
    >>> nd.get((4,)).value
    2
    """
    def merge(dst, update):
    if isinstance(dst, (tuple, list)):
    assert len(dst) == len(update)
    for value, new in zip(dst, update):
    value.merge(new)
    else:
    dst.merge(update)

    cache = NestedDict(depth)
    for key, result in enumerate_supporting_values(function, inputIterable):
    prev = cache.get(key, None)
    if prev is None:
    cache.set(key, result)
    else:
    merge(prev, result)
    return cache


    AGGREGATE_CACHE = dict() # Map from function, input sequence -> NestedDict


    def map_reduce(function, inputIterable):
    """Returns the result of merging `map(function, inputIterable)`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaced, `map_reduce` runs in time
    linear wrt `len(inputIterable)`. It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.

    These complexity guarantees let us nest `map_reduce` with
    different closed over values, and still guarantee a linear-time
    total complexity.

    This wrapper ties together all the components

    >>> INVOCATION_COUNTER = 0
    >>> data = (1, 2, 2, 4, 2, 4)
    >>> def count_eql(needle):
    ... def count(x):
    ... global INVOCATION_COUNTER
    ... INVOCATION_COUNTER += 1
    ... return Sum(x) if x == needle else None
    ... return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data).value
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = ((1, 2), (2, 2), (1, 3))
    >>> sku_costs = ((1, 10), (2, 20), (3, 30))
    >>> def sku_min_cost(sku):
    ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    ... def count_if_mod_two(id_sku):
    ... id, sku = id_sku
    ... if id % 2 == mod_two:
    ... return Sum(sku_min_cost(sku))
    ... return map_reduce(count_if_mod_two, id_skus)
    >>> sum_odd_or_even_skus(0).value
    20
    >>> sum_odd_or_even_skus(1).value
    50
    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
    AGGREGATE_CACHE[(code, inputIterable)] = \
    _precompute_map_reduce(function, len(closure), inputIterable)
    return AGGREGATE_CACHE[(code, inputIterable)].get(closure, None)


    if __name__ == "__main__":
    import doctest
    doctest.testmod()



    ## ## Extensions and future work
    ##
    ## The closure hack (`extract_function_state`) should really be
    ## extended to cover local functions. This is mostly a question of
    ## going deeply into values that are mapped to functions, and of
    ## maintaining an id-keyed map from cell to `OpaqueValue`.
    ##
    ## There is currently no support for parallelism, only caching. It
    ## should be easy to handle the return values (`NestedDict`s and
    ## aggregate classes like `Sum` or `Min`). Distributing the work in
    ## `_precompute_map_reduce` to merge locally is also not hard.
    ##
    ## The main issue with parallelism is that we can't pass functions
    ## as work units, so we'd have to stick to the `fork` process pool.
    ##
    ## There's also no support for moving (child) work forward when
    ## blocked waiting on a future. We'd have to spawn workers on the fly
    ## to oversubscribe when workers are blocked on a result (spawning on
    ## demand is already a given for `fork` workers), and to implement our
    ## own concurrency control to avoid wasted work, and probably internal
    ## throttling to avoid thrashing when we'd have more active threads
    ## than cores.
    ##
    ## That being said, the complexity is probably worth the speed up on
    ## realistic queries.
    ##
    ## At a higher level, we could support comparison joins (e.g., less
    ## than, greater than or equal, in range) if only we represented the
    ## branching programs with a data structure that supported these
    ## queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
    ## let us handle these "theta" joins, for tbe low low cost of a
    ## polylogarithmic multiplicative factor in space and time.
    ##
    ## Finally, we could update the indexed branching programs
    ## incrementally after small changes to the input data. This might
    ## sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
    ## but I think viewing each `_precompute_map_reduce` call as a purely
    ## functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).
    ##
    ## I guess, in a way, this code shows how we can simply decorrelate
    ## nested loops: we just have to be OK with building one index for
    ## each loop in the nest.