{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e78720bf", "metadata": {}, "outputs": [], "source": [ "from functools import singledispatch\n", "from contextlib import suppress" ] }, { "cell_type": "markdown", "id": "26787b9b", "metadata": {}, "source": [ "Regular single dispatch" ] }, { "cell_type": "code", "execution_count": 2, "id": "2207ec50", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IntNumeric -> 2\n", "FloatNumeric -> 4.0\n", "ComplexNumeric -> (2+3j)\n" ] } ], "source": [ "class Base:\n", " def __init__(self, value=None):\n", " self.value = value\n", " \n", " def who(self):\n", " print(f\"{self.__class__.__name__} -> {self.value}\")\n", " \n", "@singledispatch\n", "def make_numeric(value):\n", " raise NotImplementedError(\"Nada\")\n", "\n", "@make_numeric.register(int)\n", "class IntNumeric(Base):\n", " pass\n", "\n", "@make_numeric.register(float)\n", "class FloatNumeric(Base):\n", " pass\n", "\n", "@make_numeric.register(complex)\n", "class ComplexNumeric(Base):\n", " pass\n", "\n", "#@make_numeric.register(unknown)\n", "#class UnknownNumeric(Base):\n", "# pass\n", "\n", "i = make_numeric(2)\n", "f = make_numeric(4.0)\n", "c = make_numeric(2+3j)\n", "\n", "i.who()\n", "f.who()\n", "c.who()" ] }, { "cell_type": "markdown", "id": "55172db5", "metadata": {}, "source": [ "Single dispatch with the implementation class storing the `type` it specialises on as a regular parameter." ] }, { "cell_type": "code", "execution_count": 3, "id": "87e74e21", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IntNumeric -> 2\n", "FloatNumeric -> 4.0\n", "ComplexNumeric -> (2+3j)\n" ] } ], "source": [ "class Base:\n", " def __init__(self, value=None):\n", " self.value = value\n", " \n", " def who(self):\n", " print(f\"{self.__class__.__name__} -> {self.value}\")\n", " \n", "@singledispatch\n", "def make_numeric(value):\n", " raise NotImplementedError(\"Nada\")\n", "\n", "class IntNumeric(Base):\n", " base_class = int\n", "\n", "class FloatNumeric(Base):\n", " base_class = float\n", "\n", "class ComplexNumeric(Base):\n", " base_class = complex\n", "\n", "#class UnknownNumeric(Base):\n", "# base_class = unknown\n", " \n", "make_numeric.register(IntNumeric.base_class, IntNumeric)\n", "make_numeric.register(FloatNumeric.base_class, FloatNumeric)\n", "make_numeric.register(ComplexNumeric.base_class, ComplexNumeric)\n", "\n", "i = make_numeric(2)\n", "f = make_numeric(4.0)\n", "c = make_numeric(2+3j)\n", "\n", "i.who()\n", "f.who()\n", "c.who()" ] }, { "cell_type": "markdown", "id": "77ef07e3", "metadata": {}, "source": [ "Single dispatch with the implementation class holding the `type` it specialise on in a `staticmethod`. While the method hides an unknown type when the class is created, any unknown types get exposed with the explicit calls to `register`." ] }, { "cell_type": "code", "execution_count": 4, "id": "f0174eb5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IntNumeric -> 2\n", "FloatNumeric -> 4.0\n", "ComplexNumeric -> (2+3j)\n" ] } ], "source": [ "class Base: # This is the same as `class Base(metaclass=type)`\n", " def __init__(self, value=None):\n", " self.value = value\n", " \n", " def who(self):\n", " print(f\"{self.__class__.__name__} -> {self.value}\")\n", " \n", "@singledispatch\n", "def make_numeric(value):\n", " raise NotImplementedError(\"Nada\")\n", "\n", "\n", "class IntNumeric(Base):\n", " base_class = staticmethod(lambda: int)\n", "\n", "\n", "class FloatNumeric(Base):\n", " base_class = staticmethod(lambda: float)\n", "\n", "\n", "class ComplexNumeric(Base):\n", " base_class = staticmethod(lambda: complex)\n", "\n", "class UnknownNumeric(Base):\n", " base_class = staticmethod(lambda: unknown)\n", " \n", "make_numeric.register(IntNumeric.base_class(), IntNumeric)\n", "make_numeric.register(FloatNumeric.base_class(), FloatNumeric)\n", "make_numeric.register(ComplexNumeric.base_class(), ComplexNumeric)\n", "#make_numeric.register(UnknownNumeric.base_class(), UnknownNumeric)\n", "\n", "i = make_numeric(2)\n", "f = make_numeric(4.0)\n", "c = make_numeric(2+3j)\n", "\n", "i.who()\n", "f.who()\n", "c.who()" ] }, { "cell_type": "markdown", "id": "079c7817", "metadata": {}, "source": [ "Single dispatch using a metaclass to register all implementation classes. The implementations for unknown types will not be registered. The helper method for single dispatch is hidden and the base class (think `VetiverHandler`) takes on the role." ] }, { "cell_type": "code", "execution_count": 5, "id": "75f9147e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IntNumeric -> 2\n", "FloatNumeric -> 4.0\n", "ComplexNumeric -> (2+3j)\n" ] } ], "source": [ "@singledispatch\n", "def make_numeric(value):\n", " raise NotImplementedError(\"Nada\")\n", " \n", "class AutoRegisterHandler(type): # inheriting from type/metaclass creates another metaclass\n", " # __new__ of a metaclass is invoked when a new class is being created\n", " def __new__(meta, name, bases, clsdict):\n", " cls = super().__new__(meta, name, bases, clsdict)\n", " with suppress(AttributeError, NameError):\n", " make_numeric.register(cls.base_class(), cls)\n", " return cls\n", "\n", "\n", "class Base(metaclass=AutoRegisterHandler):\n", " # __new__ of a regular class is invoked before the object is instantied,\n", " # the object will be of the class it returns\n", " def __new__(cls, value=None):\n", " implementation_cls = make_numeric.registry[type(value)]\n", " return super().__new__(implementation_cls)\n", " \n", " def __init__(self, value=None):\n", " self.value = value\n", " \n", " def who(self):\n", " print(f\"{self.__class__.__name__} -> {self.value}\")\n", "\n", "class IntNumeric(Base): # type\n", " base_class = staticmethod(lambda: int)\n", "\n", "class FloatNumeric(Base):\n", " base_class = staticmethod(lambda: float)\n", "\n", "class ComplexNumeric(Base):\n", " base_class = staticmethod(lambda: complex)\n", "\n", "#unknown = str\n", "#del unknown\n", "class Unknown(Base):\n", " base_class = staticmethod(lambda: unknown)\n", "\n", "i = Base(2)\n", "f = Base(4.0)\n", "c = Base(2+3j)\n", "#u = Base('u')\n", "\n", "i.who()\n", "f.who()\n", "c.who()\n", "#u.who()" ] }, { "cell_type": "markdown", "id": "a98a2e07", "metadata": {}, "source": [ "Using `__init__subclass__` instead of a metaclass. Thanks to [machow](https://github.com/machow)." ] }, { "cell_type": "code", "execution_count": 6, "id": "3a92c752", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IntNumeric -> 2\n", "FloatNumeric -> 4.0\n", "ComplexNumeric -> (2+3j)\n" ] } ], "source": [ "@singledispatch\n", "def make_numeric(value):\n", " raise NotImplementedError(\"Nada\")\n", "\n", "class Base:\n", " # Register the specialising implementation subclass when it is created\n", " @classmethod\n", " def __init_subclass__(cls, **kwargs):\n", " super().__init_subclass__(**kwargs)\n", " with suppress(AttributeError, NameError):\n", " make_numeric.register(cls.base_class(), cls)\n", " \n", " # __new__ of a regular class is invoked before the object is instantied,\n", " # the object will be of the class it returns\n", " def __new__(cls, value=None):\n", " implementation_cls = make_numeric.registry[type(value)]\n", " return super().__new__(implementation_cls)\n", " \n", " def __init__(self, value=None):\n", " self.value = value\n", " \n", " def who(self):\n", " print(f\"{self.__class__.__name__} -> {self.value}\")\n", "\n", "class IntNumeric(Base): # type\n", " base_class = staticmethod(lambda: int)\n", "\n", "class FloatNumeric(Base):\n", " base_class = staticmethod(lambda: float)\n", "\n", "class ComplexNumeric(Base):\n", " base_class = staticmethod(lambda: complex)\n", "\n", "#unknown = str\n", "#del unknown\n", "class Unknown(Base):\n", " base_class = staticmethod(lambda: unknown)\n", "\n", "i = Base(2)\n", "f = Base(4.0)\n", "c = Base(2+3j)\n", "#u = Base('u')\n", "\n", "i.who()\n", "f.who()\n", "c.who()\n", "#u.who()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }