spring覆盖容器bean的注解实现@overridebean
项目开发中,有时第三方框架会自动注入bean到spring容器中,当我们有修改对应内置bean实现的需求时,可以采用偷梁换柱的方式来重写内置的bean,使用这种方式需要注意以下两点:
- 1、对应的bean在其他地方使用时,是基于接口注入的。
- 2、如果不是基于接口注入的bean,你可能需要同包名同类名的这种方式重写(可能会有问题,不推荐)。
从以上2点我们还可以得出一个结论,那就是“基于接口编程”的好处。
具体实现参考一下代码
(代码片段,仅供参考,根据实际使用场景修改后使用):
import java.lang.annotation.elementtype;
import java.lang.annotation.retention;
import java.lang.annotation.retentionpolicy;
import java.lang.annotation.target;
/**
* 覆盖spring容器中的bean
*
* @author shanhy
* @date 2021/4/25 13:40
*/
@retention(retentionpolicy.runtime)
@target(elementtype.type)
public @interface overridebean {
/**
* 需要替换的 bean 的名称
*
* @return
*/
string value();
}
import org.slf4j.logger;
import org.slf4j.loggerfactory;
import org.springframework.beans.beansexception;
import org.springframework.beans.factory.beanfactory;
import org.springframework.beans.factory.beanfactoryaware;
import org.springframework.beans.factory.config.beandefinitionholder;
import org.springframework.beans.factory.config.configurablelistablebeanfactory;
import org.springframework.beans.factory.support.beandefinitionregistry;
import org.springframework.beans.factory.support.beandefinitionregistrypostprocessor;
import org.springframework.beans.factory.support.genericbeandefinition;
import org.springframework.boot.autoconfigure.autoconfigurationpackages;
import org.springframework.context.annotation.classpathbeandefinitionscanner;
import org.springframework.context.annotation.configuration;
import org.springframework.core.type.filter.annotationtypefilter;
import org.springframework.util.stringutils;
import java.util.arraylist;
import java.util.list;
import java.util.objects;
import java.util.set;
/**
* 重写bean的配置类
*
* @author shanhy
* @date 2021/4/25 13:41
*/
@configuration
public class overridebeanconfiguration implements beandefinitionregistrypostprocessor, beanfactoryaware {
private static final logger log = loggerfactory.getlogger(overridebeanconfiguration.class);
private beanfactory beanfactory;
@override
public void postprocessbeandefinitionregistry(beandefinitionregistry registry) throws beansexception {
log.debug("searching for classes annotated with @overridebean");
// 自定义 scanner 扫描 classpath 下的指定注解
classpathoverridebeanannotationscanner scanner = new classpathoverridebeanannotationscanner(registry);
try {
// 获取包路径
list<string> packages = autoconfigurationpackages.get(this.beanfactory);
if (log.isdebugenabled()) {
for (string p : packages) {
log.debug("using auto-configuration base package: {}", p);
}
}
// 扫描所有加载的包
scanner.doscan(stringutils.tostringarray(packages));
} catch (illegalstateexception ex) {
log.debug("could not determine auto-configuration package, automatic overridebean scanning disabled.", ex);
}
}
@override
public void postprocessbeanfactory(configurablelistablebeanfactory factory) throws beansexception {
}
@override
public void setbeanfactory(beanfactory beanfactory) throws beansexception {
this.beanfactory = beanfactory;
}
private static class classpathoverridebeanannotationscanner extends classpathbeandefinitionscanner {
classpathoverridebeanannotationscanner(beandefinitionregistry registry) {
super(registry, false);
// 设置过滤器。仅扫描 @overridebean
addincludefilter(new annotationtypefilter(overridebean.class));
}
@override
public set<beandefinitionholder> doscan(string... basepackages) {
list<string> overrideclassnames = new arraylist<>();
// 扫描全部 package 下 annotationclass 指定的 bean
set<beandefinitionholder> beandefinitions = super.doscan(basepackages);
genericbeandefinition definition;
for (beandefinitionholder holder : beandefinitions) {
definition = (genericbeandefinition) holder.getbeandefinition();
// 获取类名,并创建 class 对象
string classname = definition.getbeanclassname();
class<?> clazz = classnametoclass(classname);
// 解析注解上的 value
overridebean annotation = objects.requirenonnull(clazz).getannotation(overridebean.class);
if (annotation == null || annotation.value().length() == 0) {
continue;
}
// 使用当前加载的 @overridebean 指定的 bean 替换 value 里指定名称的 bean
if (objects.requirenonnull(getregistry()).containsbeandefinition(annotation.value())) {
getregistry().removebeandefinition(annotation.value());
getregistry().registerbeandefinition(annotation.value(), definition);
overrideclassnames.add(clazz.getname());
}
}
log.info("found override beans: " + overrideclassnames);
return beandefinitions;
}
// 反射通过 class 名称获取 class 对象
private class<?> classnametoclass(string classname) {
try {
return class.forname(classname);
} catch (classnotfoundexception e) {
log.error("create instance failed.", e);
}
return null;
}
}
}
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持代码网。
发表评论