Spaces:
Sleeping
Sleeping
| import json | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| import numpy as np | |
| import sympy | |
| from matplotlib.cm import get_cmap | |
| from stnn.nn import stnn | |
| from stnn.pde.pde_system import PDESystem | |
| def adjust_to_nice_number(value, round_down = False): | |
| """ | |
| Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks. | |
| """ | |
| if value == 0: | |
| return value | |
| is_negative = False | |
| if value < 0: | |
| round_down = True | |
| is_negative = True | |
| value = -value | |
| exponent = np.floor(np.log10(value)) # Find exponent of 10 | |
| fractional_part = value / 10**exponent # Find leading digit(s) | |
| if round_down: | |
| if fractional_part < 1.5: | |
| nice_fractional = 1 | |
| elif fractional_part < 3: | |
| nice_fractional = 2 | |
| elif fractional_part < 7: | |
| nice_fractional = 5 | |
| else: | |
| nice_fractional = 10 | |
| else: | |
| if fractional_part <= 1: | |
| nice_fractional = 1 | |
| elif fractional_part <= 2: | |
| nice_fractional = 2 | |
| elif fractional_part <= 5: | |
| nice_fractional = 5 | |
| else: | |
| nice_fractional = 10 | |
| nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1) | |
| if is_negative: | |
| nice_value = -nice_value | |
| return nice_value | |
| def find_nice_values(min_val_raw, max_val, num_values = 4): | |
| """ | |
| Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks. | |
| """ | |
| # Calculate rough spacing between values | |
| min_val = adjust_to_nice_number(min_val_raw) | |
| frac_val = (min_val - min_val_raw) / (max_val - min_val_raw) | |
| if frac_val < 1 / num_values: | |
| min_val = min_val_raw | |
| raw_spacing = (max_val - min_val) / (num_values - 1) | |
| # Calculate order of magnitude of the spacing | |
| magnitude = np.floor(np.log10(raw_spacing)) | |
| nice_factors = np.array([1, 2, 5, 10]) | |
| normalized_spacing = raw_spacing / (10**magnitude) | |
| closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))] | |
| nice_spacing = closest_factor * (10**magnitude) | |
| nice_values = min_val + nice_spacing * np.arange(num_values) | |
| # Adjust if last value exceeds max_val | |
| if nice_values[-1] < max_val - nice_spacing: | |
| last_val = nice_values[-1] | |
| nice_values = np.append(nice_values, [last_val + nice_spacing]) | |
| return [val for val in nice_values if min_val <= val <= max_val] | |
| def format_tick_label(val): | |
| """ | |
| Format w/ scientific notation for large/small values. | |
| """ | |
| if val != 0: | |
| magnitude = np.abs(np.floor(np.log10(np.abs(val)))) | |
| if magnitude > 2: | |
| return f'{val:.1e}' | |
| elif magnitude > 1: | |
| return f'{val:.0f}' | |
| elif magnitude > 0: | |
| return f'{val:.1f}' | |
| else: | |
| return f'{val:.2f}' | |
| else: | |
| return f'{val}' | |
| def plot_simple(system, rho, fontscale = 1): | |
| # Major axis of outer boundary | |
| b2 = system.b2 | |
| # Get x, y grids from 'PDESystem' object | |
| x, y = system.get_xy_grids() | |
| # wrap around values for continuity | |
| rho = np.append(rho, rho[:, 0:1], axis = 1) | |
| # Color bar limits | |
| vmin = np.nanmin(rho) | |
| vmax = np.nanmax(rho) | |
| fig = plt.figure(figsize = (5, 5)) | |
| ax = plt.gca() | |
| im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv')) | |
| ax.set_title('rho(x,y)', fontsize = fontscale * 16) | |
| for label in ax.get_xticklabels() + ax.get_yticklabels(): | |
| label.set_fontsize(fontscale * 12) | |
| ax.set_aspect(1.0) | |
| fac = 1.05 | |
| ax.set_xlim([-fac * b2, fac * b2]) | |
| ax.set_ylim([-fac * b2, fac * b2]) | |
| cbar = fig.colorbar(im, shrink = 0.8) | |
| # Set colorbar ticks and labels to "nice" values | |
| nice_values = find_nice_values(vmin, vmax, num_values = 5) | |
| cbar.set_ticks(nice_values) | |
| cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x))) | |
| return fig | |
| def evaluate_2d_expression(expr_str, xvals, yvals): | |
| x, y = sympy.symbols('s t') | |
| expr = sympy.sympify(expr_str) | |
| f = sympy.lambdify((x, y), expr, modules = ['numpy']) | |
| result = f(xvals, yvals) | |
| if isinstance(result, (int, float)): | |
| return result * np.ones(xvals.shape) | |
| return f(xvals, yvals) | |
| ''' | |
| # Currently unused in gradio interface | |
| def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations): | |
| # Direct solution | |
| start = timeit.default_timer() | |
| pde_config = {} | |
| for key in ['nx1', 'nx2', 'nx3']: | |
| pde_config[key] = stnn_config[key] | |
| pde_config['ell'] = ell | |
| pde_config['eccentricity'] = eccentricity | |
| pde_config['a2'] = a2 | |
| system = PDESystem(pde_config) | |
| try: | |
| ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] | |
| except: | |
| raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") | |
| try: | |
| obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] | |
| except: | |
| raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") | |
| if np.any(np.isnan(ibf_data)): | |
| raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.") | |
| if np.any(np.isnan(obf_data)): | |
| raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.") | |
| ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) | |
| L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator | |
| nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3'] | |
| b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector | |
| def callback(res): | |
| print(f'GMRES residual: {res}') | |
| f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback) | |
| residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp)) | |
| if info > 0: | |
| warnings.simplefilter('always') | |
| warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning) | |
| f = asnumpy(f_xp) | |
| rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1) | |
| direct_time = timeit.default_timer() - start | |
| print(f'Done with direct solution. Time: {direct_time} seconds.') | |
| fig = plot_simple(system, rho_direct) | |
| return fig, info | |
| ''' | |
| def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str): | |
| if a2 <= eccentricity: | |
| raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).') | |
| pde_config = {} | |
| for key in ['nx1', 'nx2', 'nx3']: | |
| pde_config[key] = stnn_config[key] | |
| pde_config['ell'] = ell | |
| pde_config['eccentricity'] = eccentricity | |
| pde_config['a2'] = a2 | |
| system = PDESystem(pde_config) | |
| try: | |
| ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] | |
| except: | |
| raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") | |
| try: | |
| obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] | |
| except: | |
| raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") | |
| if np.any(np.isnan(ibf_data)): | |
| raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.") | |
| if np.any(np.isnan(obf_data)): | |
| raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.") | |
| # Permute and reshape boundary data to the format expected by the STNN model | |
| ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) | |
| ''' | |
| # Currently unused in gradio interface | |
| ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id) | |
| ''' | |
| # Load some relevant quantities from the config dictionaries | |
| ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max'] | |
| a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max'] | |
| nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3'] | |
| # Combine boundary data in single vector | |
| bf = np.zeros((1, 2 * nx2, nx3 // 2)) | |
| bf[:, :nx2, :] = ibf_data[np.newaxis, ...] | |
| bf[:, nx2:, :] = obf_data[np.newaxis, ...] | |
| # Normalize and combine parameters | |
| params = np.zeros((1, 3)) | |
| params[0, 0] = (a2 - a2_min) / (a2_max - a2_min) | |
| params[0, 1] = (ell - ell_min) / (ell_max - ell_min) | |
| params[0, 2] = eccentricity | |
| rho = model.predict([params, bf]) | |
| fig = plot_simple(system, rho[0, ...]) | |
| return fig | |
| with open('T5_config.json', 'r', encoding = 'utf-8') as json_file: | |
| stnn_config = json.load(json_file) | |
| model = stnn.build_stnn(stnn_config) | |
| model.load_weights('T5_weights.h5') | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo" | |
| "\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) " | |
| "to solve a parametric PDE problem on an elliptical annular domain. " | |
| "See the paper for a detailed description of the problem and its applications." | |
| "<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, " | |
| "including intructions for solving the PDE using a conventional iterative method (GMRES). " | |
| "Due to the long runtime of solving the PDE in this way, it is not included in the demo.") | |
| gr.Markdown("<br/>The PDE is " | |
| "$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, " | |
| "where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. " | |
| "Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to " | |
| "the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! " | |
| "on elliptical annular domains parameterized as shown below. ", | |
| latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| "## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, " | |
| "and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False}, | |
| {"left": "!", "right": "!", "display": True}]) | |
| ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0) | |
| eccentricity_input = gr.Number( | |
| label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)", | |
| value = 0.5, minimum = 0.0, maximum = 0.999) | |
| a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0) | |
| gr.Markdown( | |
| "## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, " | |
| "related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. " | |
| "Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values " | |
| "between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points " | |
| "distributed uniformly over the allowable values of $s$ and $t$." | |
| "<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where " | |
| "$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the " | |
| "inward-pointing unit normal vector. This requirement constrains the allowable values of $t$." | |
| " and is automatically enforced when building boundary conditions from the user-specified expressions below.", | |
| latex_delimiters = [{"left": "$", "right": "$", "display": False}]) | |
| inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))") | |
| outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)") | |
| submit_button = gr.Button("Submit") | |
| with gr.Column(): | |
| gr.Markdown("## Predicted Solution") | |
| predicted_output_plot = gr.Plot() | |
| submit_button.click( | |
| fn = predict_pde_solution, | |
| inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary], | |
| outputs = [predicted_output_plot] | |
| ) | |
| demo.launch() | |