#!/usr/bin/env python3 """Record factory. A record is the equivalent of a C struct.""" import sys import keyword class_template = """\ import collections class {typename}: '{typename}({arg_list})' __slots__ = {field_names!r} def __init__(self, {init_args}): 'Create new instance of {typename}.' {self_fields} = {arg_list} def __repr__(self): 'Return a nicely formatted representation string.' return self.__class__.__name__ + '({repr_fmt})'.format({self_fields}) def __iter__(self): 'Iterates over {typename} fields and values.' return zip(self.__slots__, self.__values()) def __eq__(self, other): 'Returns True if *self* is equal to *other*.' return {eq_stmt} @classmethod def _make(cls, iterable): 'Make a new {typename} object from *iterable*.' return {typename}(*iterable) @property def _fields(self): 'Returns the names of the fields in the order they were defined.' return self.__slots__ def __values(self): 'Iterates through values in the order fields were defined.' {values_stmt} def _asdict(self): 'Return a new OrderedDict which maps field names to their values.' result = collections.OrderedDict() {asdict_result} = {self_fields} return result def _replace(self, **kwargs): 'Return a new {typename} object replacing specified fields with new values' result = self._make(map(kwargs.pop, self.__slots__, self.__values())) if kwargs: raise ValueError('Got unexpected field names: ' + str(list(kwargs))) return result\ """ def validate(typename, field_names): 'Internal use, checks validity of typename and field_names.' if isinstance(field_names, str): field_names = field_names.replace(',', ' ').split() field_names = list(map(str, field_names)) typename = str(typename) # avoid name clashes for name in [typename] + field_names: if not isinstance(name, str): raise TypeError('Type names and field names must be strings') if not name.isidentifier(): raise ValueError('Type names and field names must be valid ' 'identifiers: ' + name) if keyword.iskeyword(name): raise ValueError('Type names and field names cannot be a ' 'keyword: ' + name) # check for duplicates seen = set() for name in field_names: if name.startswith('_'): raise ValueError('Field names cannot start with an underscore: ' + name) if name in seen: raise ValueError('Encountered duplicate field name: ' + name) seen.add(name) return typename, field_names def compile_(typename, source, module=None): 'Internal use, compiles struct source.' # pylint: disable=exec-used, protected-access namespace = {__name__ : 'struct_{:s}'.format(typename)} exec(source, namespace) code = namespace[typename] code._source = source if module is None: try: module = sys._getframe(1).f_globals.get('__name__', '__main__') except (AttributeError, ValueError): pass if module is not None: code.__module__ = module return code def record(typename, field_names, *, module=None): 'Factory method that creates a record class' typename, field_names = validate(typename, field_names) source = class_template.format( \ typename=typename, \ field_names=tuple(field_names), \ arg_list=', '.join(field_names), \ init_args=', '.join(f + '=None' for f in field_names), \ eq_stmt=' and '.join('self.' + f + ' == other.' + f \ for f in field_names), \ repr_fmt=', '.join(f + '={}' for f in field_names), \ self_fields=', '.join('self.' + f for f in field_names), \ values_stmt='; '.join("yield self." + f for f in field_names), \ asdict_result=', '.join("result['" + f + "']" for f in field_names) \ ) return compile_(typename, source, module)