Mybatis拦截器物理分页完整实例

在网上找了一些Mybatis物理分页方法,基本上大部分都是使用Mybatis的拦截器进行分页。我们思路是对要执行的sql语句进行拦截,其次对Sql进行修改,组装成符合各种数据库分页的Sql语句,并同时创建查询总条数的Sql,一起执行。这样在返回的结果集中可以自动的带上总条数。那我根据网上分享的分页案例进行了实际的应用,并且根据个人使用做了修改,下面看看具体的细节。

环境:
spring 4.3.6
mybatis 3.4.2

拦截器

利用拦截器实现Mybatis分页的原理:
要利用JDBC对数据库进行操作必须有一个对应的Statement对象,Mybatis在执行Sql语句前会产生一个包含Sql语句的Statement对象,而且对应的Sql语句是在Statement之前产生的,那可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用StatementHandler对象的prepare方法,即调用invocation.proceed()。
对于分页而言,在拦截器里,通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设置参数的功能把Sql语句中的参数进行替换,之后再执行进行总记录数的统计。

对于StatementHandler其实只有两个实现类:
1.RoutingStatementHandler,
2.BaseStatementHandler 抽象类

BaseStatementHandler有三个子类:
1.SimpleStatementHandler 用于处理Statement的
2.PreparedStatementHandler 处理PreparedStatement的
3.CallableStatementHandler 处理CallableStatement的

Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,
而在RoutingStatementHandler里面拥有一个StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、PreparedStatementHandler或CallableStatementHandler。
在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。

  1. 首先定义我们的请求的PageReq类,以及响应结果集的PageResp类:
    package com.wte.core;
    
    import java.util.Collections;
    import java.util.HashMap;
    import java.util.Map;
    
    public class PageReq {
        int skip = -1;
        int page = 1; //页码
        int pageSize = 10;  //页大小
        String sort;  //排序字段
        String order;  //排序
        int total= 0;  //总记录数
        Map<String, Object> filter;
    
        public Map<String, Object> getFilter() {
            if (filter != null)
                return filter;
            return Collections.emptyMap();
        }
    
        public void setFilterMap(Map<String, Object> filter) {
            this.filter = filter;
        }
        
        /**
         * 添加查询过滤条件
         * @param key 条件参数名
         * @param value 用户输入值
         */
        public void addFilter(String key, Object value) {
            if (filter == null) {
                filter = new HashMap<String, Object>();
            }
            if (key != null && value != null) {
                filter.put(key, value);
            }
        }
    	public int getPage() {
    		return page;
    	}
    	public void setPage(int page) {
    		this.page = page;
    	}
    	public int getPageSize() {
    		return pageSize;
    	}
    	public void setPageSize(int pageSize) {
    		this.pageSize = pageSize;
    	}
    	public String getSort() {
    		return sort;
    	}
    	public void setSort(String sort) {
    		this.sort = sort;
    	}
    	public String getOrder() {
    		return order;
    	}
    	public void setOrder(String order) {
    		this.order = order;
    	}
    	public int getTotal() {
    		return total;
    	}
    	public void setTotal(int total) {
    		this.total = total;
    	}
    	public void setFilter(Map<String, Object> filter) {
    		this.filter = filter;
    	}
    	public int getSkip() {
            if(skip==-1){
                return pageSize * (page - 1);
            }else{
                return skip;
            }
        }
        public void setSkip(int skip) {
            this.skip = skip;
        }
    }
    
package com.wte.core;

import java.util.ArrayList;
import java.util.List;

