手撸简易版Tomcat(二)

手撸简易版Tomcat(二)

一、实现ServletContext

在Java Web应用程序中,ServletContext代表应用程序的运行环境,一个Web应用程序对应一个唯一的ServletContext实例,ServletContext可以用于:

  • 提供初始化和全局配置:可以从ServletContext获取Web App配置的初始化参数、资源路径等信息;
  • 共享全局数据:ServletContext存储的数据可以被整个Web App的所有组件读写。

既然ServletContext是一个Web App的全局唯一实例,而Web App又运行在Servlet容器中,我们在实现ServletContext时,完全可以把它当作Servlet容器来实现,它在内部维护一组Servlet实例,并根据Servlet配置的路由信息将请求转发给对应的Servlet处理。假设我们编写了两个Servlet:

  • IndexServlet:映射路径为/
  • HelloServlet:映射路径为/hello
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@WebServlet(urlPatterns = "/")
public class IndexServlet extends HttpServlet {

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String html = "<h1>Index Page</h1>";
resp.setContentType("text/html");
PrintWriter pw = resp.getWriter();
pw.write(html);
pw.close();
}
}

@WebServlet(urlPatterns = "/hello")
public class HelloServlet extends HttpServlet {

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String name = req.getParameter("name");
String html = "<h1>Hello, " + (name == null ? "world" : name) + ".</h1>";
resp.setContentType("text/html");
PrintWriter pw = resp.getWriter();
pw.write(html);
pw.close();
}

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String name = req.getParameter("name");
System.out.println("request body data: " + new String(req.getInputStream().readAllBytes()));
String html = "<h1>Hello, " + (name == null ? "world" : name) + ".</h1>";
resp.setContentType("text/html");
PrintWriter pw = resp.getWriter();
pw.write(html);
pw.close();
}
}

那么,处理HTTP请求的路径如下:

image-20231127171157255

下面,我们来实现ServletContext。首先定义ServletMapping,它包含一个Servlet实例,以及将映射路径编译为正则表达式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
public class AbstractMapping implements Comparable<AbstractMapping> {

final Pattern pattern;
final String url;

public AbstractMapping(String urlPattern) {
this.url = urlPattern;
this.pattern = buildPattern(urlPattern);
}

public boolean matches(String uri) {
return pattern.matcher(uri).matches();
}

Pattern buildPattern(String urlPattern) {
StringBuilder sb = new StringBuilder(urlPattern.length() + 16);
// 正则表达式的开头
sb.append('^');
for (int i = 0; i < urlPattern.length(); i++) {
char ch = urlPattern.charAt(i);
if (ch == '*') {
// 表示匹配任意字符
sb.append(".*");
} else if (ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9') {
// 常规字符保持一致
sb.append(ch);
} else {
// 原字符转义,例如ch为&等价于正则表达式\&
sb.append('\\').append(ch);
}
}
// 正则表达式的结尾
sb.append('$');
return Pattern.compile(sb.toString());
}

@Override
public int compareTo(AbstractMapping o) {
// 1.长路径的优先级小,自然排序是靠前排放,符合最长前缀匹配原则
// 2."/"和"*"的优先级最高,自然排序放在最后,处理都不匹配的情况
int cmp = this.priority() - o.priority();
if (cmp == 0) {
cmp = this.url.compareTo(o.url);
}
return cmp;
}

// 满足URI路径的最长前缀匹配
int priority() {
if (this.url.equals("/")) {
return Integer.MAX_VALUE;
}
if (this.url.startsWith("*")) {
return Integer.MAX_VALUE - 1;
}
return 100000 - this.url.length();
}
}

public class ServletMapping extends AbstractMapping {

public final Servlet servlet;

public ServletMapping(String urlPattern, Servlet servlet) {
super(urlPattern);
this.servlet = servlet;
}
}

接下来实现ServletContext

1
2
3
public class ServletContextImpl implements ServletContext {
final List<ServletMapping> servletMappings = new ArrayList<>();
}

这个数据结构足够能让我们实现根据请求路径路由到某个特定的Servlet:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class ServletContextImpl implements ServletContext {
...
// HTTP请求处理入口:
public void process(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
// 请求路径:
String path = request.getRequestURI();
// 搜索Servlet:
Servlet servlet = null;
for (ServletMapping mapping : this.servletMappings) {
if (mapping.matches(path)) {
// 路径匹配:
servlet = mapping.servlet;
break;
}
}
if (servlet == null) {
// 未匹配到任何Servlet显示404 Not Found:
PrintWriter pw = response.getWriter();
pw.write("<h1>404 Not Found</h1><p>No mapping for URL: " + path + "</p>");
pw.close();
return;
}
// 由Servlet继续处理请求:
servlet.service(request, response);
}
}

这样我们就实现了ServletContext

不过,细心的同学会发现,我们编写的两个Servlet:IndexServletHelloServlet,还没有被添加到ServletContext中。那么问题来了:Servlet在什么时候被初始化?

答案是在创建ServletContext实例后,就立刻初始化所有的Servlet。我们编写一个initialize()方法,用于初始化Servlet:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
public class ServletContextImpl implements ServletContext {
Map<String, ServletRegistrationImpl> servletRegistrations = new HashMap<>();
List<ServletMapping> servletMappings = new ArrayList<>();

public void initialize(List<Class<?>> servletClasses) {
for (Class<?> c : servletClasses) {
WebServlet ws = c.getAnnotation(WebServlet.class);
// @WebServlet注解标注的Servlet会被注册处理
if (ws != null) {
logger.info("auto register @WebServlet: {}", c.getName());
@SuppressWarnings("unchecked")
Class<? extends Servlet> clazz = (Class<? extends Servlet>) c;
// 注册的Servlet的名称默认是类名首字母小写的形式,如果@WebServlet注解设置了name参数则以该参数值为准
ServletRegistration.Dynamic registration = this.addServlet(AnnoUtils.getServletName(clazz), clazz);
// 添加映射的请求路径
registration.addMapping(AnnoUtils.getServletUrlPatterns(clazz));
// 可以省略,因为本项目Servlet实现不支持initParameter参数功能,因为用处不大
registration.setInitParameters(AnnoUtils.getServletInitParams(clazz));
}
}

// init servlets:
for (String name : this.servletRegistrations.keySet()) {
ServletRegistrationImpl registration = this.servletRegistrations.get(name);
try {
// 初始化Servlet
registration.servlet.init(registration.getServletConfig());
// servletMappings中添加Servlet路由信息
for (String urlPattern : registration.getMappings()) {
if (urlPattern.equals("/")) {
// "/"对应的Servlet处理所有匹配失败的情况包括"/"本身
this.servletMappings.add(new ServletMapping("*", registration.servlet));
} else {
this.servletMappings.add(new ServletMapping(urlPattern, registration.servlet));
}
}
// Servlet注册完成
registration.initialized = true;
} catch (ServletException e) {
logger.error("init servlet failed: " + name + " / " + registration.servlet.getClass().getName(), e);
}
}
// important: sort mappings:
// 排序后长路径的Servlet映射靠前,优先最长前缀匹配
// 例如对于/test/hello请求来说:/test/hello的Servlet > /test的Servlet > /的Servlet
Collections.sort(this.servletMappings);
}

@Override
public ServletRegistration.Dynamic addServlet(String name, Class<? extends Servlet> clazz) {
if (clazz == null) {
throw new IllegalArgumentException("class is null.");
}
Servlet servlet = null;
try {
servlet = createInstance(clazz);
} catch (ServletException e) {
throw new RuntimeException(e);
}
return addServlet(name, servlet);
}

@Override
public ServletRegistration.Dynamic addServlet(String name, Servlet servlet) {
if (name == null) {
throw new IllegalArgumentException("name is null.");
}
if (servlet == null) {
throw new IllegalArgumentException("servlet is null.");
}
ServletRegistrationImpl registration = new ServletRegistrationImpl(this, name, servlet);
this.servletRegistrations.put(name, registration);
return registration;
}

@Override
public String getContextPath() {
// only support root context path:
return "";
}

@Override
public ServletContext getContext(String uripath) {
if ("".equals(uripath)) {
return this;
}
// all others are not exist:
return null;
}

// ServletContext也不支持初始化参数,因为也不是很常用,没必要实现
@Override
public String getInitParameter(String name) {
// no init parameters:
return null;
}

@Override
public Enumeration<String> getInitParameterNames() {
// no init parameters:
return Collections.emptyEnumeration();
}

@Override
public boolean setInitParameter(String name, String value) {
throw new UnsupportedOperationException("setInitParameter");
}

// Servlet API version: 6.0.0

@Override
public int getMajorVersion() {
return 6;
}

@Override
public int getMinorVersion() {
return 0;
}

@Override
public int getEffectiveMajorVersion() {
return 6;
}

@Override
public int getEffectiveMinorVersion() {
return 0;
}
}

从Servlet 3.0规范开始,我们必须要提供addServlet()动态添加一个Servlet,并且返回ServletRegistration.Dynamic,因此,我们在initialize()方法中调用addServlet(),完成所有Servlet的创建和初始化。上面的代码中出现了注解工具类AnnoUtils和Servlet注册所需的ServletRegistrationImpl,注解工具类的功能比较简单,就是根据用户Sevlet实现类及其@WebServlet注解完成一些信息解析工作,例如获取Servlet名称、获取Servlet的映射路径集合等等,我们重点看一下ServletRegistrationImpl

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
public class ServletRegistrationImpl implements ServletRegistration.Dynamic {
// 容器上下文
final ServletContext servletContext;
// Servlet名称
final String name;
// Servlet实例
final Servlet servlet;
// 映射路径集合
final List<String> urlPatterns = new ArrayList<>(4);

// 注册尚未结束
boolean initialized = false;

public ServletRegistrationImpl(ServletContext servletContext, String name, Servlet servlet) {
this.servletContext = servletContext;
this.name = name;
this.servlet = servlet;
}

public ServletConfig getServletConfig() {
return new ServletConfig() {
// Servlet的名称
@Override
public String getServletName() {
return ServletRegistrationImpl.this.name;
}

// Servlet的容器上下文
@Override
public ServletContext getServletContext() {
return ServletRegistrationImpl.this.servletContext;
}

// Servlet不支持initParameter参数功能
@Override
public String getInitParameter(String name) {
return null;
}

// Servlet不支持initParameter参数功能
@Override
public Enumeration<String> getInitParameterNames() {
return null;
}
};
}

@Override
public String getName() {
return this.name;
}

@Override
public String getClassName() {
return servlet.getClass().getName();
}

@Override
public Set<String> addMapping(String... urlPatterns) {
if (urlPatterns == null || urlPatterns.length == 0) {
throw new IllegalArgumentException("Missing urlPatterns.");
}
// 添加url路径映射
this.urlPatterns.addAll(Arrays.asList(urlPatterns));
// 返回冲突的url路径,本项目的简单实现不会检测url路径是否冲突,例如IndexServlet的映射路径"/index"与HelloServlet的映射路径"/index"不会被认定为冲突,默认程序员不会写这种代码
return Set.of();
}

@Override
public Collection<String> getMappings() {
return this.urlPatterns;
}

// 未实现的方法...

}

最后我们修改HttpConnector,实例化ServletContextImpl

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class HttpConnector implements HttpHandler {
// 持有ServletContext实例:
final ServletContextImpl servletContext;
final HttpServer httpServer;

public HttpConnector() throws IOException {
// 创建ServletContext:
this.servletContext = new ServletContextImpl();
// 初始化Servlet:
this.servletContext.initialize(List.of(IndexServlet.class, HelloServlet.class));
...
}

@Override
public void handle(HttpExchange exchange) throws IOException {
var adapter = new HttpExchangeAdapter(exchange);
var request = new HttpServletRequestImpl(adapter);
var response = new HttpServletResponseImpl(adapter);
// process:
this.servletContext.process(request, response);
}
}

image-20231127132735702

运行服务器,输入http://localhost:8080/,查看IndexServlet的输出:

image-20231127132833374

输入http://localhost:8080/hello?name=Bob,查看HelloServlet的输出:

image-20231127132910628

输入错误的路径,存在IndexServlet的路径/默认处理所有不匹配的情况:

image-20231127133206980

可见,我们已经成功完成了ServletContext和所有Servlet的管理,并实现了正确的路由。

有的同学会问:Servlet本身应该是Web App开发人员实现,而不是由服务器实现。我们在服务器中却写死了两个Servlet,这显然是不合理的。正确的方式是从外部war包加载Servlet,但是这个问题我们放到后面解决。

二、实现FilterChain

上一节我们实现了ServletContext,并且能够管理所有的Servlet组件。本节我们继续增加对Filter组件的支持。

Filter是Servlet规范中的一个重要组件,它的作用是在HTTP请求到达Servlet之前进行预处理。它可以被一个或多个Filter按照一定的顺序组成一个处理链(FilterChain),用来处理一些公共逻辑,比如打印日志、登录检查等。

Filter还可以有针对性地拦截或者放行HTTP请求,本质上一个FilterChain就是一个责任链模式。在Servlet容器中,处理流程如下:

image-20231127133339508

这里有几点需要注意:

  1. 最终处理请求的Servlet是根据请求路径选择的;
  2. Filter链上的Filter是根据请求路径匹配的,可能匹配0个或多个Filter;
  3. 匹配的Filter将组成FilterChain进行调用。

下面,我们首先将Filter纳入ServletContext中管理。和ServletMapping类似,先定义FilterMapping,它包含一个Filter实例,以及将映射路径编译为正则表达式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
public class AbstractMapping implements Comparable<AbstractMapping> {

final Pattern pattern;
final String url;

public AbstractMapping(String urlPattern) {
this.url = urlPattern;
this.pattern = buildPattern(urlPattern);
}

public boolean matches(String uri) {
return pattern.matcher(uri).matches();
}

Pattern buildPattern(String urlPattern) {
StringBuilder sb = new StringBuilder(urlPattern.length() + 16);
// 正则表达式的开头
sb.append('^');
for (int i = 0; i < urlPattern.length(); i++) {
char ch = urlPattern.charAt(i);
if (ch == '*') {
// 表示匹配任意字符
sb.append(".*");
} else if (ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9') {
// 常规字符保持一致
sb.append(ch);
} else {
// 原字符转义,例如ch为&等价于正则表达式\&
sb.append('\\').append(ch);
}
}
// 正则表达式的结尾
sb.append('$');
return Pattern.compile(sb.toString());
}

@Override
public int compareTo(AbstractMapping o) {
int cmp = this.priority() - o.priority();
if (cmp == 0) {
cmp = this.url.compareTo(o.url);
}
return cmp;
}

// 满足URI路径的最长前缀匹配
int priority() {
if (this.url.equals("/")) {
return Integer.MAX_VALUE;
}
if (this.url.startsWith("*")) {
return Integer.MAX_VALUE - 1;
}
return 100000 - this.url.length();
}
}

public class FilterMapping extends AbstractMapping {

public final Filter filter;

public FilterMapping(String urlPattern, Filter filter) {
super(urlPattern);
this.filter = filter;
}
}

接着,根据Servlet规范,我们需要提供addFilter()动态添加一个Filter,并且返回FilterRegistration.Dynamic,所以需要在ServletContext中实现相关方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public class ServletContextImpl implements ServletContext {
Map<String, FilterRegistrationImpl> filterRegistrations = new HashMap<>();
List<FilterMapping> filterMappings = new ArrayList<>();

// 根据Class Name添加Filter:
@Override
public FilterRegistration.Dynamic addFilter(String name, String className) {
return addFilter(name, Class.forName(className));
}

// 根据Class添加Filter:
@Override
public FilterRegistration.Dynamic addFilter(String name, Class<? extends Filter> clazz) {
return addFilter(name, clazz.newInstance());
}

// 根据Filter实例添加Filter:
@Override
public FilterRegistration.Dynamic addFilter(String name, Filter filter) {
var registration = new FilterRegistrationImpl(this, name, filter);
this.filterRegistrations.put(name, registration);
return registration;
}
...
}

再添加一个initFilters()方法用于向容器添加Filter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public class ServletContextImpl implements ServletContext {
...
public void initFilters(List<Class<?>> filterClasses) {
for (Class<?> c : filterClasses) {
// 获取@WebFilter注解:
WebFilter wf = c.getAnnotation(WebFilter.class);
// 添加Filter:
FilterRegistration.Dynamic registration = this.addFilter(AnnoUtils.getFilterName(clazz), clazz);
// 添加URL映射:
registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, AnnoUtils.getFilterUrlPatterns(clazz));
// 设置初始化参数:
registration.setInitParameters(AnnoUtils.getFilterInitParams(clazz));
}
for (String name : this.filterRegistrations.keySet()) {
// 依次处理每个FilterRegistration.Dynamic:
var registration = this.filterRegistrations.get(name);
// 调用Filter.init()方法:
registration.filter.init(registration.getFilterConfig());
this.nameToFilters.put(name, registration.filter);
// 将Filter定义的每个URL映射编译为正则表达式:
for (String urlPattern : registration.getUrlPatternMappings()) {
this.filterMappings.add(new FilterMapping(urlPattern, registration.filter));
}
}
}
...
}

这样,我们就完成了对Filter组件的管理。

