AST学习
AST简介:
AST(Abstract syntac tree)是编译原理中的概念,是对源代码语法结构的一种抽象表示,它以树的形式表现编程语言的语法结构,树上的每个节点都表示源代码中的一种结构。
下面的代码展示了以demo.py中的ast语法,对source_code.py中的内容进行修改,并将修改后的内容转回源代码并写入到target_code.py中,这个过程可以作为客户化定制的内容。
(mmlab中的config机制,采用了另一种方式,并不对config文件的语法进行解析,而是基于base congfig 对个人的config进行merge和替换,得到最终的config,然后通过底层维护的字符串到类的映射拿到config中字符串字段中type的字符串,从而拿到类及其参数)
以语法规则进行解析和更改后,可以生成可执行的python文件(虽然mmlab中的config也是.py文件,但它只是个config而无实际意义)
demo.py
import ast
import astor
# source_file 是任何一个.py文件的路径
with open("./ast_learning/source_code.py", 'r', encoding='utf-8') as f:
source_code = f.read()
tree = ast.parse(source_code)
import_nodes = []
empty_lines = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == 'Classification_2d':
class_node = node
elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
import_nodes.append(node)
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) and not node.value.s.strip():
empty_lines.append(node.lineno)
copied_class_node = ast.copy_location(class_node, ast.ClassDef())
# 替换类节点中的__init__中的内容
for stmt in copied_class_node.body:
if isinstance(stmt, ast.FunctionDef) and stmt.name == '__init__':
for sub_stmt in stmt.body:
# 遍历__init__中的所有操作(super,赋值等)
if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'net':
sub_stmt.value = ast.parse('models.convnext_large(pretrained=False)').body[0].value
# 下面的方式会更改原来的sub_stmt.value 的 type 从_ast.Call object 变为 _ast.Name object 但 也是能用的
# sub_stmt.value = ast.Name(id='models.resnet50(pretrained=False)', ctx=ast.Load(models.resnet50))
if isinstance(sub_stmt, ast.Assign) and len(sub_stmt.targets) == 1 and isinstance(sub_stmt.targets[0], ast.Attribute) and sub_stmt.targets[0].attr == 'loss':
sub_stmt.value = ast.parse('nn.CrossEntropyLoss').body[0].value
# ast.parse不会改变node的type,
# 几种其他方式的mode赋值
# sub_stmt.value = ast.Name(id='nn.L1Loss', ctx=ast.Load()) # 会更改原本的value的type从_ast.Attribute object 变为_ast.Name object
code_tree=ast.Module(body=import_nodes+[copied_class_node])
# 四个空格作为每级缩进
copied_code = astor.to_source(code_tree, indent_with=' ' * 4)
with open("./ast_learning/target_code.py", 'w') as f:
f.write(copied_code)
source_code.py
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchvision.models as models
# import 的等级必须是models和nn
import torch.nn as nn
class Classification_2d(pl.LightningModule):
def __init__(self, label_dict={},log_dir=''):
super(Classification_2d, self).__init__()
self.num_classes = len(label_dict)
self.net=models.resnet18(pretrained=True)
# resnet 系列
self.fc = nn.Linear(self.net.fc.in_features, self.num_classes)
self.net.fc = nn.Identity()
self.loss=nn.L1Loss
self.label_dict=label_dict
self.label_to_name_dict={v:k for k,v in label_dict.items()}
self.training_save=True
self.log_dir=log_dir
target_code.py运行后的结果
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchvision.models as models
import torch.nn as nn
class Classification_2d(pl.LightningModule):
def __init__(self, label_dict={}, log_dir=''):
super(Classification_2d, self).__init__()
self.num_classes = len(label_dict)
self.net = models.convnext_large(pretrained=False)
self.fc = nn.Linear(self.net.fc.in_features, self.num_classes)
self.net.fc = nn.Identity()
self.loss = nn.CrossEntropyLoss
self.label_dict = label_dict
self.label_to_name_dict = {v: k for k, v in label_dict.items()}
self.training_save = True
self.log_dir = log_dir