-
-
Save fish2000/1af4b852d20b7568a9b9c90fe2346b6d to your computer and use it in GitHub Desktop.
A Taste Of Julia / C++ in Python – simple Python multiple dispatch from type hints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import * | |
| import re | |
| def to_regex(typevar, groups): | |
| def to_matchgroup(arg, groups): | |
| if type(arg) is TypeVar: | |
| if arg in groups: | |
| return "(?P={})".format(arg.__name__) | |
| else: | |
| groups |= {arg} | |
| return "(?P<{}>.*?)".format(arg.__name__) | |
| else: | |
| return to_regex(arg, groups) | |
| if typevar in {float, int, str}: | |
| return typevar.__name__ | |
| elif typevar.mro()[1] is Sequence: | |
| return "(?:list|set|tuple)\[{}\]".format(to_matchgroup(typevar.__args__[0], groups)) | |
| return ".*?" | |
| def get_element_types(sequence): | |
| return set(type(el) for el in sequence) | |
| def to_callee(arg): | |
| if type(arg) in [float, int]: | |
| return type(arg).__name__ | |
| elif type(arg) in [list, set, tuple]: | |
| t = type(arg).__name__ + '[{}]' | |
| eltypes = get_element_types(arg) | |
| if len(eltypes) == 1: | |
| return t.format(list(eltypes)[0].__name__) | |
| else: | |
| raise RuntimeError("Not implemented yet.") | |
| else: | |
| raise RuntimeError("Not implemented yet.") | |
| def to_match_target(caller_signature): | |
| return ", ".join([to_callee(el) for el in caller_signature]) | |
| def to_regex_sig(caller_signature): | |
| groups = set() | |
| return ", ".join([to_regex(el, groups) for el in caller_signature]) | |
| class overloaded(object): | |
| fmap = {} | |
| def __init__(self, f): | |
| signature = tuple(x[1] for x in f.__annotations__.items()) | |
| groups = set() | |
| self.fmap[to_regex_sig(signature)] = f | |
| def __call__(self, *args): | |
| match_sig = to_match_target(args) | |
| for key, func in self.fmap.items(): | |
| print("Matching: {} against\n {}\n".format(match_sig, key)) | |
| if (re.match(key, match_sig)): | |
| print(" === MATCH ===\n\n") | |
| return func(*args) | |
| else: | |
| raise RuntimeError("No overload found for ", match_sig) | |
| @overloaded | |
| def add(a: int, b: int): | |
| return a + b + 100 | |
| @overloaded | |
| def add(a: float, b: float): | |
| return a + b | |
| T = TypeVar('T') | |
| U = TypeVar('U') | |
| @overloaded | |
| def add(a: Sequence[T], b: float): | |
| return [x + b for x in a] | |
| @overloaded | |
| def add(a: Sequence[T], b: Sequence[T]): | |
| return [x + y for x, y in zip(a, b)] | |
| @overloaded | |
| def add(a: Sequence[T], b: Sequence[str]): | |
| return [str(x) + y for x, y in zip(a, b)] | |
| if __name__ == '__main__': | |
| print(add(3, 5)) | |
| print(add(4.5, 8.2)) | |
| print(add([1, 2, 3], 5.0)) | |
| print(add([1, 2, 3], [1, 2, 3])) | |
| print(add([1, 2, 3], ["a", "b", "c"])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment