Spaces:
Sleeping
Sleeping
| import pydot | |
| import re | |
| from keras.models import Model | |
| from keras.layers import Layer, InputLayer | |
| from pygments.lexers import graphviz | |
| # May be necessary to manually add Graphviz to PATH, e.g. | |
| # import os | |
| # os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin' | |
| def visualize_model(model, layer_labels = None, layer_colors = None, groupings = None, exclude_input_layer = False, | |
| verbose = False, output_filename = 'model_graph.png'): | |
| """ | |
| Creates a visual graph of a keras model. There is an option to group certain layers into subgraphs | |
| (argument 'groupings'). | |
| Args: | |
| model: A Keras Model instance | |
| layer_labels (optional): List of labels for each layer. Defaults to layer names. | |
| layer_colors (optional): List of colors for each layer. Defaults to white for all layers. | |
| groupings (optional): Dictionary specifying groups of layers. Each key is a group name, | |
| and its value is a list of layer names belonging to that group. | |
| exclude_input_layer (optional): Boolean indicating whether to exclude the input layer from the graph. | |
| verbose (boolean, optional): Whether to print verbose output. Defaults to False. | |
| output_filename (optional): name of the output file for saving the generated graph. | |
| Output: | |
| Image file with name 'output_filename'. | |
| """ | |
| if not isinstance(model, Model): | |
| raise ValueError("model should be a Keras model instance") | |
| num_layers = len(model.layers) | |
| # Default labels and colors if not provided | |
| if not layer_labels: | |
| layer_labels = [layer.name for layer in model.layers] | |
| if not layer_colors: | |
| default_color = 'white' | |
| layer_colors = [default_color] * num_layers | |
| # Create a directed graph | |
| graph = pydot.Dot(graph_type = 'digraph', rankdir = 'LR') | |
| # Create nodes for each layer and add to subgraphs if specified | |
| subgraphs = {} | |
| layer_id_map = {} | |
| for i, layer in enumerate(model.layers): | |
| # Exclude the input layer if specified | |
| if exclude_input_layer and isinstance(layer, InputLayer): | |
| continue | |
| # Create a node for the layer | |
| layer_id = str(id(layer)) | |
| layer_id_map[layer] = layer_id | |
| label = layer_labels[i] | |
| color = layer_colors[i] | |
| node = pydot.Node(layer_id, label = label, style = 'filled', fillcolor = color, shape = 'box') | |
| # Check for groupings and add the node to the appropriate subgraph or main graph | |
| group_name = None | |
| if groupings: | |
| for group, members in groupings.items(): | |
| if layer.name in members: | |
| group_name = group | |
| break | |
| if group_name: | |
| if group_name not in subgraphs: | |
| subgraph = pydot.Cluster(group_name, label = group_name, style = 'dashed', fontsize = 24) | |
| subgraphs[group_name] = subgraph | |
| subgraphs[group_name].add_node(node) | |
| else: | |
| graph.add_node(node) | |
| # Add subgraphs to the main graph | |
| for subgraph in subgraphs.values(): | |
| graph.add_subgraph(subgraph) | |
| # Add edges based on layer connections | |
| for layer in model.layers: | |
| if exclude_input_layer and isinstance(layer, InputLayer): | |
| continue | |
| # Handle custom or non-standard layers | |
| if hasattr(layer, '_inbound_nodes'): | |
| inbound_nodes = layer._inbound_nodes | |
| else: | |
| # If the layer doesn't have '_inbound_nodes', skip edge creation | |
| continue | |
| inbound_layers = [] | |
| for inbound_node in inbound_nodes: | |
| inbound_layers = inbound_node.inbound_layers | |
| if not isinstance(inbound_layers, list): | |
| inbound_layers = [inbound_layers] | |
| for inbound_node in inbound_nodes: | |
| for inbound_layer in inbound_layers: | |
| if isinstance(inbound_layer, Layer) and inbound_layer in layer_id_map: | |
| src_id = layer_id_map[inbound_layer] | |
| dest_id = layer_id_map[layer] | |
| if (re.search('sequential', inbound_layer.name, flags = re.IGNORECASE) or | |
| re.search(r'operators__.getitem_[0-9]+$', inbound_layer.name, flags = re.IGNORECASE)): | |
| graph.add_edge(pydot.Edge(src_id, dest_id, style = 'invis')) | |
| else: | |
| graph.add_edge(pydot.Edge(src_id, dest_id)) | |
| if verbose: | |
| print(f"Added edge from {inbound_layer.name} to {layer.name}") | |
| graph.set_graph_defaults(sep = '+125,125') | |
| try: | |
| graph.write_png(output_filename) | |
| except FileNotFoundError as e: | |
| print(f'\nFailed to create network visualization using pydot and graphviz. Pleasure ensure that ' | |
| 'the output filename is valid, and graphviz is installed and included in the system PATH variable. ' | |
| f'Original error: {e}') | |
| except Exception as e: | |
| print(f'\nFailed to create network visualization using pydot and graphviz. Original error: {e}') | |
| else: | |
| print(f'Model visualization saved to {output_filename}') | |