Mybatis拦截器物理分页完整实例

在网上找了一些Mybatis物理分页方法,基本上大部分都是使用Mybatis的拦截器进行分页。我们思路是对要执行的sql语句进行拦截,其次对Sql进行修改,组装成符合各种数据库分页的Sql语句,并同时创建查询总条数的Sql,一起执行。这样在返回的结果集中可以自动的带上总条数。那我根据网上分享的分页案例进行了实际的应用,并且根据个人使用做了修改,下面看看具体的细节。

环境:
spring 4.3.6
mybatis 3.4.2

拦截器

利用拦截器实现Mybatis分页的原理:
要利用JDBC对数据库进行操作必须有一个对应的Statement对象,Mybatis在执行Sql语句前会产生一个包含Sql语句的Statement对象,而且对应的Sql语句是在Statement之前产生的,那可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用StatementHandler对象的prepare方法,即调用invocation.proceed()。
对于分页而言,在拦截器里,通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设置参数的功能把Sql语句中的参数进行替换,之后再执行进行总记录数的统计。

对于StatementHandler其实只有两个实现类:
1.RoutingStatementHandler,
2.BaseStatementHandler 抽象类

BaseStatementHandler有三个子类:
1.SimpleStatementHandler 用于处理Statement的
2.PreparedStatementHandler 处理PreparedStatement的
3.CallableStatementHandler 处理CallableStatement的

Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,
而在RoutingStatementHandler里面拥有一个StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、PreparedStatementHandler或CallableStatementHandler。
在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。

  1. 首先定义我们的请求的PageReq类,以及响应结果集的PageResp类:
    package com.wte.core;
    
    import java.util.Collections;
    import java.util.HashMap;
    import java.util.Map;
    
    public class PageReq {
        int skip = -1;
        int page = 1; //页码
        int pageSize = 10;  //页大小
        String sort;  //排序字段
        String order;  //排序
        int total= 0;  //总记录数
        Map<String, Object> filter;
    
        public Map<String, Object> getFilter() {
            if (filter != null)
                return filter;
            return Collections.emptyMap();
        }
    
        public void setFilterMap(Map<String, Object> filter) {
            this.filter = filter;
        }
        
        /**
         * 添加查询过滤条件
         * @param key 条件参数名
         * @param value 用户输入值
         */
        public void addFilter(String key, Object value) {
            if (filter == null) {
                filter = new HashMap<String, Object>();
            }
            if (key != null && value != null) {
                filter.put(key, value);
            }
        }
    	public int getPage() {
    		return page;
    	}
    	public void setPage(int page) {
    		this.page = page;
    	}
    	public int getPageSize() {
    		return pageSize;
    	}
    	public void setPageSize(int pageSize) {
    		this.pageSize = pageSize;
    	}
    	public String getSort() {
    		return sort;
    	}
    	public void setSort(String sort) {
    		this.sort = sort;
    	}
    	public String getOrder() {
    		return order;
    	}
    	public void setOrder(String order) {
    		this.order = order;
    	}
    	public int getTotal() {
    		return total;
    	}
    	public void setTotal(int total) {
    		this.total = total;
    	}
    	public void setFilter(Map<String, Object> filter) {
    		this.filter = filter;
    	}
    	public int getSkip() {
            if(skip==-1){
                return pageSize * (page - 1);
            }else{
                return skip;
            }
        }
        public void setSkip(int skip) {
            this.skip = skip;
        }
    }
    
package com.wte.core;

import java.util.ArrayList;
import java.util.List;

public class PageResp<T> {
	List<T> rows;  //结果集
    int total; //总行数
    public PageResp(List<T> records, int total) {
        super();
        this.rows = records;
        this.total = total;
    }
    public List<T> getRows() {
        return rows;
    }
    public void setRows(List<T> rows) {
        this.rows = rows;
    }
    public int getTotal() {
        return total;
    }
    public void setTotal(int total) {
        this.total = total;
    }
	public static <T> PageResp<T> fromSingleResult(T element) {
        List<T> records = new ArrayList<T>(1);
        records.add(element);
        PageResp<T> result = new PageResp<T>(records, 1);
        return result;
    }
}

