Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import Reshape, Flatten | |
| import t3f | |
| import os | |
| import logging | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL | |
| logging.getLogger('tensorflow').setLevel(logging.FATAL) | |
| class SoftmaxEmbeddingLayer(tf.keras.layers.Layer): | |
| """ | |
| Parameter embedding layer that generates the weights used for stacking the tensor networks. It | |
| takes the parameter array, lambda = (ell, a1, a2), as input and outputs K numbers that sum to 1. | |
| Attributes: | |
| output_dim (int): The dimension of the output | |
| expansion_dim (int): The dimension used for expanding the input in intermediate layers. | |
| """ | |
| def __init__(self, output_dim, d, expansion_dim = 30, **kwargs): | |
| super(SoftmaxEmbeddingLayer, self).__init__(**kwargs) | |
| self.reduction_layer = None | |
| self.expansion_layers = None | |
| self.output_dim = output_dim | |
| self.expansion_dim = expansion_dim | |
| self.d = d # Number of dense layers | |
| def build(self, input_shape): | |
| # Expansion layers to increase dimensionality | |
| self.expansion_layers = [tf.keras.layers.Dense(self.expansion_dim, activation = 'relu') for _ in range(self.d)] | |
| # Reduction layer to bring dimensionality back to the desired output dimension | |
| self.reduction_layer = tf.keras.layers.Dense(self.output_dim) | |
| def call(self, inputs): | |
| expanded = inputs | |
| for layer in self.expansion_layers: | |
| expanded = layer(expanded) | |
| return tf.nn.softmax(self.reduction_layer(expanded)) | |
| def get_config(self): | |
| return {'output_dim': self.output_dim, 'expansion_dim': self.expansion_dim} | |
| class EinsumTTLRegularizer(tf.keras.regularizers.Regularizer): | |
| """ | |
| Regularizer for the Einsum layer of the TTL layer class, penalizing high-frequency components of the | |
| weights vector. | |
| Attributes: | |
| strength (float): The regularization strength. | |
| midpoint (int): Index demarcating the inner and outer boundaries, i.e. x[:midpoint] contains | |
| data for the inner boundary, and x[midpoint:] contains data for the outer boundary. | |
| The regularization is designed so it does not penalize variations across this index. | |
| """ | |
| def __init__(self, strength, midpoint): | |
| self.strength = strength | |
| self.midpoint = midpoint | |
| def __call__(self, x): | |
| diff = tf.abs(x[1:self.midpoint - 1] - x[0:self.midpoint - 2]) \ | |
| + tf.abs(x[self.midpoint + 1:2 * self.midpoint - 1] - x[self.midpoint:2 * self.midpoint - 2]) | |
| return self.strength * tf.reduce_sum(diff) | |
| def get_config(self): | |
| return {'strength': self.strength, 'midpoint': self.midpoint} | |
| def cosine_initializer(kx = 1.0): | |
| """ | |
| Initializer for the Einsum layer of the TTL layer class. Sets the weights to a linear combination | |
| of cos(kx * x) and cos(2 * kx * x), where x is the weight vector. | |
| Args: | |
| kx (float, optional): Frequency of the cosine terms. Defaults to 1.0. | |
| Returns: | |
| _initializer: Weight initializer function | |
| """ | |
| def _initializer(shape, dtype = None): | |
| x_values = np.linspace(-np.pi, np.pi, shape[0]) | |
| cos_values = np.random.uniform(-0.1, 0.3) * np.abs(np.cos(kx * x_values)) \ | |
| + np.random.uniform(-0.05, 0.05) * np.abs(np.cos(2.0 * kx * x_values)) | |
| return tf.convert_to_tensor(-cos_values, dtype = dtype) | |
| return _initializer | |
| class EinsumTTL(tf.keras.layers.Layer): | |
| """ | |
| Layer that contracts the input tensor over the second dimension before passing it to the TTL. | |
| If regularization is enabled, it applies an `EinsumTTLRegularizer` to the kernels. | |
| Attributes: | |
| (nx2, nx3) (integers): Shape parameters characterizing input tensor dimensions. T | |
| The shape of the input tensor is (2*nx2, nx3//2). | |
| W (int): Number of einsum contractions | |
| kernels (list): List of weight matrices for each einsum contraction | |
| regularization_strength (float): The strength of the regularization if used. | |
| use_regularization (bool): Flag to indicate whether regularization is used. | |
| """ | |
| def __init__(self, nx2, nx3, W, use_regularization, regularization_strength = 0.005, **kwargs): | |
| super(EinsumTTL, self).__init__(**kwargs) | |
| self.nx2 = nx2 | |
| self.nx3 = nx3 | |
| self.W = W | |
| self.kernels = [] | |
| self.regularization_strength = regularization_strength | |
| self.use_regularization = use_regularization | |
| if self.use_regularization: | |
| regularizer = EinsumTTLRegularizer(self.regularization_strength, self.nx3 // 4) | |
| else: | |
| regularizer = None | |
| initializer_values_ = [1.0, 0.5, 2.0, 3.0] * W | |
| initializer_values = initializer_values_[:W] | |
| for i in range(W): | |
| self.kernels.append(self.add_weight( | |
| name = f'w{i + 1}', | |
| shape = (nx3 // 2,), | |
| regularizer = regularizer, | |
| initializer = cosine_initializer(initializer_values[i]) | |
| )) | |
| def call(self, inputs): | |
| parts = [] | |
| for w in self.kernels: | |
| part_a = tf.einsum('abc,c->ab', inputs[:, :self.nx2, :self.nx3 // 4], w[:self.nx3 // 4]) + \ | |
| tf.einsum('abc,c->ab', inputs[:, :self.nx2, self.nx3 // 4:self.nx3 // 2], | |
| tf.reverse(w[:self.nx3 // 4], axis = [0])) | |
| part_b = tf.einsum('abc,c->ab', inputs[:, self.nx2:, :self.nx3 // 4], w[self.nx3 // 4:self.nx3 // 2]) + \ | |
| tf.einsum('abc,c->ab', inputs[:, self.nx2:, self.nx3 // 4:self.nx3 // 2], | |
| tf.reverse(w[self.nx3 // 4:self.nx3 // 2], axis = [0])) | |
| parts.extend([part_a, part_b]) | |
| return tf.concat(parts, axis = 1) | |
| def get_config(self): | |
| return {'use_regularization': self.use_regularization, | |
| 'regularization_strength': self.regularization_strength} | |
| class TTL(tf.keras.layers.Layer): | |
| """ | |
| TTL (Tensor Train Layer) is a custom TensorFlow Keras layer that builds a model | |
| based on the given configuration. This layer is designed to work with | |
| tensor train decomposition in neural networks. | |
| Attributes: | |
| config (dict): Configuration dictionary containing parameters for the model. | |
| 'nx1', 'nx2', 'nx3': Integers, dimensions of the finite-difference grid | |
| 'shape1': List of integers, defines the shape of the output tensor in the tensor train format. | |
| The length of shape1 must match the length of shape2. | |
| 'shape2': List of integers, specifies the shape of the input tensor in the tensor train format. | |
| The length of shape2 must match the length of shape1. | |
| 'ranks': List of integers, represents the ranks in the tensor train decomposition. | |
| The length of this list determines the complexity and the number of parameters in the tensor train layer. | |
| 'W' (int): Number of weight vectors to use in the initial EinsumTTL layer. Setting W = 0 means that no EinsumTLL | |
| used. | |
| 'use_regularization' (boolean, optional, default: False): Indicates whether regularization is used in the EinsumTTL. | |
| 'regularization_strength' (float, optional, default: 0): Strength of the regularization | |
| model (tf.keras.Sequential): The Sequential model built based on the provided configuration. | |
| Methods: | |
| load_config(self, config): Loads configuration | |
| build_model(self): Builds the layer | |
| call(inputs): Method for the forward pass of the layer. | |
| """ | |
| def __init__(self, config, **kwargs): | |
| super(TTL, self).__init__(**kwargs) | |
| self.model = Sequential() | |
| self.nx1 = None | |
| self.nx2 = None | |
| self.nx3 = None | |
| self.shape1 = None | |
| self.shape2 = None | |
| self.ranks = None | |
| self.W = None | |
| self.use_regularization = None | |
| self.regularization_strength = None | |
| self._required_keys = ['nx1', 'nx2', 'nx3', 'shape1', 'shape2', 'ranks', 'W'] | |
| config.setdefault('use_regularization', False) | |
| config.setdefault('regularization_strength', 0.0) | |
| self.load_config(config) | |
| self.config = config | |
| self.build_model() | |
| def load_config(self, config): | |
| missing_keys = [key for key in self._required_keys if key not in config] | |
| if missing_keys: | |
| raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}") | |
| if not isinstance(config['use_regularization'], bool): | |
| raise TypeError('use_regularization must be a boolean.') | |
| else: | |
| self.use_regularization = config['use_regularization'] | |
| self.regularization_strength = 0.0 | |
| for key in ['nx1', 'nx2', 'nx3', 'W']: | |
| if not isinstance(config[key], int): | |
| raise TypeError(f"{key} must be an integer.") | |
| for key in ['nx1', 'nx2', 'nx3']: | |
| if config[key] <= 0: | |
| raise ValueError(f"{key} must be positive.") | |
| if config['W'] < 0: | |
| raise ValueError("W must be non-negative.") | |
| nx1, nx2, nx3 = config['nx1'], config['nx2'], config['nx3'] | |
| self.nx1 = nx1 | |
| self.nx2 = nx2 | |
| self.nx3 = nx3 | |
| W = config['W'] | |
| self.W = W | |
| input_dim = 2 * nx2 * W | |
| if W == 0: | |
| input_dim = nx2 * nx3 | |
| shape1, shape2 = config['shape1'], config['shape2'] | |
| if len(shape1) != len(shape2): | |
| raise ValueError( | |
| f'shape1 and shape2 must have the same length. ' | |
| f'Received: shape1 = {shape1}, shape2 = {shape2}.' | |
| ) | |
| elif np.prod(np.array(shape1)) != nx1 * nx2: | |
| raise ValueError( | |
| f'prod(shape1) must be equal to the output dimension of the TTL ' | |
| f'(nx1 * nx2,). Received: prod(shape1) = {np.prod(np.array(shape1))}, ' | |
| f'nx1 * nx2 = {nx1 * nx2}.' | |
| ) | |
| elif np.prod(np.array(shape2)) != input_dim: | |
| raise ValueError( | |
| f'prod(shape2) must be equal to the input dimension of the TTL ' | |
| f'(2 * nx2 * W or nx2 * nx3 if W = 0). ' | |
| f'Received: prod(shape2) = {np.prod(np.array(shape2))}, required input dimension = {input_dim}.' | |
| ) | |
| else: | |
| self.shape1 = shape1 | |
| self.shape2 = shape2 | |
| self.ranks = config['ranks'] | |
| def build_model(self): | |
| if self.W == 0: | |
| self.model.add(Flatten(input_shape = (2 * self.nx2, self.nx3 // 2))) | |
| else: | |
| self.model.add(EinsumTTL(self.nx2, self.nx3, self.W, self.use_regularization, | |
| regularization_strength = self.regularization_strength, | |
| input_shape = (2 * self.nx2, self.nx3 // 2))) | |
| self.model.add(Flatten()) | |
| tt_layer = t3f.nn.KerasDense(input_dims = self.shape2, output_dims = self.shape1, | |
| tt_rank = self.ranks, use_bias = False, activation = 'linear') | |
| self.model.add(tt_layer) | |
| self.model.add(Reshape((self.nx1, self.nx2))) | |
| def call(self, inputs): | |
| return self.model(inputs) | |