ShareJoin.java 15.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
package io.mycat.catlets;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import io.mycat.backend.mysql.nio.handler.MiddlerQueryResultHandler;
import io.mycat.backend.mysql.nio.handler.MiddlerResultHandler;
import io.mycat.cache.LayerCachePool;
import io.mycat.config.ErrorCode;
import io.mycat.config.Fields;
import io.mycat.config.model.SchemaConfig;
import io.mycat.config.model.SystemConfig;
import io.mycat.net.mysql.FieldPacket;
import io.mycat.net.mysql.RowDataPacket;
import io.mycat.route.RouteResultset;
import io.mycat.route.RouteResultsetNode;
import io.mycat.route.factory.RouteStrategyFactory;
import io.mycat.server.NonBlockingSession;
import io.mycat.server.ServerConnection;
import io.mycat.server.parser.ServerParse;
import io.mycat.sqlengine.AllJobFinishedListener;
import io.mycat.sqlengine.EngineCtx;
import io.mycat.sqlengine.SQLJobHandler;
import io.mycat.util.ByteUtil;
import io.mycat.util.ResultSetUtil;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**  
 * 功能详细描述:分片join
 * @author sohudo[http://blog.csdn.net/wind520]
 * @create 2015年01月22日 下午6:50:23 
 * @version 0.0.1
 */

public class ShareJoin implements Catlet {
	private EngineCtx ctx;
	private RouteResultset rrs ;
	private JoinParser joinParser;
	
	private Map<String, byte[]> rows = new ConcurrentHashMap<String, byte[]>();
	private Map<String,String> ids = new ConcurrentHashMap<String,String>();
	//private ConcurrentLinkedQueue<String> ids = new ConcurrentLinkedQueue<String>();
	
	private List<byte[]> fields; //主表的字段
	private ArrayList<byte[]> allfields;//所有的字段
	private boolean isMfield=false;
	private int mjob=0;
	private int maxjob=0;
	private int joinindex=0;//关联join表字段的位置
	private int sendField=0;
	private boolean childRoute=false;
	private boolean jointTableIsData=false;
	// join 字段的类型,一般情况都是int, long; 增加该字段为了支持非int,long类型的(一般为varchar)joinkey的sharejoin
 	// 参见:io.mycat.server.packet.FieldPacket 属性: public int type;
 	// 参见:http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition
 	private int joinKeyType = Fields.FIELD_TYPE_LONG; // 默认 join 字段为int型
 	
	//重新路由使用
	private SystemConfig sysConfig; 
	private SchemaConfig schema;
	private int sqltype; 
	private String charset; 
	private ServerConnection sc;	
	private LayerCachePool cachePool;
	public void setRoute(RouteResultset rrs){
		this.rrs =rrs;
	}	
	
	public void route(SystemConfig sysConfig, SchemaConfig schema,int sqlType, String realSQL, String charset, ServerConnection sc,	LayerCachePool cachePool) {
		int rs = ServerParse.parse(realSQL);
		this.sqltype = rs & 0xff;
		this.sysConfig=sysConfig; 
		this.schema=schema;
		this.charset=charset; 
		this.sc=sc;	
		this.cachePool=cachePool;		
		try {
		 //  RouteStrategy routes=RouteStrategyFactory.getRouteStrategy();	
		  // rrs =RouteStrategyFactory.getRouteStrategy().route(sysConfig, schema, sqlType2, realSQL,charset, sc, cachePool);		   
			MySqlStatementParser parser = new MySqlStatementParser(realSQL);			
			SQLStatement statement = parser.parseStatement();
			if(statement instanceof SQLSelectStatement) {
			   SQLSelectStatement st=(SQLSelectStatement)statement;
			   SQLSelectQuery sqlSelectQuery =st.getSelect().getQuery();
				if(sqlSelectQuery instanceof MySqlSelectQueryBlock) {
					MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)st.getSelect().getQuery();
					joinParser=new JoinParser(mysqlSelectQuery,realSQL);
					joinParser.parser();
				}	
			}
		   /*	
		   if (routes instanceof DruidMysqlRouteStrategy) {
			   SQLSelectStatement st=((DruidMysqlRouteStrategy) routes).getSQLStatement();
			   SQLSelectQuery sqlSelectQuery =st.getSelect().getQuery();
				if(sqlSelectQuery instanceof MySqlSelectQueryBlock) {
					MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)st.getSelect().getQuery();
					joinParser=new JoinParser(mysqlSelectQuery,realSQL);
					joinParser.parser();
				}
		   }
		   */
		} catch (Exception e) {
		
		}
	}
	private void getRoute(String sql){
		try {
		  if (joinParser!=null){
			rrs =RouteStrategyFactory.getRouteStrategy().route(sysConfig, schema, sqltype,sql,charset, sc, cachePool);
		  }
		} catch (Exception e) {
			
		}
	}
	private String[] getDataNodes(){		
		String[] dataNodes =new String[rrs.getNodes().length] ;
		for (int i=0;i<rrs.getNodes().length;i++){
			dataNodes[i]=rrs.getNodes()[i].getName();
		}
		return dataNodes;
	}
	private String getDataNode(String[] dataNodes){
		String dataNode="";
		for (int i=0;i<dataNodes.length;i++){
			dataNode+=dataNodes[i]+",";
		}
		return dataNode;
	}
	public void processSQL(String sql, EngineCtx ctx) {
		String ssql=joinParser.getSql();
		getRoute(ssql);
		RouteResultsetNode[] nodes = rrs.getNodes();
		if (nodes == null || nodes.length == 0 || nodes[0].getName() == null
				|| nodes[0].getName().equals("")) {
			ctx.getSession().getSource().writeErrMessage(ErrorCode.ER_NO_DB_ERROR,
					"No dataNode found ,please check tables defined in schema:"
							+ ctx.getSession().getSource().getSchema());
			return;
		} 
		this.ctx=ctx;
		String[] dataNodes =getDataNodes();
		maxjob=dataNodes.length;
	 

    	//huangyiming
		ShareDBJoinHandler joinHandler = new ShareDBJoinHandler(this,joinParser.getJoinLkey(),sc.getSession2());		
		ctx.executeNativeSQLSequnceJob(dataNodes, ssql, joinHandler);
    	EngineCtx.LOGGER.info("Catlet exec:"+getDataNode(getDataNodes())+" sql:" +ssql);

		ctx.setAllJobFinishedListener(new AllJobFinishedListener() {
			@Override
			public void onAllJobFinished(EngineCtx ctx) {				
				 if (!jointTableIsData) {
					 ctx.writeHeader(fields);
				 }
				 
				 MiddlerResultHandler middlerResultHandler = sc.getSession2().getMiddlerResultHandler();

					if(  middlerResultHandler !=null ){
						//sc.getSession2().setCanClose(false);
						middlerResultHandler.secondEexcute(); 
					} else{
						ctx.writeEof();
					}
				EngineCtx.LOGGER.info("发送数据OK"); 
			}
		});
	}
	
    public void putDBRow(String id,String nid, byte[] rowData,int findex){
    	rows.put(id, rowData);	
    	ids.put(id, nid);
    	joinindex=findex;
		//ids.offer(nid);
		int batchSize = 999;
		// 满1000条,发送一个查询请求
		if (ids.size() > batchSize) {
			createQryJob(batchSize);
		}            	
    }
    
    public void putDBFields(List<byte[]> mFields){
    	 if (!isMfield){
    		 fields=mFields; 
    	 }    	
    }    

   public void endJobInput(String dataNode, boolean failed){
	   mjob++;
	   if (mjob>=maxjob){
		 createQryJob(Integer.MAX_VALUE);
	     ctx.endJobInput();
	   }
	  // EngineCtx.LOGGER.info("完成"+mjob+":" + dataNode+" failed:"+failed);
   }
   
	//private void createQryJob(String dataNode,int batchSize) {	
	private void createQryJob(int batchSize) {	
		int count = 0;
		Map<String, byte[]> batchRows = new ConcurrentHashMap<String, byte[]>();
		String theId = null;
		StringBuilder sb = new StringBuilder().append('(');
		String svalue="";
		for(Map.Entry<String,String> e: ids.entrySet() ){
			theId=e.getKey();
			byte[] rowbyte = rows.remove(theId);
			if(rowbyte!=null){
				batchRows.put(theId, rowbyte);
			}			
			if (!svalue.equals(e.getValue())){
				if(joinKeyType == Fields.FIELD_TYPE_VAR_STRING 
						|| joinKeyType == Fields.FIELD_TYPE_STRING){ // joinkey 为varchar
						sb.append("'").append(e.getValue()).append("'").append(','); // ('digdeep','yuanfang') 
				}else{ // 默认joinkey为int/long
					sb.append(e.getValue()).append(','); // (1,2,3) 
				}
			}
			svalue=e.getValue();
			if (count++ > batchSize) {
				break;
			}			
		}
		/*
		while ((theId = ids.poll()) != null) {
			batchRows.put(theId, rows.remove(theId));
			sb.append(theId).append(',');
			if (count++ > batchSize) {
				break;
			}
		}
		*/
		if (count == 0) {
			return;
		}
		jointTableIsData=true;
		sb.deleteCharAt(sb.length() - 1).append(')');
		String sql = String.format(joinParser.getChildSQL(), sb);
		//if (!childRoute){
		  getRoute(sql);
		 //childRoute=true;
		//}
		ctx.executeNativeSQLParallJob(getDataNodes(),sql, new ShareRowOutPutDataHandler(this,fields,joinindex,joinParser.getJoinRkey(), batchRows,ctx.getSession()));
		EngineCtx.LOGGER.info("SQLParallJob:"+getDataNode(getDataNodes())+" sql:" + sql);		
	}  
	public void writeHeader(String dataNode,List<byte[]> afields, List<byte[]> bfields) {
		sendField++;
		if (sendField==1){		  	
			//huangyiming add 只是中间过程数据不能发送给客户端
			MiddlerResultHandler middlerResultHandler = sc.getSession2().getMiddlerResultHandler();
 			if(middlerResultHandler ==null ){
				 ctx.writeHeader(afields, bfields);
 			}  
 		  setAllFields(afields, bfields);
		 // EngineCtx.LOGGER.info("发送字段2:" + dataNode);
		}
		
	}
	private void setAllFields(List<byte[]> afields, List<byte[]> bfields){		
		allfields=new ArrayList<byte[]>();
		for (byte[] field : afields) {
			allfields.add(field);
		}
		//EngineCtx.LOGGER.info("所有字段2:" +allfields.size());
		for (int i=1;i<bfields.size();i++){
			allfields.add(bfields.get(i));
		}
		
	}
	public List<byte[]> getAllFields(){		
		return allfields;
	}
	public void writeRow(RowDataPacket rowDataPkg){
		ctx.writeRow(rowDataPkg);
	}
	
	public int getFieldIndex(List<byte[]> fields,String fkey){
		int i=0;
		for (byte[] field :fields) {	
			  FieldPacket fieldPacket = new FieldPacket();
			  fieldPacket.read(field);	
			  if (ByteUtil.getString(fieldPacket.orgName).equals(fkey)){
				  joinKeyType = fieldPacket.type;
				  return i;				  
			  }
			  i++;
			}
		return i;		
	}	
}

