Skip to content

Instantly share code, notes, and snippets.

@jinyangustc
Created April 19, 2025 22:14
Show Gist options
  • Select an option

  • Save jinyangustc/28fedfb71cd8c3b2e78c68931b8de3e6 to your computer and use it in GitHub Desktop.

Select an option

Save jinyangustc/28fedfb71cd8c3b2e78c68931b8de3e6 to your computer and use it in GitHub Desktop.

Revisions

  1. jinyangustc created this gist Apr 19, 2025.
    103 changes: 103 additions & 0 deletions python_sum_types.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,103 @@
    import time
    import traceback
    from typing import Literal, final


    class String:
    """
    A variant containing a string
    """


    class ListOfInts:
    """
    A variant containing a list of ints
    """


    @final
    class MySumType:
    def __init__(
    self,
    adt: tuple[String, str] | tuple[ListOfInts, list[int]],
    ) -> None:
    self.adt = adt


    def consume(err_data: str):
    x1 = MySumType((String(), 'hello'))
    x2 = MySumType((ListOfInts(), [1, 2, 3]))
    x3 = MySumType((ListOfInts(), err_data))
    for x in [x1, x2, x3]:
    match x.adt:
    case (String(), data):
    print(data.upper())
    case (ListOfInts(), data):
    print(sum(data))


    if __name__ == '__main__':
    try:
    consume('lsp will show error')
    except TypeError as e:
    traceback.print_exc()

    print('---')

    # --- benchmark ---
    data = ['hello', 'world', [1, 2, 3], list(range(1000))]

    sum_type_data: list[MySumType] = []
    tuple_data: list[
    tuple[Literal['str'], str] | tuple[Literal['list_of_ints'], list[int]]
    ] = []
    for x in data:
    if isinstance(x, str):
    sum_type_data.append(MySumType((String(), x)))
    tuple_data.append(('str', x))
    else:
    sum_type_data.append(MySumType((ListOfInts(), x)))
    tuple_data.append(('list_of_ints', x))

    max_iter = 1_000_000

    counter = 0
    start_time = time.perf_counter()
    for i in range(max_iter):
    for x in sum_type_data:
    match x.adt:
    case (String(), xx):
    counter += len(xx)
    case (ListOfInts(), xx):
    counter += sum(xx)
    end_time = time.perf_counter()
    print(
    f'match with wrapper class: {end_time - start_time:.4f} seconds for {max_iter} iterations'
    )

    counter = 0
    start_time = time.perf_counter()
    for i in range(max_iter):
    for x in tuple_data:
    match x:
    case ('str', xx):
    counter += len(xx)
    case ('list_of_ints', xx):
    counter += sum(xx)
    end_time = time.perf_counter()
    print(
    f'match without wrapper class: {end_time - start_time:.4f} seconds for {max_iter} iterations'
    )

    counter = 0
    start_time = time.perf_counter()
    for i in range(max_iter):
    for x in data:
    if isinstance(x, str):
    counter += len(x)
    else:
    counter += sum(x)
    end_time = time.perf_counter()
    print(
    f'if-else on native values: {end_time - start_time:.4f} seconds for {max_iter} iterations'
    )