Skip to content

Instantly share code, notes, and snippets.

@pkhuong
Last active November 8, 2024 15:56
Show Gist options
  • Select an option

  • Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.

Select an option

Save pkhuong/69b457af82eeb2bc3ebf1a3e4209ae69 to your computer and use it in GitHub Desktop.
A minimal version of Yannakakis's algorithm for mostly plain Python
#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
license, imports
# Yannakakis.py by Paul Khuong
#
# To the extent possible under law, the person who associated CC0 with
# Yannakakis.py has waived all copyright and related or neighboring rights
# to Yannakakis.py.
#
# You should have received a copy of the CC0 legalcode along with this
# work.  If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.

import collections.abc

Linear-time analytical queries in plain Python

I didn't know Mihalis Yannakakis is about to turn 70, but that's a fun coïncidence.

This short hack shows how Yannakakis's algorithm lets us implement linear-time (wrt the number of input data rows) analytical queries in regular programming languages, without offloading joins to a specialised database query language, thus avoiding the associated impedance mismatch. There are restrictions on the queries we can express -- Yannakakis's algorithm relies on a hypertree width of 1, and on having hypertree decomposition as a witness -- but that's kind of reasonable: a fractional hypertree width > 1 would mean there are databases for which the intermediate results could superlinearly larger than the input database due to the AGM bound. The hypertree decomposition witness isn't a burden either: structured programs naturally yield a hypertree decomposition, unlike more declarative logic programs that tend to hide the structure implicit in the programmer's thinking.

The key is to mark function arguments as either negative (true inputs) or positive (possible interesting input values derived from negative arguments). In this hack, closed over values are positive, and the only negative argument is the current data row.

We also assume that these functions are always used in a map/reduce pattern, and thus we only memoise the result of map_reduce(function, input), with a group-structured reduction: the reduce function must be associative and commutative, and there must be a zero (neutral) value.

With these constraints, we can express joins in natural Python without incurring the poor runtime scaling of the nested loops we actually wrote. This Python file describes the building blocks to handle aggregation queries like the following

>>> id_skus = [(1, 2), (2, 2), (1, 3)]
>>> sku_costs = [(1, 10), (2, 20), (3, 30)]
>>> def sku_min_cost(sku):
...     return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
...
>>> def sum_odd_or_even_skus(mod_two):
...     def count_if_mod_two(id_sku):
...         id, sku = id_sku
...         if id % 2 == mod_two:
...             return Sum(sku_min_cost(sku))
...     return map_reduce(count_if_mod_two, id_skus)
...

with linear scaling in the length of id_skus and sku_costs, and caching for similar queries.

At a small scale, everything's fast.

>>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
20
0.001024007797241211
>>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
50
4.982948303222656e-05

As we increase the scale by 1000x in both tuples, the runtime scales (sub) linearly for the first query, and is unchanged for the second:

>>> id_skus = id_skus * 1000
>>> sku_costs = sku_costs * 1000
>>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
20000
0.09638118743896484
>>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
50000
8.797645568847656e-05

This still kind of holds when we multiply by another factor of 100:

>>> id_skus = id_skus * 100
>>> sku_costs = sku_costs * 100
>>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
2000000
5.9715189933776855
>>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
5000000
0.00021195411682128906

The magic behind the curtains is memoisation (unsurprisingly), but a special implementation that can share work for similar closures: the memoisation key consists of the function without closed over bindings and the call arguments, while the memoised value is a data structure from the tuple of closed over values to the map_reduce output.

This concrete representation of the function as a branching program is the core of Yannakakis's algorithm: we'll iterate over each datum in the input, run the function on it with logical variables instead of the closed over values, and generate a mapping from closed over values to result for all non-zero results. We'll then merge the mappings for all input data together (there is no natural ordering here, hence the group structure).

The output data structure controls the join we can implement. We show how a simple stack of nested key-value mappings handles equijoins, but a k-d tree would handle inequalities, i.e., "theta" joins (the rest of the machinery already works in terms of less than/greater than constraints).

As long as we explore a bounded number of paths for each datum and a bounded number of function, input cache keys, we'll spend a bounded amount of time on each input datum, and thus linear time total. The magic of Yannakakis's algorithm is that this works even when there are nested map_reduce calls, which would naïvely result in polynomial time (degree equal to the nesting depth).

Memoising through a Python function's closed over values

Even if functions were hashable and comparable for extensional equality, directly using closures as memoisation keys in calls like map_reduce(function, input) would result in superlinear runtime for nested map_reduce calls.

This extract_function_state accepts a function (with or without closed over state), and returns four values:

  1. The underlying code object
  2. The tuple of closed over values (current value for mutable cells)
  3. A function to rebind the closure with new closed over values
  4. The name of the closed over bindings

The third return value, the rebind function, silently fails on complicated cases; this is a hack, after all. In short, it only handles closing over immutable atomic values like integers or strings, but not, e.g., functions (yet), or mutable bindings.

def extract_function_state(function):
    """Accepts a function object and returns information about it: a
    hash key for the object, a tuple of closed over values, a function
    to return a fresh closure with a different tuple of closed over
    values, and a closure of closed over names

    >>> def test(x): return lambda y: x == y
    >>> extract_function_state(test)[1]
    ()
    >>> extract_function_state(test)[3]
    ()
    >>> fun = test(4)
    >>> extract_function_state(fun)[1]
    (4,)
    >>> extract_function_state(fun)[3]
    ('x',)
    >>>
    >>> fun(4)
    True
    >>> fun(5)
    False
    >>> rebind = extract_function_state(fun)[2]
    >>> rebound_4, rebound_5 = rebind([4]), rebind([5])
    >>> rebound_4(4)
    True
    >>> rebound_4(5)
    False
    >>> rebound_5(4)
    False
    >>> rebound_5(5)
    True
    """
    code = function.__code__
    names = code.co_freevars

    if function.__closure__ is None:  # Toplevel function
        assert names == ()

        def rebind(values):
            if len(values) != 0:
                raise RuntimeError(
                    f"Values must be empty for toplevel function. values={values}"
                )
            return function

        return code, (), rebind, names

    closure = tuple(cell.cell_contents for cell in function.__closure__)
    assert len(names) == len(closure), (closures, names)

    # TODO: rebind recursively (functions are also cells)
    def rebind(values):
        if len(values) != len(names):
            raise RuntimeError(
                f"Values must match names. names={names} values={values}"
            )
        return function.__class__(
            code,
            function.__globals__,
            function.__name__,
            function.__defaults__,
            tuple(
                cell.__class__(value)
                for cell, value in zip(function.__closure__, values)
            ),
        )

    return code, closure, rebind, names

Logical variables for closed-over values

We wish to enumerate the support of a function call (parameterised over closed over values), and the associated result value. We'll do that by rebinding the closure to point at instances of OpaqueValue and enumerating all the possible constraints on these OpaqueValues. These OpaqueValues work like logical variables that let us run a function in reverse: when get a non-zero (non-None) return value, we look at the accumulated constraint set on the opaque values and use them to update the data representation of the function's result (we assume that we can represent the constraints on all OpaqueValues in our result data structure).

Currently, we only support nested dictionaries, so each OpaqueValues must be either fully unconstrained (wildcard that matches any value), or constrained to be exactly equal to a value. There's no reason we can't use k-d trees though, and it's not harder to track a pair of bounds (lower and upper) than a set of inequalities, so we'll handle the general ordered OpaqueValue case.

