前置知识
访问者模式
(Visitor Pattern)是一种行为设计模式,它允许你定义在不改变被访问元素的类的前提下,扩展其功能。通过将操作(操作或算法)从对象结构中提取出来,可以在不修改这些对象的前提下,定义新的操作。这种模式非常适合处理对象结构中的元素,并且需要根据其类别执行不同操作的情况。
主要角色
- Visitor(访问者):
-
- 定义了对每个元素对象访问的行为,它的方法对不同类型的元素对象执行不同的操作。
- ConcreteVisitor(具体访问者):
-
- 实现了 Visitor 接口,提供具体的操作算法。
- Element(元素):
-
- 定义一个 accept 方法,该方法接受一个访问者对象作为参数,通常是一个抽象方法,被不同类型的元素实现。
- ConcreteElement(具体元素):
-
- 实现了 accept 方法,该方法调用访问者的访问方法,以便访问者可以处理它的数据。
- ObjectStructure(对象结构):
-
- 维护了一个元素的集合,提供一个高层接口允许访问者访问元素。
注意:这里的元素我们可以理解为被访问者。
示例代码
假设我们有一个程序,需要计算不同类型的员工(全职员工和兼职员工)的工资。我们将使用访问者模式来实现此功能。
- 定义被访问者接口和具体被访问者
// 被访问者接口
interface Employee {
void accept(Visitor visitor);
}
// 具体具体被访问者 - 全职员工
class FullTimeEmployee implements Employee {
private String name;
private double salary;
public FullTimeEmployee(String name, double salary) {
this.name = name;
this.salary = salary;
}
public String getName() {
return name;
}
public double getSalary() {
return salary;
}
@Override
public void accept(Visitor visitor) {
visitor.visit(this);
}
}
// 具体被访问者 - 兼职员工
class PartTimeEmployee implements Employee {
private String name;
private double wage;
public PartTimeEmployee(String name, double wage) {
this.name = name;
this.wage = wage;
}
public String getName() {
return name;
}
public double getWage() {
return wage;
}
@Override
public void accept(Visitor visitor) {
visitor.visit(this);
}
}
- 定义访问者接口和具体访问者
// 访问者接口,访问者接口定义了多个 visit 方法,每个方法对应一个具体被访问者类型,执行特定的操作。
interface Visitor {
void visit(FullTimeEmployee employee);
void visit(PartTimeEmployee employee);
}
// 具体访问者 - 工资计算器
class SalaryCalculator implements Visitor {
private double totalSalary = 0;
public double getTotalSalary() {
return totalSalary;
}
@Override
public void visit(FullTimeEmployee employee) {
totalSalary += employee.getSalary();
}
@Override
public void visit(PartTimeEmployee employee) {
totalSalary += employee.getWage() * 20; // 假设一个月工作 20 天
}
}
- 定义对象结构和测试代码
// 对象结构
/**
* 对象结构维护了一个集合(如列表),能够遍历其中的元素。
* 提供了一个接受访问者的方法,遍历所有元素并调用各自的 accept 方法。
*/
class Employees {
private List<Employee> employees = new ArrayList<>();
public void attach(Employee employee) {
employees.add(employee);
}
public void detach(Employee employee) {
employees.remove(employee);
}
public void accept(Visitor visitor) {
for (Employee employee : employees) {
employee.accept(visitor);
}
}
}
// 测试代码
public class VisitorPatternDemo {
public static void main(String[] args) {
Employees employees = new Employees();
employees.attach(new FullTimeEmployee("John", 5000));
employees.attach(new PartTimeEmployee("Jane", 20));
SalaryCalculator calculator = new SalaryCalculator();
employees.accept(calculator);
System.out.println("Total salary: " + calculator.getTotalSalary());
}
}
示例解释
- Employee 接口定义了 accept 方法,用于接受访问者访问。
- FullTimeEmployee 和 PartTimeEmployee 是具体的元素,实现了 accept 方法,并通过访问者接口将自己传递给访问者。
- Visitor 接口定义了访问者的操作方法,这里是 visit 方法,根据具体的元素类型执行不同的操作。
- SalaryCalculator 是具体的访问者,实现了 visit 方法,根据不同的元素类型累加工资或薪水。
- Employees 是对象结构,维护了员工列表,并提供了接受访问者的方法。
运行 VisitorPatternDemo 类,将会计算出所有员工的总工资,演示了访问者模式的使用和作用。
访问者模式的优点在于可以在不修改现有代码的情况下,增加新的操作(例如计算工资);缺点在于增加新的元素类可能需要修改访问者接口和所有的访问者实现类。
jsqlparse介绍
JSqlParse是一款很精简的sql解析工具,它可以将常用的sql文本解析成具有层级结构的“语法树”,我们可以针对解析后的“树节点(也即官网里说的有层次结构的java类)”进行处理进而生成符合我们要求的sql形式。
官网给的介绍很简洁:JSqlParser 解析 SQL 语句并将其转换为 Java 类的层次结构。生成的层次结构可以使用访问者模式进行访问(官网地址:JSqlParser - Home)。
官网的介绍即是该中间件的全部,虽然介绍很短,但是其功能着实强悍。
JSqlParser 是一个开源的 SQL 语句解析工具,它可以对 SQL 语句进行解析、重构等各种操作:
- 能够将 SQL 字符串转换成一个可操作的抽象语法树(AST),这使得程序能够理解和操作 SQL 语句的各个组成部分。
- 根据需求对解析出的AST进行修改,比如添加额外的过滤条件,然后再将AST转换回SQL字符串,实现需求定制化的SQL语句构建。
SELECT语法树简图:
jar包结构介绍
这里我使用的是4.0版本,maven依赖如下:
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>4.0</version>
</dependency>
JSqlParse的总体代码量不大,结构也很简单,其项目整体结构图如下:
可以看到其总共只有五个大的包,各个包的功能定义也很清晰:
- expression:包含表达式相关的类和接口,可以简单看做sql解析后的组成对象之一。如果需要对sql进行一些更改变换,基本都会涉及到这个包。
- parse:JSqlParse最核心的包,这个包里的类实现了sql的解析,进而我们才可以对解析后的sql(“java类”)做各种自定义处理。虽然这个包是最核心的包,但如果纯粹从使用角度上来说可以不必太在意它,除非我们想深入了解sql解析的过程。
- schema:可以理解为模式,即定义一些和数据中概念相对应的类,如表Table、列Column等。
- statement:sql语句也分很多种,如增删改查等,这个包下就对应各种解析后java类所组成的sql语句,其内部结构如下:
util:JSqlParse解析中用到的工具类,基本也不用太在意,不过有个TablesNamesFinder类则具有较强的参考价值。
其中该组件最厉害的地方是parse包的解析,即将sql解析成一组有血缘(或者成层级嵌套)的对象集,要了解这块,需要对antlr有较深的理解才行。感兴趣的可以专门去看一下。不过如果我们只是使用,就不需要专门了解语法的解析了,我们只需要知道如何对解析后的sql进行修改即可。下面我会先讲解大致大体的如何去做,最后一节再讲解其中的一些原理。
使用介绍
sql语句的修改是通过实现对应的访问者接口实现的,比如你想对from之后的table名称进行处理,那么你只需要实现 FromItemVisitor 接口并重写 访问Table的方法即可。如果你想对sql中的函数进行处理,那么你只需要实现ExpressionVisitor接口并重写其中对应的方法接口即可。
是不是很简单,不过这里有个问题就是我们如何把我们自定义的访问者传给解析后的sql对象。因为解析后的sql对象是具有层级的,我们要处理的对象很有可能在最内层。如果你想自己遍历解析后的sql对象,然后把访问者传给特定的对象,这个方法虽然可行,但只能用于于不包含嵌套或者嵌套层次不深的sql语句,一旦包含嵌套语句或者sql语句很复杂,你很难一层层的去处理。
正确的做法是从sql解析后的第一层开始,将每个遇到的相关访问者接口都实现一遍,这样在获得解析后的sql对象后,直接就可以将自定义访问者对象传进去,也不需要我们自己一层层去剥开sql对象。我们只需要专注于自己需要的重写的访问者方法即可。展示下我实际中变更select语句用到的一些访问者接口,贴出来给大家看下:
StatementVisitor, SelectVisitor, SelectItemVisitor, FromItemVisitor, GroupByVisitor, ExpressionVisitor,ItemsListVisitor
这些访问者接口我也不是一次性全实现的,而是从最外层的StatementVisitor开始,一点点加的,后续如果有需要可能还会再加,这个过程是一个比较繁琐的逐渐深入和查漏补缺的过程,所以在sql语法替换时一定要保持谨慎。但这也给出一个建议,千万不要试图追踪各个模块的迭代处理
情况,这样很容易把你绕进去,你只需关注当前所在的模块即可,其它的通过accpet交给其它对应的visitor去处理。
下面以更改select类型语句,将from之后table表名称从table1改为table2,和将max函数修改为min函数作为目标,我们来实现下这个需求:
首先是流程代码,如下:
public class Main {
public static void main(String[] args) throws Exception{
//1、获取原始sql输入
String sql = "select max(age) from table1";
System.out.println("old sql:[{}]"+sql);
//2、创建解析器
CCJSqlParserManager mgr = new CCJSqlParserManager();
//3、使用解析器解析sql生成具有层次结构的java类
Statement stmt = mgr.parse(new StringReader(sql));
//4、将自定义访问者传入解析后的sql对象
stmt.accept(new MyJSqlVisitor());
//5、打印转换后的sql语句
System.out.println("new sql:[{}]" + stmt.toString());
}
}
其次是最核心的访问者接口实现类,这里为了便于向大家展示sql修改的过程,我们一个个的添加接口:
首先是stmt.accept,这个对象接收的是一个StatementVisitor,所以我们在自定义的类MyJSqlVisitor中先实现这个接口,因为我们要改的是select类语句,所以我们可以找到对应的visitor方法(至于为什么这个接口就是跟selet语句相关,一个是根据方法名推断,一个是debug查看,debug可以看到sql语句一层层的对象,再细就不啰嗦了,实战个几次就懂了)
public class MyJSqlVisitor implements StatementVisitor {
@Override
public void visit(Select select) {
SelectBody selectBody = select.getSelectBody();
if (selectBody != null) {
selectBody.accept(this);
}
}
}
注意下,这里我只列出了一个实现的方法,是因为篇幅有限,我只截取了实现改动的方法,后续也是只展示实现了变动的代码,接着可以看到selectBody也需要一个SelectVisitor类型的访问者,所以我们再MyJSqlVisitor中添加实现该接口:
public class MyJSqlVisitor implements StatementVisitor, SelectVisitor {
@Override
public void visit(Select select) {
SelectBody selectBody = select.getSelectBody();
if (selectBody != null) {
selectBody.accept(this);
}
}
@Override
public void visit(PlainSelect plainSelect) {
/** 处理select字段 */
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (selectItems != null && selectItems.size() > 0) {
selectItems.forEach(selectItem -> {
selectItem.accept(this);
});
}
/** 处理表名或子查询 */
FromItem fromItem = plainSelect.getFromItem();
if (fromItem!=null){
fromItem.accept(this);
}
}
}
该接口对应的visit方法中 selectItem和fromItem同时还需要SelectItemVisitor,FromItemVisitor两种访问者,所以我们先来实现SelectItemVisitor这个接口:
public class MyJSqlVisitor implements StatementVisitor, SelectVisitor ,SelectItemVisitor {
@Override
public void visit(Select select) {
SelectBody selectBody = select.getSelectBody();
if (selectBody != null) {
selectBody.accept(this);
}
}
@Override
public void visit(PlainSelect plainSelect) {
/** 处理select字段 */
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (selectItems != null && selectItems.size() > 0) {
selectItems.forEach(selectItem -> {
selectItem.accept(this);
});
}
/** 处理表名或子查询 */
FromItem fromItem = plainSelect.getFromItem();
if (fromItem!=null){
fromItem.accept(this);
}
}
// 这个方法我们并没有考虑完全,比如select项目中可能有子查询还有可能有case表达式,这些我们都没考虑,这里只是先展示了一种思路。
@Override
public void visit(SelectExpressionItem selectExpressionItem) {
if (Function.class.isInstance(selectExpressionItem.getExpression())) {
Function function = (Function) selectExpressionItem.getExpression();
function.accept(this);
}
}
}
可以看到function.accept还需要一个ExpressionVisitor,这里我们接着实现它:
public class MyJSqlVisitor implements StatementVisitor, SelectVisitor ,SelectItemVisitor, ExpressionVisitor {
@Override
public void visit(Select select) {
SelectBody selectBody = select.getSelectBody();
if (selectBody != null) {
selectBody.accept(this);
}
}
@Override
public void visit(PlainSelect plainSelect) {
/** 处理select字段 */
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (selectItems != null && selectItems.size() > 0) {
selectItems.forEach(selectItem -> {
selectItem.accept(this);
});
}
/** 处理表名或子查询 */
FromItem fromItem = plainSelect.getFromItem();
if (fromItem!=null){
fromItem.accept(this);
}
}
// 这个方法我们并没有考虑完全,比如select项目中可能有子查询还有可能有case表达式,这些我们都没考虑,这里只是先展示了一种思路。
@Override
public void visit(SelectExpressionItem selectExpressionItem) {
if (Function.class.isInstance(selectExpressionItem.getExpression())) {
Function function = (Function) selectExpressionItem.getExpression();
function.accept(this);
}
}
@Override
public void visit(Function function) {
if (function.getName().equalsIgnoreCase("max")){
function.setName("min");
}
}
}
至此,max转min已经结束,我们再回过头实现FromItemVisitor接口:
public class MyJSqlVisitor implements StatementVisitor, SelectVisitor ,SelectItemVisitor, ExpressionVisitor,FromItemVisitor {
@Override
public void visit(Select select) {
SelectBody selectBody = select.getSelectBody();
if (selectBody != null) {
selectBody.accept(this);
}
}
@Override
public void visit(PlainSelect plainSelect) {
/** 处理select字段 */
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (selectItems != null && selectItems.size() > 0) {
selectItems.forEach(selectItem -> {
selectItem.accept(this);
});
}
/** 处理表名或子查询 */
FromItem fromItem = plainSelect.getFromItem();
if (fromItem!=null){
fromItem.accept(this);
}
}
// 这个方法我们并没有考虑完全,比如select项目中可能有子查询还有可能有case表达式,这些我们都没考虑,这里只是先展示了一种思路。
@Override
public void visit(SelectExpressionItem selectExpressionItem) {
if (Function.class.isInstance(selectExpressionItem.getExpression())) {
Function function = (Function) selectExpressionItem.getExpression();
function.accept(this);
}
}
// 实现将max函数转为min函数
@Override
public void visit(Function function) {
if (function.getName().equalsIgnoreCase("max")){
function.setName("min");
}
}
//实现表名称的更换
@Override
public void visit(Table table) {
if (table.getName().equalsIgnoreCase("table1")){
table.setName("table2");
}
}
}
至此,我们的两个修改目标已经达成,运行main看下效果:
old sql:[{}]select max(age) from table1
new sql:[{}]SELECT min(age) FROM table2
Process finished with exit code 0
可以看到我们的目的实现了,不过这里请留意我们并没有考虑子查询等其它情况,这个demo只是展示一种修改思路,工作中具体的操作要考虑的比这细致的多。
可以看到我们的目的实现了,不过这里请留意我们并没有考虑子查询等其它情况,这个demo只是展示一种修改思路,工作中具体的操作要考虑的比这细致的多。
使用建议:
1)一个个的添加接口,遇到什么类型的访问者,加什么类型的实现接口,防止一次性加太多忘记实现逻辑。
2)不要试图追踪各个sql对象的迭代处理情况,这样很容易把你绕进去,你只需关注当前所在的方法模块即可,其它的通过accpet交给其它对应的visitor去处理即可。
3)不要试图一次性实现所有的访问者接口,根据需要进行实现
4)sql语法树具有很强的层次性,当被访问者在进行处理时,要考虑到自己的子元素是不是也要进行迭代处理,如果需要的话,那么就调用对应子元素的accpect方法,并将相关访问者传递进去
5)如果没有使用容器技术,所有的访问者接口尽量放在一个类中实现,这样当有accept需要visitor对象的时候直接传this就行。(我一开始没有用容器管理bean,每个visitor接口我都单独创建一个实现类,最后因为使用不到,造成迭代访问时栈溢出错误)
核心原理介绍
这块只是展示sql迭代访问修改的原理,并不涉及将sql文本解析为对象类的原理。好了,进入正文。
要想理解sql迭代修改的原理,其实只要了解访问者模式和多态这两个知识点就行。如果不了解的可以先去查看对应的知识点,然后再看下源码仔细体会下。下面我会简单介绍下,在前文我们也提过,要想修改sql,只需要实现对应的访问接口即可,然后将访问者传入被访问的sql对象中。
在JSqlParse中,将解析后的sql对象看做被访问者,我们自定义的visitor则看做访问者。该组件同时将各类被访问者和访问者都抽象出了接口,我们代码编辑时通过接口确定大体的执行流程,在具体的代码运行阶段,就会通过多态寻找对应的实现类。就拿demo中的statement来说,它是一个接口,但是运行的时候就会根据sql情况定位到具体的实现类,我们demo中对应的具体实现类就是select对象,此时进入该对象查看具体的accept方法:
可以看到被访问者调用的还是访问者的visit方法,也就是我们对应的重写方法。以此类推,剩下的各个层级处理也是通过重复这个过程,所以想理解这个处理过程,一定要理解访问者模式
JSQLParer高效使用
JSqlParser 是SQL语句分析的插件,他使用Java语言去解析SQL。
sqlparser提供很多的数据库语法解析支持其中支持很多oracle的特殊语法。
可以结合mybatis的拦截修改SQL来实现多租户、SQL拼接甚至联表的功能。
SQL解析
获取SQL中的信息
public class TestSqlparser {
public static void main(String[] args) throws JSQLParserException {
// 根据sql创建select
Select stmt = (Select) CCJSqlParserUtil.parse("SELECT col1 AS a, col2 AS b, col3 AS c FROM table T WHERE col1 = 10 AND col2 = 20 AND col3 = 30");
Map<String, Expression> map = new HashMap<>();
Map<String, String> mapTable = new HashMap<>();
((PlainSelect) stmt.getSelectBody()).getFromItem().accept(new FromItemVisitorAdapter() {
@Override
public void visit(Table table) {
// 获取别名 => 表名
mapTable.put(table.getAlias().getName(), table.getName());
}
});
((PlainSelect) stmt.getSelectBody()).getWhere().accept(new ExpressionVisitorAdapter() {
@Override
public void visit(AndExpression expr) {
// 获取where表达式
System.out.println(expr);
}
});
for (SelectItem selectItem : ((PlainSelect)stmt.getSelectBody()).getSelectItems()) {
selectItem.accept(new SelectItemVisitorAdapter() {
@Override
public void visit(SelectExpressionItem item) {
// 获取字段别名 => 字段名
map.put(item.getAlias().getName(), item.getExpression());
}
});
}
System.out.println("map " + map);
System.out.println("mapTables" + mapTable);
}
}
创建Select的方式
创建Select(非SQL String 创建)
@Test
void testUnionAll3() {
try {
Select t1 = SelectUtils.buildSelectFromTable(new Table("t1"));
Select select = SelectUtils.buildSelectFromTableAndExpressions(new Table("t2"), new Column("id"), new Column("username"));
Select select1 = SelectUtils.buildSelectFromTableAndExpressions(new Table("t3"), "1+1", "2+2");
System.out.println(t1.toString());
System.out.println(select.toString());
System.out.println(select1.toString());
} catch (JSQLParserException e) {
throw new RuntimeException(e);
}
Insert 插入字段和值
@Test
void testUnionAll4() {
try {
Insert parse = (Insert) CCJSqlParserUtil.parse("insert into testTable (c1,c2) values(1,3)");
System.out.println(parse.toString());
parse.addColumns(new Column("c3"));
parse.getItemsList().accept(new ItemsListVisitor() {
@Override
public void visit(SubSelect subSelect) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(ExpressionList expressionList) {
expressionList.getExpressions().add(new LongValue(4));
}
@Override
public void visit(NamedExpressionList namedExpressionList) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(MultiExpressionList multiExprList) {
}
});
System.out.println(parse.toString());
parse.getColumns().add(new Column("c4"));
((ExpressionList)parse.getItemsList()).getExpressions().add(new LongValue(5));
System.out.println(parse.toString());
} catch (JSQLParserException e) {
throw new RuntimeException(e);
}
列替换
public class ReplaceColumnValues {
static class ReplaceColumnAndLongValues extends ExpressionDeParser {
@Override
public void visit(LongValue longValue) {
this.getBuffer().append("?");
}
@Override
public void visit(StringValue stringValue) {
this.getBuffer().append("?");
}
public static String cleanStatement(String sql) throws JSQLParserException {
StringBuilder buffer = new StringBuilder();
ExpressionDeParser expr = new ReplaceColumnAndLongValues();
SelectDeParser selectDeparser = new SelectDeParser(expr, buffer);
expr.setSelectVisitor(selectDeparser);
expr.setBuffer(buffer);
StatementDeParser stmtDeparser = new StatementDeParser(expr, selectDeparser, buffer);
Statement stmt = CCJSqlParserUtil.parse(sql);
stmt.accept(stmtDeparser);
return stmtDeparser.getBuffer().toString();
}
public static void main(String[] args) throws JSQLParserException {
System.out.println(cleanStatement("SELECT 'abc', 5 FROM mytable WHERE col='test'"));
System.out.println(cleanStatement("UPDATE table1 A SET A.columna = 'XXX' WHERE A.cod_table = 'YYY'"));
System.out.println(cleanStatement("INSERT INTO example (num, name, address, tel) VALUES (1, 'name', 'test ', '1234-1234')"));
System.out.println(cleanStatement("DELETE FROM table1 where col=5 and col2=4"));
}
}
}
where条件中字段替换
替换条件字段col_1到col1
@Test
void testUnionAll5() {
try {
Select stmt = (Select) CCJSqlParserUtil.parse("SELECT col1 AS a, col2 AS b, col3 AS c FROM table WHERE col_1 = 10 AND col_2 = 20 AND col_3 = 30");
System.out.println("before " + stmt.toString());
((PlainSelect)stmt.getSelectBody()).getWhere().accept(new ExpressionVisitorAdapter() {
@Override
public void visit(Column column) {
column.setColumnName(column.getColumnName().replace("_", ""));
}
});
System.out.println("after " + stmt.toString());
} catch (JSQLParserException e) {
throw new RuntimeException(e);
}
}
解析SQL例子
Statement stmt = CCJSqlParserUtil.parse("SELECT * FROM tab1");
Statements stmt = CCJSqlParserUtil.parseStatements("SELECT * FROM tab1; SELECT * FROM tab2");
Expression expr = CCJSqlParserUtil.parseExpression("a*(5+mycolumn)");
可以直接将String SQL片段解析成Expression再将expr插入到SQL语句中。
获取所有tableNames
Statement statement = CCJSqlParserUtil.parse("SELECT * FROM MY_TABLE1");
Select selectStatement = (Select) statement;
TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
List<String> tableList = tablesNamesFinder.getTableList(selectStatement);
自动生成别名
Select select = (Select) CCJSqlParserUtil.parse("select a,b,c from test");
final AddAliasesVisitor instance = new AddAliasesVisitor();
select.getSelectBody().accept(instance);
结果:
SELECT a AS A1, b AS A2, c AS A3 FROM test
SQL函数
@Test
void testUnionAll6() {
/**
* SQL 函数
* SELECT function(列) FROM 表
*/
Table t1 = new Table("tab1").withAlias(new Alias("t1").withUseAs(true)); // 表1
PlainSelect plainSelect = new PlainSelect();
plainSelect.setFromItem(t1); // 设置FROM t1= > SELECT FROM tab1 AS t1
List<SelectItem> selectItemList = new ArrayList<>(); // 查询元素集合
SelectExpressionItem selectExpressionItem001 = new SelectExpressionItem(); // 元素1表达式
selectExpressionItem001.setExpression(new Column(t1,"col001"));
SelectExpressionItem selectExpressionItem002 = new SelectExpressionItem(); // 元素2表达式
selectExpressionItem002.setExpression(new Column(t1,"col002"));
selectItemList.add(0, selectExpressionItem001); // 添加入队
selectItemList.add(1, selectExpressionItem002); // 添加入队
// COUNT
SelectExpressionItem selectExpressionItemCount = new SelectExpressionItem(); // 创建函数元素表达式
selectExpressionItemCount.setAlias(new Alias("count")); // 设置别名
Function function = new Function(); // 创建函数对象 Function extends ASTNodeAccessImpl implements Expression
function.setName("COUNT"); // 设置函数名
ExpressionList expressionListCount = new ExpressionList(); // 创建参数表达式
expressionListCount.setExpressions(Collections.singletonList(new Column(t1, "id")));
function.setParameters(expressionListCount); // 设置参数
selectExpressionItemCount.setExpression(function);
selectItemList.add(2,selectExpressionItemCount);
plainSelect.setSelectItems(selectItemList); // 添加查询元素集合入select对象
System.err.println(plainSelect); // SELECT t1.col001, t1.col002, COUNT(t1.id) AS count FROM tab1 AS t1
}
单表where条件拼装
@Test
void testUnionAll7() {
/**
* 单表SQL查询
*
* @throws JSQLParserException
*/
// 单表全量
Table table = new Table("test");
Select select = SelectUtils.buildSelectFromTable(table);
System.err.println(select); // SELECT * FROM test
// 指定列查询
Select buildSelectFromTableAndExpressions = SelectUtils.buildSelectFromTableAndExpressions(new Table("test"), new Column("col1"), new Column("col2"));
System.err.println(buildSelectFromTableAndExpressions); // SELECT col1, col2 FROM test
// WHERE =
EqualsTo equalsTo = new EqualsTo(); // 等于表达式
equalsTo.setLeftExpression(new Column(table, "user_id")); // 设置表达式左边值
equalsTo.setRightExpression(new StringValue("123456"));// 设置表达式右边值
PlainSelect plainSelect = (PlainSelect) select.getSelectBody(); // 转换为更细化的Select对象
plainSelect.setWhere(equalsTo);
System.err.println(plainSelect);// SELECT * FROM test WHERE test.user_id = '123456'
// WHERE != <>
NotEqualsTo notEqualsTo = new NotEqualsTo();
notEqualsTo.setLeftExpression(new Column(table, "user_id")); // 设置表达式左边值
notEqualsTo.setRightExpression(new StringValue("123456"));// 设置表达式右边值
PlainSelect plainSelectNot = (PlainSelect) select.getSelectBody();
plainSelectNot.setWhere(notEqualsTo);
System.err.println(plainSelectNot);// SELECT * FROM test WHERE test.user_id <> '123456'
// 其他运算符, 参考上面代码添加表达式即可
GreaterThan gt = new GreaterThan(); // ">"
GreaterThanEquals geq = new GreaterThanEquals(); // ">="
MinorThan mt = new MinorThan(); // "<"
MinorThanEquals leq = new MinorThanEquals();// "<="
IsNullExpression isNull = new IsNullExpression(); // "is null"
isNull.setNot(true);// "is not null"
LikeExpression nlike = new LikeExpression();
nlike.setNot(true); // "not like"
Between bt = new Between();
bt.setNot(true);// "not between"
// WHERE LIKE
LikeExpression likeExpression = new LikeExpression(); // 创建Like表达式对象
likeExpression.setLeftExpression(new Column("username")); // 表达式左边
likeExpression.setRightExpression(new StringValue("张%")); // 右边表达式
PlainSelect plainSelectLike = (PlainSelect) select.getSelectBody();
plainSelectLike.setWhere(likeExpression);
System.err.println(plainSelectLike); // SELECT * FROM test WHERE username LIKE '张%'
// WHERE IN
Set<String> deptIds = new HashSet<>(); // 创建IN范围的元素集合
deptIds.add("0001");
deptIds.add("0002");
ItemsList itemsList = new ExpressionList(deptIds.stream().map(StringValue::new).collect(Collectors.toList())); // 把集合转变为JSQLParser需要的元素列表
InExpression inExpression = new InExpression(new Column("dept_id "), itemsList); // 创建IN表达式对象,传入列名及IN范围列表
PlainSelect plainSelectIn = (PlainSelect) select.getSelectBody();
plainSelectIn.setWhere(inExpression);
System.err.println(plainSelectIn); // SELECT * FROM test WHERE dept_id IN ('0001', '0002')
// WHERE BETWEEN AND
Between between = new Between();
between.setBetweenExpressionStart(new LongValue(18)); // 设置起点值
between.setBetweenExpressionEnd(new LongValue(30)); // 设置终点值
between.setLeftExpression(new Column("age")); // 设置左边的表达式,一般为列
PlainSelect plainSelectBetween = (PlainSelect) select.getSelectBody();
plainSelectBetween.setWhere(between);
System.err.println(plainSelectBetween); // SELECT * FROM test WHERE age BETWEEN 18 AND 30
// WHERE AND 多个条件结合,都需要成立
AndExpression andExpression = new AndExpression(); // AND 表达式
andExpression.setLeftExpression(equalsTo); // AND 左边表达式
andExpression.setRightExpression(between); // AND 右边表达式
PlainSelect plainSelectAnd = (PlainSelect) select.getSelectBody();
plainSelectAnd.setWhere(andExpression);
System.err.println(plainSelectAnd); // SELECT * FROM test WHERE test.user_id = '123456' AND age BETWEEN 18 AND 30
// WHERE OR 多个条件满足一个条件成立返回
OrExpression orExpression = new OrExpression();// OR 表达式
orExpression.setLeftExpression(equalsTo); // OR 左边表达式
orExpression.setRightExpression(between); // OR 右边表达式
PlainSelect plainSelectOr = (PlainSelect) select.getSelectBody();
plainSelectOr.setWhere(orExpression);
System.err.println(plainSelectOr); // SELECT * FROM test WHERE test.user_id = '123456' OR age BETWEEN 18 AND 30
// ORDER BY 排序
OrderByElement orderByElement = new OrderByElement(); // 创建排序对象
orderByElement.isAsc(); // 设置升序排列 从小到大
orderByElement.setExpression(new Column("col01")); // 设置排序字段
PlainSelect plainSelectOrderBy = (PlainSelect) select.getSelectBody();
plainSelectOrderBy.addOrderByElements(orderByElement);
System.err.println(plainSelectOrderBy); // SELECT * FROM test WHERE test.user_id = '123456' OR age BETWEEN 18 AND 30 ORDER BY col01
}
JOIN 拼装
/**
* 多表SQL查询
* JOIN / INNER JOIN: 如果表中有至少一个匹配,则返回行
* LEFT JOIN: 即使右表中没有匹配,也从左表返回所有的行
* RIGHT JOIN: 即使左表中没有匹配,也从右表返回所有的行
* FULL JOIN: 只要其中一个表中存在匹配,就返回行
*/
@Test
public void testSelectManyTable() {
Table t1 = new Table("tab1").withAlias(new Alias("t1").withUseAs(true)); // 表1
Table t2 = new Table("tab2").withAlias(new Alias("t2", false)); // 表2
PlainSelect plainSelect = new PlainSelect().addSelectItems(new AllColumns()).withFromItem(t1); // SELECT * FROM tab1 AS t1
// JOIN ON 如果表中有至少一个匹配,则返回行
Join join = new Join(); // 创建Join对象
join.withRightItem(t2); // 添加Join的表 JOIN t2 =>JOIN tab2 t2
EqualsTo equalsTo = new EqualsTo(); // 添加 = 条件表达式 t1.user_id = t2.user_id
equalsTo.setLeftExpression(new Column(t1, "user_id "));
equalsTo.setRightExpression(new Column(t2, "user_id "));
join.withOnExpression(equalsTo);// 添加ON
plainSelect.addJoins(join);
System.err.println(plainSelect); // SELECT * FROM tab1 AS t1 JOIN tab2 t2 ON t1.user_id = t2.user_id
// 设置join参数可实现其他类型join
// join.setLeft(true); LEFT JOIN
// join.setRight(true); RIGHT JOIN
// join.setFull(true); FULL JOIN
// join.setInner(true);
}
校验SQL
String sql = "DROP INDEX IF EXISTS idx_tab2_id;";
// validate statement if it's valid for all given databases.
Validation validation = new Validation(Arrays.asList(DatabaseType.SQLSERVER, DatabaseType.MARIADB,
DatabaseType.POSTGRESQL, DatabaseType.H2), sql);
List<ValidationError> errors = validation.validate();
// validate against pre-defined FeaturesAllowed.DML set
String sql = "CREATE TABLE tab1 (id NUMERIC(10), val VARCHAR(30))";
Validation validation = new Validation(Arrays.asList(FeaturesAllowed.DML), sql);
List<ValidationError> errors = validation.validate();
// only DML is allowed, got error for using a DDL statement
log.error (errors);
Validates metadata such as names of tables, views, columns for their existence or non-existence
java.sql.Connection connection = ...;
String sql = "ALTER TABLE mytable ADD price numeric(10,5) not null";
Validation validation = new Validation(Arrays.asList(new JdbcDatabaseMetaDataCapability(connection,
// NamesLookup: Databases handle names differently
NamesLookup.UPPERCASE)), sql);
List<ValidationError> errors = validation.validate();
// do something else with the parsed statements
Statements statements = validation.getParsedStatements();
// check for validation-errors
if (!errors.isEmpty()) {
...
}
基于springboot和mybatis的拦截器和JSQLParser实现数据隔离
在构建多租户系统或需要数据权限控制的应用时,数据隔离是一个关键问题,而解决这一问题的有效方案之一是在项目的数据库访问层实现数据过滤。 Spring Boot 项目中利用Mybatis的强大拦截器机制结合JSqlParser ——一个功能丰富的 SQL 解析器,来轻松实现数据隔离的目标。本文根据示例展示如何根据当前的运行环境来实现数据隔离。
Mybatis拦截器
Mybatis 支持在 SQL 执行的不同阶段拦截并插入自定义逻辑。
本文将通过拦截 StatementHandler 接口的 prepare方法修改SQL语句,实现数据隔离的目的。
详细步骤
1. 导入依赖
Mybatis 依赖:
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>3.0.3</version>
</dependency>
JSqlParser 依赖:
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>4.6</version>
</dependency>
注意: 如果项目选择了 Mybatis Plus 作为数据持久层框架,那么就无需另外添加 Mybatis 和 JSqlParser 的依赖。Mybatis Plus 自身已经包含了这两项依赖,并且保证了它们之间的兼容性。重复添加这些依赖可能会引起版本冲突,从而干扰项目的稳定性。
2. 定义一个拦截器
拦截所有 query 语句并在条件中加入 env 条件
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.RowConstructor;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.values.ValuesStatement;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@Intercepts(
{
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
}
)
public class DataIsolationInterceptor implements Interceptor {
/**
* 从配置文件中环境变量
*/
@Value("${spring.profiles.active}")
private String env;
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
//确保只有拦截的目标对象是 StatementHandler 类型时才执行特定逻辑
if (target instanceof StatementHandler) {
StatementHandler statementHandler = (StatementHandler) target;
// 获取 BoundSql 对象,包含原始 SQL 语句
BoundSql boundSql = statementHandler.getBoundSql();
String originalSql = boundSql.getSql();
String newSql = setEnvToStatement(originalSql);
// 使用MetaObject对象将新的SQL语句设置到BoundSql对象中
MetaObject metaObject = SystemMetaObject.forObject(boundSql);
metaObject.setValue("sql", newSql);
}
// 执行SQL
return invocation.proceed();
}
private String setEnvToStatement(String originalSql) {
net.sf.jsqlparser.statement.Statement statement;
try {
statement = CCJSqlParserUtil.parse(originalSql);
} catch (JSQLParserException e) {
throw new RuntimeException("EnvironmentVariableInterceptor::SQL语句解析异常:"+originalSql);
}
if (statement instanceof Select) {
Select select = (Select) statement;
PlainSelect selectBody = select.getSelectBody(PlainSelect.class);
if (selectBody.getFromItem() instanceof Table) {
Expression newWhereExpression;
if (selectBody.getJoins() == null || selectBody.getJoins().isEmpty()) {
newWhereExpression = setEnvToWhereExpression(selectBody.getWhere(), null);
} else {
// 如果是多表关联查询,在关联查询中新增每个表的环境变量条件
newWhereExpression = multipleTableJoinWhereExpression(selectBody);
}
// 将新的where设置到Select中
selectBody.setWhere(newWhereExpression);
} else if (selectBody.getFromItem() instanceof SubSelect) {
// 如果是子查询,在子查询中新增环境变量条件
// 当前方法只能处理单层子查询,如果有多层级的子查询的场景需要通过递归设置环境变量
SubSelect subSelect = (SubSelect) selectBody.getFromItem();
PlainSelect subSelectBody = subSelect.getSelectBody(PlainSelect.class);
Expression newWhereExpression = setEnvToWhereExpression(subSelectBody.getWhere(), null);
subSelectBody.setWhere(newWhereExpression);
}
// 获得修改后的语句
return select.toString();
} else if (statement instanceof Insert) {
Insert insert = (Insert) statement;
setEnvToInsert(insert);
return insert.toString();
} else if (statement instanceof Update) {
Update update = (Update) statement;
Expression newWhereExpression = setEnvToWhereExpression(update.getWhere(),null);
// 将新的where设置到Update中
update.setWhere(newWhereExpression);
return update.toString();
} else if (statement instanceof Delete) {
Delete delete = (Delete) statement;
Expression newWhereExpression = setEnvToWhereExpression(delete.getWhere(),null);
// 将新的where设置到delete中
delete.setWhere(newWhereExpression);
return delete.toString();
}
return originalSql;
}
/**
* 将需要隔离的字段加入到SQL的Where语法树中
* @param whereExpression SQL的Where语法树
* @param alias 表别名
* @return 新的SQL Where语法树
*/
private Expression setEnvToWhereExpression(Expression whereExpression, String alias) {
// 添加SQL语法树的一个where分支,并添加环境变量条件
AndExpression andExpression = new AndExpression();
EqualsTo envEquals = new EqualsTo();
envEquals.setLeftExpression(new Column(StringUtils.isNotBlank(alias) ? String.format("%s.env", alias) : "env"));
envEquals.setRightExpression(new StringValue(env));
if (whereExpression == null){
return envEquals;
} else {
// 将新的where条件加入到原where条件的右分支树
andExpression.setRightExpression(envEquals);
andExpression.setLeftExpression(whereExpression);
return andExpression;
}
}
/**
* 多表关联查询时,给关联的所有表加入环境隔离条件
* @param selectBody select语法树
* @return 新的SQL Where语法树
*/
private Expression multipleTableJoinWhereExpression(PlainSelect selectBody){
Table mainTable = selectBody.getFromItem(Table.class);
String mainTableAlias = mainTable.getAlias().getName();
// 将 t1.env = ENV 的条件添加到where中
Expression newWhereExpression = setEnvToWhereExpression(selectBody.getWhere(), mainTableAlias);
List<Join> joins = selectBody.getJoins();
for (Join join : joins) {
FromItem joinRightItem = join.getRightItem();
if (joinRightItem instanceof Table) {
Table joinTable = (Table) joinRightItem;
String joinTableAlias = joinTable.getAlias().getName();
// 将每一个join的 tx.env = ENV 的条件添加到where中
newWhereExpression = setEnvToWhereExpression(newWhereExpression, joinTableAlias);
}
}
return newWhereExpression;
}
/**
* 新增数据时,插入env字段
* @param insert Insert 语法树
*/
private void setEnvToInsert(Insert insert) {
// 添加env列
List<Column> columns = insert.getColumns();
columns.add(new Column("env"));
// values中添加环境变量值
List<SelectBody> selects = insert.getSelect().getSelectBody(SetOperationList.class).getSelects();
for (SelectBody select : selects) {
if (select instanceof ValuesStatement){
ValuesStatement valuesStatement = (ValuesStatement) select;
ExpressionList expressions = (ExpressionList) valuesStatement.getExpressions();
List<Expression> values = expressions.getExpressions();
for (Expression expression : values){
if (expression instanceof RowConstructor) {
RowConstructor rowConstructor = (RowConstructor) expression;
ExpressionList exprList = rowConstructor.getExprList();
exprList.addExpressions(new StringValue(env));
}
}
}
}
}
}
3. 测试
Select
Mapper:
<select id="queryAllByOrgLevel" resultType="com.lyx.mybatis.entity.AllInfo">
SELECT a.username,a.code,o.org_code,o.org_name,o.level
FROM admin a left join organize o on a.org_id=o.id
WHERE a.dr=0 and o.level=#{level}
</select>
刚进入拦截器时,Mybatis 解析的 SQL 语句:
SELECT a.username,a.code,o.org_code,o.org_name,o.level
FROM admin a left join organize o on a.org_id=o.id
WHERE a.dr=0 and o.level=?
执行完 setEnvToStatement(originalSql) 方法后,得到的新 SQL 语句:
SELECT a.username, a.code, o.org_code, o.org_name, o.level
FROM admin a LEFT JOIN organize o ON a.org_id = o.id
WHERE a.dr = 0 AND o.level = ? AND a.env = 'test' AND o.env = 'test'
Insert
刚进入拦截器时,Mybatis 解析的 SQL 语句:
INSERT INTO admin ( id, username, code, org_id ) VALUES ( ?, ?, ?, ? )
执行完 setEnvToInsert(insert) 方法后,得到的新 SQL 语句:
INSERT INTO admin (id, username, code, org_id, env) VALUES (?, ?, ?, ?, 'test')
Update
刚进入拦截器时,Mybatis 解析的 SQL 语句:
UPDATE admin SET username=?, code=?, org_id=? WHERE id=?
执行完 setWhere(newWhereExpression) 方法后,得到的新 SQL 语句:
UPDATE admin SET username = ?, code = ?, org_id = ? WHERE id = ? AND env = 'test'
Delete
刚进入拦截器时,Mybatis 解析的 SQL 语句:
DELETE FROM admin WHERE id=?
执行完 setWhere(newWhereExpression) 方法后,得到的新 SQL 语句:
DELETE FROM admin WHERE id = ? AND env = 'test'
4. 为什么要拦截 StatementHandler 接口的 prepare 方法?
可以注意到,在这个例子中定义拦截器时 @Signature 注解中拦截的是 StatementHandler 接口的 prepare 方法,为什么拦截的是 prepare 方法而不是 query 和 update 方法?为什么拦截 query 和 update 方法修改 SQL 语句后仍然执行的是原 SQL ?
这是因为 SQL 语句是在 prepare 方法中被构建和参数化的。prepare 方法是负责准备 PreparedStatement 对象的,这个对象表示即将要执行的 SQL 语句。在 prepare 方法中可以对 SQL 语句进行修改,而这些修改将会影响最终执行的 SQL 。
而 query 和 update 方法是在 prepare 方法之后被调用的。它们主要的作用是执行已经准备好的 PreparedStatement 对象。在这个阶段,SQL 语句已经被创建并绑定了参数值,所以拦截这两个方法并不能改变已经准备好的 SQL 语句。
简单来说,如果想要修改SQL语句的内容(比如增加 WHERE 子句、改变排序规则等),那么需要在 SQL 语句被准备之前进行拦截,即在 prepare 方法的执行过程中进行。
以下是 MyBatis 执行过程中的几个关键步骤:
- 解析配置和映射文件: MyBatis 启动时,首先加载配置文件和映射文件,解析里面的 SQL 语句。
- 生成 StatementHandler 和 BoundSql : 当执行一个操作,比如查询或更新时,MyBatis 会创建一个 StatementHandler 对象,并包装了 BoundSql 对象,后者包含了即将要执行的 SQL 语句及其参数。
- 执行 prepare 方法: StatementHandler 的 prepare 方法被调用,完成 PreparedStatement 的创建和参数设置。
- 执行 query 或 update : 根据执行的是查询操作还是更新操作,MyBatis 再调用 query 或 update 方法来实际执行 SQL 。
- 通过在 prepare 方法进行拦截,我们可以在 SQL 语句被最终确定之前更改它,从而使修改生效。如果在 query 或 update 方法中进行拦截,则无法更改 SQL 语句,只能在执行前后进行其他操作,比如日志记录或者结果处理。