SceneRecallService.java 7.81 KB
package com.yoho.search.recall.scene;

import com.alibaba.fastjson.JSONObject;
import com.yoho.search.models.SearchApiResult;
import com.yoho.search.recall.scene.models.*;
import com.yoho.search.recall.scene.persional.PersionalFactor;
import com.yoho.search.recall.scene.persional.RecallPersionalService;
import com.yoho.search.recall.scene.request.BrandRecallRequestBuilder;
import com.yoho.search.recall.scene.request.CommonRecallRequestBuilder;
import com.yoho.search.recall.scene.request.SortPriceRecallRequestBuilder;
import com.yoho.search.recall.scene.strategy.StrategyNameEnum;
import com.yoho.search.recall.sort.helper.RecallServiceHelper;
import com.yoho.search.service.helper.SearchCommonHelper;
import com.yoho.search.service.helper.SearchServiceHelper;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang.StringUtils;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.*;

@Component
public class SceneRecallService {

    @Autowired
    private SearchServiceHelper searchServiceHelepr;
    @Autowired
    private RecallPersionalService recallPersionalService;
    @Autowired
    private CommonRecallRequestBuilder commonRequestBuilder;
    @Autowired
    private BrandRecallRequestBuilder brandRequestBuilder;
    @Autowired
    private SortPriceRecallRequestBuilder sortPriceRequestBuilder;
    @Autowired
    private RecallCommonService recallCommonService;
    @Autowired
    private RecallServiceHelper recallServiceHelper;
    @Autowired
    private SearchCommonHelper searchCommonHelper;

    public SearchApiResult sceneRecall(Map<String, String> paramMap) {
        try {
            //1、分页参数验证
            int pageSize = StringUtils.isBlank(paramMap.get("viewNum")) ? 10 : Integer.parseInt(paramMap.get("viewNum"));
            int page = StringUtils.isBlank(paramMap.get("page")) ? 1 : Integer.parseInt(paramMap.get("page"));
            if (page < 1 || pageSize < 0 || page * pageSize > 1000000) {
                return new SearchApiResult().setCode(400).setMessage("分页参数不合法");
            }
            //2、构造召回相关参数
            RecallParams recallParams = this.buildRecallParams(paramMap);
            //3、执行召回
            RecallResponseBatch recallResponseBatch = this.doBatchRecall(recallParams);
            //4、构造返回结果
            JSONObject dataMap = new JSONObject();
            dataMap.put("total", recallResponseBatch.getTotal());
            dataMap.put("page", page);
            dataMap.put("page_size", pageSize);
            dataMap.put("page_total", searchCommonHelper.getTotalPage(recallResponseBatch.getTotal(), pageSize));
            dataMap.put("product_list", recallResponseBatch.getSknList());
            dataMap.put("product_list_size", recallResponseBatch.getSknList().size());
            return new SearchApiResult().setData(dataMap);
        } catch (Exception e) {
            e.printStackTrace();
            return new SearchApiResult().setData(null).setCode(500).setMessage("Exception");
        }
    }

    private RecallParams buildRecallParams(Map<String, String> paramMap) throws Exception {
        QueryBuilder query = searchServiceHelepr.constructQueryBuilder(paramMap);
        BoolQueryBuilder filter = searchServiceHelepr.constructFilterBuilder(paramMap, null);
        List<String> firstProductSkns = recallServiceHelper.getFirstProductSkns(paramMap);
        int pageSize = MapUtils.getIntValue(paramMap, "viewNum", 10);
        int uid = MapUtils.getIntValue(paramMap, "uid", 1);
        String udid = MapUtils.getString(paramMap, "udid", "");
        return new RecallParams(query, filter, firstProductSkns, pageSize, uid, udid);
    }

    private RecallResponseBatch doBatchRecall(RecallParams param) {
        //1、构造召回请求
        List<RecallRequest> allRequests = this.buildRecallRequests(param);
        //2、批量召回
        List<RecallRequestResponse> requestResponses = recallCommonService.batchRecallAndCache(allRequests);
        //3、从兜底类型中获取总数
        long total = this.getTotalCount(requestResponses);
        //4、获取召回的skn
        List<RecallResponseBatch.SknResult> sknResults = this.distinctRecallSkn(requestResponses);
        //5、构造返回结果
        return new RecallResponseBatch(total, sknResults);
    }

    private List<RecallRequest> buildRecallRequests(RecallParams param) {
        //1、构造召回请求
        List<RecallRequest> allRequests = new ArrayList<>();
        //2、构造非个性化的请求
        List<RecallRequest> commonRequests = commonRequestBuilder.buildCommonRecallRequests(param.getQuery(), param.getFilter(), param.getFirstProductSkns(), param.getPageSize());
        allRequests.addAll(commonRequests);
        //3、获取个性化因子
        PersionalFactor persionalFactor = recallPersionalService.queryPersionalFactor(param.getQuery(), param.getFilter(), param.getUid(), param.getUdid());
        //4、构建个性化品牌的召回请求
        List<RecallRequest> brandRequests = brandRequestBuilder.buildBrandRecallRequests(param.getQuery(), param.getFilter(), persionalFactor.getBrandIds());
        allRequests.addAll(brandRequests);
        //5、构建个性化品牌的召回请求
        List<RecallRequest> sortPriceRequests = sortPriceRequestBuilder.buildSortPriceRecallRequests(param.getQuery(), param.getFilter(), persionalFactor.getSortPriceAreas());
        allRequests.addAll(sortPriceRequests);
        return allRequests;
    }

    /**
     * 从兜底类型中获取总数
     *
     * @param requestResponses
     * @return
     */
    private long getTotalCount(List<RecallRequestResponse> requestResponses) {
        long total = 0;
        for (RecallRequestResponse requestResponse : requestResponses) {
            RecallRequest request = requestResponse.getRequest();
            RecallResponse response = requestResponse.getResponse();
            if (request ==null || response == null) {
                continue;
            }
            if (StrategyNameEnum.COMMON.name().equalsIgnoreCase(request.requestType())){
                total = response.getTotal();
                break;
            }
        }
        return total;
    }

    /**
     * 召回结果去重
     *
     * @param requestResponses
     * @return
     */
    private List<RecallResponseBatch.SknResult> distinctRecallSkn(List<RecallRequestResponse> requestResponses) {
        List<RecallResponseBatch.SknResult> sknResults = new ArrayList<>();
        Map<Integer,List<String>> sknRequestMaps = new HashMap<>();
        for (RecallRequestResponse requestResponse : requestResponses) {
            RecallRequest request = requestResponse.getRequest();
            RecallResponse response = requestResponse.getResponse();
            if (request==null || response == null || response.getSkns()==null){
                continue;
            }
            for (RecallResponse.RecallSkn recallSkn : response.getSkns()) {
                List<String> requestTypes = sknRequestMaps.get(recallSkn.getSkn());
                if(requestTypes==null){
                    sknResults.add(new RecallResponseBatch.SknResult(recallSkn));
                    requestTypes = new ArrayList<>();
                    sknRequestMaps.put(recallSkn.getSkn(),requestTypes);
                }
                requestTypes.add(request.requestType());
            }
        }
        for (RecallResponseBatch.SknResult sknResult:sknResults) {
            sknResult.setRequestTypes(sknRequestMaps.get(sknResult.getProductSkn()));
        }
        return sknResults;
    }

}