In the input program (the query), we assume closed over values are only used for comparisons (equality, inequality, relational operators, or conversion to bool, i.e., non-zero testing). Knowing the result of each (non-redundant) comparison tightens the range of potential values for the OpaqueValue... eventually down to a single point value that our hash-based indexes can handle.

Of course, if a comparison isn't redundant, there are multiple feasible results, so we need an external oracle to pick one. An external caller is responsible for injecting its logic as OpaqueValue.CMP_HANDLER, and driving the exploration of the search space.

N.B., the set of constraints we can handle is determined by the ground data structure to represent finitely supported functions.

class OpaqueValue:
    """An opaque value is a one-dimensional range of Python values,
    represented as a lower and an upper bound, each of which is
    potentially exclusive.

    `OpaqueValue`s are only used in queries for comparisons with
    ground values.  All comparisons are turned into three-way
    `__cmp__` calls; non-redundant `__cmp__` calls (which could return
    more than one value) are resolved by calling `CMP_HANDLER`
    and tightening the bound in response.

    >>> x = OpaqueValue("x")
    >>> x == True
    True
    >>> 1 if x else 2
    1
    >>> x.reset()
    >>> ### Not supported by our index data structure (yet)
    >>> # >>> x > 4
    >>> # False
    >>> # >>> x < 4
    >>> # False
    >>> # >>> x == 4
    >>> # True
    >>> # >>> x < 10
    >>> # True
    >>> # >>> x >= 10
    >>> # False
    >>> # >>> x > -10
    >>> # True
    >>> # >>> x <= -10
    >>> # False
    """

    # Resolving function for opaque values
    CMP_HANDLER = lambda opaque, value: 0

    def __init__(self, name, index=0):
        self.name = name
        self.index = index
        # Upper and lower bounds
        self.lower = self.upper = None
        self.lowerExcl = self.upperExcl = True

    def __str__(self):
        left = "(" if self.lowerExcl else "["
        right = ")" if self.upperExcl else "]"
        return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"

    def reset(self):
        """Clears all accumulated constraints on this `OpaqueValue`."""
        self.lower = self.upper = None
        self.lowerExcl = self.upperExcl = True

    def indefinite(self):
        """Returns whether this `OpaqueValue` is still unconstrained."""
        return self.lower == None and self.upper == None

    def definite(self):
        """Returns whether is `OpaqueValue` is constrained to an exact value."""
        return (
            self.lower == self.upper
            and self.lower is not None
            and not self.lowerExcl
            and not self.upperExcl
        )

    def value(self):
        """Returns the exact value for this `OpaqueValue`, assuming there is one."""
        return self.lower

    def _contains(self, value, strictly):
        """Returns whether this `OpaqueValue`'s range includes `value`
        (maybe `strictly` inside the range).

        The first value is the containment truth value, and the second
        is the forced `__cmp__` value, if the `value` is *not*
        (strictly) contained in the range.
        """
        if self.lower is not None:
            if self.lower > value:
                return False, 1
            if self.lower == value and (strictly or self.lowerExcl):
                return False, 1
        if self.upper is not None:
            if self.upper < value:
                return False, -1
            if self.upper == value and (strictly or self.upperExcl):
                return False, -1
        return True, 0 if self.definite() else None

    def contains(self, value, strictly=False):
        """Returns whether `values` is `strictly` contained in the
        `OpaqueValue`'s range.
        """
        try:
            return self._contains(value, strictly)[0]
        except TypeError:
            return False

    def potential_mask(self, other):
        """Returns the set of potential `__cmp__` values for `other`
        that are compatible with the current range: bit 0 is set if
        `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
        is possible.
        """
        if not self.contains(other):
            return 0
        if self.contains(other, strictly=True):
            return 7
        if self.definite() and self.value == other:
            return 2
        # We have a non-strict inclusion, and inequality.
        if self.lower == other and self.lowerExcl:
            return 6
        assert self.upper == other and self.upperExcl
        return 3

    def __cmp__(self, other):
        """Three-way comparison between this `OpaqueValue` and `other`.

        When the result is known from the current bound, we just return
        that value.  Otherwise, we ask `CMP_HANDLER` what value to return
        and update the bound accordingly.
        """
        if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
            # If we have a definite value and `other` is an indefinite
            # `OpaqueValue`, flip the comparison order to let the `other`
            # argument be a ground value.
            return -other.__cmp__(self.value())

        if isinstance(other, OpaqueValue) and not other.definite():
            raise RuntimeError(
                f"OpaqueValue may only be compared with ground values. self={self} other={other}"
            )
        if isinstance(other, OpaqueValue):
            other = other.value()  # Make sure `other` is a ground value

        if other is None:
            # We use `None` internally, and it doesn't compare well
            raise RuntimeError("OpaqueValue may not be compared with None")

        compatible, order = self._contains(other, False)
        if order is not None:
            return order
        order = OpaqueValue.CMP_HANDLER(self, other)
        if order < 0:
            self._add_bound(upper=other, upperExcl=True)
        elif order == 0:
            self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
        else:
            self._add_bound(lower=other, lowerExcl=True)
        return order

    def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
        """Updates the internal range for this new bound."""
        assert lower is None or self.contains(lower, strictly=lowerExcl)
        assert upper is None or self.contains(upper, strictly=upperExcl)
        if lower is not None:
            self.lower = lower
            self.lowerExcl = lowerExcl

        assert upper is None or self.contains(upper, strictly=upperExcl)
        if upper is not None:
            self.upper = upper
            self.upperExcl = upperExcl

    def __bool__(self):
        return self != 0

    def __eq__(self, other):
        return self.__cmp__(other) == 0

    def __ne__(self, other):
        return self.__cmp__(other) != 0

    # No other comparator because we don't index ranges
    # (no k-d tree).

    # def __lt__(self, other):
    #     return self.__cmp__(other) < 0

    # def __le__(self, other):
    #     return self.__cmp__(other) <= 0

    # def __gt__(self, other):
    #     return self.__cmp__(other) > 0

    # def __ge__(self, other):
    #     return self.__cmp__(other) >= 0

Depth-first exploration of a function call's support

We assume a None result represents a zero value wrt the aggregate merging function (e.g., 0 for a sum). For convenience, we also treat tuples and lists of Nones identically.

We simply maintain a stack of CMP_HANDLER calls, where each entry in the stack consists of an OpaqueValue and bitset of CMP_HANDLER results still to explore (-1, 0, or 1). This stack is filled on demand, and CMP_HANDLER returns the first result allowed by the bitset.

Once we have a result, we tweak the stack to force depth-first exploration of a different part of the solution space: we drop the first bit in the bitset of results to explore, and drop the entry wholesale if the bitset is now empty (all zero). When this tweaking leaves an empty stack, we're done.

This ends up enumerating all the paths through the function call with a non-recursive depth-first traversal.

We then do the same for each datum in our input sequence, and merge results for identical keys together.

def is_zero_result(value):
    """Checks if `value` is a "zero" aggregate value: either `None`,
    or an iterable of all `None`.

    >>> is_zero_result(None)
    True
    >>> is_zero_result(False)
    False
    >>> is_zero_result(True)
    False
    >>> is_zero_result(0)
    False
    >>> is_zero_result(-10)
    False
    >>> is_zero_result(1.5)
    False
    >>> is_zero_result("")
    False
    >>> is_zero_result("asd")
    False
    >>> is_zero_result((None, None))
    True
    >>> is_zero_result((None, 1))
    False
    >>> is_zero_result([])
    True
    >>> is_zero_result([None])
    True
    >>> is_zero_result([None, (None, None)])
    False
    """
    if value is None:
        return True
    if isinstance(value, (tuple, list)):
        return all(item is None for item in value)
    return False