同样,我们介绍一下上面代码中出现的FilterRegistrationImpl类,与之前说过的ServeletRegistrationImpl非常相似:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
public class FilterRegistrationImpl implements FilterRegistration.Dynamic {
// 容器上下文
final ServletContext servletContext;
// Filter的名称
final String name;
// Filter的实例
public final Filter filter;
// Filter的初始化参数
final InitParameters initParameters = new InitParameters();
// 匹配路径
final List<String> urlPatterns = new ArrayList<>(4);
// Filter注册尚未结束
public boolean initialized = false;

public FilterRegistrationImpl(ServletContext servletContext, String name, Filter filter) {
this.servletContext = servletContext;
this.name = name;
this.filter = filter;
}

public FilterConfig getFilterConfig() {
return new FilterConfig() {
@Override
public String getFilterName() {
return FilterRegistrationImpl.this.name;
}

@Override
public ServletContext getServletContext() {
return FilterRegistrationImpl.this.servletContext;
}

// Filter组件支持initParameter参数功能
@Override
public String getInitParameter(String name) {
return FilterRegistrationImpl.this.initParameters.getInitParameter(name);
}

// Filter组件支持initParameter参数功能
@Override
public Enumeration<String> getInitParameterNames() {
return FilterRegistrationImpl.this.initParameters.getInitParameterNames();
}
};
}

@Override
public String getName() {
return this.name;
}

@Override
public String getClassName() {
return filter.getClass().getName();
}

// proxy to InitParameters:

@Override
public boolean setInitParameter(String name, String value) {
checkNotInitialized("setInitParameter");
return this.initParameters.setInitParameter(name, value);
}

@Override
public String getInitParameter(String name) {
return this.initParameters.getInitParameter(name);
}

@Override
public Set<String> setInitParameters(Map<String, String> initParameters) {
checkNotInitialized("setInitParameter");
return this.initParameters.setInitParameters(initParameters);
}

@Override
public Map<String, String> getInitParameters() {
return this.initParameters.getInitParameters();
}

@Override
public void setAsyncSupported(boolean isAsyncSupported) {
checkNotInitialized("setInitParameter");
if (isAsyncSupported) {
throw new UnsupportedOperationException("Async is not supported.");
}
}

@Override
public void addMappingForServletNames(EnumSet<DispatcherType> dispatcherTypes, boolean isMatchAfter, String... servletNames) {
throw new UnsupportedOperationException("addMappingForServletNames");
}

@Override
public void addMappingForUrlPatterns(EnumSet<DispatcherType> dispatcherTypes, boolean isMatchAfter, String... urlPatterns) {
checkNotInitialized("addMappingForUrlPatterns");
if (!dispatcherTypes.contains(DispatcherType.REQUEST) || dispatcherTypes.size() != 1) {
throw new IllegalArgumentException("Only support DispatcherType.REQUEST.");
}
if (urlPatterns == null || urlPatterns.length == 0) {
throw new IllegalArgumentException("Missing urlPatterns.");
}
for (String urlPattern : urlPatterns) {
this.urlPatterns.add(urlPattern);
}
}

@Override
public Collection<String> getServletNameMappings() {
return List.of();
}

@Override
public Collection<String> getUrlPatternMappings() {
return this.urlPatterns;
}

private void checkNotInitialized(String name) {
if (this.initialized) {
throw new IllegalStateException("Cannot call " + name + " after initialization.");
}
}
}

下一步,是改造process()方法,把原来直接把请求扔给Servlet处理,改成先匹配Filter,处理后再扔给最终的Servlet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class ServletContextImpl implements ServletContext {
...
public void process(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
// 获取请求路径:
String path = request.getRequestURI();
// 查找Servlet:
Servlet servlet = null;
for (ServletMapping mapping : this.servletMappings) {
if (mapping.matches(path)) {
servlet = mapping.servlet;
break;
}
}
if (servlet == null) {
// 404错误:
PrintWriter pw = response.getWriter();
pw.write("<h1>404 Not Found</h1><p>No mapping for URL: " + path + "</p>");
pw.close();
return;
}
// 查找Filter:
List<Filter> enabledFilters = new ArrayList<>();
for (FilterMapping mapping : this.filterMappings) {
if (mapping.matches(path)) {
enabledFilters.add(mapping.filter);
}
}
Filter[] filters = enabledFilters.toArray(Filter[]::new);
// 构造FilterChain实例:
FilterChain chain = new FilterChainImpl(filters, servlet);
// 由FilterChain处理:
chain.doFilter(request, response);
}
...
}

注意上述FilterChain不仅包含一个Filter[]数组,还包含一个Servlet,这样我们调用chain.doFilter()时,在FilterChain中最后一个处理请求的就是Servlet,这样设计可以简化我们实现FilterChain的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public class FilterChainImpl implements FilterChain {
final Filter[] filters;
final Servlet servlet;
final int total; // Filter总数量
int index = 0; // 下一个要处理的Filter[index]

public FilterChainImpl(Filter[] filters, Servlet servlet) {
this.filters = filters;
this.servlet = servlet;
this.total = filters.length;
}

@Override
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (index < total) {
int current = index;
index++;
// 调用下一个Filter处理:
filters[current].doFilter(request, response, this);
} else {
// 调用Servlet处理:
servlet.service(request, response);
}
}
}

注意FilterChain是一个递归调用,因为在执行Filter.doFilter()时,需要把FilterChain自身传进去,在执行Filter.doFilter()之前,就要把index调整到正确的值。

