. 背景 本人从事后端刚满一年,主要从事java和python工作,准备在最近一年内多学习各框架的源码,实现技术突破,为了加深印象,本文将通过代码+图文+文字说明的形式来写一个自定义的spring核心控制器dispatcherServlet以及相关组件,以供后续学习。
2. 项目准备 1) 工具
idea2018, jdk1.8,tomcat 8.5
2) spring 核心原理
spring 通过容器并使用工厂模式来创建实例Bean,即控制反转,然后通过将被依赖的类装配到指定的类里实现调用,即我们常用的依赖注入, 通过扫描包的方式来管理所有需要的注解和类,Spring在容器启动的时候初始化好所有的Bean以及相关信息并存储起来,等请求的时候再获取出来,通过反射动态加载的形式调用Controller里对应的方法。
3. 知识储备 ①java注解相关知识 ② java反射相关知识 ③servlet相关知识 ④spring相关知识
4. 搭建项目 准备好相关工具后,我就可以开始搭建项目,coding....,完整步骤如下:
1) 建立应用
文章图片
2) 导入依赖
我在这里没有使用maven来管理依赖,如果没有用maven的话,那么就按照如下方式添加tomcat依赖:
文章图片
3) 项目目录结构
项目的目录结构截图如下, src作为source_root,即根目录, 如下图
文章图片
如果src不是根目录,那么可以设置:
4) 配置tomcat
指定名称
文章图片
指定访问的context
文章图片
指定发布的包:配置好上述的全部信息就可以正常启动服务器了!启动成功后,会自动调用 http://localhost:8080/spring/ 。
5) 建立项目遇到的问题以及解决方法
①配置好服务器后启动tomcat报错:not found for the web module
解决方法:添加web包到Facets里,在Aitifacts里添加发布到的war包 第一步,选择projet structures。
文章图片
第二步,选择facets里的web,添加本项目,选择ok即可!
文章图片
第三步,添加Artifacts, 选择本项目ok即可。
第四步,选择刚配置好的artifacts添加到 edit configuration的deployment里:
文章图片
重新启动服务器即可解决上述问题!
②启动报错: Cannot start compilation: the output path is not specified for module "spring-study".Specify the output path in the Project Structure dialog.
解决方法: 在当前项目目录下新建一个out目录,指定out的路径即可,如下图:
文章图片
配置好后,重新启动即可解决问题!
5. 自定义配置 1) 配置servlet
如果上述的东西都准备好了,那么在web.xml文件中定义标签时不会出现红色,类所在的路径一定要正确, 同时定义标签,指定配置文件所在的路径。
spring web applicationDispatcherServlet
com.springframework.demo.servlet.MyDispatcherServlet
contextConfigLocationclasspath*:application.properties
DispatcherServlet
/*
复制代码
2) 自定义配置文件application.properties
application.properties配置文件内容为:
scan-package=com.springframework.demo
复制代码
6. 自定义Spring相关注解
注解名称 |
功能描述 |
作用范围 |
MyAutoWire |
装配bean |
属性FIELD上 |
MyController |
请求控制器 |
类TYPE上 |
MyRequestMapping |
路由请求 |
类TYPE或者方法METHOD上 |
MyRequestParam |
请求参数 |
参数PARAMETER上 |
MyService |
标记处理业务逻辑的类 |
类TYPE上 |
MyAutoWire
package com.springframework.demo.mydefine;
import java.lang.annotation.*;
/**
* @author bingbing
* @date 2020/12/23 0023 15:01
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutoWire {
String value() default "";
}复制代码
MyController
package com.springframework.demo.mydefine;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author bingbing
* @date 2020/12/23 0023 15:02
*/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyController {
String value() default "";
}复制代码
MyRequestMapping
package com.springframework.demo.mydefine;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author bingbing
* @date 2020/12/23 0023 15:08
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyRequestMapping {
String value() default "";
}复制代码
MyRequestParam
package com.springframework.demo.mydefine;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author bingbing
* @date 2020/12/23 0023 15:09
*/
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyRequestParam {
String value() default "";
}复制代码
MyService
package com.springframework.demo.mydefine;
import java.lang.annotation.*;
/**
* @author bingbing
* @date 2020/12/23 0023 15:03
*/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {String value() default "";
}复制代码
ActionController
package com.springframework.demo;
import com.springframework.demo.mydefine.MyAutoWire;
import com.springframework.demo.mydefine.MyController;
import com.springframework.demo.mydefine.MyRequestMapping;
import com.springframework.demo.mydefine.MyRequestParam;
import com.springframework.demo.service.impl.DemoServiceImpl;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* @author bingbing
* @date 2020/12/23 0023 15:10
*/@MyController
@MyRequestMapping("/action")
public class ActionController {@MyAutoWire
private DemoServiceImpl demoService;
@MyRequestMapping("/query")
public void querySomeThing(
HttpServletRequest request,
HttpServletResponse response,
@MyRequestParam("id") Integer id,
@MyRequestParam("username") String username) throws IOException {
String result = demoService.read(id, username);
response.getWriter().println(result);
}
}复制代码
业务逻辑接口
package com.springframework.demo.service;
/**
* @author bingbing
* @date 2020/12/23 0023 15:19
*/
public interface IService {String read(Integer id, String name);
}复制代码
业务逻辑类
package com.springframework.demo.service.impl;
import com.springframework.demo.mydefine.MyService;
import com.springframework.demo.service.IService;
/**
* @author bingbing
* @date 2020/12/23 0023 15:18
*/
@MyService
public class DemoServiceImpl implements IService {
@Override
public String read(Integer id, String name) {
System.out.println("id=" + id + ",username=" + name);
String str = "coding is a good habit!";
System.out.println(str);
return str;
}
}复制代码
7. 核心代码详解 tomcat容器启动时,会调用我们重写的init(ServletConfig config)方法,该方法为GenericServlet抽象类里定义的方法,我们可以在此方法里实现初始化spring容器。
public void init(ServletConfig config) throws ServletException {
this.config = config;
this.init();
}public void init() throws ServletException {
}复制代码
1) 扫描包下所有class
通过扫描com.springframework.demo包,我们可以得到所有类和接口对应的class文件,将这所有的Class文件对应的全限定名(报名.类名)暂存到一个Map里面
//解析.class文件, 通过注解来识别
private void doScanner(String scanPackage) {
System.out.println("开始扫描包!" + mapping);
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File classDir = new File(url.getFile());
for (File file : classDir.listFiles()) {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName());
} else if (!file.getName().endsWith(".class")) {
continue;
}
String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
if (!file.isDirectory()) {
mapping.put(clazzName, null);
}
}
System.out.println("扫描完毕!" + mapping);
}复制代码
2) 遍历所有class, 按照类型来分类处理
mapping.put(url, method);
mapping.put(beanName, instance);
3) 对Controller下装配的Bean进行强制赋予访问权限
如被autowire注解标记的Bean。
Class clazz = obejct.getClass();
if (clazz.isAnnotationPresent(MyController.class)) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(MyAutoWire.class)) {
continue;
}
MyAutoWire autoWire = field.getAnnotation(MyAutoWire.class);
String beanName = autoWire.value();
if ("".equals(beanName)) {
beanName = field.getType().getName();
}
//授予权限
field.setAccessible(true);
field.set(mapping.get(clazz.getName()), mapping.get(beanName));
}
}
复制代码
field.set()方法设置的2个object:
文章图片
4) 分析Spring容器里的mapping
通过上述代码我们可以发现,spring容器通过hashmap存储controller、service、url请求等组件信息,初始化完毕后 ,得到的mapping里包含了10个键值对,分别对应的功能如下:
键|路径 |
值|对象 |
描述 |
com.springframework.demo.service.IService |
com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b |
装配接口Iservice以及实现类信息 |
com.springframework.demo.service.impl.DemoServiceImpl |
com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b |
装配实现类对象信息 |
com.springframework.demo.mydefine.MyController |
null |
装配controller注解组件 |
/action/query |
com.springframework.demo.ActionController.querySomeThing(javax.servlet.http.HttpServletRequest,javax.servlet.http.HttpServletResponse,java.lang.Integer,java.lang.String) throws java.io.IOException |
装配url, 以及url请求所对应的method信息 |
com.springframework.demo.ActionController |
com.springframework.demo.ActionController@39613cd0} |
处理请求的类信息 |
map:{
com.springframework.demo.service.IService=com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b, com.springframework.demo.service.impl.DemoServiceImpl=com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b,
com.springframework.demo.mydefine.MyController=null, com.springframework.demo.mydefine.MyService=null,
/action/query=public void com.springframework.demo.ActionController.querySomeThing(javax.servlet.http.HttpServletRequest,javax.servlet.http.HttpServletResponse,java.lang.Integer,java.lang.String) throws java.io.IOException,
com.springframework.demo.mydefine.MyAutoWire=null, com.springframework.demo.servlet.MyDispatcherServlet=null, com.springframework.demo.mydefine.MyRequestParam=null, com.springframework.demo.mydefine.MyRequestMapping=null, com.springframework.demo.ActionController=com.springframework.demo.ActionController@39613cd0}复制代码
5)处理请求
访问 http://localhost:8080/spring/action/query?id=1&username=bingbing
将get请求转到doPost()上,从mapping里获取到url对应的method, 通过method反射执行ActionController里的 querySomeThing( HttpServletRequest request, HttpServletResponse response, @MyRequestParam("id") Integer id, @MyRequestParam("username") String username)
方法, 执行完毕后,使用response给出浏览器响应即可!
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
System.out.println("请求成功!");
this.doPost(req, resp);
}@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws RuntimeException {
//处理并且分发请求
try {
doDispatcher(req, resp);
} catch (IOException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
System.out.println("InvocationTargetException");
} catch (IllegalAccessException e) {
System.out.println("非法访问!");
}
}private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
System.out.println("url:" + url);
if (!this.mapping.containsKey(url)) {
resp.getWriter().println("404 NOT FOUND");
}
//调用url里对应的方法
Method method = (Method) mapping.get(url);
Map parameterMap = req.getParameterMap();
method.invoke(this.mapping.get(method.getDeclaringClass().getName()), new Object[]{req, resp, new Integer(parameterMap.get("id")[0]), parameterMap.get("username")[0]});
}
复制代码
被执行的方法:
@MyRequestMapping("/query")
public void querySomeThing(
HttpServletRequest request,
HttpServletResponse response,
@MyRequestParam("id") Integer id,
@MyRequestParam("username") String username) throws IOException {
String result = demoService.read(id, username);
response.getWriter().println(result);
}复制代码
6) 中心控制器MyDispatcherServlet 完整代码
package com.springframework.demo.servlet;
import com.springframework.demo.mydefine.MyAutoWire;
import com.springframework.demo.mydefine.MyController;
import com.springframework.demo.mydefine.MyRequestMapping;
import com.springframework.demo.mydefine.MyService;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.io.InputStream;
/**
* 自定义DispatchServlet
*
* @author Administrator
*/
public class MyDispatcherServlet extends HttpServlet {/**
* 用来存放bean或请求的信息
*/
private Map mapping = new HashMap<>();
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
System.out.println("请求成功!");
this.doPost(req, resp);
}@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws RuntimeException {
//处理并且分发请求
try {
doDispatcher(req, resp);
} catch (IOException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
e.printStackTrace();
System.out.println("InvocationTargetException");
} catch (IllegalAccessException e) {
System.out.println("非法访问!");
}
}//动态加载参数, 可按照参数的顺序来绑定
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
System.out.println("url:" + url);
if (!this.mapping.containsKey(url)) {
resp.getWriter().println("404 NOT FOUND");
}
//调用url里对应的方法
Method method = (Method) mapping.get(url);
Map parameterMap = req.getParameterMap();
Object[] params = new Object[parameterMap.size() + 2];
params[0] = req;
params[1] = resp;
//获取方法的参数类型
Class>[] cls = method.getParameterTypes();
int index = 2;
//index=2cls=2 ,index=3 ,cls=3
for (Map.Entry s : parameterMap.entrySet()) {
if (cls[index].getName() == "java.lang.Integer") {
params[index] = new Integer(String.valueOf(s.getValue()[0]));
} else if (cls[index].getName() == "java.lang.String") {
params[index] = String.valueOf(s.getValue()[0]);
}
index++;
}
method.invoke(mapping.get(method.getDeclaringClass().getName()), params);
}@Override
public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException {
super.service(req, res);
}//扫描所有的class
private void doScanner(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File rootDir = new File(url.getFile());
for (File file : rootDir.listFiles()) {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName());
} else if (!file.getName().endsWith(".class")) {
continue;
}if (!file.isDirectory()) {
String className = scanPackage + "." + file.getName().replaceAll(".class", "");
mapping.put(className, null);
}
}
}@Override
public void init(ServletConfig config) throws ServletException {
System.out.println("开始初始化容器....");
InputStream is = null;
// 1. 通过classLoader方法getResourceAsStream()获取到配置文件对象
try {
Properties configText = new Properties();
String configName = config.getInitParameter("contextConfigLocation");
configName = configName.substring(configName.indexOf(":") + 1);
is = this.getClass().getClassLoader().getResourceAsStream(configName);
configText.load(is);
String packageName = configText.getProperty("scan-package");
// 2. 扫描包,装配到mapping里, key 为 class的全限定名,value为null
doScanner(packageName);
// 3. 遍历map的key,设置所有的controller和service。
for (String clazzName : mapping.keySet()) {
if (!clazzName.contains(".")) {
continue;
}
Class> clazz = Class.forName(clazzName);
// 判断Class被哪个注解标记
if (clazz.isAnnotationPresent(MyController.class)) {
// 被controller注解标记的类
mapping.put(clazzName, clazz.newInstance());
// 如果被requestMapping注解标记
String baseurl = "";
if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
baseurl = requestMapping.value();
}
// 获取controller类下的所有methods
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (!method.isAnnotationPresent(MyRequestMapping.class)) {
continue;
}
// 获取RequestMapping 对象
MyRequestMapping requestMapping = method.getAnnotation(MyRequestMapping.class);
String url = baseurl + requestMapping.value();
// 解释了url不能重复的原因
mapping.put(url, method);
}} else if (clazz.isAnnotationPresent(MyService.class)) {
// 被MyService注解标记的类, 装配接口实现类和clazz,对map的key重新赋值
MyService myService = clazz.getAnnotation(MyService.class);
String beanName = myService.value();
if ("".equals(beanName)) {
beanName = clazz.getName();
}
Object obj = clazz.newInstance();
mapping.put(beanName, obj);
// 重新装配 Service类下的所有接口
for (Class> cls : clazz.getInterfaces()) {
mapping.put(cls.getName(), obj);
}
} else {
continue;
}}
// 对类下的bean 进行授权可访问
for (Object obj : mapping.values()) {
if (obj == null) {
continue;
}
Class> clz = obj.getClass();
Field[] fields = clz.getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(MyAutoWire.class)) {
continue;
}
MyAutoWire myAutoWire = field.getAnnotation(MyAutoWire.class);
String beanName = myAutoWire.value();
if ("".equals(beanName)) {
// bean名为属性的类型的全限定名
beanName = field.getType().getName();
}
field.setAccessible(true);
field.set(mapping.get(clz.getName()), mapping.get(beanName));
}
}
System.out.println("扫描完毕!");
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
} finally {
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}复制代码
8. 运行效果展示 访问 http://localhost:8080/spring/action/query?id=1&username=bingbing , 需要带上id和username参数,出现如下效果,表示spring的核心原理精简版就实现了。
文章图片
9. 代码优化 1)method在使用反射执行的时候,不能动态的绑定方法参数以及参数类型。
我们可以从上述doDispatcher()方法内,使用invoke()方法执行的时候 method.invoke(this.mapping.get(method.getDeclaringClass().getName()), new Object[]{req, resp, new Integer(parameterMap.get("id")[0]), parameterMap.get("username")[0]});
需要指定参数id和username,此种传参方式传参就是比较固定死板, 因此我想了一个办法是从method里获取到所有的参数类型,再遍历parameterMap的时候根据下标所在位置对应的元素所在类型进行判断是否为cls里面对应的类型。 优化代码如下:
//动态加载参数, 可按照参数的顺序来绑定
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
System.out.println("url:" + url);
if (!this.mapping.containsKey(url)) {
resp.getWriter().println("404 NOT FOUND");
}
//调用url里对应的方法
Method method = (Method) mapping.get(url);
Map parameterMap = req.getParameterMap();
Object[] params = new Object[parameterMap.size() + 2];
params[0] = req;
params[1] = resp;
//获取方法的参数类型
Class>[] cls = method.getParameterTypes();
int index = 2;
//index=2cls=2 ,index=3 ,cls=3
for (Map.Entry s : parameterMap.entrySet()) {
if (cls[index].getName() == "java.lang.Integer") {
params[index] = new Integer(String.valueOf(s.getValue()[0]));
} else if (cls[index].getName() == "java.lang.String") {
params[index] = String.valueOf(s.getValue()[0]);
}
index++;
}
method.invoke(this.mapping.get(method.getDeclaringClass().getName()), params);
}复制代码
【编程|手写Spring核心原理MVC实现】改进点: 与之前方式不同的是在调用方法时不需要知道参数的名称来做指定参数类型的转换,变的更灵活些。
推荐阅读