class ShareDBJoinHandler implements SQLJobHandler {
	private List<byte[]> fields;
	private final ShareJoin ctx;
	private String joinkey;
	private NonBlockingSession session;
	public ShareDBJoinHandler(ShareJoin ctx,String joinField,NonBlockingSession session) {
		super();
		this.ctx = ctx;
		this.joinkey=joinField;
		this.session = session;
		//EngineCtx.LOGGER.info("二次查询:"  +" sql:" + querySQL+"/"+joinkey);
	}

	//private Map<String, byte[]> rows = new ConcurrentHashMap<String, byte[]>();
	//private ConcurrentLinkedQueue<String> ids = new ConcurrentLinkedQueue<String>();

	@Override
	public void onHeader(String dataNode, byte[] header, List<byte[]> fields) {
		this.fields = fields;
		ctx.putDBFields(fields);
	}
	

	/*
	public static String getFieldNames(List<byte[]> fields){
		String str="";
		for (byte[] field :fields) {	
		  FieldPacket fieldPacket = new FieldPacket();
		  fieldPacket.read(field);	
		  str+=ByteUtil.getString(fieldPacket.name)+",";
		}
		return str;
	}
	
	public static String getFieldName(byte[] field){
		FieldPacket fieldPacket = new FieldPacket();
		fieldPacket.read(field);	
		return ByteUtil.getString(fieldPacket.name);
	}
	*/
	@Override
	public boolean onRowData(String dataNode, byte[] rowData) {
		int fid=this.ctx.getFieldIndex(fields,joinkey);
		String id = ResultSetUtil.getColumnValAsString(rowData, fields, 0);//主键,默认id
		String nid = ResultSetUtil.getColumnValAsString(rowData, fields, fid);
		// 放入结果集
		//rows.put(id, rowData);
		ctx.putDBRow(id,nid, rowData,fid);
		return false;
	}

