网上的一些Mybatis物理分页方法,基本上大部分都是使用Mybatis的拦截器进行分页。基本思路是对要执行的sql语句进行拦截,其次对Sql进行修改,组装成符合各种数据库分页的Sql语句,并同时创建查询总条数的Sql,一起执行。这样在返回的结果集中可以自动的带上总条数。根据网上分享的分页案例进行了实际的应用,并且根据个人使用做了修改,下面看看具体的细节。
环境:
spring 4.3.6
mybatis 3.4.2
拦截器
利用拦截器实现Mybatis分页的原理:
要利用JDBC对数据库进行操作必须有一个对应的Statement对象,Mybatis在执行Sql语句前会产生一个包含Sql语句的Statement对象,而且对应的Sql语句是在Statement之前产生的,那可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用StatementHandler对象的prepare方法,即调用invocation.proceed()。
对于分页而言,在拦截器里,通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设置参数的功能把Sql语句中的参数进行替换,之后再执行进行总记录数的统计。
对于StatementHandler其实只有两个实现类:
1.RoutingStatementHandler,
2.BaseStatementHandler 抽象类
BaseStatementHandler有三个子类:
1.SimpleStatementHandler 用于处理Statement的
2.PreparedStatementHandler 处理PreparedStatement的
3.CallableStatementHandler 处理CallableStatement的
Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,
而在RoutingStatementHandler里面拥有一个StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、PreparedStatementHandler或CallableStatementHandler。
在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。
- 首先定义我们的请求的PageReq类,以及响应结果集的PageResp类:
package com.wte.core;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class PageReq {
int skip = -1;
int page = 1; //页码
int pageSize = 10; //页大小
String sort; //排序字段
String order; //排序
int total= 0; //总记录数
Map<String, Object> filter;
public Map<String, Object> getFilter() {
if (filter != null)
return filter;
return Collections.emptyMap();
}
public void setFilterMap(Map<String, Object> filter) {
this.filter = filter;
}
/**
* 添加查询过滤条件
* @param key 条件参数名
* @param value 用户输入值
*/
public void addFilter(String key, Object value) {
if (filter == null) {
filter = new HashMap<String, Object>();
}
if (key != null && value != null) {
filter.put(key, value);
}
}
public int getPage() {
return page;
}
public void setPage(int page) {
this.page = page;
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}
public String getSort() {
return sort;
}
public void setSort(String sort) {
this.sort = sort;
}
public String getOrder() {
return order;
}
public void setOrder(String order) {
this.order = order;
}
public int getTotal() {
return total;
}
public void setTotal(int total) {
this.total = total;
}
public void setFilter(Map<String, Object> filter) {
this.filter = filter;
}
public int getSkip() {
if(skip==-1){
return pageSize * (page - 1);
}else{
return skip;
}
}
public void setSkip(int skip) {
this.skip = skip;
}
}
package com.wte.core;
import java.util.ArrayList;
import java.util.List;
public class PageResp<T> {
List<T> rows; //结果集
int total; //总行数
public PageResp(List<T> records, int total) {
super();
this.rows = records;
this.total = total;
}
public List<T> getRows() {
return rows;
}
public void setRows(List<T> rows) {
this.rows = rows;
}
public int getTotal() {
return total;
}
public void setTotal(int total) {
this.total = total;
}
public static <T> PageResp<T> fromSingleResult(T element) {
List<T> records = new ArrayList<T>(1);
records.add(element);
PageResp<T> result = new PageResp<T>(records, 1);
return result;
}
}
2. 拦截器代码
package com.wte.core.interceptor;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import com.wte.core.PageReq;
import com.wte.utils.ReflectHelper;
/**
* 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。
*/
@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class,Integer.class})})
public class PageInterceptor implements Interceptor {
private String dialect; //数据库方言
private String pageSqlId; //mapper.xml中需要拦截的ID(正则匹配)
@SuppressWarnings("unchecked")
public Object intercept(Invocation invocation) throws Throwable {
//我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。
if(invocation.getTarget() instanceof RoutingStatementHandler){
RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Object param = boundSql.getParameterObject();
if (param instanceof Map) {
PageReq page =(PageReq)((Map<String,Object>)param).get("page");
if(page!=null){
//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");
//拦截到的prepare方法参数是一个Connection对象
Connection connection = (Connection)invocation.getArgs()[0];
//获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
String sql = boundSql.getSql();
//给当前的page参数对象设置总记录数
this.setTotalRecord(page,param,mappedStatement, connection);
//获取分页Sql语句
String pageSql = this.getPageSql(page, sql);
//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
}
}
}
return invocation.proceed();
}
/**
* 给当前的参数对象page设置总记录数
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void setTotalRecord(PageReq page,Object filter,
MappedStatement mappedStatement, Connection connection) {
//获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
//delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
BoundSql boundSql = mappedStatement.getBoundSql(filter);
//获取到我们自己写在Mapper映射语句中对应的Sql语句
String sql = boundSql.getSql();
//通过查询Sql语句获取到对应的计算总记录数的sql语句
String countSql = this.getCountSql(sql);
//通过BoundSql获取对应的参数映射
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
//利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, filter);
//通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, filter, countBoundSql);
//通过connection建立一个countSql对应的PreparedStatement对象。
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
//通过parameterHandler给PreparedStatement对象设置参数
parameterHandler.setParameters(pstmt);
//之后就是执行获取总记录数的Sql语句和获取结果了。
rs = pstmt.executeQuery();
if (rs.next()) {
int totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setTotal(totalRecord);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (rs != null)
rs.close();
if (pstmt != null)
pstmt.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
* @param sql
* @return
*/
private String getCountSql(String sql) {
int index = sql.indexOf("from");
return "select count(1) " + sql.substring(index);
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(PageReq page, String sql) {
StringBuffer sqlBuffer = new StringBuffer(sql);
if ("mysql".equalsIgnoreCase(dialect)) {
return getMysqlPageSql(page, sqlBuffer);
} else if ("oracle".equalsIgnoreCase(dialect)) {
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}
/**
* 获取Mysql数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(PageReq page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Mysql中记录的位置是从0开始的。
//System.out.println("page:"+page.getPage()+"-------"+page.getRows());
int offset = (page.getPage() - 1) * page.getPageSize();
sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
return sqlBuffer.toString();
}
/**
* 获取Oracle数据库的分页查询语句
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(PageReq page, StringBuffer sqlBuffer) {
//计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPage() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
//上面的Sql语句拼接之后大概是这个样子:
//select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16
return sqlBuffer.toString();
}
/**
* 拦截器对应的封装原始对象的方法
*/
public Object plugin(Object arg0) {
// TODO Auto-generated method stub
if (arg0 instanceof StatementHandler) {
return Plugin.wrap(arg0, this);
} else {
return arg0;
}
}
/**
* 设置注册拦截器时设定的属性
*/
public void setProperties(Properties p) {
}
public String getDialect() {
return dialect;
}
public void setDialect(String dialect) {
this.dialect = dialect;
}
public String getPageSqlId() {
return pageSqlId;
}
public void setPageSqlId(String pageSqlId) {
this.pageSqlId = pageSqlId;
}
}
3. 其中使用到了一个帮助类,代码如下:
package com.wte.utils;
import java.lang.reflect.Field;
import org.apache.commons.lang3.reflect.FieldUtils;
public class ReflectHelper {
public static Object getFieldValue(Object obj , String fieldName ){
if(obj == null){
return null ;
}
Field targetField = getTargetField(obj.getClass(), fieldName);
try {
return FieldUtils.readField(targetField, obj, true ) ;
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return null ;
}
public static Field getTargetField(Class<?> targetClass, String fieldName) {
Field field = null;
try {
if (targetClass == null) {
return field;
}
if (Object.class.equals(targetClass)) {
return field;
}
field = FieldUtils.getDeclaredField(targetClass, fieldName, true);
if (field == null) {
field = getTargetField(targetClass.getSuperclass(), fieldName);
}
} catch (Exception e) {
}
return field;
}
public static void setFieldValue(Object obj , String fieldName , Object value ){
if(null == obj){return;}
Field targetField = getTargetField(obj.getClass(), fieldName);
try {
FieldUtils.writeField(targetField, obj, value) ;
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
4. 在我们的spring-mybatis.xml中注册plugins,在sqlSessionFactory中加入plugins的property:
<!-- spring和MyBatis完美整合,不需要mybatis的配置映射文件 -->
<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
<property name="dataSource" ref="dataSource" />
<!-- 自动扫描mapping.xml文件 -->
<property name="mapperLocations" value="classpath:/mapping/*.xml"></property>
<!-- Mybatis分页拦截器 -->
<property name="plugins">
<array>
<bean class="com.wte.core.interceptor.PageInterceptor">
<property name="dialect" value="mysql" />
<property name="pageSqlId" value="findPaging"/>
</bean>
</array>
</property>
<!-- mybatis配置文件 -->
<property name="configLocation" value="classpath:mybatis-config.xml"></property>
</bean>
业务代码
首先我们的Controller代码:
/**
* 获取用户分页列表
* @param pageReq
* @param user
* @param response
* @return
* @throws Exception
*/
@RequestMapping("/list")
public @ResponseBody PageResp<UserBean> pageList(PageReq pageReq, UserBean user, HttpServletResponse response) throws Exception {
Map<String, Object> filter= new HashMap<String, Object>();
if (!StringUtils.isEmpty(user.getUserId())) {
filter.put("userId", StringUtil.formatLike(user.getUserId()));
}
if (!StringUtils.isEmpty(user.getUserName())) {
filter.put("userName", StringUtil.formatLike(user.getUserName()));
}
pageReq.setFilter(filter);
PageResp<UserBean> page= userService.findPaging(pageReq);
return page;
}
Service继承了我们的Service泛型的抽象类,抽象类中对findPaging定义代码如下:
/**
* 分页及条件查询
* @param pageReq 分页及查询过滤参数
* @return 指定页的查询结果
*/
public PageResp<E> findPaging(PageReq pageReq) {
int pageNo = pageReq.getPage();
int pageSize = pageReq.getPageSize();
String sortCol = pageReq.getSort();
String sortDir = pageReq.getOrder();
pageReq.addFilter("pageNo", pageNo);
pageReq.addFilter("pageSize", pageSize);
pageReq.addFilter("sortCol", sortCol);
pageReq.addFilter("sortDir", sortDir);
pageReq.addFilter("page",pageReq);
List<E> list= dao.findPaging(pageReq.getFilter());
PageResp<E> result= new PageResp<E>(list,pageReq.getTotal());
return result;
}
注:这里对于sortCol,sortDir应当去除特殊字符,因为我们的Mapper使用的是${},这里我没有做。
dao当然只要定义findPaging的接口方法就了。那我们的Mapper中的Sql就只需要如下:
<sql id="Base_Column_List">
UserId,UserName,Password,PasswordTime,State,EmpId,SysUser,CreateUser,CreateTime,UpdateUser,UpdateTime
</sql>
<select id="findPaging" parameterType="Map" resultMap="userResult">
select <include refid="Base_Column_List" />
from bp_user
<where>
<if test="userId!=null and userId!='' ">
and userId like #{userId}
</if>
<if test="userName!=null and userName!='' ">
and userName like #{userName}
</if>
</where>
<if test="sortCol!=null and sortCol!='' ">
order by ${sortCol}
<if test="sortDir!=null and sortDir!='' ">
${sortDir}
</if>
</if>
</select>
注:mybatis中的#{}与${}要加以区分。#{}适合填入值,${}适合放列明,但是${}不能避免sql注入,所以需要自己来处理order by传入的参数。
这样,我们的分类就实现了,当然,我们现在的分页中只支持单列排序,那么如果要修改成多列排序,知道原理,做适当修改就可以实现了。
参考地址:
- 在mybatis执行SQL语句之前进行拦击处理:http://blog.csdn.net/hfmbook/article/details/41985853
- 分享一个完整的Mybatis分页解决方案:http://blog.csdn.net/u013306940/article/details/51168359