2. 拦截器代码

package com.wte.core.interceptor;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import com.wte.core.PageReq;
import com.wte.utils.ReflectHelper;
  
/**  
 * 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。    
 */    
@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class,Integer.class})})  
public class PageInterceptor implements Interceptor {  
    private String dialect; //数据库方言    
    private String pageSqlId; //mapper.xml中需要拦截的ID(正则匹配)
    @SuppressWarnings("unchecked")
    public Object intercept(Invocation invocation) throws Throwable {
        //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。  
        if(invocation.getTarget() instanceof RoutingStatementHandler){    
            RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();  
            StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");  
            BoundSql boundSql = delegate.getBoundSql();
            Object param = boundSql.getParameterObject();
            if (param instanceof Map) {
            	PageReq page =(PageReq)((Map<String,Object>)param).get("page");
            	if(page!=null){
                    //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性    
                    MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");  
                    //拦截到的prepare方法参数是一个Connection对象    
                    Connection connection = (Connection)invocation.getArgs()[0];  
                    //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句    
                    String sql = boundSql.getSql();  
                    
                    //给当前的page参数对象设置总记录数
                    this.setTotalRecord(page,param,mappedStatement, connection);
                    
                    //获取分页Sql语句
                    String pageSql = this.getPageSql(page, sql);  
                    //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
                    ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
            	}
            }
        }
        return invocation.proceed();  
    }  

     /**  
     * 给当前的参数对象page设置总记录数  
     * @param page Mapper映射语句对应的参数对象  
     * @param mappedStatement Mapper映射语句  
     * @param connection 当前的数据库连接  
     */    
    private void setTotalRecord(PageReq page,Object filter,
           MappedStatement mappedStatement, Connection connection) {    
       //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。    
       //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。    
       BoundSql boundSql = mappedStatement.getBoundSql(filter);
       //获取到我们自己写在Mapper映射语句中对应的Sql语句    
       String sql = boundSql.getSql();  
       //通过查询Sql语句获取到对应的计算总记录数的sql语句    
       String countSql = this.getCountSql(sql);  
       //通过BoundSql获取对应的参数映射    
       List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();  
       //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。    
       BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, filter);  
       //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象    
       ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, filter, countBoundSql);  
       //通过connection建立一个countSql对应的PreparedStatement对象。    
       PreparedStatement pstmt = null;  
       ResultSet rs = null;  
       try {    
           pstmt = connection.prepareStatement(countSql);  
           //通过parameterHandler给PreparedStatement对象设置参数    
           parameterHandler.setParameters(pstmt);  
           //之后就是执行获取总记录数的Sql语句和获取结果了。    
           rs = pstmt.executeQuery();  
           if (rs.next()) {    
              int totalRecord = rs.getInt(1);  
              //给当前的参数page对象设置总记录数    
              page.setTotal(totalRecord);  
           }    
       } catch (SQLException e) {    
           e.printStackTrace(); 
       } finally {    
           try {    
              if (rs != null)    
                  rs.close();  
               if (pstmt != null)    
                  pstmt.close();  
           } catch (SQLException e) {    
              e.printStackTrace();  
           }
       }
    }

    /**  
     * 根据原Sql语句获取对应的查询总记录数的Sql语句  
     * @param sql  
     * @return  
     */    
    private String getCountSql(String sql) {    
       int index = sql.indexOf("from");  
       return "select count(1) " + sql.substring(index);  
    }    

    /**  
     * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle  
     * 其它的数据库都 没有进行分页  
     *  
     * @param page 分页对象  
     * @param sql 原sql语句  
     * @return  
     */    
    private String getPageSql(PageReq page, String sql) {    
       StringBuffer sqlBuffer = new StringBuffer(sql);  
       if ("mysql".equalsIgnoreCase(dialect)) {    
    	   return getMysqlPageSql(page, sqlBuffer);  
       } else if ("oracle".equalsIgnoreCase(dialect)) {    
           return getOraclePageSql(page, sqlBuffer);  
       }    
       return sqlBuffer.toString();  
    }    

    /**  
    * 获取Mysql数据库的分页查询语句  
    * @param page 分页对象  
    * @param sqlBuffer 包含原sql语句的StringBuffer对象  
    * @return Mysql数据库分页语句  
    */    
   private String getMysqlPageSql(PageReq page, StringBuffer sqlBuffer) {    
      //计算第一条记录的位置,Mysql中记录的位置是从0开始的。    
      //System.out.println("page:"+page.getPage()+"-------"+page.getRows());
      int offset = (page.getPage() - 1) * page.getPageSize();  
      sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());  
      return sqlBuffer.toString();  
   }    
      
   /**  
    * 获取Oracle数据库的分页查询语句  
    * @param page 分页对象  
    * @param sqlBuffer 包含原sql语句的StringBuffer对象  
    * @return Oracle数据库的分页查询语句  
    */    
   private String getOraclePageSql(PageReq page, StringBuffer sqlBuffer) {    
      //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的    
      int offset = (page.getPage() - 1) * page.getPageSize() + 1;  
      sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());  
      sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);  
      //上面的Sql语句拼接之后大概是这个样子:    
      //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16    
      return sqlBuffer.toString();  
   }    
     
    /**  
     * 拦截器对应的封装原始对象的方法  
     */          
    public Object plugin(Object arg0) {    
        // TODO Auto-generated method stub    
        if (arg0 instanceof StatementHandler) {    
            return Plugin.wrap(arg0, this);  
        } else {    
            return arg0;  
        }   
    }    
    
    /**  
     * 设置注册拦截器时设定的属性  
     */   
    public void setProperties(Properties p) {  
    }  
  
    public String getDialect() {  
        return dialect;
    }  
  
    public void setDialect(String dialect) {  
        this.dialect = dialect;
    }  
  
    public String getPageSqlId() {  
        return pageSqlId;
    }  
  
    public void setPageSqlId(String pageSqlId) {  
        this.pageSqlId = pageSqlId;
    }  
}  

3. 其中使用到了一个帮助类,代码如下:

package com.wte.utils;
import java.lang.reflect.Field;  
import org.apache.commons.lang3.reflect.FieldUtils; 

public class ReflectHelper {  
    public static Object getFieldValue(Object obj , String fieldName ){  
        if(obj == null){  
            return null ;  
        }  
        Field targetField = getTargetField(obj.getClass(), fieldName);  
        try {  
            return FieldUtils.readField(targetField, obj, true ) ;  
        } catch (IllegalAccessException e) {  
            e.printStackTrace();  
        }   
        return null ;  
    }  
      
    public static Field getTargetField(Class<?> targetClass, String fieldName) {  
        Field field = null;  
        try {  
            if (targetClass == null) {  
                return field;  
            }  
  
            if (Object.class.equals(targetClass)) {  
                return field;  
            }  
  
            field = FieldUtils.getDeclaredField(targetClass, fieldName, true);  
            if (field == null) {  
                field = getTargetField(targetClass.getSuperclass(), fieldName);  
            }  
        } catch (Exception e) {  
        }
        return field;  
    }  

    public static void setFieldValue(Object obj , String fieldName , Object value ){  
        if(null == obj){return;}  
        Field targetField = getTargetField(obj.getClass(), fieldName);    
        try {  
             FieldUtils.writeField(targetField, obj, value) ;  
        } catch (IllegalAccessException e) {  
            e.printStackTrace();  
        }   
    }   
}  

4. 在我们的spring-mybatis.xml中注册plugins,在sqlSessionFactory中加入plugins的property:

<!-- spring和MyBatis完美整合,不需要mybatis的配置映射文件 -->
	<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
		<property name="dataSource" ref="dataSource" />
		<!-- 自动扫描mapping.xml文件 -->
        <property name="mapperLocations" value="classpath:/mapping/*.xml"></property>
        <!-- Mybatis分页拦截器 -->
        <property name="plugins">
	        <array>
	            <bean class="com.wte.core.interceptor.PageInterceptor">
	                <property name="dialect" value="mysql" />
	                <property name="pageSqlId" value="findPaging"/>  
	            </bean>
	        </array>
	    </property>
        
        <!-- mybatis配置文件 -->
        <property name="configLocation" value="classpath:mybatis-config.xml"></property>
	</bean>

业务代码

  1. 首先我们的Controller代码:
    	/**
    	 * 获取用户分页列表
    	 * @param pageReq
    	 * @param user
    	 * @param response
    	 * @return
    	 * @throws Exception
    	 */
    	@RequestMapping("/list")
        public @ResponseBody PageResp<UserBean> pageList(PageReq pageReq, UserBean user, HttpServletResponse response) throws Exception {
            Map<String, Object> filter= new HashMap<String, Object>();
            if (!StringUtils.isEmpty(user.getUserId())) {
            	filter.put("userId", StringUtil.formatLike(user.getUserId()));
            }
            if (!StringUtils.isEmpty(user.getUserName())) {
            	filter.put("userName", StringUtil.formatLike(user.getUserName()));
            }
            pageReq.setFilter(filter);
            PageResp<UserBean> page= userService.findPaging(pageReq);
            return page;
        }
    
  2. Service继承了我们的Service泛型的抽象类,抽象类中对findPaging定义代码如下:
    	/**
         * 分页及条件查询
         * @param pageReq 分页及查询过滤参数
         * @return 指定页的查询结果
         */
    	public PageResp<E> findPaging(PageReq pageReq) {
    		int pageNo = pageReq.getPage();
    		int pageSize = pageReq.getPageSize();
    		String sortCol = pageReq.getSort();
    		String sortDir = pageReq.getOrder();
    		
    		pageReq.addFilter("pageNo", pageNo);
    		pageReq.addFilter("pageSize", pageSize);
    		pageReq.addFilter("sortCol", sortCol);
    		pageReq.addFilter("sortDir", sortDir);
    		pageReq.addFilter("page",pageReq);
    		List<E> list= dao.findPaging(pageReq.getFilter());
    		PageResp<E> result=  new PageResp<E>(list,pageReq.getTotal());
    		return result;
    	}
    

    注:这里对于sortCol,sortDir应当去除特殊字符,因为我们的Mapper使用的是${},这里我没有做。

  3. dao当然只要定义findPaging的接口方法就了。那我们的Mapper中的Sql就只需要如下:
    	<sql id="Base_Column_List">
    		UserId,UserName,Password,PasswordTime,State,EmpId,SysUser,CreateUser,CreateTime,UpdateUser,UpdateTime
    	</sql>
    	<select id="findPaging" parameterType="Map" resultMap="userResult">
            select <include refid="Base_Column_List" />
            from bp_user
           <where>
                <if test="userId!=null and userId!='' ">
                     and userId like #{userId}
                </if>
                <if test="userName!=null and userName!='' ">
                     and userName like #{userName}
                </if>
            </where>
           	<if test="sortCol!=null and sortCol!='' ">
           		order by ${sortCol}
           		<if test="sortDir!=null and sortDir!='' ">
                 	${sortDir}
            	</if>
           	</if>
        </select>
    

    注:mybatis中的#{}与${}要加以区分。#{}适合填入值,${}适合放列明,但是${}不能避免sql注入,所以需要自己来处理order by传入的参数。

这样,我们的分类就实现了,当然,我们现在的分页中只支持单列排序,那么如果要修改成多列排序,知道原理,做适当修改就可以实现了。
Mybatis拦截器实现物理分页

参考地址:

  1. 在mybatis执行SQL语句之前进行拦击处理:http://blog.csdn.net/hfmbook/article/details/41985853
  2. 分享一个完整的Mybatis分页解决方案:http://blog.csdn.net/u013306940/article/details/51168359

发表评论

电子邮件地址不会被公开。 必填项已用*标注