	@Override
	public void finished(String dataNode, boolean failed, String errorMsg) {
		if(failed){
			session.getSource().writeErrMessage(ErrorCode.ER_UNKNOWN_ERROR, errorMsg);
		}else{
			ctx.endJobInput(dataNode,failed);
		}
	}

}

class ShareRowOutPutDataHandler implements SQLJobHandler {
	private final List<byte[]> afields;
	private List<byte[]> bfields;
	private final ShareJoin ctx;
	private final Map<String, byte[]> arows;
	private int joinL;//A表(左边)关联字段的位置
	private int joinR;//B表(右边)关联字段的位置
	private String joinRkey;//B表(右边)关联字段
	public NonBlockingSession session;

	public ShareRowOutPutDataHandler(ShareJoin ctx,List<byte[]> afields,int joini,String joinField,Map<String, byte[]> arows,NonBlockingSession session) {
		super();
		this.afields = afields;
		this.ctx = ctx;
		this.arows = arows;		
		this.joinL =joini;
		this.joinRkey= joinField;
		this.session = session;
		//EngineCtx.LOGGER.info("二次查询:" +arows.size()+ " afields:"+FenDBJoinHandler.getFieldNames(afields));
    }

	@Override
	public void onHeader(String dataNode, byte[] header, List<byte[]> bfields) {
		  this.bfields=bfields;		
		  joinR=this.ctx.getFieldIndex(bfields,joinRkey);
		  MiddlerResultHandler middlerResultHandler = session.getMiddlerResultHandler();

			if(  middlerResultHandler ==null ){
				  ctx.writeHeader(dataNode,afields, bfields);

			} 
 	}

