graphviz官方参考链接:
http://www.graphviz.org/documentation/
https://graphviz.readthedocs.io/en/stable/index.html
文章目录
- 需求描述
- 环境配置
- 实现思路
- 代码实现
需求描述
根据各模块之间的传参关系绘制出数据流,如下图所示:
并且生成对应的graphviz代码:
digraph my_graph {
Input [fillcolor=gray70 shape=box style=filled]
Output [fillcolor=gray70 shape=box style=filled]
NodeA
NodeB
NodeC
Input -> NodeA [label=0]
Input -> NodeA [label=1]
NodeA -> NodeB [label=0]
NodeA -> NodeC [label=1]
NodeB -> Output [label=0]
NodeC -> Output [label=0]
}
环境配置
- 安装Python中需要使用的
graphviz
包:
pip install graphviz
- 安装
graphviz
工具(可选,如果不安装无法直接使用Python的graphviz
包导出图片),例如ubuntu系统安装指令如下,其他系统可参考官方文档https://www.graphviz.org/download/:
sudo apt install graphviz
- VSCODE安装
Graphviz Interactive Preview
插件(可选,如果使用vscode开发建议安装此插件,通过此插件可以直接可视化graphviz代码,并保存图片)
实现思路
实现一个Node基类,所有的模块实现都继承自该基类。再实现一个Message基类,模块之间传递的数据都继承自该基类。然后在数据传递过程中记录流经的每个模块的名称以及数据的传递方向即可绘制出想要的数据流。
代码实现
下面给出了一个简易的实现方式:
import os
from graphviz import Digraph
__graph_dict__ = {}
class Message:
def __init__(self, node_name: str, idx: int):
self.node_name = node_name
self.idx = idx
class EdgeInfo:
def __init__(self, start_node_name: str, end_node_name: str, label: str) -> None:
self.start_node_name = start_node_name
self.end_node_name = end_node_name
self.label = label
def __str__(self):
return f'{self.start_node_name} -> {self.end_node_name} [label="{self.label}"];'
class Node:
input_num: int
output_num: int
node_name: str
def __call__(self, *args):
global __graph_dict__
assert len(args) == self.input_num
if self.node_name not in __graph_dict__:
__graph_dict__[self.node_name] = []
for input_ in args:
__graph_dict__[input_.node_name].append(EdgeInfo(input_.node_name,
self.node_name,
str(input_.idx)))
res = tuple(Message(self.node_name, i) for i in range(self.output_num))
if self.output_num == 1:
return res[0]
return res
def export_graphviz(graph, num_input: int, save_path: str):
base_name = os.path.basename(save_path)
name, _ = base_name.split(".")
global __graph_dict__
__graph_dict__.clear()
__graph_dict__.update({"Input": [], "Output": []})
# infer and collect flow info
input_args = tuple(Message("Input", i) for i in range(num_input))
outputs = graph(*input_args)
for ouput_ in outputs:
if ouput_.node_name not in __graph_dict__:
__graph_dict__[ouput_.node_name] = []
__graph_dict__[ouput_.node_name].append(EdgeInfo(ouput_.node_name,
"Output",
str(ouput_.idx)))
# create graph code
digraph = Digraph(name=name, format="jpg")
# add nodes
keys = list(__graph_dict__.keys())
for k in keys:
if k in ["Input", "Output"]:
digraph.node(k, **{"shape": "box", "style": "filled", "fillcolor": "gray70"})
else:
digraph.node(k)
# add edges
for k in keys:
for edge_info in __graph_dict__[k]:
digraph.edge(edge_info.start_node_name,
edge_info.end_node_name,
edge_info.label)
# print digraph code
print(digraph.source)
# export gv and jpg file
try:
digraph.render(directory=os.path.dirname(save_path))
except Exception as e:
print(f"export digraph failed, {e}")
class NodeA(Node):
def __init__(self):
self.input_num = 2
self.output_num = 2
self.node_name = "NodeA"
class NodeB(Node):
def __init__(self):
self.input_num = 1
self.output_num = 1
self.node_name = "NodeB"
class NodeC(Node):
def __init__(self):
self.input_num = 1
self.output_num = 1
self.node_name = "NodeC"
class Graph:
def __init__(self):
self.node_a = NodeA()
self.node_b = NodeB()
self.node_c = NodeC()
def __call__(self, x0, x1):
y0, y1 = self.node_a(x0, x1)
z0 = self.node_b(y0)
z1 = self.node_c(y1)
return z0, z1
if __name__ == "__main__":
graph = Graph()
export_graphviz(graph, num_input=2, save_path="./my_graph.gv")
执行上述代码后会生成my_graph.gv
以及my_graph.gv.jpg
两个文件(如果没有安装graphviz工具是不会生成的),其中my_graph.gv
是graphviz的代码形式,my_graph.gv.jpg
是可视化后的结果。