def enumerate_opaque_values(function, values):
    """Explores the set of `OpaqueValue` constraints when calling
    `function`.

    Enumerates all constraints for the `OpaqueValue` instances in
    `values`, and yields a pair of equality constraints for the
    `value` and the corresponding return value, for all non-zero
    values.

    This essentially turns `function()` into a *not necessarily
    ordered* branching program on `values`.

    >>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
    >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
    [(((0, 0),), 1), (((0, 1), (1, 2)), 2)]

    """
    explorationStack = []  # List of (value, bitmaskOfCmp)
    while True:
        for value in values:
            value.reset()

        stackIndex = 0
        constraints = []

        def handle(value, other):
            nonlocal stackIndex
            if len(explorationStack) == stackIndex:
                explorationStack.append((value, value.potential_mask(other)))

            expectedValue, mask = explorationStack[stackIndex]
            assert value is expectedValue
            assert mask != 0

            if (mask & 1) != 0:
                ret = -1
            elif (mask & 2) != 0:
                constraints.append((value, other))
                ret = 0
            elif (mask & 4) != 0:
                ret = 1
            else:
                assert False, f"bad mask {mask}"

            stackIndex += 1
            return ret

        OpaqueValue.CMP_HANDLER = handle
        result = function()
        if not is_zero_result(result):
            for value in values:
                assert (
                    value.definite() or value.indefinite()
                ), f"partially constrained {value} temporarily unsupported"
                assert value.indefinite() or any(key is value for key, _ in constraints)
            yield (tuple((key.index, other) for key, other in constraints),
                   result)

        # Drop everything that was fully explored, then move the next
        # top of stack to the next option.
        while explorationStack:
            value, mask = explorationStack[-1]
            assert 0 <= mask < 8

            mask &= mask - 1  # Drop first bit
            if mask != 0:
                explorationStack[-1] = (value, mask)
                break
            explorationStack.pop()
        if not explorationStack:
            break


def enumerate_supporting_values(function, args):
    """Lists the bag of mapping from closed over values to non-zero result,
    for all calls `function(arg) for args in args`.

    >>> def count_eql(needle): return lambda x: 1 if x == needle else None
    >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
    [(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
    """
    _, _, rebind, names = extract_function_state(function)
    values = [OpaqueValue(name, index) for index, name in enumerate(names)]
    reboundFunction = rebind(values)
    for arg in args:
        yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)

Type driven merges

The interesting part of map/reduce is the reduction step. While some like to use first-class functions to describe reduction, in my opinion, it often makes more sense to define reduction at the type level: it's essential that merge operators be commutative and associative, so isolating the merge logic in dedicated classes makes sense to me.

This file defines a single mergeable value type, Sum, but we could have different ones, e.g., hyperloglog unique counts, or streaming statistical moments.

class Sum:
    """A counter for summed values."""

    def __init__(self, value=0):
        self.value = value

    def merge(self, other):
        assert isinstance(other, Sum)
        self.value += other.value


class Min:
    """A running `min` value tracker."""

    def __init__(self, value=None):
        self.value = value

    def merge(self, other):
        assert isinstance(other, Min)
        if not self.value:
            self.value = other.value
        elif other.value:
            self.value = min(self.value, other.value)

Nested dictionary with wildcard

There's a direct relationship between the data structure we use to represent the result of function calls as branching functions, and the constraints we can support on closed over values for non-zero results.

In a real implementation, this data structure would host most of the complexity: it's the closest thing we have to indexes.

For now, support equality with ground value as our only constraint. This means we can dispatch on the closed over values by order of appearance, with either a wildcard value (matches everything) or a hash map. At each internal level, the value is another NestedDictLevel. At the leaf, the value is the precomputed value.

class NestedDictLevel:
    """One level in a nested dictionary index.  We may either have a
    value for everything (leaf node), or a key-value dict for a specific
    index in the tuple key.
    """

    def __init__(self, indexValues, depth=0):
        assert depth <= len(indexValues)
        self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
        self.value = None
        self.dict = dict()

    def get(self, keys, default):
        """Gets the value for `keys` in this level, or `default` if None."""
        assert self.value is not None or self.keyIndex is not None

        if self.value is not None:
            return self.value
        next = self.dict.get(keys[self.keyIndex], None)
        if next is None:
            return default
        return next.get(keys, default)

    def set(self, indexValues, mergeFunction, depth=0):
        """Sets the value for `keys` in this level."""
        assert depth <= len(indexValues)
        if depth == len(indexValues):  # Leaf
            self.value = mergeFunction(self.value)
            return

        index, value = indexValues[depth]
        assert self.keyIndex == index
        if value not in self.dict:
            self.dict[value] = NestedDictLevel(indexValues, depth + 1)
        self.dict[value].set(indexValues, mergeFunction, depth + 1)


class NestedDict:
    """A nested dict of a given `depth` maps tuples of `depth` keys to
    a value.  Each `NestedDictLevel` handles a different level.  We
    don't test each index in the tuple in fixed order to avoid the
    combinatorial explosion that can happen when change ordering
    (e.g., conversion from BDD to OBDD).
    """

    def __init__(self, length):
        self.top = None
        self.length = length

    def get(self, keys, default=None):
        """Gets the value associated with `keys`, or `default` if None."""
        assert len(keys) == self.length

        if self.top is None:
            return default
        return self.top.get(keys, default)

    def set(self, indexKeyValues, mergeFn):
        """Sets the value associated with `((index, key), ...)`."""
        assert all(0 <= index < self.length for index, _ in indexKeyValues)

        if self.top is None:
            self.top = NestedDictLevel(indexKeyValues)
        self.top.set(indexKeyValues, mergeFn)

Identity key-value maps

class IdMap:
    def __init__(self):
        self.entries = dict()  # tuple of id -> (key, value)
        # the value's first element keeps the ids stable.

    def get(self, keys, default=None):
        ids = tuple(id(key) for key in keys)
        return self.entries.get(ids, (None, default))[1]

    def __contains__(self, keys):
        ids = tuple(id(key) for key in keys)
        return ids in self.entries

    def __getitem__(self, keys):
        ids = tuple(id(key) for key in keys)
        return self.entries[ids][1]

    def __setitem__(self, keys, value):
        ids = tuple(id(key) for key in keys)
        self.entries[ids] = (keys, value)

Cached map_reduce

As mentioned earlier, we assume reduce is determined implicitly by the reduced values' type. We also have enumerate_supporting_values to find all the closed over values that yield a non-zero result, for all values in a sequence.

We can thus accept a function and an input sequence, find the supporting values, and merge the result associated with identical supporting values.

Again, we only support ground equality constraints (see assertion on L568), i.e., only equijoins. There's nothing that stops a more sophisticated implementation from using range trees to support inequality or range joins.

We'll cache the precomputed values by code object (i.e., function without closed over values) and input sequence. If we don't have a precomputed value, we'll use enumerate_supporting_values to run the function backward for each input datum from the sequence, and accumulate the results in a NestedDict. Working backward to find closure values that yield a non-zero result (for each input datum) lets us precompute a branching program that directly yields the result. We represent these branching programs explicitly, so we can also directly update a branching program for the result of merging all the values returned by mapping over the input sequence, for a given closure.

