springboot配置多数据源后mybatis拦截器失效的解决

目录

  • 1. 解析配置文件初始化数据源
  • 2. 定义数据源枚举类型
  • 3. TheadLocal保存数据源类型
  • 4. 自定义sqlSessionProxy
  • 5. 自定义路由
  • 6. 定义切面,dao层定义切面
  • 7. 最后在写库增加事务管理
  • 8. 在配置文件中增加数据源配置
配置文件是通过springcloudconfig远程分布式配置。采用阿里Druid数据源。并支持一主多从的读写分离。分页组件通过拦截器拦截带有page后缀的方法名,动态的设置total总数。

1. 解析配置文件初始化数据源
@Configurationpublic class DataSourceConfiguration {/*** 数据源类型*/@Value("${spring.datasource.type}")private Class dataSourceType; /*** 主数据源配置** @return*/@Bean(name = "masterDataSource", destroyMethod = "close")@Primary@ConfigurationProperties(prefix = "spring.datasource")public DataSource masterDataSource() {DataSource source = DataSourceBuilder.create().type(dataSourceType).build(); return source; }/*** 从数据源配置** @return*/@Bean(name = "slaveDataSource0")@ConfigurationProperties(prefix = "spring.slave0")public DataSource slaveDataSource0() {DataSource source = DataSourceBuilder.create().type(dataSourceType).build(); return source; }/*** 从数据源集合** @return*/@Bean(name = "slaveDataSources")public List slaveDataSources() {List slaveDataSources = new ArrayList(); slaveDataSources.add(slaveDataSource0()); return slaveDataSources; }}


2. 定义数据源枚举类型
public enum DataSourceType {master("master", "master"), slave("slave", "slave"); private String type; private String name; DataSourceType(String type, String name) {this.type = type; this.name = name; }public String getType() {return type; }public void setType(String type) {this.type = type; }public String getName() {return name; }public void setName(String name) {this.name = name; }}


3. TheadLocal保存数据源类型
public class DataSourceContextHolder {private static final ThreadLocal local = new ThreadLocal(); public static ThreadLocal getLocal() {return local; }public static void slave() {local.set(DataSourceType.slave.getType()); }public static void master() {local.set(DataSourceType.master.getType()); }public static String getJdbcType() {return local.get(); }public static void clearDataSource(){local.remove(); }}