我们编写两个测试用的Filter:

  • LogFilter:匹配/*,打印请求方法、路径等信息;
  • HelloFilter:匹配/hello,根据请求参数决定放行还是返回403错误。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@WebFilter(urlPatterns = "/*")
public class LogFilter implements Filter {

final Logger logger = LoggerFactory.getLogger(getClass());

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
logger.info("{}: {}", req.getMethod(), req.getRequestURI());
chain.doFilter(request, response);
}
}

@WebFilter(urlPatterns = "/hello")
public class HelloFilter implements Filter {

final Logger logger = LoggerFactory.getLogger(getClass());
Set<String> names = Set.of("Bob", "Alice", "Tom", "Jerry");

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
String name = req.getParameter("name");
logger.info("Check parameter name = {}", name);
if (name != null && names.contains(name)) {
chain.doFilter(request, response);
} else {
logger.warn("Access denied: name = {}", name);
HttpServletResponse resp = (HttpServletResponse) response;
resp.sendError(403, "Forbidden");
}
}
}

image-20231127140913723

在初始化ServletContextImpl时将Filter加进去,先测试http://localhost:8080/

image-20231127140954473

观察后台输出,LogFilter应该起作用:

image-20231127141041731

再测试http://localhost:8080/hello?name=Bob

image-20231127141115190

观察后台输出,HelloFilterLogFilter应该起作用:

image-20231127141201376

最后测试http://localhost:8080/hello?name=Jim

image-20231127141308161

可以看到,HelloFilter拦截了请求,返回403错误,最终的HelloServlet并没有处理该请求。

现在,我们就成功地在ServletContext中实现了对Filter的管理,以及根据每个请求,构造对应的FilterChain来处理请求。目前还有几个小问题:

一是和Servlet一样,Filter本身应该是Web App开发人员实现,而不是由服务器实现。我们在在服务器中写死了两个Filter,这个问题后续解决;

二是Servlet规范并没有规定多个Filter应该如何排序,我们在实现时也没有对Filter进行排序。如果要按固定顺序给Filter排序,从Servlet规范来说怎么排序都可以,通常是按@WebFilter定义的filterName进行排序,Spring Boot提供的一个FilterRegistrationBean允许开发人员自己定义Filter的顺序。

三、实现HttpSession

HttpSession是Java Web App的一种机制,用于在客户端和服务器之间维护会话状态信息。

当客户端第一次请求Web应用程序时,服务器会为该客户端创建一个唯一的Session ID,该ID本质上是一个随机字符串,然后,将该ID存储在客户端的一个名为JSESSIONID的Cookie中。与此同时,服务器会在内存中创建一个HttpSession对象,与Session ID关联,用于存储与该客户端相关的状态信息。

当客户端发送后续请求时,服务器根据客户端发送的名为JSESSIONID的Cookie中获得Session ID,然后查找对应的HttpSession对象,并从中读取或继续写入状态信息。

Session主要用于维护一个客户端的会话状态。通常,用户成功登录后,可以通过如下代码创建一个新的HttpSession,并将用户ID、用户名等信息放入HttpSession

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@WebServlet(urlPatterns = "/login")
public class LoginServlet extends HttpServlet {
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String username = req.getParameter("username");
String password = req.getParameter("password");
if (loginOk(username, password)) {
// 登录成功,获取Session:
HttpSession session = req.getSession();
// 将用户名放入Session:
session.setAttribute("username", username);
// 返回首页:
resp.sendRedirect("/");
} else {
// 登录失败:
resp.sendRedirect("/error");
}
}
}

在其他页面,可以随时获取HttpSession并取出用户信息,然后在页面展示给用户:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@WebServlet(urlPatterns = "/")
public class IndexServlet extends HttpServlet {
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
// 获取Session:
HttpSession session = req.getSession();
// 从Session中取出用户名:
String username = (String) session.getAttribute("username");
if (username == null) {
// 未获取到用户名,说明未登录:
resp.sendRedirect("/login");
} else {
// 获取到用户名,说明已登录:
String html = "<p>Welcome, " + username + "!</p>";
resp.setContentType("text/html");
PrintWriter pw = resp.getWriter();
pw.write(html);
pw.close();
}
}
}

当用户登出时,需要调用HttpSessioninvalidate()方法,让会话失效,这样,用户将重新回到未登录状态,因为后续调用req.getSession()将返回一个新的HttpSession,从这个新的HttpSession取出的username将是null

HttpSession的生命周期如下所示:

  • 第一次调用req.getSession()时,服务器会为该客户端创建一个新的HttpSession对象;

  • 后续调用req.getSession()时,服务器会返回与之关联的HttpSession对象;

  • 调用req.getSession().invalidate()时,服务器会销毁该客户端对应的HttpSession对象;

  • 当客户端一段时间内没有新的请求,服务器会根据Session超时自动销毁超时的HttpSession对象。

HttpSession是一个接口,Java的Web应用调用HttpServletRequestgetSession()方法时,需要返回一个HttpSession的实现类。

了解了以上关于HttpSession的相关规范后,我们就可以开始实现对HttpSession的支持。

首先,我们需要一个SessionManager,用来管理所有的Session:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
public class SessionManager implements Runnable {
final Logger logger = LoggerFactory.getLogger(getClass());
// 引用ServletContext:
final ServletContextImpl servletContext;
// 持有SessionID -> Session:
final Map<String, HttpSessionImpl> sessions = new ConcurrentHashMap<>();
// Session默认过期时间(秒):
final int inactiveInterval;

public SessionManager(ServletContextImpl servletContext, int interval) {
this.servletContext = servletContext;
this.inactiveInterval = interval;
Thread t = new Thread(this, "Session-Cleanup-Thread");
// Session过期清理线程是守护线程
t.setDaemon(true);
t.start();
}

// 根据SessionID获取一个Session:
public HttpSession getSession(String sessionId) {
HttpSessionImpl session = sessions.get(sessionId);
if (session == null) {
// Session未找到,创建一个新的Session:
session = new HttpSessionImpl(this.servletContext, sessionId, inactiveInterval);
sessions.put(sessionId, session);
} else {
// Session已存在,更新最后访问时间:
session.lastAccessedTime = System.currentTimeMillis();
}
return session;
}

// 删除Session:
public void remove(HttpSession session) {
this.sessions.remove(session.getId());
}

@Override
public void run() {
for (;;) {
try {
// 每隔一分钟扫描一次,清理过期session
Thread.sleep(60_000L);
} catch (InterruptedException e) {
break;
}
long now = System.currentTimeMillis();
for (String sessionId : sessions.keySet()) {
HttpSession session = sessions.get(sessionId);
if (session.getLastAccessedTime() + session.getMaxInactiveInterval() * 1000L < now) {
logger.warn("remove expired session: {}, last access time: {}", sessionId, DateUtils.formatDateTimeGMT(session.getLastAccessedTime()));
session.invalidate();
}
}
}
}
}

SessionManagerServletContextImpl持有唯一实例。

再编写一个HttpSession的实现类HttpSessionImpl

1
2
3
4
5
6
7
8
public class HttpSessionImpl implements HttpSession {
ServletContextImpl servletContext; // ServletContext
String sessionId; // SessionID
int maxInactiveInterval; // 过期时间(s)
long creationTime; // 创建时间(ms)
long lastAccessedTime; // 最后一次访问时间(ms)
Attributes attributes; // getAttribute/setAttribute
}

然后,我们分析一下用户调用Session的代码:

1
2
HttpSession session = request.getSession();
session.invalidate();

由于HttpSession是从HttpServletRequest获得的,因此,必须在HttpServletRequestImpl中引用ServletContextImpl,才能访问SessionManager

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
public class HttpServletRequestImpl implements HttpServletRequest {
// 引用ServletContextImpl:
ServletContextImpl servletContext;
// 引用HttpServletResponse:
HttpServletResponse response;
// 请求头信息:
final HttpHeaders headers;
// 请求参数信息:
final Parameters parameters;
// 记录request对象的getInputStream方法和getReader方法的使用情况:
Boolean inputCalled = null;

public HttpServletRequestImpl(ServletContextImpl servletContext, HttpExchangeRequest exchangeRequest, HttpServletResponse response) {
this.servletContext = servletContext;
this.exchangeRequest = exchangeRequest;
this.response = response;
this.headers = new HttpHeaders(exchangeRequest.getRequestHeaders());
this.parameters = new Parameters(exchangeRequest, "UTF-8");
}

@Override
public ServletInputStream getInputStream() throws IOException {
// 字节输入流只能调用一次,不能重复打开
if (this.inputCalled == null) {
this.inputCalled = Boolean.TRUE;
// ServletInputStreamImpl内部维持了输入字节缓冲区和读取位置
return new ServletInputStreamImpl(this.exchangeRequest.getRequestBody());
}
throw new IllegalStateException("Cannot reopen input stream after " + (this.inputCalled ? "getInputStream()" : "getReader()") + " was called.");
}

@Override
public BufferedReader getReader() throws IOException {
// 字节输入流只能调用一次,不能重复打开
if (this.inputCalled == null) {
this.inputCalled = Boolean.FALSE;
// 返回请求的字符输入流
return new BufferedReader(new InputStreamReader(new ByteArrayInputStream(this.exchangeRequest.getRequestBody()), StandardCharsets.UTF_8));
}
throw new IllegalStateException("Cannot reopen input stream after " + (this.inputCalled ? "getInputStream()" : "getReader()") + " was called.");
}

@Override
public HttpSession getSession(boolean create) {
String sessionId = null;
// 获取所有Cookie:
Cookie[] cookies = getCookies();
if (cookies != null) {
// 查找JSESSIONID:
for (Cookie cookie : cookies) {
if ("JSESSIONID".equals(cookie.getName())) {
// 拿到Session ID:
sessionId = cookie.getValue();
break;
}
}
}
// 未获取到SessionID,且create=false,返回null:
if (sessionId == null && !create) {
return null;
}
// 未获取到SessionID,但create=true,创建新的Session:
if (sessionId == null) {
// 如果Header已经发送,则无法创建Session,因为无法添加Cookie:
if (this.response.isCommitted()) {
throw new IllegalStateException("Cannot create session for response is commited.");
}
// 创建随机字符串作为SessionID:
sessionId = UUID.randomUUID().toString();
// 构造一个名为JSESSIONID的Cookie:
String cookieValue = "JSESSIONID=" + sessionId + "; Path=/; SameSite=Strict; HttpOnly";
// 添加到HttpServletResponse的Header:
this.response.addHeader("Set-Cookie", cookieValue);
}
// 返回一个Session对象:
return this.servletContext.sessionManager.getSession(sessionId);
}

@Override
public HttpSession getSession() {
return getSession(true);
}

@Override
public Cookie[] getCookies() {
// 获取请求头中Cookie对应的字符串信息
String cookieValue = this.getHeader("Cookie");
// 解析字符串
return HttpUtils.parseCookies(cookieValue);
}

@Override
public String getHeader(String name) {
return this.headers.getHeader(name);
}


...
}

HttpServletRequestImpl的改造主要是加入了ServletContextImplHttpServletResponse的引用:可以通过前者访问到SessionManager,而创建的新的SessionID需要通过后者把Cookie发送到客户端,因此,在HttpConnector中,做相应的修改如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class HttpConnector implements HttpHandler {
// 过期时间设置为600秒
final SessionManager sessionManager = new SessionManager(this, 600);
...
@Override
public void handle(HttpExchange exchange) throws IOException {
var adapter = new HttpExchangeAdapter(exchange);
var response = new HttpServletResponseImpl(adapter);
// 创建Request时,需要引用servletContext和response:
var request = new HttpServletRequestImpl(this.servletContext, adapter, response);
// process:
try {
this.servletContext.process(request, response);
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
}
}

当用户调用session.invalidate()时,要让Session失效,就需要从SessionManager中移除:

1
2
3
4
5
6
7
8
9
10
public class HttpSessionImpl implements HttpSession {
...
@Override
public void invalidate() {
// 从SessionManager中移除:
this.servletContext.sessionManager.remove(this);
this.sessionId = null;
}
...
}

最后,我们还需要实现Session的自动过期。由于我们管理的Session实际上是以Map<String, HttpSession>存储的,所以,让Session自动过期就是定期扫描所有的Session,然后根据最后一次访问时间将过期的Session自动删除。给SessionManager加一个Runnable接口,并启动一个Daemon线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class SessionManager implements Runnable {
...
public SessionManager(ServletContextImpl servletContext, int interval) {
...
// 启动Daemon线程:
Thread t = new Thread(this);
t.setDaemon(true);
t.start();
}

// 扫描线程:
@Override
public void run() {
for (;;) {
// 每60秒扫描一次:
try {
Thread.sleep(60_000L);
} catch (InterruptedException e) {
break;
}
// 当前时间:
long now = System.currentTimeMillis();
// 遍历Session:
for (String sessionId : sessions.keySet()) {
HttpSession session = sessions.get(sessionId);
// 判断是否过期:
if (session.getLastAccessedTime() + session.getMaxInactiveInterval() * 1000L < now) {
// 删除过期的Session:
logger.warn("remove expired session: {}, last access time: {}", sessionId, DateUtils.formatDateTimeGMT(session.getLastAccessedTime()));
session.invalidate();
}
}
}
}
}

HttpServletRequestHttpServletResponse与Cookie相关的实现方法补全,我们就得到了一个基于Cookie的HttpSession实现!

前面介绍过HttpServletRequest,那我们还是说一下HttpServletResponse补全后的内容吧,总体内容跟Request类似,不过Response的字节输出流和字符输出流是可以多次打开的,而且在发送响应前JSESSION对应的Cookie必须先设置好,不能一旦发送响应开始后便不能修改响应头部信息了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
public class HttpServletResponseImpl implements HttpServletResponse {
final HttpExchangeResponse exchangeResponse;
// 响应头,包含Cookie信息
final HttpHeaders headers;
// 默认的响应码
int status = 200;
int bufferSize = 1024;
// 记录response的getOutputStream方法和getWriter方法的使用情况
Boolean callOutput = null;
// 字节输出流
ServletOutputStream output;
// 字符输出流
PrintWriter writer;
// 响应内容类型,默认是text/html
String contentType;
// 响应内容大小
long contentLength = 0;
// 是否已发送响应的标记
boolean committed = false;

public HttpServletResponseImpl(HttpExchangeResponse exchangeResponse) {
this.exchangeResponse = exchangeResponse;
this.headers = new HttpHeaders(exchangeResponse.getResponseHeaders());
this.setContentType("text/html");
}

@Override
public ServletOutputStream getOutputStream() throws IOException {
// 第一次调用
if (callOutput == null) {
// 开始发生响应,采用分块传输
commitHeaders(0);
// ServletOutputStreamImpl维护一个字节输出流
this.output = new ServletOutputStreamImpl(this.exchangeResponse.getResponseBody());
this.callOutput = Boolean.TRUE;
return this.output;
}
// 下一次调用直接使用之前已经创建好的输出流
if (callOutput) {
return this.output;
}
throw new IllegalStateException("Cannot open output stream when writer is opened.");
}

@Override
public PrintWriter getWriter() throws IOException {
// 第一次调用
if (callOutput == null) {
// 开始发生响应,采用分块传输
commitHeaders(0);
// 字符输出流
this.writer = new PrintWriter(this.exchangeResponse.getResponseBody(), true, StandardCharsets.UTF_8);
this.callOutput = Boolean.FALSE;
return this.writer;
}
// 下一次调用直接使用之前已经创建好的输出流
if (!callOutput) {
return this.writer;
}
throw new IllegalStateException("Cannot open writer when output stream is opened.");
}

void commitHeaders(long length) throws IOException {
this.exchangeResponse.sendResponseHeaders(this.status, length);
this.committed = true;
}

void checkNotCommitted() {
if (this.committed) {
throw new IllegalStateException("Response is committed.");
}
}

@Override
public void sendError(int sc, String msg) throws IOException {
checkNotCommitted();
this.status = sc;
// 发送错误响应,不采用分块传输
commitHeaders(-1);
}

@Override
public void sendError(int sc) throws IOException {
sendError(sc, "Error");
}

@Override
public void sendRedirect(String location) throws IOException {
checkNotCommitted();
// 设置临时重定向
this.status = 302;
this.headers.setHeader("Location", location);
// 发送错误响应,不采用分块传输
commitHeaders(-1);
}

@Override
public void addHeader(String name, String value) {
checkNotCommitted();
this.headers.addHeader(name, value);
}
}

最后需要注意的一点是,和HttpServletRequest不同,访问HttpServletRequest实例的一定是一个线程,因此,HttpServletRequestgetAttribute()setAttribute()不需要同步,底层存储用HashMap即可。但是,访问HttpSession实例的可能是多线程,所以,HttpSessiongetAttribute()setAttribute()需要实现并发访问,底层存储用ConcurrentHashMap即可。

image-20231127152238495

访问IndexServlet,第一次访问时,将获取到新的HttpSession,此时,HttpSession没有用户信息,因此显示登录表单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@WebServlet(urlPatterns = "/")
public class IndexServlet extends HttpServlet {

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
HttpSession session = req.getSession();
String username = (String) session.getAttribute("username");
String html;
if (username == null) {
html = "<h1>Index Page</h1><form method=\"post\" action=\"/login\"><legend>Please Login</legend><p>User Name: <input type=\"text\" name=\"username\"></p><p>Password: <input type=\"password\" name=\"password\"></p><p><button type=\"submit\">Login</button></p></form>";
} else {
html = "<h1>Index Page</h1><p>Welcome, {username}!</p><p><a href=\"/logout\">Logout</a></p>".replace("{username}", username);
}
resp.setContentType("text/html");
PrintWriter pw = resp.getWriter();
pw.write(html);
pw.close();
}
}

@WebServlet(urlPatterns = "/login")
public class LoginServlet extends HttpServlet {

Map<String, String> users = Map.of( // user database
"bob", "bob123", //
"alice", "alice123", //
"root", "admin123" //
);

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String username = req.getParameter("username");
String password = req.getParameter("password");
String expectedPassword = users.get(username.toLowerCase());
if (expectedPassword == null || !expectedPassword.equals(password)) {
PrintWriter pw = resp.getWriter();
pw.write("<h1>Login Failed</h1><p>Invalid username or password.</p><p><a href=\"/\">Try again</a></p>");
pw.close();
} else {
req.getSession().setAttribute("username", username);
resp.sendRedirect("/");
}
}
}

@WebServlet(urlPatterns = "/logout")
public class LogoutServlet extends HttpServlet {

@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
HttpSession session = req.getSession();
session.invalidate();
resp.sendRedirect("/");
}
}

image-20231127152338332

登录成功后,可以看到用户名已放入HttpSessionIndexServletHttpSession获取到用户名后将用户名显示出来:

image-20231127152509584

刷新页面,IndexServlet仍将显示登录的用户名,因为根据Cookie拿到相同的SessionID后,获取的HttpSession是同一个实例。由于我们设定的HttpSession过期时间是10分钟,等待至少10分钟,观察控制台输出:

image-20231127154211169

大约在15:25:33时清理了过期的Session,最后一次访问时间是15:36:14(注意时间需要经过时区调整),再次刷新页面将显示登录表单:

image-20231127154049413

由于没有对HttpSession进行持久化处理,重启服务器后,将丢失所有用户的Session。如果希望重启服务器后保留用户的Session,则需要将Session数据持久化到文件或数据库,此功能要求用户放入HttpSession的Java对象必须是可序列化的;

因为Session不容易扩展,因此,大规模集群的Web App通常自己管理Cookie来实现登录功能,这样,将用户状态完全保存在浏览器端,不使用Session,服务器就可以做到无状态集群。

四、实现Listener

在Java Web App中,除了Servlet、Filter和HttpSession外,还有一种Listener组件,用于事件监听。

Listener是Java Web App中的一种事件监听机制,用于监听Web应用程序中产生的事件,例如,在ServletContext初始化完成后,会触发contextInitialized事件,实现了ServletContextListener接口的Listener就可以接收到事件通知,可以在内部做一些初始化工作,如加载配置文件,初始化数据库连接池等。实现了HttpSessionListener接口的Listener可以接收到Session的创建和消耗事件,这样就可以统计在线用户数。

Listener机制是基于观察者模式实现的,即当某个事件发生时,Listener会接收到通知并执行相应的操作。

Servlet规范定义了很多种Listener接口,常用的Listener包括:

  • ServletContextListener:用于监听ServletContext的创建和销毁事件;
  • HttpSessionListener:用于监听HttpSession的创建和销毁事件;
  • ServletRequestListener:用于监听ServletRequest的创建和销毁事件;
  • ServletContextAttributeListener:用于监听ServletContext属性的添加、修改和删除事件;
  • HttpSessionAttributeListener:用于监听HttpSession属性的添加、修改和删除事件;
  • ServletRequestAttributeListener:用于监听ServletRequest属性的添加、修改和删除事件。

下面我们就来实现上述常用的Listener。

首先我们需要在ServletContextImpl中注册并管理所有的Listener,所以用不同的List持有注册的Listener:

1
2
3
4
5
6
7
8
9
10
public class ServletContextImpl implements ServletContext {
...
private List<ServletContextListener> servletContextListeners = null;
private List<ServletContextAttributeListener> servletContextAttributeListeners = null;
private List<ServletRequestListener> servletRequestListeners = null;
private List<ServletRequestAttributeListener> servletRequestAttributeListeners = null;
private List<HttpSessionAttributeListener> httpSessionAttributeListeners = null;
private List<HttpSessionListener> httpSessionListeners = null;
...
}

然后,实现ServletContextaddListener()接口,用于注册Listener:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class ServletContextImpl implements ServletContext {
...
@Override
public void addListener(String className) {
addListener(Class.forName(className));
}

@Override
public void addListener(Class<? extends EventListener> clazz) {
addListener(clazz.newInstance());
}

@Override
public <T extends EventListener> void addListener(T t) {
// 根据Listener类型放入不同的List:
if (t instanceof ServletContextListener listener) {
if (this.servletContextListeners == null) {
this.servletContextListeners = new ArrayList<>();
}
this.servletContextListeners.add(listener);
} else if (t instanceof ServletContextAttributeListener listener) {
if (this.servletContextAttributeListeners == null) {
this.servletContextAttributeListeners = new ArrayList<>();
}
this.servletContextAttributeListeners.add(listener);
} else if ...
...代码略...
} else {
throw new IllegalArgumentException("Unsupported listener: " + t.getClass().getName());
}
}
...
}

接下来,就是在合适的时机,触发这些Listener。以ServletContextAttributeListener为例,统一触发的方法放在ServletContextImpl中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class ServletContextImpl implements ServletContext {
...
void invokeServletContextAttributeAdded(String name, Object value) {
logger.info("invoke ServletContextAttributeAdded: {} = {}", name, value);
if (this.servletContextAttributeListeners != null) {
var event = new ServletContextAttributeEvent(this, name, value);
for (var listener : this.servletContextAttributeListeners) {
listener.attributeAdded(event);
}
}
}

void invokeServletContextAttributeRemoved(String name, Object value) {
logger.info("invoke ServletContextAttributeRemoved: {} = {}", name, value);
if (this.servletContextAttributeListeners != null) {
var event = new ServletContextAttributeEvent(this, name, value);
for (var listener : this.servletContextAttributeListeners) {
listener.attributeRemoved(event);
}
}
}

void invokeServletContextAttributeReplaced(String name, Object value) {
logger.info("invoke ServletContextAttributeReplaced: {} = {}", name, value);
if (this.servletContextAttributeListeners != null) {
var event = new ServletContextAttributeEvent(this, name, value);
for (var listener : this.servletContextAttributeListeners) {
listener.attributeReplaced(event);
}
}
}
...
}

当Web App的任何组件调用ServletContextsetAttribute()removeAttribute()时,就可以触发事件通知:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class ServletContextImpl implements ServletContext {
...
@Override
public void setAttribute(String name, Object value) {
if (value == null) {
removeAttribute(name);
} else {
Object old = this.attributes.setAttribute(name, value);
if (old == null) {
// 触发attributeAdded:
this.invokeServletContextAttributeAdded(name, value);
} else {
// 触发attributeReplaced:
this.invokeServletContextAttributeReplaced(name, value);
}
}
}

@Override
public void removeAttribute(String name) {
Object old = this.attributes.removeAttribute(name);
// 触发attributeRemoved:
this.invokeServletContextAttributeRemoved(name, old);
}
...
}

其他事件触发也是类似的写法,此处不再重复。Servlet规范定义了各种Listener组件,我们支持了其中常用的大部分EventListener组件;Listener组件由ServletContext统一管理,并提供统一调度入口方法;通知机制允许多线程同时调用,如果要防止并发调用Listener的回调方法,需要Listener组件本身在内部做好同步。

image-20231127155242546

为了测试Listener机制是否生效,我们还需要先编写不同类型的Listener,例如,HelloHttpSessionAttributeListener实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@WebListener
public class HelloHelloHttpSessionAttributeListener implements HttpSessionAttributeListener {

final Logger logger = LoggerFactory.getLogger(getClass());

@Override
public void attributeAdded(HttpSessionBindingEvent event) {
logger.info(">>> HttpSession attribute added: {} = {}", event.getName(), event.getValue());
}

@Override
public void attributeRemoved(HttpSessionBindingEvent event) {
logger.info(">>> HttpSession attribute removed: {} = {}", event.getName(), event.getValue());
}

@Override
public void attributeReplaced(HttpSessionBindingEvent event) {
logger.info(">>> HttpSession attribute replaced: {} = {}", event.getName(), event.getValue());
}
}

然后在HttpConnector中注册所有的Listener:

1
2
3
4
List<Class<? extends EventListener>> listenerClasses = List.of(HelloHttpSessionAttributeListener.class, ...);
for (Class<? extends EventListener> listenerClass : listenerClasses) {
this.servletContext.addListener(listenerClass);
}

启动服务器,在浏览器中登录或登出,观察日志输出,在每个请求处理前后,可以看到ServletRequestListener的创建和销毁事件:

1
2
3
08:58:23.944 [HTTP-Dispatcher] INFO  c.i.j.e.l.HelloServletRequestListener -- >>> ServletRequest initialized: HttpServletRequestImpl@71a49a97[GET:/]
...
08:58:24.008 [HTTP-Dispatcher] INFO c.i.j.e.l.HelloServletRequestListener -- >>> ServletRequest destroyed: HttpServletRequestImpl@71a49a97[GET:/]

在第一次访问页面和登出时,可以看到HttpSessionListener的创建和销毁事件:

1
2
3
08:58:23.947 [HTTP-Dispatcher] INFO  c.i.j.e.l.HelloHttpSessionListener -- >>> HttpSession created: com.itranswarp.jerrymouse.engine.HttpSessionImpl@15037a31
...
08:58:36.766 [HTTP-Dispatcher] INFO c.i.j.e.l.HelloHttpSessionListener -- >>> HttpSession destroyed: com.itranswarp.jerrymouse.engine.HttpSessionImpl@15037a31

其他事件的触发也可以在日志中找到,这说明我们成功地实现了Servlet规范的Listener机制。