This last map_reduce definition ties everything together, and I think is really the general heart of Yannakakis's algorithm as an instance of bottom-up dynamic programming.

def _precompute_map_reduce(function, depth, inputIterable):
    """Given a function (a closure), the number of values the function
    closes over, and an input iterable, generates a `NestedDict`
    representation for `reduce(map(function, inputIterable))`, where
    the reduction step simply calls `merge` on the return values
    (tuples are merged elementwise), and the `NestedDict` keys
    represent closed over values.

    >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
    >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
    >>> nd.get((0,))
    >>> nd.get((1,)).value
    1
    >>> nd.get((2,)).value
    2
    >>> nd.get((4,)).value
    2
    """

    def merge(dst, update):
        if dst is None:
            return update

        if isinstance(dst, (tuple, list)):
            assert len(dst) == len(update)
            for value, new in zip(dst, update):
                value.merge(new)
        else:
            dst.merge(update)
        return dst

    cache = NestedDict(depth)
    for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
        cache.set(indexKeyValues, lambda old: merge(old, result))
    return cache


AGGREGATE_CACHE = IdMap()  # Map from function, input sequence -> NestedDict


def map_reduce(function, inputIterable):
    """Returns the result of merging `map(function, inputIterable)`.

    `None` return values represent neutral elements (i.e., the result
    of mapping an empty `inputIterable`), and values are otherwise
    reduced by calling `merge` on a mutable accumulator.

    Assuming `function` is well-behaced, `map_reduce` runs in time
    linear wrt `len(inputIterable)`.  It's also always cached on a
    composite key that consists of the `function`'s code object (i.e.,
    without closed over values) and the `inputIterable`.

    These complexity guarantees let us nest `map_reduce` with
    different closed over values, and still guarantee a linear-time
    total complexity.

    This wrapper ties together all the components

    >>> INVOCATION_COUNTER = 0
    >>> data = (1, 2, 2, 4, 2, 4)
    >>> def count_eql(needle):
    ...     def count(x):
    ...         global INVOCATION_COUNTER
    ...         INVOCATION_COUNTER += 1
    ...         return Sum(x) if x == needle else None
    ...     return count
    >>> INVOCATION_COUNTER
    0
    >>> map_reduce(count_eql(4), data).value
    8
    >>> INVOCATION_COUNTER
    18
    >>> map_reduce(count_eql(2), data).value
    6
    >>> INVOCATION_COUNTER
    18
    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sku_min_cost(sku):
    ...     return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
    >>> def sum_odd_or_even_skus(mod_two):
    ...     def count_if_mod_two(id_sku):
    ...         id, sku = id_sku
    ...         if id % 2 == mod_two:
    ...             return Sum(sku_min_cost(sku))
    ...     return map_reduce(count_if_mod_two, id_skus)
    >>> sum_odd_or_even_skus(0).value
    20
    >>> sum_odd_or_even_skus(1).value
    50
    """
    assert isinstance(inputIterable, collections.abc.Iterable)
    assert not isinstance(inputIterable, collections.abc.Iterator)
    code, closure, *_ = extract_function_state(function)
    if (code, inputIterable) not in AGGREGATE_CACHE:
        AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
            function, len(closure), inputIterable
        )
    return AGGREGATE_CACHE[code, inputIterable].get(closure, None)

 
if __name__ == "__main__":
    import doctest

    doctest.testmod()

Is this actually a DB post?

Although the intro name-dropped Yannakakis, the presentation here has a very programming language / logic programming flavour. I think the logic programming point of view, where we run a program backwards with logical variables, is much clearer than the specific case of conjunctive equijoin queries in the usual presentation of Yannakakis's algorithm. In particular, I think there's a clear path to handle range or comparison joins: it's all about having an index data structure to handle range queries.

It should be clear how to write conjunctive queries as Python functions, given a hypertree decomposition. The reverse is much more complex, if only because Python is much more powerful than just CQ, and that's actually a liability: this hack will blindly try to convert any function to a branching program, instead of giving up noisily when the function is too complex.

The other difference from classical CQs is that we focus on aggregates. That's because aggregates are the more general form: if we just want to avoid useless work while enumerating all join rows, we only need a boolean aggregate that tells us whether the join will yield at least one row. We could also special case types for which merges don't save space (e.g., set of row ids), and instead enumerate values by walking the branching program tree.

The aggregate viewpoint also works for fun extensions like indexed access to ranked results: that extension ends up counting the number of output values up to a certain key.

I guess, in a way, we just showed a trivial way to decorrelate queries with a hypertree-width of 1. We just have to be OK with building one index for each loop in the nest... but it should be possible to pattern match on pre-defined indexes and avoid obvious redundancy.

Extensions and future work

Use a dedicated DSL

First, the whole idea of introspecting closures to stub in logical variable is a terrible hack (looks cool though ;). A real production implementation should apply CPS partial evaluation to a purely functional programming language, then bulk reverse-evaluate with a SIMD implementation of the logical program.