4. 自定义sqlSessionProxy 并将数据源填充到DataSourceRoute
@Configuration@ConditionalOnClass({EnableTransactionManagement.class})@Import({DataSourceConfiguration.class})public class DataSourceSqlSessionFactory {private Logger logger = Logger.getLogger(DataSourceSqlSessionFactory.class); @Value("${spring.datasource.type}")private Class dataSourceType; @Value("${mybatis.mapper-locations}")private String mapperLocations; @Value("${mybatis.type-aliases-package}")private String aliasesPackage; @Value("${slave.datasource.number}")private int dataSourceNumber; @Resource(name = "masterDataSource")private DataSource masterDataSource; @Resource(name = "slaveDataSources")private List slaveDataSources; @Bean@ConditionalOnMissingBeanpublic SqlSessionFactory sqlSessionFactory() throws Exception {logger.info("======================= init sqlSessionFactory"); SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean(); sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy()); PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations)); sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage); sqlSessionFactoryBean.getObject().getConfiguration().setMapUnderscoreToCamelCase(true); return sqlSessionFactoryBean.getObject(); }@Bean(name = "roundRobinDataSourceProxy")public AbstractRoutingDataSource roundRobinDataSourceProxy() {logger.info("======================= init robinDataSourceProxy"); DataSourceRoute proxy = new DataSourceRoute(dataSourceNumber); Map targetDataSources = new HashMap(); targetDataSources.put(DataSourceType.master.getType(), masterDataSource); if(null != slaveDataSources) {for(int i=0; i
【springboot配置多数据源后mybatis拦截器失效的解决】
5. 自定义路由
public class DataSourceRoute extends AbstractRoutingDataSource {private Logger logger = Logger.getLogger(DataSourceRoute.class); private final int dataSourceNumber; public DataSourceRoute(int dataSourceNumber) {this.dataSourceNumber = dataSourceNumber; }@Overrideprotected Object determineCurrentLookupKey() {String typeKey = DataSourceContextHolder.getJdbcType(); logger.info("==================== swtich dataSource:" + typeKey); if (typeKey.equals(DataSourceType.master.getType())) {return DataSourceType.master.getType(); }else{//从数据源随机分配Random random = new Random(); int slaveDsIndex = random.nextInt(dataSourceNumber); return slaveDsIndex; }}}


6. 定义切面,dao层定义切面
@Aspect@Componentpublic class DataSourceAop {private Logger logger = Logger.getLogger(DataSourceAop.class); @Before("execution(* com.dbq.iot.mapper..*.get*(..)) || execution(* com.dbq.iot.mapper..*.isExist*(..)) " +"|| execution(* com.dbq.iot.mapper..*.select*(..)) || execution(* com.dbq.iot.mapper..*.count*(..)) " +"|| execution(* com.dbq.iot.mapper..*.list*(..)) || execution(* com.dbq.iot.mapper..*.query*(..))" +"|| execution(* com.dbq.iot.mapper..*.find*(..))|| execution(* com.dbq.iot.mapper..*.search*(..))")public void setSlaveDataSourceType(JoinPoint joinPoint) {DataSourceContextHolder.slave(); logger.info("=========slave, method:" + joinPoint.getSignature().getName()); }@Before("execution(* com.dbq.iot.mapper..*.add*(..)) || execution(* com.dbq.iot.mapper..*.del*(..))" +"||execution(* com.dbq.iot.mapper..*.upDate*(..)) || execution(* com.dbq.iot.mapper..*.insert*(..))" +"||execution(* com.dbq.iot.mapper..*.create*(..)) || execution(* com.dbq.iot.mapper..*.update*(..))" +"||execution(* com.dbq.iot.mapper..*.delete*(..)) || execution(* com.dbq.iot.mapper..*.remove*(..))" +"||execution(* com.dbq.iot.mapper..*.save*(..)) || execution(* com.dbq.iot.mapper..*.relieve*(..))" +"|| execution(* com.dbq.iot.mapper..*.edit*(..))")public void setMasterDataSourceType(JoinPoint joinPoint) {DataSourceContextHolder.master(); logger.info("=========master, method:" + joinPoint.getSignature().getName()); }}


7. 最后在写库增加事务管理
@Configuration@Import({DataSourceConfiguration.class})public class DataSouceTranscation extends DataSourceTransactionManagerAutoConfiguration {private Logger logger = Logger.getLogger(DataSouceTranscation.class); @Resource(name = "masterDataSource")private DataSource masterDataSource; /*** 配置事务管理器** @return*/@Bean(name = "transactionManager")public DataSourceTransactionManager transactionManagers() {logger.info("===================== init transactionManager"); return new DataSourceTransactionManager(masterDataSource); }}


8. 在配置文件中增加数据源配置
spring.datasource.name=writedbspring.datasource.url=jdbc:mysql://192.168.0.1/master?useUnicode=true& characterEncoding=utf8& autoReconnect=true& failOverReadOnly=falsespring.datasource.username=rootspring.datasource.password=1234spring.datasource.type=com.alibaba.druid.pool.DruidDataSourcespring.datasource.driver-class-name=com.mysql.jdbc.Driverspring.datasource.filters=statspring.datasource.initialSize=20spring.datasource.minIdle=20spring.datasource.maxActive=200spring.datasource.maxWait=60000#从库的数量slave.datasource.number=1spring.slave0.name=readdbspring.slave0.url=jdbc:mysql://192.168.0.2/slave?useUnicode=true& characterEncoding=utf8& autoReconnect=true& failOverReadOnly=falsespring.slave0.username=rootspring.slave0.password=1234spring.slave0.type=com.alibaba.druid.pool.DruidDataSourcespring.slave0.driver-class-name=com.mysql.jdbc.Driverspring.slave0.filters=statspring.slave0.initialSize=20spring.slave0.minIdle=20spring.slave0.maxActive=200spring.slave0.maxWait=60000

这样就实现了在springcloud框架下的读写分离,并且支持多个从库的负载均衡(简单的通过随机分配,也有网友通过算法实现平均分配,具体做法是通过一个线程安全的自增长Integer类型,取余实现。个人觉得没大必要。如果有大神有更好的方法可以一起探讨。)
Mabatis分页配置可通过dao层的拦截器对特定方法进行拦截,拦截后添加自己的逻辑代码,比如计算total等,具体代码如下(参考了网友的代码,主要是通过@Intercepts注解):
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})public class PageInterceptor implements Interceptor {private static final Log logger = LogFactory.getLog(PageInterceptor.class); private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory(); private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory(); private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory(); private static String defaultDialect = "mysql"; // 数据库类型(默认为mysql)private static String defaultPageSqlId = ".*Page$"; // 需要拦截的ID(正则匹配)private String dialect = ""; // 数据库类型(默认为mysql)private String pageSqlId = ""; // 需要拦截的ID(正则匹配)@Overridepublic Object intercept(Invocation invocation) throws Throwable {StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); MetaObject metaStatementHandler = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY); // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类)while (metaStatementHandler.hasGetter("h")) {Object object = metaStatementHandler.getValue("h"); metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY); }// 分离最后一个代理对象的目标类while (metaStatementHandler.hasGetter("target")) {Object object = metaStatementHandler.getValue("target"); metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY); }Configuration configuration = (Configuration) metaStatementHandler.getValue("delegate.configuration"); if (null == dialect || "".equals(dialect)) {logger.warn("Property dialect is not setted,use default 'mysql' "); dialect = defaultDialect; }if (null == pageSqlId || "".equals(pageSqlId)) {logger.warn("Property pageSqlId is not setted,use default '.*Page$' "); pageSqlId = defaultPageSqlId; }MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement"); // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的MappedStatement的sqlif (mappedStatement.getId().matches(pageSqlId)) {BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql"); Object parameterObject = boundSql.getParameterObject(); if (parameterObject == null) {throw new NullPointerException("parameterObject is null!"); } else {PageParameter page = (PageParameter) metaStatementHandler.getValue("delegate.boundSql.parameterObject.page"); String sql = boundSql.getSql(); // 重写sqlString pageSql = buildPageSql(sql, page); metaStatementHandler.setValue("delegate.boundSql.sql", pageSql); metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET); metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT); Connection connection = (Connection) invocation.getArgs()[0]; // 重设分页参数里的总页数等setPageParameter(sql, connection, mappedStatement, boundSql, page); }}// 将执行权交给下一个拦截器return invocation.proceed(); }/*** @param sql* @param connection* @param mappedStatement* @param boundSql* @param page*/private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,BoundSql boundSql, PageParameter page) {// 记录总记录数String countSql = "select count(0) from (" + sql + ") as total"; PreparedStatement countStmt = null; ResultSet rs = null; try {countStmt = connection.prepareStatement(countSql); BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,boundSql.getParameterMappings(), boundSql.getParameterObject()); Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters"); if (metaParamsField != null) {try {MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters"); ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo); } catch (SecurityException | NoSuchFieldException | IllegalArgumentException| IllegalAccessException e) {// TODO Auto-generated catch blocklogger.error("Ignore this exception", e); }}Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters"); if (additionalField != null) {try {Map map = (Map) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters"); ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map); } catch (SecurityException | NoSuchFieldException | IllegalArgumentException| IllegalAccessException e) {// TODO Auto-generated catch blocklogger.error("Ignore this exception", e); }}setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject()); rs = countStmt.executeQuery(); int totalCount = 0; if (rs.next()) {totalCount = rs.getInt(1); }page.setTotalCount(totalCount); int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1); page.setTotalPage(totalPage); } catch (SQLException e) {logger.error("Ignore this exception", e); } finally {try {if (rs != null){rs.close(); }} catch (SQLException e) {logger.error("Ignore this exception", e); }try {if (countStmt != null){countStmt.close(); }} catch (SQLException e) {logger.error("Ignore this exception", e); }}}/*** 对SQL参数(?)设值** @param ps* @param mappedStatement* @param boundSql* @param parameterObject* @throws SQLException*/private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,Object parameterObject) throws SQLException {ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql); parameterHandler.setParameters(ps); }/*** 根据数据库类型,生成特定的分页sql** @param sql* @param page* @return*/private String buildPageSql(String sql, PageParameter page) {if (page != null) {StringBuilder pageSql = new StringBuilder(); pageSql = buildPageSqlForMysql(sql,page); return pageSql.toString(); } else {return sql; }}/*** mysql的分页语句** @param sql* @param page* @return String*/public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {StringBuilder pageSql = new StringBuilder(100); String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize()); pageSql.append(sql); pageSql.append(" limit " + beginrow + "," + page.getPageSize()); return pageSql; }@Overridepublic Object plugin(Object target) {if (target instanceof StatementHandler) {return Plugin.wrap(target, this); } else {return target; }}@Overridepublic void setProperties(Properties properties) {}}

这里碰到一个比较有趣的问题,就是sql如果是foreach参数,在拦截后无法注入。需要加入以下代码才可以(有得资料上只提到重置metaParameters)。
Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters"); if (metaParamsField != null) {try {MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters"); ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo); } catch (SecurityException | NoSuchFieldException | IllegalArgumentException| IllegalAccessException e) {// TODO Auto-generated catch blocklogger.error("Ignore this exception", e); }}Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters"); if (additionalField != null) {try {Map map = (Map) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters"); ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map); } catch (SecurityException | NoSuchFieldException | IllegalArgumentException| IllegalAccessException e) {// TODO Auto-generated catch blocklogger.error("Ignore this exception", e); }}

读写分离倒是写好了,但是发现增加了mysql一主多从的读写分离后,此分页拦截器直接失效。
最后分析原因是因为,我们在做主从分离时,自定义了SqlSessionFactory,导致此拦截器没有注入。
在上面第4步中,DataSourceSqlSessionFactory中注入拦截器即可,具体代码如下
通过注解引入拦截器类:
@Import({DataSourceConfiguration.class,PageInterceptor.class})

注入拦截器
@Autowiredprivate PageInterceptor pageInterceptor;

SqlSessionFactoryBean中设置拦截器
sqlSessionFactoryBean.setPlugins(newInterceptor[]{pageInterceptor});

这里碰到一个坑,就是设置plugins时必须在sqlSessionFactoryBean.getObject()之前。
SqlSessionFactory在生成的时候就会获取plugins,并设置到Configuration中,如果在之后设置则不会注入。
可跟踪源码看到:
sqlSessionFactoryBean.getObject()

public SqlSessionFactory getObject() throws Exception {if (this.sqlSessionFactory == null) {afterPropertiesSet(); }return this.sqlSessionFactory; }

public void afterPropertiesSet() throws Exception {notNull(dataSource, "Property 'dataSource' is required"); notNull(sqlSessionFactoryBuilder, "Property 'sqlSessionFactoryBuilder' is required"); state((configuration == null && configLocation == null) || !(configuration != null && configLocation != null),"Property 'configuration' and 'configLocation' can not specified with together"); this.sqlSessionFactory = buildSqlSessionFactory(); }

buildSqlSessionFactory()
if (!isEmpty(this.plugins)) {for (Interceptor plugin : this.plugins) {configuration.addInterceptor(plugin); if (LOGGER.isDebugEnabled()) {LOGGER.debug("Registered plugin: '" + plugin + "'"); }}}

最后贴上正确的配置代码(DataSourceSqlSessionFactory代码片段)
@Bean@ConditionalOnMissingBeanpublic SqlSessionFactory sqlSessionFactory() throws Exception {logger.info("======================= init sqlSessionFactory"); SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean(); sqlSessionFactoryBean.setPlugins(new Interceptor[]{pageInterceptor}); sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy()); PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations)); sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage); SqlSessionFactory sqlSessionFactory = sqlSessionFactoryBean.getObject(); sqlSessionFactory.getConfiguration().setMapUnderscoreToCamelCase(true); return sqlSessionFactory; }

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    推荐阅读