	//不是主键,获取join左边的的记录
	private byte[] getRow(Map<String, byte[]> batchRowsCopy,String value,int index){
		for(Map.Entry<String,byte[]> e: batchRowsCopy.entrySet() ){
			String key=e.getKey();
			RowDataPacket rowDataPkg = ResultSetUtil.parseRowData(e.getValue(), afields);
			String id = ByteUtil.getString(rowDataPkg.fieldValues.get(index));
			if (id.equals(value)){
				return batchRowsCopy.remove(key);
			}
		}
		return null;
	}

	@Override
	public boolean onRowData(String dataNode, byte[] rowData) {
		RowDataPacket rowDataPkgold = ResultSetUtil.parseRowData(rowData, bfields);
		//拷贝一份batchRows
		Map<String, byte[]> batchRowsCopy = new ConcurrentHashMap<String, byte[]>();
		batchRowsCopy.putAll(arows);
		// 获取Id字段,
		String id = ByteUtil.getString(rowDataPkgold.fieldValues.get(joinR));
		// 查找ID对应的A表的记录
		byte[] arow = getRow(batchRowsCopy,id,joinL);//arows.remove(id);
//		byte[] arow = getRow(id,joinL);//arows.remove(id);
		while (arow!=null) {
			RowDataPacket rowDataPkg = ResultSetUtil.parseRowData(arow,afields );//ctx.getAllFields());
			for (int i=1;i<rowDataPkgold.fieldCount;i++){
				// 设置b.name 字段
				byte[] bname = rowDataPkgold.fieldValues.get(i);
				rowDataPkg.add(bname);
				rowDataPkg.addFieldCount(1);
			}
			//RowData(rowDataPkg);
			// huangyiming add
			MiddlerResultHandler middlerResultHandler = session.getMiddlerResultHandler();
			if(null == middlerResultHandler ){
				ctx.writeRow(rowDataPkg);
			}else{
				
				 if(middlerResultHandler instanceof MiddlerQueryResultHandler){
					// if(middlerResultHandler.getDataType().equalsIgnoreCase("string")){
						 byte[] columnData = rowDataPkg.fieldValues.get(0);
						 if(columnData !=null && columnData.length >0){
 							 String rowValue =    new String(columnData);
							 middlerResultHandler.add(rowValue);	
						 }
				   //}
				 }
				
			} 
			
			arow = getRow(batchRowsCopy,id,joinL);
//		   arow = getRow(id,joinL);
		}
		return false;
	}
	

	@Override
	public void finished(String dataNode, boolean failed, String errorMsg) {
		if(failed){
			session.getSource().writeErrMessage(ErrorCode.ER_UNKNOWN_ERROR, errorMsg);
		}
	}
}