There'll be restrictions on the output traces, but that's OK: a different prototype makes me believe the restrictions correspond to deterministic logspace (L), and it makes sense to restrict our analyses to L. Just like grammars are easier to work with when restricted to LL(1), DSLs that only capture L tend to be easier to analyse and optimise... and L is reasonably larger (a polynomial-time algorithm that's not in L would be a huge result).

Handle local functions

While we sin with the closure hack (extract_function_state) it should really be extended to cover local functions. This is mostly a question of going deeply into values that are mapped to functions, and of maintaining an id-keyed map from cell to OpaqueValue.

We could also add support for partial application objects, which may be easier for multiprocessing.

Parallelism

There is currently no support for parallelism, only caching. It should be easy to handle the return values (NestedDicts and aggregate classes like Sum or Min). Distributing the work in _precompute_map_reduce to merge locally is also not hard.

The main issue with parallelism is that we can't pass functions as work units, so we'd have to stick to the fork process pool.

There's also no support for moving (child) work forward when blocked waiting on a future. We'd have to spawn workers on the fly to oversubscribe when workers are blocked on a result (spawning on demand is already a given for fork workers), and to implement our own concurrency control to avoid wasted work, and probably internal throttling to avoid thrashing when we'd have more active threads than cores.

That being said, the complexity is probably worth the speed up on realistic queries.

Theta joins

At a higher level, we could support comparison joins (e.g., less than, greater than or equal, in range) if only we represented the branching programs with a data structure that supported these queries. A k-d tree would let us handle these "theta" joins, for tbe low low cost of a polylogarithmic multiplicative factor in space and time.

Self-adjusting computation

Finally, we could update the indexed branching programs incrementally after small changes to the input data. This might sound like a job for streaming engines like timely dataflow, but I think viewing each _precompute_map_reduce call as a purely functional map/reduce job gives a better fit with self-adjusting computation.

Once we add logic to recycle previously constructed indexes, it will probably make sense to allow an initial filtering step before map/reduce, with a cache key on the filter function (with closed over values and all). We can often implement the filtering more efficiently than we can run functions backward, and we'll also observe that slightly different filter functions often result in not too dissimilar filtered sets. Factoring out this filtering can thus enable more reuse of partial precomputed results.

#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
##{license, imports
# Yannakakis.py by Paul Khuong
#
# To the extent possible under law, the person who associated CC0 with
# Yannakakis.py has waived all copyright and related or neighboring rights
# to Yannakakis.py.
#
# You should have received a copy of the CC0 legalcode along with this
# work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
import collections.abc
##}
## # Linear-time analytical queries in plain Python
##
## <small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small>
##
## This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
## lets us implement linear-time (wrt the number of input data rows)
## analytical queries in regular programming languages, without
## offloading joins to a specialised database query language, thus
## avoiding the associated impedance mismatch. There are restrictions
## on the queries we can express -- Yannakakis's algorithm relies on a
## hypertree width of 1, and on having hypertree decomposition as a witness
## -- but that's kind of reasonable: a fractional hypertree width > 1
## would mean there are databases for which the intermediate results
## could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860).
## The hypertree decomposition witness isn't a burden either:
## structured programs naturally yield a hypertree decomposition,
## unlike more declarative logic programs that tend to hide the
## structure implicit in the programmer's thinking.
##
## The key is to mark function arguments as either negative (true
## inputs) or positive (possible interesting input values derived from
## negative arguments). In this hack, closed over values are
## positive, and the only negative argument is the current data row.
##
## We also assume that these functions are always used in a map/reduce
## pattern, and thus we only memoise the result of
## `map_reduce(function, input)`, with a group-structured reduction:
## the reduce function must be associative and commutative, and there
## must be a zero (neutral) value.
##
## With these constraints, we can express joins in natural Python
## without incurring the poor runtime scaling of the nested loops we
## actually wrote. This Python file describes the building blocks
## to handle aggregation queries like the following
##
## >>> id_skus = [(1, 2), (2, 2), (1, 3)]
## >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
## >>> def sku_min_cost(sku):
## ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
## ...
## >>> def sum_odd_or_even_skus(mod_two):
## ... def count_if_mod_two(id_sku):
## ... id, sku = id_sku
## ... if id % 2 == mod_two:
## ... return Sum(sku_min_cost(sku))
## ... return map_reduce(count_if_mod_two, id_skus)
## ...
##
## with linear scaling in the length of `id_skus` and `sku_costs`, and
## caching for similar queries.
##
## At a small scale, everything's fast.
##
## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
## 20
## 0.001024007797241211
## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
## 50
## 4.982948303222656e-05
##
## As we increase the scale by 1000x in both tuples, the runtime
## scales (sub) linearly for the first query, and is unchanged
## for the second:
##
## >>> id_skus = id_skus * 1000
## >>> sku_costs = sku_costs * 1000
## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
## 20000
## 0.09638118743896484
## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
## 50000
## 8.797645568847656e-05
##
## This still kind of holds when we multiply by another factor of 100:
##
## >>> id_skus = id_skus * 100
## >>> sku_costs = sku_costs * 100
## >>> begin = time.time(); print(sum_odd_or_even_skus(0).value); print(time.time() - begin)
## 2000000
## 5.9715189933776855
## >>> begin = time.time(); print(sum_odd_or_even_skus(1).value); print(time.time() - begin)
## 5000000
## 0.00021195411682128906
##
## The magic behind the curtains is memoisation (unsurprisingly), but
## a special implementation that can share work for similar closures:
## the memoisation key consists of the function *without closed over
## bindings* and the call arguments, while the memoised value is a
## data structure from the tuple of closed over values to the
## `map_reduce` output.
##
## This concrete representation of the function as a branching program
## is the core of Yannakakis's algorithm: we'll iterate over each
## datum in the input, run the function on it with logical variables
## instead of the closed over values, and generate a mapping from
## closed over values to result for all non-zero results. We'll then
## merge the mappings for all input data together (there is no natural
## ordering here, hence the group structure).
##
## The output data structure controls the join we can implement. We
## show how a simple stack of nested key-value mappings handles
## equijoins, but a [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007)
## would handle inequalities, i.e., "theta" joins (the rest of the
## machinery already works in terms of less than/greater than
## constraints).
##
## As long as we explore a bounded number of paths for each datum and
## a bounded number of `function, input` cache keys, we'll spend a
## bounded amount of time on each input datum, and thus linear time
## total. The magic of Yannakakis's algorithm is that this works even
## when there are nested `map_reduce` calls, which would naïvely
## result in polynomial time (degree equal to the nesting depth).
##{<h2>Memoising through a Python function's closed over values</h2>
##
## Even if functions were hashable and comparable for extensional
## equality, directly using closures as memoisation keys in calls like
## `map_reduce(function, input)` would result in superlinear runtime
## for nested `map_reduce` calls.
##
## This `extract_function_state` accepts a function (with or without
## closed over state), and returns four values:
## 1. The underlying code object
## 2. The tuple of closed over values (current value for mutable cells)
## 3. A function to rebind the closure with new closed over values
## 4. The name of the closed over bindings
##
## The third return value, the `rebind` function, silently fails on
## complicated cases; this is a hack, after all. In short, it only
## handles closing over immutable atomic values like integers or
## strings, but not, e.g., functions (yet), or mutable bindings.
def extract_function_state(function):
"""Accepts a function object and returns information about it: a
hash key for the object, a tuple of closed over values, a function
to return a fresh closure with a different tuple of closed over
values, and a closure of closed over names
>>> def test(x): return lambda y: x == y
>>> extract_function_state(test)[1]
()
>>> extract_function_state(test)[3]
()
>>> fun = test(4)
>>> extract_function_state(fun)[1]
(4,)
>>> extract_function_state(fun)[3]
('x',)
>>>
>>> fun(4)
True
>>> fun(5)
False
>>> rebind = extract_function_state(fun)[2]
>>> rebound_4, rebound_5 = rebind([4]), rebind([5])
>>> rebound_4(4)
True
>>> rebound_4(5)
False
>>> rebound_5(4)
False
>>> rebound_5(5)
True
"""
code = function.__code__
names = code.co_freevars
if function.__closure__ is None: # Toplevel function
assert names == ()
def rebind(values):
if len(values) != 0:
raise RuntimeError(
f"Values must be empty for toplevel function. values={values}"
)
return function
return code, (), rebind, names
closure = tuple(cell.cell_contents for cell in function.__closure__)
assert len(names) == len(closure), (closures, names)
# TODO: rebind recursively (functions are also cells)
def rebind(values):
if len(values) != len(names):
raise RuntimeError(
f"Values must match names. names={names} values={values}"
)
return function.__class__(
code,
function.__globals__,
function.__name__,
function.__defaults__,
tuple(
cell.__class__(value)
for cell, value in zip(function.__closure__, values)
),
)
return code, closure, rebind, names
##}
## ## Logical variables for closed-over values
##
## We wish to enumerate the support of a function call (parameterised
## over closed over values), and the associated result value. We'll
## do that by rebinding the closure to point at instances of
## `OpaqueValue` and enumerating all the possible constraints on these
## `OpaqueValue`s. These `OpaqueValue`s work like logical variables
## that let us run a function in reverse: when get a non-zero
## (non-None) return value, we look at the accumulated constraint set
## on the opaque values and use them to update the data representation
## of the function's result (we assume that we can represent the
## constraints on all `OpaqueValue`s in our result data structure).
##
## Currently, we only support nested dictionaries, so each
## `OpaqueValue`s must be either fully unconstrained (wildcard that
## matches any value), or constrained to be exactly equal to a value.
## There's no reason we can't use k-d trees though, and it's not
## harder to track a pair of bounds (lower and upper) than a set of
## inequalities, so we'll handle the general ordered `OpaqueValue`
## case.
##
## In the input program (the query), we assume closed over values are
## only used for comparisons (equality, inequality, relational
## operators, or conversion to bool, i.e., non-zero testing). Knowing
## the result of each (non-redundant) comparison tightens the range
## of potential values for the `OpaqueValue`... eventually down to a
## single point value that our hash-based indexes can handle.
##
## Of course, if a comparison isn't redundant, there are multiple
## feasible results, so we need an external oracle to pick one. An
## external caller is responsible for injecting its logic as
## `OpaqueValue.CMP_HANDLER`, and driving the exploration of the
## search space.
##
## N.B., the set of constraints we can handle is determined by the
## ground data structure to represent finitely supported functions.
class OpaqueValue:
"""An opaque value is a one-dimensional range of Python values,
represented as a lower and an upper bound, each of which is
potentially exclusive.
`OpaqueValue`s are only used in queries for comparisons with
ground values. All comparisons are turned into three-way
`__cmp__` calls; non-redundant `__cmp__` calls (which could return
more than one value) are resolved by calling `CMP_HANDLER`
and tightening the bound in response.
>>> x = OpaqueValue("x")
>>> x == True
True
>>> 1 if x else 2
1
>>> x.reset()
>>> ### Not supported by our index data structure (yet)
>>> # >>> x > 4
>>> # False
>>> # >>> x < 4
>>> # False
>>> # >>> x == 4
>>> # True
>>> # >>> x < 10
>>> # True
>>> # >>> x >= 10
>>> # False
>>> # >>> x > -10
>>> # True
>>> # >>> x <= -10
>>> # False
"""
# Resolving function for opaque values
CMP_HANDLER = lambda opaque, value: 0
def __init__(self, name, index=0):
self.name = name
self.index = index
# Upper and lower bounds
self.lower = self.upper = None
self.lowerExcl = self.upperExcl = True
def __str__(self):
left = "(" if self.lowerExcl else "["
right = ")" if self.upperExcl else "]"
return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"
def reset(self):
"""Clears all accumulated constraints on this `OpaqueValue`."""
self.lower = self.upper = None
self.lowerExcl = self.upperExcl = True
def indefinite(self):
"""Returns whether this `OpaqueValue` is still unconstrained."""
return self.lower == None and self.upper == None
def definite(self):
"""Returns whether is `OpaqueValue` is constrained to an exact value."""
return (
self.lower == self.upper
and self.lower is not None
and not self.lowerExcl
and not self.upperExcl
)
def value(self):
"""Returns the exact value for this `OpaqueValue`, assuming there is one."""
return self.lower
def _contains(self, value, strictly):
"""Returns whether this `OpaqueValue`'s range includes `value`
(maybe `strictly` inside the range).
The first value is the containment truth value, and the second
is the forced `__cmp__` value, if the `value` is *not*
(strictly) contained in the range.
"""
if self.lower is not None:
if self.lower > value:
return False, 1
if self.lower == value and (strictly or self.lowerExcl):
return False, 1
if self.upper is not None:
if self.upper < value:
return False, -1
if self.upper == value and (strictly or self.upperExcl):
return False, -1
return True, 0 if self.definite() else None
def contains(self, value, strictly=False):
"""Returns whether `values` is `strictly` contained in the
`OpaqueValue`'s range.
"""
try:
return self._contains(value, strictly)[0]
except TypeError:
return False
def potential_mask(self, other):
"""Returns the set of potential `__cmp__` values for `other`
that are compatible with the current range: bit 0 is set if
`-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
is possible.
"""
if not self.contains(other):
return 0
if self.contains(other, strictly=True):
return 7
if self.definite() and self.value == other:
return 2
# We have a non-strict inclusion, and inequality.
if self.lower == other and self.lowerExcl:
return 6
assert self.upper == other and self.upperExcl
return 3
def __cmp__(self, other):
"""Three-way comparison between this `OpaqueValue` and `other`.
When the result is known from the current bound, we just return
that value. Otherwise, we ask `CMP_HANDLER` what value to return
and update the bound accordingly.
"""
if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
# If we have a definite value and `other` is an indefinite
# `OpaqueValue`, flip the comparison order to let the `other`
# argument be a ground value.
return -other.__cmp__(self.value())
if isinstance(other, OpaqueValue) and not other.definite():
raise RuntimeError(
f"OpaqueValue may only be compared with ground values. self={self} other={other}"
)
if isinstance(other, OpaqueValue):
other = other.value() # Make sure `other` is a ground value
if other is None:
# We use `None` internally, and it doesn't compare well
raise RuntimeError("OpaqueValue may not be compared with None")
compatible, order = self._contains(other, False)
if order is not None:
return order
order = OpaqueValue.CMP_HANDLER(self, other)
if order < 0:
self._add_bound(upper=other, upperExcl=True)
elif order == 0:
self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
else:
self._add_bound(lower=other, lowerExcl=True)
return order
def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
"""Updates the internal range for this new bound."""
assert lower is None or self.contains(lower, strictly=lowerExcl)
assert upper is None or self.contains(upper, strictly=upperExcl)
if lower is not None:
self.lower = lower
self.lowerExcl = lowerExcl
assert upper is None or self.contains(upper, strictly=upperExcl)
if upper is not None:
self.upper = upper
self.upperExcl = upperExcl
def __bool__(self):
return self != 0
def __eq__(self, other):
return self.__cmp__(other) == 0
def __ne__(self, other):
return self.__cmp__(other) != 0
# No other comparator because we don't index ranges
# (no k-d tree).
# def __lt__(self, other):
# return self.__cmp__(other) < 0
# def __le__(self, other):
# return self.__cmp__(other) <= 0
# def __gt__(self, other):
# return self.__cmp__(other) > 0
# def __ge__(self, other):
# return self.__cmp__(other) >= 0
## ## Depth-first exploration of a function call's support
##
## We assume a `None` result represents a zero value wrt the aggregate
## merging function (e.g., 0 for a sum). For convenience, we also treat
## tuples and lists of `None`s identically.
##
## We simply maintain a stack of `CMP_HANDLER` calls, where each entry
## in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER`
## results still to explore (-1, 0, or 1). This stack is filled on demand,
## and `CMP_HANDLER` returns the first result allowed by the bitset.
##
## Once we have a result, we tweak the stack to force depth-first
## exploration of a different part of the solution space: we drop
## the first bit in the bitset of results to explore, and drop the
## entry wholesale if the bitset is now empty (all zero). When
## this tweaking leaves an empty stack, we're done.
##
## This ends up enumerating all the paths through the function call
## with a non-recursive depth-first traversal.
##
## We then do the same for each datum in our input sequence, and merge
## results for identical keys together.
def is_zero_result(value):
"""Checks if `value` is a "zero" aggregate value: either `None`,
or an iterable of all `None`.
>>> is_zero_result(None)
True
>>> is_zero_result(False)
False
>>> is_zero_result(True)
False
>>> is_zero_result(0)
False
>>> is_zero_result(-10)
False
>>> is_zero_result(1.5)
False
>>> is_zero_result("")
False
>>> is_zero_result("asd")
False
>>> is_zero_result((None, None))
True
>>> is_zero_result((None, 1))
False
>>> is_zero_result([])
True
>>> is_zero_result([None])
True
>>> is_zero_result([None, (None, None)])
False
"""
if value is None:
return True
if isinstance(value, (tuple, list)):
return all(item is None for item in value)
return False
def enumerate_opaque_values(function, values):
"""Explores the set of `OpaqueValue` constraints when calling
`function`.
Enumerates all constraints for the `OpaqueValue` instances in
`values`, and yields a pair of equality constraints for the
`value` and the corresponding return value, for all non-zero
values.
This essentially turns `function()` into a *not necessarily
ordered* branching program on `values`.
>>> x, y = OpaqueValue("x", 0), OpaqueValue("y", 1)
>>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
[(((0, 0),), 1), (((0, 1), (1, 2)), 2)]
"""
explorationStack = [] # List of (value, bitmaskOfCmp)
while True:
for value in values:
value.reset()
stackIndex = 0
constraints = []
def handle(value, other):
nonlocal stackIndex
if len(explorationStack) == stackIndex:
explorationStack.append((value, value.potential_mask(other)))
expectedValue, mask = explorationStack[stackIndex]
assert value is expectedValue
assert mask != 0
if (mask & 1) != 0:
ret = -1
elif (mask & 2) != 0:
constraints.append((value, other))
ret = 0
elif (mask & 4) != 0:
ret = 1
else:
assert False, f"bad mask {mask}"
stackIndex += 1
return ret
OpaqueValue.CMP_HANDLER = handle
result = function()
if not is_zero_result(result):
for value in values:
assert (
value.definite() or value.indefinite()
), f"partially constrained {value} temporarily unsupported"
assert value.indefinite() or any(key is value for key, _ in constraints)
yield (tuple((key.index, other) for key, other in constraints),
result)
# Drop everything that was fully explored, then move the next
# top of stack to the next option.
while explorationStack:
value, mask = explorationStack[-1]
assert 0 <= mask < 8
mask &= mask - 1 # Drop first bit
if mask != 0:
explorationStack[-1] = (value, mask)
break
explorationStack.pop()
if not explorationStack:
break
def enumerate_supporting_values(function, args):
"""Lists the bag of mapping from closed over values to non-zero result,
for all calls `function(arg) for args in args`.
>>> def count_eql(needle): return lambda x: 1 if x == needle else None
>>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
[(((0, 1),), 1), (((0, 2),), 1), (((0, 4),), 1), (((0, 4),), 1), (((0, 2),), 1)]
"""
_, _, rebind, names = extract_function_state(function)
values = [OpaqueValue(name, index) for index, name in enumerate(names)]
reboundFunction = rebind(values)
for arg in args:
yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
## ## Type driven merges
##
## The interesting part of map/reduce is the reduction step. While
## some like to use first-class functions to describe reduction, in my
## opinion, it often makes more sense to define reduction at the type
## level: it's essential that merge operators be commutative and
## associative, so isolating the merge logic in dedicated classes
## makes sense to me.
##
## This file defines a single mergeable value type, `Sum`, but we
## could have different ones, e.g., hyperloglog unique counts, or
## streaming statistical moments.
class Sum:
"""A counter for summed values."""
def __init__(self, value=0):
self.value = value
def merge(self, other):
assert isinstance(other, Sum)
self.value += other.value
class Min:
"""A running `min` value tracker."""
def __init__(self, value=None):
self.value = value
def merge(self, other):
assert isinstance(other, Min)
if not self.value:
self.value = other.value
elif other.value:
self.value = min(self.value, other.value)
## ## Nested dictionary with wildcard
##
## There's a direct relationship between the data structure we use to
## represent the result of function calls as branching functions, and
## the constraints we can support on closed over values for non-zero
## results.
##
## In a real implementation, this data structure would host most of
## the complexity: it's the closest thing we have to indexes.
##
## For now, support equality *with ground value* as our only
## constraint. This means we can dispatch on the closed over values
## by order of appearance, with either a wildcard value (matches
## everything) or a hash map. At each internal level, the value is
## another `NestedDictLevel`. At the leaf, the value is the
## precomputed value.
class NestedDictLevel:
"""One level in a nested dictionary index. We may either have a
value for everything (leaf node), or a key-value dict for a specific
index in the tuple key.
"""
def __init__(self, indexValues, depth=0):
assert depth <= len(indexValues)
self.keyIndex = indexValues[depth][0] if depth < len(indexValues) else None
self.value = None
self.dict = dict()
def get(self, keys, default):
"""Gets the value for `keys` in this level, or `default` if None."""
assert self.value is not None or self.keyIndex is not None
if self.value is not None:
return self.value
next = self.dict.get(keys[self.keyIndex], None)
if next is None:
return default
return next.get(keys, default)
def set(self, indexValues, mergeFunction, depth=0):
"""Sets the value for `keys` in this level."""
assert depth <= len(indexValues)
if depth == len(indexValues): # Leaf
self.value = mergeFunction(self.value)
return
index, value = indexValues[depth]
assert self.keyIndex == index
if value not in self.dict:
self.dict[value] = NestedDictLevel(indexValues, depth + 1)
self.dict[value].set(indexValues, mergeFunction, depth + 1)
class NestedDict:
"""A nested dict of a given `depth` maps tuples of `depth` keys to
a value. Each `NestedDictLevel` handles a different level. We
don't test each index in the tuple in fixed order to avoid the
combinatorial explosion that can happen when change ordering
(e.g., conversion from BDD to OBDD).
"""
def __init__(self, length):
self.top = None
self.length = length
def get(self, keys, default=None):
"""Gets the value associated with `keys`, or `default` if None."""
assert len(keys) == self.length
if self.top is None:
return default
return self.top.get(keys, default)
def set(self, indexKeyValues, mergeFn):
"""Sets the value associated with `((index, key), ...)`."""
assert all(0 <= index < self.length for index, _ in indexKeyValues)
if self.top is None:
self.top = NestedDictLevel(indexKeyValues)
self.top.set(indexKeyValues, mergeFn)
##{<h2>Identity key-value maps</h2>
class IdMap:
def __init__(self):
self.entries = dict() # tuple of id -> (key, value)
# the value's first element keeps the ids stable.
def get(self, keys, default=None):
ids = tuple(id(key) for key in keys)
return self.entries.get(ids, (None, default))[1]
def __contains__(self, keys):
ids = tuple(id(key) for key in keys)
return ids in self.entries
def __getitem__(self, keys):
ids = tuple(id(key) for key in keys)
return self.entries[ids][1]
def __setitem__(self, keys, value):
ids = tuple(id(key) for key in keys)
self.entries[ids] = (keys, value)
##}
## ## Cached `map_reduce`
##
## As mentioned earlier, we assume `reduce` is determined implicitly
## by the reduced values' type. We also have
## `enumerate_supporting_values` to find all the closed over values
## that yield a non-zero result, for all values in a sequence.
##
## We can thus accept a function and an input sequence, find the
## supporting values, and merge the result associated with identical
## supporting values.
##
## Again, we only support ground equality constraints (see assertion
## on L568), i.e., only equijoins. There's nothing that stops a more
## sophisticated implementation from using range trees to support
## inequality or range joins.
##
## We'll cache the precomputed values by code object (i.e., function
## without closed over values) and input sequence. If we don't have a
## precomputed value, we'll use `enumerate_supporting_values` to run
## the function backward for each input datum from the sequence, and
## accumulate the results in a `NestedDict`. Working backward to find
## closure values that yield a non-zero result (for each input datum)
## lets us precompute a branching program that directly yields the
## result. We represent these branching programs explicitly, so we
## can also directly update a branching program for the result of
## merging all the values returned by mapping over the input sequence,
## for a given closure.
##
## This last `map_reduce` definition ties everything together, and
## I think is really the general heart of Yannakakis's algorithm
## as an instance of bottom-up dynamic programming.
def _precompute_map_reduce(function, depth, inputIterable):
"""Given a function (a closure), the number of values the function
closes over, and an input iterable, generates a `NestedDict`
representation for `reduce(map(function, inputIterable))`, where
the reduction step simply calls `merge` on the return values
(tuples are merged elementwise), and the `NestedDict` keys
represent closed over values.
>>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
>>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
>>> nd.get((0,))
>>> nd.get((1,)).value
1
>>> nd.get((2,)).value
2
>>> nd.get((4,)).value
2
"""
def merge(dst, update):
if dst is None:
return update
if isinstance(dst, (tuple, list)):
assert len(dst) == len(update)
for value, new in zip(dst, update):
value.merge(new)
else:
dst.merge(update)
return dst
cache = NestedDict(depth)
for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
cache.set(indexKeyValues, lambda old: merge(old, result))
return cache
AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict
def map_reduce(function, inputIterable):
"""Returns the result of merging `map(function, inputIterable)`.
`None` return values represent neutral elements (i.e., the result
of mapping an empty `inputIterable`), and values are otherwise
reduced by calling `merge` on a mutable accumulator.
Assuming `function` is well-behaced, `map_reduce` runs in time
linear wrt `len(inputIterable)`. It's also always cached on a
composite key that consists of the `function`'s code object (i.e.,
without closed over values) and the `inputIterable`.
These complexity guarantees let us nest `map_reduce` with
different closed over values, and still guarantee a linear-time
total complexity.
This wrapper ties together all the components
>>> INVOCATION_COUNTER = 0
>>> data = (1, 2, 2, 4, 2, 4)
>>> def count_eql(needle):
... def count(x):
... global INVOCATION_COUNTER
... INVOCATION_COUNTER += 1
... return Sum(x) if x == needle else None
... return count
>>> INVOCATION_COUNTER
0
>>> map_reduce(count_eql(4), data).value
8
>>> INVOCATION_COUNTER
18
>>> map_reduce(count_eql(2), data).value
6
>>> INVOCATION_COUNTER
18
>>> id_skus = [(1, 2), (2, 2), (1, 3)]
>>> sku_costs = [(1, 10), (2, 20), (3, 30)]
>>> def sku_min_cost(sku):
... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs).value
>>> def sum_odd_or_even_skus(mod_two):
... def count_if_mod_two(id_sku):
... id, sku = id_sku
... if id % 2 == mod_two:
... return Sum(sku_min_cost(sku))
... return map_reduce(count_if_mod_two, id_skus)
>>> sum_odd_or_even_skus(0).value
20
>>> sum_odd_or_even_skus(1).value
50
"""
assert isinstance(inputIterable, collections.abc.Iterable)
assert not isinstance(inputIterable, collections.abc.Iterator)
code, closure, *_ = extract_function_state(function)
if (code, inputIterable) not in AGGREGATE_CACHE:
AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
function, len(closure), inputIterable
)
return AGGREGATE_CACHE[code, inputIterable].get(closure, None)
if __name__ == "__main__":
import doctest
doctest.testmod()
## ## Is this actually a DB post?
##
## Although the intro name-dropped Yannakakis, the presentation here
## has a very programming language / logic programming flavour. I
## think the logic programming point of view, where we run a program
## backwards with logical variables, is much clearer than the specific
## case of conjunctive equijoin queries in the usual presentation of
## Yannakakis's algorithm. In particular, I think there's a clear
## path to handle range or comparison joins: it's all about having an
## index data structure to handle range queries.
##
## It should be clear how to write conjunctive queries as Python
## functions, given a hypertree decomposition. The reverse is much
## more complex, if only because Python is much more powerful than
## just CQ, and that's actually a liability: this hack will blindly
## try to convert any function to a branching program, instead of
## giving up noisily when the function is too complex.
##
## The other difference from classical CQs is that we focus on
## aggregates. That's because aggregates are the more general form:
## if we just want to avoid useless work while enumerating all join
## rows, we only need a boolean aggregate that tells us whether the
## join will yield at least one row. We could also special case types
## for which merges don't save space (e.g., set of row ids), and
## instead enumerate values by walking the branching program tree.
##
## The aggregate viewpoint also works for
## [fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf):
## that extension ends up counting the number of output values up to a
## certain key.
##
## I guess, in a way, we just showed a trivial way to decorrelate
## queries with a hypertree-width of 1. We just have to be OK with
## building one index for each loop in the nest... but it should be
## possible to pattern match on pre-defined indexes and avoid obvious
## redundancy.
##
## ## Extensions and future work
##
## ### Use a dedicated DSL
##
## First, the whole idea of introspecting closures to stub in logical
## variable is a terrible hack (looks cool though ;). A real
## production implementation should apply CPS partial evaluation to a
## purely functional programming language, then bulk reverse-evaluate
## with a SIMD implementation of the logical program.
##
## There'll be restrictions on the output traces, but that's OK: a
## different prototype makes me believe the restrictions correspond to
## deterministic logspace (L), and it makes sense to restrict our
## analyses to L. Just like grammars are easier to work with when
## restricted to LL(1), DSLs that only capture L tend to be easier to
## analyse and optimise... and L is reasonably larger (a
## polynomial-time algorithm that's not in L would be a *huge*
## result).
##
## ### Handle local functions
##
## While we sin with the closure hack (`extract_function_state`) it
## should really be extended to cover local functions. This is mostly
## a question of going deeply into values that are mapped to
## functions, and of maintaining an id-keyed map from cell to
## `OpaqueValue`.
##
## We could also add support for partial application objects, which
## may be easier for multiprocessing.
##
## ### Parallelism
##
## There is currently no support for parallelism, only caching. It
## should be easy to handle the return values (`NestedDict`s and
## aggregate classes like `Sum` or `Min`). Distributing the work in
## `_precompute_map_reduce` to merge locally is also not hard.
##
## The main issue with parallelism is that we can't pass functions
## as work units, so we'd have to stick to the `fork` process pool.
##
## There's also no support for moving (child) work forward when
## blocked waiting on a future. We'd have to spawn workers on the fly
## to oversubscribe when workers are blocked on a result (spawning on
## demand is already a given for `fork` workers), and to implement our
## own concurrency control to avoid wasted work, and probably internal
## throttling to avoid thrashing when we'd have more active threads
## than cores.
##
## That being said, the complexity is probably worth the speed up on
## realistic queries.
##
## ### Theta joins
##
## At a higher level, we could support comparison joins (e.g., less
## than, greater than or equal, in range) if only we represented the
## branching programs with a data structure that supported these
## queries. A [k-d tree](https://dl.acm.org/doi/10.1145/361002.361007) would
## let us handle these "theta" joins, for tbe low low cost of a
## polylogarithmic multiplicative factor in space and time.
##
## ### Self-adjusting computation
##
## Finally, we could update the indexed branching programs
## incrementally after small changes to the input data. This might
## sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
## but I think viewing each `_precompute_map_reduce` call as a purely
## functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).
##
## Once we add logic to recycle previously constructed indexes, it
## will probably make sense to allow an initial filtering step before
## map/reduce, with a cache key on the filter function (with closed
## over values and all). We can often implement the filtering more
## efficiently than we can run functions backward, and we'll also
## observe that slightly different filter functions often result
## in not too dissimilar filtered sets. Factoring out this filtering
## can thus enable more reuse of partial precomputed results.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment