"""A function to generate a modular connectivity mask. You can use it to mask a weight matrix, and then mask the gradients using a hook to fully disable those connections. self.hh = Parameter(torch.normal(...)) mask = hierarchical_modular_mask(200, level=3, density=0.05, scale=2.0) self.hh.data *= mask self.hh.register_hook(lambda grad: grad * mask) """ def hierarchical_modular_mask( size: int, *, level: int = 3, density: float = 0.05, scale: float = 2.0, dale_mask: Optional[Tensor] = None, inhibitory_distance: int = 0 ): """Generate a connectivity matrix mask of hierarchical modules. Modules are densely connected locally, with increasing sparsity to nodes in more distal modules. If a `dale_mask` is given, the inhibitory units (0 values in the mask) will be disabled for all but the immediately local module. This gives a "local inhibition" effect, with long-range excitatory connections. If you want more distal inhibition, you can configure it with the `inhibitory_distance` argument. Args: size: The side-length of the square mask. level: The number of hierarchical levels to generate. density: The density of the highest level of the hierarchy (i.e.g the off-diagonal regions). scale: The scaling factor for the density at each subsequent level of the hierarchy. dale_mask: A tensor of shape (size), specifying a 1 for positive units and a 0 for negative units. inhibitory_distance: A number specifying the number of hierarchical steps inhibitory connections may span. Returns: A (size, size) mask matrix of 1's and 0's ref: https://www.nature.com/articles/srep22057 """ assert inhibitory_distance <= level, "inhibitory_distance cannot be larger than the total number of hierarchical levels" first_half = math.ceil(size/2) second_half = math.floor(size/2) bg = (torch.rand((size, size)) <= density).float() blank_diag = 1 - torch.block_diag( torch.ones(first_half, first_half), torch.ones(second_half, second_half), ) if level > 0: inner_diag = torch.zeros(size, size) + torch.block_diag( hierarchical_modular_mask( first_half, level-1, density * scale, scale, dale_mask[:first_half] if dale_mask is not None else dale_mask ), hierarchical_modular_mask( second_half, level-1, density * scale, scale, dale_mask[first_half:] if dale_mask is not None else dale_mask ), ) else: inner_diag = torch.zeros(size, size) + torch.block_diag( (torch.rand(first_half, first_half) <= density * scale).float(), (torch.rand(second_half, second_half) <= density * scale).float(), ) if dale_mask is not None and level > inhibitory_distance: # Disable non-local inhibitory connections bg *= dale_mask mask = (bg * blank_diag) + inner_diag return mask