public class PageResp<T> {
	List<T> rows;  //结果集
    int total; //总行数
    public PageResp(List<T> records, int total) {
        super();
        this.rows = records;
        this.total = total;
    }
    public List<T> getRows() {
        return rows;
    }
    public void setRows(List<T> rows) {
        this.rows = rows;
    }
    public int getTotal() {
        return total;
    }
    public void setTotal(int total) {
        this.total = total;
    }
	public static <T> PageResp<T> fromSingleResult(T element) {
        List<T> records = new ArrayList<T>(1);
        records.add(element);
        PageResp<T> result = new PageResp<T>(records, 1);
        return result;
    }
}
  1. 拦截器代码
    package com.wte.core.interceptor;
    import java.sql.Connection;
    import java.sql.PreparedStatement;
    import java.sql.ResultSet;
    import java.sql.SQLException;
    import java.util.List;
    import java.util.Map;
    import java.util.Properties;
    import org.apache.ibatis.executor.parameter.ParameterHandler;
    import org.apache.ibatis.executor.statement.RoutingStatementHandler;
    import org.apache.ibatis.executor.statement.StatementHandler;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.plugin.Interceptor;
    import org.apache.ibatis.plugin.Intercepts;
    import org.apache.ibatis.plugin.Invocation;
    import org.apache.ibatis.plugin.Plugin;
    import org.apache.ibatis.plugin.Signature;
    import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
    import com.wte.core.PageReq;
    import com.wte.utils.ReflectHelper;
      
    /**  
     * 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。    
     */    
    @Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class,Integer.class})})  
    public class PageInterceptor implements Interceptor {  
        private String dialect; //数据库方言    
        private String pageSqlId; //mapper.xml中需要拦截的ID(正则匹配)
        @SuppressWarnings("unchecked")
        public Object intercept(Invocation invocation) throws Throwable {
            //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。  
            if(invocation.getTarget() instanceof RoutingStatementHandler){    
                RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();  
                StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");  
                BoundSql boundSql = delegate.getBoundSql();
                Object param = boundSql.getParameterObject();
                if (param instanceof Map) {
                	PageReq page =(PageReq)((Map<String,Object>)param).get("page");
                	if(page!=null){
                        //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性    
                        MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");  
                        //拦截到的prepare方法参数是一个Connection对象    
                        Connection connection = (Connection)invocation.getArgs()[0];  
                        //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句    
                        String sql = boundSql.getSql();  
                        
                        //给当前的page参数对象设置总记录数
                        this.setTotalRecord(page,param,mappedStatement, connection);
                        
                        //获取分页Sql语句
                        String pageSql = this.getPageSql(page, sql);  
                        //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
                        ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
                	}
                }
            }
            return invocation.proceed();  
        }  
    
         /**  
         * 给当前的参数对象page设置总记录数  
         * @param page Mapper映射语句对应的参数对象  
         * @param mappedStatement Mapper映射语句  
         * @param connection 当前的数据库连接  
         */    
        private void setTotalRecord(PageReq page,Object filter,
               MappedStatement mappedStatement, Connection connection) {    
           //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。    
           //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。    
           BoundSql boundSql = mappedStatement.getBoundSql(filter);
           //获取到我们自己写在Mapper映射语句中对应的Sql语句    
           String sql = boundSql.getSql();  
           //通过查询Sql语句获取到对应的计算总记录数的sql语句    
           String countSql = this.getCountSql(sql);  
           //通过BoundSql获取对应的参数映射    
           List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();  
           //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。    
           BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, filter);  
           //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象    
           ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, filter, countBoundSql);  
           //通过connection建立一个countSql对应的PreparedStatement对象。    
           PreparedStatement pstmt = null;  
           ResultSet rs = null;  
           try {    
               pstmt = connection.prepareStatement(countSql);  
               //通过parameterHandler给PreparedStatement对象设置参数    
               parameterHandler.setParameters(pstmt);  
               //之后就是执行获取总记录数的Sql语句和获取结果了。    
               rs = pstmt.executeQuery();  
               if (rs.next()) {    
                  int totalRecord = rs.getInt(1);  
                  //给当前的参数page对象设置总记录数    
                  page.setTotal(totalRecord);  
               }    
           } catch (SQLException e) {    
               e.printStackTrace(); 
           } finally {    
               try {    
                  if (rs != null)    
                      rs.close();  
                   if (pstmt != null)    
                      pstmt.close();  
               } catch (SQLException e) {    
                  e.printStackTrace();  
               }
           }
        }
    
        /**  
         * 根据原Sql语句获取对应的查询总记录数的Sql语句  
         * @param sql  
         * @return  
         */    
        private String getCountSql(String sql) {    
           int index = sql.indexOf("from");  
           return "select count(1) " + sql.substring(index);  
        }    
    
        /**  
         * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle  
         * 其它的数据库都 没有进行分页  
         *  
         * @param page 分页对象  
         * @param sql 原sql语句  
         * @return  
         */    
        private String getPageSql(PageReq page, String sql) {    
           StringBuffer sqlBuffer = new StringBuffer(sql);  
           if ("mysql".equalsIgnoreCase(dialect)) {    
        	   return getMysqlPageSql(page, sqlBuffer);  
           } else if ("oracle".equalsIgnoreCase(dialect)) {    
               return getOraclePageSql(page, sqlBuffer);  
           }    
           return sqlBuffer.toString();  
        }    
    
        /**  
        * 获取Mysql数据库的分页查询语句  
        * @param page 分页对象  
        * @param sqlBuffer 包含原sql语句的StringBuffer对象  
        * @return Mysql数据库分页语句  
        */    
       private String getMysqlPageSql(PageReq page, StringBuffer sqlBuffer) {    
          //计算第一条记录的位置,Mysql中记录的位置是从0开始的。    
          //System.out.println("page:"+page.getPage()+"-------"+page.getRows());
          int offset = (page.getPage() - 1) * page.getPageSize();  
          sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());  
          return sqlBuffer.toString();  
       }    
          
       /**  
        * 获取Oracle数据库的分页查询语句  
        * @param page 分页对象  
        * @param sqlBuffer 包含原sql语句的StringBuffer对象  
        * @return Oracle数据库的分页查询语句  
        */    
       private String getOraclePageSql(PageReq page, StringBuffer sqlBuffer) {    
          //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的    
          int offset = (page.getPage() - 1) * page.getPageSize() + 1;  
          sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());  
          sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);  
          //上面的Sql语句拼接之后大概是这个样子:    
          //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16    
          return sqlBuffer.toString();  
       }    
         
        /**  
         * 拦截器对应的封装原始对象的方法  
         */          
        public Object plugin(Object arg0) {    
            // TODO Auto-generated method stub    
            if (arg0 instanceof StatementHandler) {    
                return Plugin.wrap(arg0, this);  
            } else {    
                return arg0;  
            }   
        }    
        
        /**  
         * 设置注册拦截器时设定的属性  
         */   
        public void setProperties(Properties p) {  
        }  
      
        public String getDialect() {  
            return dialect;
        }  
      
        public void setDialect(String dialect) {  
            this.dialect = dialect;
        }  
      
        public String getPageSqlId() {  
            return pageSqlId;
        }  
      
        public void setPageSqlId(String pageSqlId) {  
            this.pageSqlId = pageSqlId;
        }  
    }  
    

  2. 其中使用到了一个帮助类,代码如下:

    package com.wte.utils;
    import java.lang.reflect.Field;  
    import org.apache.commons.lang3.reflect.FieldUtils; 
    
    public class ReflectHelper {  
        public static Object getFieldValue(Object obj , String fieldName ){  
            if(obj == null){  
                return null ;  
            }  
            Field targetField = getTargetField(obj.getClass(), fieldName);  
            try {  
                return FieldUtils.readField(targetField, obj, true ) ;  
            } catch (IllegalAccessException e) {  
                e.printStackTrace();  
            }   
            return null ;  
        }  
          
        public static Field getTargetField(Class<?> targetClass, String fieldName) {  
            Field field = null;  
            try {  
                if (targetClass == null) {  
                    return field;  
                }  
      
                if (Object.class.equals(targetClass)) {  
                    return field;  
                }  
      
                field = FieldUtils.getDeclaredField(targetClass, fieldName, true);  
                if (field == null) {  
                    field = getTargetField(targetClass.getSuperclass(), fieldName);  
                }  
            } catch (Exception e) {  
            }
            return field;  
        }  
    
        public static void setFieldValue(Object obj , String fieldName , Object value ){  
            if(null == obj){return;}  
            Field targetField = getTargetField(obj.getClass(), fieldName);    
            try {  
                 FieldUtils.writeField(targetField, obj, value) ;  
            } catch (IllegalAccessException e) {  
                e.printStackTrace();  
            }   
        }   
    }  
    

  3. 在我们的spring-mybatis.xml中注册plugins,在sqlSessionFactory中加入plugins的property:

    <!-- spring和MyBatis完美整合,不需要mybatis的配置映射文件 -->
    	<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
    		<property name="dataSource" ref="dataSource" />
    		<!-- 自动扫描mapping.xml文件 -->
            <property name="mapperLocations" value="classpath:/mapping/*.xml"></property>
            <!-- Mybatis分页拦截器 -->
            <property name="plugins">
    	        <array>
    	            <bean class="com.wte.core.interceptor.PageInterceptor">
    	                <property name="dialect" value="mysql" />
    	                <property name="pageSqlId" value="findPaging"/>  
    	            </bean>
    	        </array>
    	    </property>
            
            <!-- mybatis配置文件 -->
            <property name="configLocation" value="classpath:mybatis-config.xml"></property>
    	</bean>
    

