手写实现一个ORM框架
- 什么是ORM框架、ORM框架的作用
- 效果演示
- 框架设计
- 代码细节
- SqlBuilder
- Sql
- Executor
- StatementHandler
- ParameterHandler
- ResultSetHandler
- 逆序生成实体类
大家好,本人最近写了一个ORM框架,想在这里分享给大家,让大家来学习学习。
废话不多说,直接进入正题。
什么是ORM框架、ORM框架的作用
首先介绍一下ORM框架的相关知识。
数据库表是行列格式的,而Java是面向对象的,我们需要通过操作JDBC的结果集ResultSet,一行行遍历,再一列一列的处理结果,在new一个对象去set对应的值,这就显得非常繁琐,也与Java的面向对象编程格格不入。
ORM框架就可以解决这个问题,通过对象与关系型数据库建立一个映射关系,就可以省去操作JDBC的结果集ResultSet这一步繁琐的操作,直接把对库表的查询结果映射成对应的对象。
除此之外,操作JDBC还要我们自己调用PreparedStatement把参数一个一个的set进去,ORM框架的另一个作用就是省去设置参数的繁琐操作,根据参数类型自动调用PreparedStatement对应的set方法设置参数。
最后JDBC的操作的一般都是模板代码:通过DriverManager取得Connection,通过Connection取得PreparedStatement,然后通过PreparedStatement执行查询或更新,最后把获取到的ResultSet处理成返回结果。使用ORM框架,这些模板代码不需要我们重复的写,ORM框架帮我们封装了这些模板代码。
了解了ORM框架的作用之后,下面就开始介绍我们自己手写的ORM框架。
效果演示
库表:test.student 学生表
编写测试类:
public class StudentTest {
@Before
public void before() {
Configuration.init("com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/test?useUnicode=true&characterEncoding=UTF-8", "root", "root");
}
@Test
public void testSimpleQuery() {
Configuration configuration = Configuration.get();
List<Student> result = SqlBuilder.createSql(configuration)
.select(Student.class)
.from(Student.class)
.where()
.eq(Student::getSex, "女")
.and()
.eq(Student::getGradeId, 1)
.build()
.query();
System.out.println(result);
}
}
执行测试类,控制台打印:
21:18:50.076 [main] INFO com.huangjunyi1993.easy.sql.sql.SqlBuilder - this sql is: select studentno, loginpwd, studentname, sex, gradeid, phone, address, borndate, email from student where sex=? and gradeid=? ;
[Student [studentNo=S1101002, loginPwd=228996246, studentName=洛飞, sex=女, gradeId=1, phone=666762663, address=天津市南开区, borndate=Wed Feb 07 00:00:00 CST 1990, email=jnqlpkdwb@nsjpt.com], Student [studentNo=S1101003, loginPwd=228996247, studentName=凌辉, sex=女, gradeId=1, phone=353149818, address=北京市海淀区成府路, borndate=Sun Apr 04 00:00:00 CST 1993, email=eepispykh@oitbl.com], Student [studentNo=S1101008, loginPwd=228996257, studentName=凌洋, sex=女, gradeId=1, phone=15812345680, address=湖南省长沙, borndate=Thu Nov 30 00:00:00 CST 1989, email=null], Student [studentNo=S1101011, loginPwd=228996267, studentName=圆荷, sex=女, gradeId=1, phone=13512344483, address=河北省石家庄, borndate=Thu Mar 16 00:00:00 CST 1989, email=idfwxlbjr@bkxko.com], Student [studentNo=S1101012, loginPwd=228996270, studentName=崔今生, sex=女, gradeId=1, phone=13512345684, address=河北省邯郸市, borndate=Fri Jan 05 00:00:00 CST 1990, email=qrakldetd@ogtso.com], Student [studentNo=S1101017, loginPwd=228996276, studentName=赵七, sex=女, gradeId=1, phone=511686053, address=北京市海淀区中关村, borndate=Thu Jun 27 00:00:00 CST 1985, email=ltshcitdp@qdpeh.com]]
一个字,酷!!!
框架设计
一共七个核心组件:Configuration、SqlBuilder、Sql、Executor、StatementHandler、ParameterHandler、ResultSetHandler。
Configuration是全局配置类,保存数据库启动类名、数据库url、用户名、密码等信息。
SqlBuilder是建造者模式的实现,流式编程的方式编写代码形式的sql,最后可以调用build()方法构造一个Sql对象,Sql对象就包含了真实的sql。
就像这样:
Sql<Student> sql = SqlBuilder.createSql(configuration)
.select(Student.class)
.from(Student.class)
.where()
.eq(Student::getSex, "女")
.and()
.eq(Student::getGradeId, 1)
.build()
SqlBuilder的build()创建一个Sql对象,Sql对象保存了要被执行的sql,以及用于执行Sql的Executor执行器。
调用Sql的query()方法或者execute()方法,sql将会被执行,里面会调用Executor执行sql。
Executor会通过JDBC的DriverManager获取数据库连接对象Connection,然后调用Connection的prepareStatement(sql)对sql进行预编译,获取到JDBC的PreparedStatement对象。然后把PreparedStatement对象交给StatementHandler处理。
StatementHandler会调用ParameterHandler进行预编译sql中的参数设置,然后调用JDBC的PreparedStatement对象的executeQuery()方法或者executeUpdate()方法执行sql,然后把JDBC的结果集对象ResultSet交给ResultSetHandler处理。
ParameterHandler保存了一个参数类型数组Class<?>[] parameterTypes,ParameterHandler根据参数类型调用PreparedStatement的setXXX方法进行参数设置。
ResultSetHandler保存了一个返回值类型Class<T> returnType,ResultSetHandler根据Class里面的字段类型调用ResultSet的getXXX方法获取返回值,通过反射的方式 field.set(obj, value) 设置到对象字段中。
代码细节
SqlBuilder
SqlBuilder有各种方法,比如select(Class<?> clazz),select(String… fieldNames)、update(Class<?> clazz)、set(String fieldName, Object parameter)、deleteFrom(Class<?> clazz)、insertInto(Class<?> clazz)、from(Class<?> clazz) 等等。满足大多数常用sql的编写。
我们看一个select(Class<?> clazz)方法:
public <T> SqlBuilder<T> select(Class<T> clazz) {
sql.append("select ");
Field[] fields = clazz.getDeclaredFields();
List<String> fieldNames = new ArrayList<>();
for (Field field : fields) {
FieldName fieldName = field.getAnnotation(FieldName.class);
if (fieldName == null) {
continue;
}
if (StringUtils.isBlank(fieldName.value())) {
throw new RuntimeException(String.format("fieldName is blank, class=%s, field=%", returnType.getName(), field.getName()));
}
fieldNames.add(fieldName.value());
}
sql.append(String.join(", ", fieldNames) + " ");
return (SqlBuilder<T>) this;
}
反射获取类中的字段对象Field,获取字段上的的@FieldName注解,取出注解中的值作为字段名,拼接sql如:“select studentno, loginpwd, studentname, sex, gradeid, phone, address, borndate, email ”。
再看一下 SqlBuilder 的 from(Class<?> clazz) 方法:
public SqlBuilder<T> from(Class<?> clazz) {
TableName annotation = clazz.getAnnotation(TableName.class);
if (annotation == null) {
throw new RuntimeException("table name is null");
}
if (StringUtils.isBlank(annotation.value())) {
throw new RuntimeException("table name is blank");
}
String tableName = annotation.value();
return from(tableName);
}
public SqlBuilder<T> from(String tableName) {
sql.append("from ").append(tableName).append(" ");
return this;
}
反射获取到类上的@TableName注解,取出注解中的值作为表名,与前面的select方法的sql进行拼接,把 “from 表名” 拼接在后面,此时的sql就是:“select studentno, loginpwd, studentname, sex, gradeid, phone, address, borndate, email from student ”。
where()方法和and()方法就不看了,非常简单。
再看一下SqlBuilder的eq(String fieldName, Object parameter):
public <F> SqlBuilder<T> eq(SFunction<F, ?> function, Object parameter) {
String fieldName = FieldNameUtil.getFieldName(function);
sql.append(fieldName + "=? ");
saveParameter(parameter);
return this;
}
private void saveParameter(Object parameter) {
if (parameters == null) {
parameters = new ArrayList<>();
}
parameters.add(parameter);
if (parameterTypes == null) {
parameterTypes = new ArrayList<>();
}
if (parameter instanceof Long) {
parameterTypes.add(Long.class);
} else if (parameter instanceof Integer) {
parameterTypes.add(Integer.class);
} else if (parameter instanceof Short) {
parameterTypes.add(Short.class);
} else if (parameter instanceof Byte) {
parameterTypes.add(Byte.class);
} else if (parameter instanceof Double) {
parameterTypes.add(Double.class);
} else if (parameter instanceof Float) {
parameterTypes.add(Float.class);
} else if (parameter instanceof Character) {
parameterTypes.add(Character.class);
} else if (parameter instanceof Boolean) {
parameterTypes.add(Boolean.class);
} else if (parameter instanceof String) {
parameterTypes.add(String.class);
} else if (parameter instanceof Date) {
parameterTypes.add(Date.class);
} else if (parameter instanceof Sql) {
parameterTypes.add(Sql.class);
} else {
throw new RuntimeException("no support type");
}
}
eq方法的第一行是通过方法引用取得字段上的@FieldName注解,然后取得注解中声明的字段名。
String fieldName = FieldNameUtil.getFieldName(function);
比如Student类中有个字段:
@FieldName("sex")
private String sex;
那么调用FieldNameUtil.getFieldName(Student::getSex),就取到了“sex”这个列名(这里表列名跟Student类的字段名相同)。
eq方法设置sql中的等值查询条件,但是为了防止sql注入,我们使用的是预编译sql,所以此时拼接到sql的是一个问号,然后参数保存到List<Object> parameters属性中,参数类型保存到 List<Class<?>> parameterTypes 属性中。
sql.append(fieldName + "=? ");
saveParameter(parameter);
加上eq条件之后,拼出的sql就是:“select studentno, loginpwd, studentname, sex, gradeid, phone, address, borndate, email from student where sex=? and gradeid=? ;”。
public Sql<T> build() {
this.sql.append(";");
LOGGER.info("this sql is: {}", this.sql);
Sql<T> sql = new Sql<>(
this.configuration,
this.sql.toString().trim(),
this.parameterTypes != null ? this.parameterTypes.toArray(new Class[]{}) : null,
this.parameters != null ? this.parameters.toArray(new Object[]{}) : null,
this.returnType);
return sql;
}
最后build()方法就是创建了一个Sql对象返回,创建Sql对象时把前面拼接的sql、以及组装好的参数信息parameters、parameterTypes 传递给了Sql对象的构造方法。
Sql
构造方法:
public Sql(Configuration configuration, String sql, Class<?>[] parameterTypes, Object[] parameters, Class<T> returnType) {
super();
this.configuration = configuration;
this.sql = sql;
this.parameterTypes = parameterTypes;
this.parameters = parameters;
this.returnType = returnType;
this.init();
}
private void init() {
executor = new Executor<>(configuration, sql, parameterTypes, parameters, returnType);
}
Sql对象的构造方法创建了一个Executor对象,并把sql、parameterTypes,、parameters等参数传递给Executor。
public int execute() {
return executor.executeUpdate();
}
@SuppressWarnings({ "unchecked", "hiding" })
public List<T> query() {
return executor.executeQuery();
}
然后Sql中的execute方法和query方法都是直接调用Executor对象。
Executor
查询类型的sql(select)会调用Executor的executeQuery()方法,增删改类型的sql(insert、update、delete)会调用Executor的executeUpdate()方法。我们看看executeQuery()方法:
public List<T> executeQuery() {
CacheKey cacheKey = CacheKey.get(sql, parameterTypes, parameters);
if (cache.containsKey(cacheKey)) {
return (List<T>) cache.get(cacheKey);
}
Connection conn = null;
PreparedStatement prepareStatement = null;
try {
conn = DriverManager.getConnection(configuration.getConnectionUrl(), configuration.getUserName(), configuration.getPassword());
prepareStatement = conn.prepareStatement(sql);
StatementHandler<T> statementHandler = new StatementHandler<>(parameterTypes, parameters, returnType, prepareStatement);
List<T> reuslt = statementHandler.handleQuery();
cache.put(cacheKey, reuslt);
return reuslt;
} catch (Exception e) {
......
} finally {
......
}
}
首先构造一个CacheKey对象,从缓存查找看看有没有之前执行过的结果,有就取缓存的结果返回,不再往下执行。
CacheKey cacheKey = CacheKey.get(sql, parameterTypes, parameters);
if (cache.containsKey(cacheKey)) {
return (List<T>) cache.get(cacheKey);
}
cache中的缓存会在当前Executor执行executeUpdate()方法时被清空。
如果没有就要往下执行,查询数据库。
conn = DriverManager.getConnection(configuration.getConnectionUrl(), configuration.getUserName(), configuration.getPassword());
prepareStatement = conn.prepareStatement(sql);
StatementHandler<T> statementHandler = new StatementHandler<>(parameterTypes, parameters, returnType, prepareStatement);
List<T> reuslt = statementHandler.handleQuery();
DriverManager.getConnection(…) 是JDBC的方法,获取数据库连接对象Connection。conn.prepareStatement(sql)也是JDBC的方法,对sql进行预编译,返回一个PreparedStatement对象。然后创建StatementHandler,StatementHandler保存了参数类型parameterTypes、参数parameters、返回值类型returnType、prepareStatement等属性。然后调用statementHandler的handleQuery执行查询。
StatementHandler
我们看看StatementHandler的handleQuery()方法:
public List<T> handleQuery() throws SQLException, InstantiationException, IllegalAccessException {
parameterHandler.handle(preparedStatement);
ResultSet resultSet = preparedStatement.executeQuery();
List<T> result = resultSetHandler.handle(resultSet);
if (resultSet != null) {
resultSet.close();
}
return result;
}
首先是通过ParameterHandler处理参数的设置:
parameterHandler.handle(preparedStatement);
然后调用JDBC的方法preparedStatement.executeQuery()执行查询,获取返回的结果集ResultSet:
ResultSet resultSet = preparedStatement.executeQuery();
然后调用ResultSetHandler处理结果集到返回对象的映射:
List<T> result = resultSetHandler.handle(resultSet);
ParameterHandler
ParameterHandler的handle处理参数的设置:
public void handle(PreparedStatement preparedStatement) throws SQLException {
if (parameterTypes == null) {
return;
}
for (int i = 0; i < parameterTypes.length; i++) {
if (Long.class.isAssignableFrom(parameterTypes[i]) || long.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setLong(i + 1, (long) parameters[i]);
} else if (Integer.class.isAssignableFrom(parameterTypes[i]) || int.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setInt(i + 1, (int) parameters[i]);
} else if (Short.class.isAssignableFrom(parameterTypes[i]) || short.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setShort(i + 1, (short) parameters[i]);
} else if (Byte.class.isAssignableFrom(parameterTypes[i]) || byte.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setByte(i + 1, (byte) parameters[i]);
} else if (Double.class.isAssignableFrom(parameterTypes[i]) || double.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setDouble(i + 1, (double) parameters[i]);
} else if (Float.class.isAssignableFrom(parameterTypes[i]) || float.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setFloat(i + 1, (float) parameters[i]);
} else if (Character.class.isAssignableFrom(parameterTypes[i]) || char.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setString(i + 1, String.valueOf((char) parameters[i]));
} else if (Boolean.class.isAssignableFrom(parameterTypes[i]) || boolean.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setBoolean(i + 1, (boolean) parameters[i]);
} else if (String.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setString(i + 1, (String) parameters[i]);
} else if (Timestamp.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setTimestamp(i + 1, (Timestamp) parameters[i]);
} else if (Date.class.isAssignableFrom(parameterTypes[i])) {
preparedStatement.setDate(i + 1, new java.sql.Date(((Date) parameters[i]).getTime()));
} else {
throw new RuntimeException("no support type");
}
}
}
就是根据parameterTypes中保存的参数类型,调用JDBC的方法,把parameters中的参数设置到sql中。比如Long类型,那么调用preparedStatement.setLong(index, parameter),如果是String类,调用
preparedStatement.setString(index, parameter)。
ResultSetHandler
ResultSetHandler的handle把结果集ResultSet转换成指定类型的对象:
public List<T> handle(ResultSet resultSet) throws SQLException, InstantiationException, IllegalAccessException {
if (returnType == null || Map.class.isAssignableFrom(returnType)) {
return handleForMap(resultSet);
}
if (isBaseReturnType(returnType)) {
return handleForBaseType(resultSet, returnType);
}
Field[] fields = returnType.getDeclaredFields();
List<T> result = new ArrayList<>();
while(resultSet.next()) {
T t = returnType.newInstance();
for (Field field : fields) {
FieldName annotation = field.getAnnotation(FieldName.class);
if (annotation == null) {
continue;
}
if (StringUtils.isBlank(annotation.value())) {
throw new RuntimeException(String.format("fieldName is blank, class=%s, field=%", returnType.getName(), field.getName()));
}
String name = annotation.value();
Class<?> type = field.getType();
field.setAccessible(true);
Object value = getResultSetValue(resultSet, type, name);
if (value != null) {
field.set(t, value);
}
}
result.add(t);
}
return result;
}
反射获取指定类型里面的字段
Field[] fields = returnType.getDeclaredFields();
然后遍历ResultSet,每一条查询记录对应一个对象:
List<T> result = new ArrayList<>();
while(resultSet.next()) {
......
}
每一条查询记录,创建一个指定类型的对象,给对象中的字段赋值,保存到返回的List中:
T t = returnType.newInstance();
for (Field field : fields) {
......
}
result.add(t);
对于每个对象中的字段的赋值,就是反射取得字段上的@FieldName注解,取得注解里的字段名name,然后反射取得字段类型type,调用getResultSetValue(resultSet, type, name)从结果集ResultSet中获取该字段对象的列值value,然后反射field.set(t, value)设置字段值。
FieldName annotation = field.getAnnotation(FieldName.class);
// 一些校验......
String name = annotation.value();
Class<?> type = field.getType();
field.setAccessible(true);
Object value = getResultSetValue(resultSet, type, name);
if (value != null) {
field.set(t, value);
}
getResultSetValue方法:
private Object getResultSetValue(ResultSet resultSet, Class<?> type, String name) {
Object value = null;
try {
if (Long.class.isAssignableFrom(type) || long.class.isAssignableFrom(type)) {
value = resultSet.getLong(name);
} else if (Integer.class.isAssignableFrom(type) || int.class.isAssignableFrom(type)) {
value = resultSet.getInt(name);
} else if (Short.class.isAssignableFrom(type) || short.class.isAssignableFrom(type)) {
value = resultSet.getShort(name);
} else if (Byte.class.isAssignableFrom(type) || byte.class.isAssignableFrom(type)) {
value = resultSet.getByte(name);
} else if (Double.class.isAssignableFrom(type) || double.class.isAssignableFrom(type)) {
value = resultSet.getDouble(name);
} else if (Float.class.isAssignableFrom(type) || float.class.isAssignableFrom(type)) {
value = resultSet.getFloat(name);
} else if (Character.class.isAssignableFrom(type) || char.class.isAssignableFrom(type)) {
value = resultSet.getString(name).charAt(0);
} else if (Boolean.class.isAssignableFrom(type) || boolean.class.isAssignableFrom(type)) {
value = resultSet.getBoolean(name);
} else if (String.class.isAssignableFrom(type)) {
value = resultSet.getString(name);
} else if (Timestamp.class.isAssignableFrom(type)) {
value = resultSet.getTimestamp(name);
} else if (Date.class.isAssignableFrom(type)) {
value = new Date(resultSet.getDate(name).getTime());
} else {
throw new RuntimeException("no support type");
}
} catch (SQLException e) {
LOGGER.warn("getResultSetValue: {}", e.getMessage());
}
return value;
}
参数type是字段的类型,name是字段上@FieldName注解声明的列名。根据指定类型,调用对应的JDBC方法,比如long类型则调用resultSet.getLong(name),String类型则调用resultSet.getString(name)。
逆序生成实体类
我们定义了两个注解:一个是@TableName,添加到类上,用于声明该类对应的表名;一个是@FieldName,添加到字段上,用于声明该字段对应表中的哪一列。
就像这样:
@TableName("t_sensitive_word")
public class SensitiveWord{
@FieldName("f_id")
private String id;
@FieldName("f_content")
private String content;
@FieldName("f_member_id")
private String memberId;
@FieldName("f_create_time")
private long createTime;
@FieldName("f_modify_date")
private Timestamp modifyDate;
@FieldName("f_company_id")
private String companyId;
// 下面各种set、get......
}
每个表都要手动写一个这个实体类,太繁琐了。于是还写了个逆向生成实体类的框架EntityGenerator,根据数据库表逆向生成所有的实体类。
通过查询 INFORMATION_SCHEMA.TABLES 获取指定数据库上的所有表名,然后遍历所有表名,通过PreparedStatement的getMetaData方法获取结果集中的元数据信息,里面就包含表字段名和表字段类型等信息,就可以根据这些信息生成一个实体类。
代码就不贴上来了,有兴趣的可以从代码仓把代码拉取到本地研究。
代码仓地址:https://gitee.com/huang_junyi/easy-sql
全文完。