|
|
1
|
+package com.yoho.unions.interceptor;
|
|
|
2
|
+
|
|
|
3
|
+import com.google.common.collect.ImmutableList;
|
|
|
4
|
+import com.google.common.collect.UnmodifiableListIterator;
|
|
|
5
|
+import com.yoho.core.common.utils.MD5;
|
|
|
6
|
+import com.yoho.error.exception.ServiceException;
|
|
|
7
|
+import com.yoho.unions.exception.RequestHeaderInvalidateException;
|
|
|
8
|
+import com.yoho.unions.exception.SecurityNotMatchException;
|
|
|
9
|
+import org.apache.commons.lang.StringUtils;
|
|
|
10
|
+import org.slf4j.Logger;
|
|
|
11
|
+import org.slf4j.LoggerFactory;
|
|
|
12
|
+import org.springframework.web.servlet.HandlerInterceptor;
|
|
|
13
|
+import org.springframework.web.servlet.ModelAndView;
|
|
|
14
|
+
|
|
|
15
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
16
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
17
|
+import java.util.*;
|
|
|
18
|
+
|
|
|
19
|
+/**
|
|
|
20
|
+ * Created by xinfei on 16/2/20.
|
|
|
21
|
+ */
|
|
|
22
|
+public class SecurityInterceptor implements HandlerInterceptor {
|
|
|
23
|
+
|
|
|
24
|
+ private final Logger logger = LoggerFactory.getLogger(SecurityInterceptor.class);
|
|
|
25
|
+
|
|
|
26
|
+
|
|
|
27
|
+ //读取配置文件中的private key 配置
|
|
|
28
|
+ private final Map<String, String> privateKeyMap = new HashMap<>();
|
|
|
29
|
+
|
|
|
30
|
+ // 这些url不会进行client-security校验。 例如 "/notify"
|
|
|
31
|
+ private List<String> excludeUrls;
|
|
|
32
|
+
|
|
|
33
|
+ //这些方法不用校验
|
|
|
34
|
+ private List<String> excludeMethods;
|
|
|
35
|
+
|
|
|
36
|
+ //这些IP下的这些方法不校验
|
|
|
37
|
+ private Map<String /*method*/, String /*ip*/> excludeMethodsBySubnet;
|
|
|
38
|
+ //是否启用
|
|
|
39
|
+ private boolean isDebugEnable = false;
|
|
|
40
|
+
|
|
|
41
|
+ @Override
|
|
|
42
|
+ public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
|
|
|
43
|
+ //(1)排除掉exclude和debug模式
|
|
|
44
|
+ if (this.isIgnore(request)) {
|
|
|
45
|
+ return true;
|
|
|
46
|
+ }
|
|
|
47
|
+
|
|
|
48
|
+ //(2)获取请求参数的信息
|
|
|
49
|
+ Map<String, String> params = this.getRequestInfo(request);
|
|
|
50
|
+
|
|
|
51
|
+ //(3)验证请求参数中是否包含必填参数, 如果不包含则请求失败(联盟暂时只做 client_secret 的校验)
|
|
|
52
|
+ this.validateReqParams(params);
|
|
|
53
|
+
|
|
|
54
|
+ //(4)校验安全码是否正确
|
|
|
55
|
+ this.validateSecurity(params);
|
|
|
56
|
+
|
|
|
57
|
+ return true;
|
|
|
58
|
+ }
|
|
|
59
|
+
|
|
|
60
|
+ @Override
|
|
|
61
|
+ public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
|
|
|
62
|
+
|
|
|
63
|
+ }
|
|
|
64
|
+
|
|
|
65
|
+ @Override
|
|
|
66
|
+ public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
|
|
|
67
|
+
|
|
|
68
|
+ }
|
|
|
69
|
+
|
|
|
70
|
+
|
|
|
71
|
+ /**
|
|
|
72
|
+ * 验证请求参数中的必填参数
|
|
|
73
|
+ *
|
|
|
74
|
+ * @param params 请求参数
|
|
|
75
|
+ * @throws ServiceException
|
|
|
76
|
+ */
|
|
|
77
|
+ private void validateReqParams(Map<String, String> params) throws RequestHeaderInvalidateException {
|
|
|
78
|
+ ImmutableList must_exist_params = ImmutableList.of("client_secret");
|
|
|
79
|
+ UnmodifiableListIterator<String> it = must_exist_params.listIterator();
|
|
|
80
|
+ while (it.hasNext()) {
|
|
|
81
|
+ String k = it.next();
|
|
|
82
|
+ String headerValue = params.get(k);
|
|
|
83
|
+ if (StringUtils.isEmpty(headerValue)) {
|
|
|
84
|
+ logger.warn("header {} not exist or empty", k);
|
|
|
85
|
+ throw new RequestHeaderInvalidateException(k);
|
|
|
86
|
+ }
|
|
|
87
|
+ }
|
|
|
88
|
+ }
|
|
|
89
|
+
|
|
|
90
|
+ /**
|
|
|
91
|
+ * 验证请求参数的client_secret是否正确
|
|
|
92
|
+ *
|
|
|
93
|
+ * @param reqParams
|
|
|
94
|
+ * @throws SecurityNotMatchException
|
|
|
95
|
+ */
|
|
|
96
|
+ private void validateSecurity(Map<String, String> reqParams) throws SecurityNotMatchException {
|
|
|
97
|
+ //根据请求参数生成加密码
|
|
|
98
|
+ String caculated_sign = this.getSign(reqParams);
|
|
|
99
|
+ //获取前台传入的client_secret, 比较参数是否一致
|
|
|
100
|
+ String request_sign = reqParams.get("client_secret");
|
|
|
101
|
+ if (!request_sign.equalsIgnoreCase(caculated_sign)) {
|
|
|
102
|
+ logger.warn("client security not match. request_sign:{}, caculate_sign:{}", request_sign, caculated_sign);
|
|
|
103
|
+ throw new SecurityNotMatchException();
|
|
|
104
|
+ }
|
|
|
105
|
+ }
|
|
|
106
|
+
|
|
|
107
|
+
|
|
|
108
|
+ /**
|
|
|
109
|
+ * 获取请求参数信息: requestParam
|
|
|
110
|
+ *
|
|
|
111
|
+ * @param httpServletRequest
|
|
|
112
|
+ * @return Map<String, String> 请求参数的键-值
|
|
|
113
|
+ */
|
|
|
114
|
+ private Map<String, String> getRequestInfo(HttpServletRequest httpServletRequest) {
|
|
|
115
|
+ Map<String, String> map = new HashMap<>();
|
|
|
116
|
+ Enumeration paramNames = httpServletRequest.getParameterNames();
|
|
|
117
|
+ while (paramNames.hasMoreElements()) {
|
|
|
118
|
+ String key = (String) paramNames.nextElement();
|
|
|
119
|
+ String value = httpServletRequest.getParameter(key);
|
|
|
120
|
+ map.put(key, value);
|
|
|
121
|
+ }
|
|
|
122
|
+ return map;
|
|
|
123
|
+ }
|
|
|
124
|
+
|
|
|
125
|
+ /**
|
|
|
126
|
+ * 根据请求的参数生成client_secret
|
|
|
127
|
+ *
|
|
|
128
|
+ * @param reqParams
|
|
|
129
|
+ * @return
|
|
|
130
|
+ */
|
|
|
131
|
+ private String getSign(Map<String, String> reqParams) {
|
|
|
132
|
+ //(1)删除不需要生成加密内容的参数
|
|
|
133
|
+ ImmutableList list = ImmutableList.of("/api", "client_secret", "q", "debug_data");
|
|
|
134
|
+ SortedMap<String, String> filtedMap = new TreeMap<>();
|
|
|
135
|
+ for (Map.Entry<String, String> entry : reqParams.entrySet()) {
|
|
|
136
|
+ String k = entry.getKey();
|
|
|
137
|
+ if (!list.contains(k)) {
|
|
|
138
|
+ filtedMap.put(k, entry.getValue());
|
|
|
139
|
+ }
|
|
|
140
|
+ }
|
|
|
141
|
+
|
|
|
142
|
+ //(2)根据客户端类型生成相映的客户端类型的KEY
|
|
|
143
|
+ String clientType = reqParams.get("client_type");
|
|
|
144
|
+ String privateKey = privateKeyMap.get(clientType);
|
|
|
145
|
+ filtedMap.put("private_key", privateKey);
|
|
|
146
|
+
|
|
|
147
|
+
|
|
|
148
|
+ //(3)将参数组装起来, 用=相连, 多个参数直接使用&连接, 如: string: k1=v1&k2=v2
|
|
|
149
|
+ List<String> array = new LinkedList<>();
|
|
|
150
|
+ for (Map.Entry<String, String> entry : filtedMap.entrySet()) {
|
|
|
151
|
+ String pair = entry.getKey() + "=" + entry.getValue();
|
|
|
152
|
+ array.add(pair.trim());
|
|
|
153
|
+ }
|
|
|
154
|
+ String signStr = String.join("&", array);
|
|
|
155
|
+ //sign md5
|
|
|
156
|
+ String sign = MD5.md5(signStr);
|
|
|
157
|
+ return sign.toLowerCase();
|
|
|
158
|
+ }
|
|
|
159
|
+
|
|
|
160
|
+
|
|
|
161
|
+ //排除掉不需要校验的
|
|
|
162
|
+ private boolean isIgnore(HttpServletRequest request) {
|
|
|
163
|
+ //如果请求url包含在过滤的url,则直接返回. 请求url可能是 "/gateway/xxx”这种包含了context的。
|
|
|
164
|
+ if (excludeUrls != null) {
|
|
|
165
|
+ final String requestUri = request.getRequestURI();
|
|
|
166
|
+ for (String excludeUri : excludeUrls) {
|
|
|
167
|
+ if (requestUri.equals(excludeUri) || requestUri.endsWith(excludeUri)) {
|
|
|
168
|
+ logger.info("excludeUri uri: {} for client-security check success.", requestUri);
|
|
|
169
|
+ return true;
|
|
|
170
|
+ }
|
|
|
171
|
+ }
|
|
|
172
|
+ }
|
|
|
173
|
+
|
|
|
174
|
+ //如果请求method包含这些,则不校验
|
|
|
175
|
+ if (excludeMethods != null) {
|
|
|
176
|
+ final String method = request.getParameter("method");
|
|
|
177
|
+ if (StringUtils.isNotEmpty(method) && excludeMethods.contains(method)) {
|
|
|
178
|
+ return true;
|
|
|
179
|
+ }
|
|
|
180
|
+ }
|
|
|
181
|
+
|
|
|
182
|
+ //配置文件配置为 is_debug_enable 为true,并且请求携带参数debug为XYZ,就放行
|
|
|
183
|
+ if (isDebugEnable && "XYZ".equals(request.getParameter("debug"))) {
|
|
|
184
|
+ return true;
|
|
|
185
|
+ }
|
|
|
186
|
+ return false;
|
|
|
187
|
+ }
|
|
|
188
|
+
|
|
|
189
|
+ /**
|
|
|
190
|
+ * spring setter from security-keyyml
|
|
|
191
|
+ *
|
|
|
192
|
+ * @param keyConfigMap security key 的配置
|
|
|
193
|
+ */
|
|
|
194
|
+ public void setKeyConfigMap(Map<String, Object> keyConfigMap) {
|
|
|
195
|
+ List<Map<String, Object>> keys = (List<Map<String, Object>>) keyConfigMap.get("client_keys");
|
|
|
196
|
+
|
|
|
197
|
+ for (Map<String, Object> one : keys) {
|
|
|
198
|
+ privateKeyMap.put((String) one.get("type"), (String) one.get("key"));
|
|
|
199
|
+ }
|
|
|
200
|
+ }
|
|
|
201
|
+
|
|
|
202
|
+ /**
|
|
|
203
|
+ * 设置不校验的url地址
|
|
|
204
|
+ */
|
|
|
205
|
+ public void setExcludeUrls(List<String> excludeUrls) {
|
|
|
206
|
+ this.excludeUrls = excludeUrls;
|
|
|
207
|
+ }
|
|
|
208
|
+
|
|
|
209
|
+ public void setExcludeMethods(List<String> excludeMethods) {
|
|
|
210
|
+ this.excludeMethods = excludeMethods;
|
|
|
211
|
+ }
|
|
|
212
|
+
|
|
|
213
|
+ public void setIsDebugEnable(boolean isDebugEnable) {
|
|
|
214
|
+ this.isDebugEnable = isDebugEnable;
|
|
|
215
|
+ }
|
|
|
216
|
+} |