业务代码

  1. 首先我们的Controller代码:

    	/**
    	 * 获取用户分页列表
    	 * @param pageReq
    	 * @param user
    	 * @param response
    	 * @return
    	 * @throws Exception
    	 */
    	@RequestMapping("/list")
        public @ResponseBody PageResp<UserBean> pageList(PageReq pageReq, UserBean user, HttpServletResponse response) throws Exception {
            Map<String, Object> filter= new HashMap<String, Object>();
            if (!StringUtils.isEmpty(user.getUserId())) {
            	filter.put("userId", StringUtil.formatLike(user.getUserId()));
            }
            if (!StringUtils.isEmpty(user.getUserName())) {
            	filter.put("userName", StringUtil.formatLike(user.getUserName()));
            }
            pageReq.setFilter(filter);
            PageResp<UserBean> page= userService.findPaging(pageReq);
            return page;
        }
    

  2. Service继承了我们的Service泛型的抽象类,抽象类中对findPaging定义代码如下:

    	/**
         * 分页及条件查询
         * @param pageReq 分页及查询过滤参数
         * @return 指定页的查询结果
         */
    	public PageResp<E> findPaging(PageReq pageReq) {
    		int pageNo = pageReq.getPage();
    		int pageSize = pageReq.getPageSize();
    		String sortCol = pageReq.getSort();
    		String sortDir = pageReq.getOrder();
    		
    		pageReq.addFilter("pageNo", pageNo);
    		pageReq.addFilter("pageSize", pageSize);
    		pageReq.addFilter("sortCol", sortCol);
    		pageReq.addFilter("sortDir", sortDir);
    		pageReq.addFilter("page",pageReq);
    		List<E> list= dao.findPaging(pageReq.getFilter());
    		PageResp<E> result=  new PageResp<E>(list,pageReq.getTotal());
    		return result;
    	}
    

    注:这里对于sortCol,sortDir应当去除特殊字符,因为我们的Mapper使用的是${},这里我没有做。

  3. dao当然只要定义findPaging的接口方法就了。那我们的Mapper中的Sql就只需要如下:

    	<sql id="Base_Column_List">
    		UserId,UserName,Password,PasswordTime,State,EmpId,SysUser,CreateUser,CreateTime,UpdateUser,UpdateTime
    	</sql>
    	<select id="findPaging" parameterType="Map" resultMap="userResult">
            select <include refid="Base_Column_List" />
            from bp_user
           <where>
                <if test="userId!=null and userId!='' ">
                     and userId like #{userId}
                </if>
                <if test="userName!=null and userName!='' ">
                     and userName like #{userName}
                </if>
            </where>
           	<if test="sortCol!=null and sortCol!='' ">
           		order by ${sortCol}
           		<if test="sortDir!=null and sortDir!='' ">
                 	${sortDir}
            	</if>
           	</if>
        </select>
    

    注:mybatis中的#{}与${}要加以区分。#{}适合填入值,${}适合放列明,但是${}不能避免sql注入,所以需要自己来处理order by传入的参数。

这样,我们的分类就实现了,当然,我们现在的分页中只支持单列排序,那么如果要修改成多列排序,知道原理,做适当修改就可以实现了。
Mybatis拦截器实现物理分页

参考地址:
1. 在mybatis执行SQL语句之前进行拦击处理:http://blog.csdn.net/hfmbook/article/details/41985853
2. 分享一个完整的Mybatis分页解决方案:http://blog.csdn.net/u013306940/article/details/51168359

发表评论

电子邮件地址不会被公开。 必填项已用*标注