Sampler.java 13.5 KB
package com.example.demo;

import lombok.Data;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
import java.util.*;

/**
 * Created by jack on 2018/4/24.
 */
@Service
@ConfigurationProperties("info")
@Data
public class Sampler {
    public static final Logger LOGGER = LoggerFactory.getLogger(Sampler.class);
    public String querySknFile;
    public String skn2ClassFile;
    public String aliSknResFile;
    public String opencvSknResFile;
    public String allValidPicFile;
    public String allPaizhaoPicFile;
    public String paizhaoResFile;

    public String qSimpleFile;
    public String pSimpleFile;
    public String nSimpleFile;

    //存放  skn 的大小分类
    public Map<String, String> skn2ClassMap = new HashMap<>();
    //存放  分类 下的skn
    public Map<String, List<String>> class2SknMap = new HashMap<>();

    public Map<String, List<String>> aliResMap = new HashMap<>();

    public Map<String, List<String>> visnetResMap = new HashMap<>();

    public Map<String, Map<String, Double>> opencvResMap = new HashMap<>();

    public Map<String, String> allValidPicMap = new HashMap<>();
    public Map<String, String> allPaizhaoPicMap = new HashMap<>();

//
//    @PostConstruct
    public void make() throws IOException {
        //read all clothes skn 类别
        readAllSknMap();
        readAllAliSknMap();
        readAllOpencvMap();
        readAllValidSknPicMap();
        //make
        makeSampleFile();

        readAllPaizhaoPicMap();
        makeYoloSampFile();
    }

    private void readAllAliSknMap() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(aliSknResFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String[] items = StringUtils.split(line.trim(), ":");
            if (null == items || 2 != items.length) continue;
            String q_skn = items[0];
            List<String> simpleList = new ArrayList<>();
            simpleList.addAll(Arrays.asList(StringUtils.split(items[1], ",")));
            if (aliResMap.containsKey(q_skn)) continue;
            aliResMap.put(q_skn, simpleList);
        }
    }


    //读取所有skn的分类信息
    public void readAllSknMap() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(skn2ClassFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String[] items = StringUtils.split(line.trim(), ",");
            if (null == items || 2 > items.length) continue;
            String skn = items[0];
            String skn_class = "";
            for (int i = 1; i < items.length; i++) skn_class = skn_class + items[i] + ",";
            skn_class = skn_class.replace(".jpeg", "");
            skn2ClassMap.put(skn, skn_class);
            if (!class2SknMap.containsKey(skn_class)) {
                class2SknMap.put(skn_class, new ArrayList<>());
            }
            class2SknMap.get(skn_class).add(skn);
        }
    }


    //读取 opencv的计算信息
    public void readAllOpencvMap() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(opencvSknResFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String[] items = StringUtils.split(line.trim(), ":");
            if (null == items || 2 != items.length) continue;
            String q_skn = items[0];
            Map<String, Double> simpleMap = new HashMap<>();
            String[] simples = StringUtils.split(items[1].trim(), ",");
            if (null == simples || 0 == simples.length) continue;
            List<String> simpleList = new ArrayList<>();

            for (String oneSimple : simples) {
                String[] kv = StringUtils.split(oneSimple.trim(), "-");
                if (null == kv || 0 == kv.length) {
                    continue;
                }
                simpleMap.put(kv[0], Double.valueOf(kv[1]));
                simpleList.add(kv[0]);
            }
            visnetResMap.put(q_skn, simpleList);
            opencvResMap.put(q_skn, simpleMap);
        }
    }


    //加载所有 yolo抠图
    public void readAllValidSknPicMap() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(allValidPicFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String[] items = StringUtils.split(line.trim(), "/");
            String[] littleItems = StringUtils.split(items[items.length - 1].trim(), "-");
            String skn = littleItems[0];
            allValidPicMap.put(skn, line);
        }
    }

    //加载所有拍照够截图 信息
    public void readAllPaizhaoPicMap() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(allPaizhaoPicFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String[] items = StringUtils.split(line.trim(), "/");
            if (null == items || 0 == items.length) {
                LOGGER.info(line);
                continue;
            }
            String[] littleItems = StringUtils.split(items[items.length - 1].trim(), "_");
            String skn = littleItems[0];
            allPaizhaoPicMap.put(skn, line);
        }
    }

    //同一小分类下随机 1个skn
    public String randomSameClassOne(String oneSkn, String oneClass) {
        List<String> skns = class2SknMap.get(oneClass);
        if (null == skns || skns.isEmpty()) return "";
        String skn = oneSkn;
        Random random = new Random();
        while (StringUtils.equals(oneSkn, skn)) {
            skn = skns.get(random.nextInt(skns.size() - 1));
        }
        return skn;
    }

    //不同小分类下随机 1个skn
    public String randomDifClassOne(String oneSkn, String oneClass) {
        Set<String> classSet = class2SknMap.keySet();
        List<String> classList = new ArrayList<>(classSet);
        String selectClass = oneClass;
        Random random = new Random();
        String skn = "";
        while (StringUtils.isBlank(skn)) {
            //先确定不同分类
            while (StringUtils.equals(selectClass, oneClass)) {
                selectClass = classList.get(random.nextInt(classList.size() - 1));
            }
            List<String> selectClassSknList = class2SknMap.get(selectClass);
            if (null == selectClassSknList || selectClassSknList.isEmpty()) continue;
            //随机选择 skn
            skn = selectClassSknList.get(random.nextInt(selectClassSknList.size() - 1));
        }
        return skn;
    }

    //ali 后三
    public List<String> aliLastThree(String querySkn) {
        List<String> res = aliResMap.get(querySkn);
        List<String> lastThree = new ArrayList<>();
        if (null == res || 0 >= res.size()) return lastThree;
        lastThree.addAll(res.subList(res.size() - 4, res.size() - 1));
        return lastThree;
    }

    //visnet 不同分类
    public List<String> vistnetDifClassOnes(String querySkn, String queryClass) {
        List<String> visRes = visnetResMap.get(querySkn);
        List<String> res = new ArrayList<>();
        if (null == visRes) return res;
        for (String one : visRes) {
            if (!StringUtils.equals(queryClass, skn2ClassMap.getOrDefault(one, ""))) {
                res.add(one);
            }
        }
        return res;
    }

    //直方图 计算值小于60%均值的
    public List<String> opencvLast60PercentSkns(String querySkn) {
        Map<String, Double> sknOpevMap = opencvResMap.get(querySkn);
        List<String> last60Skns = new ArrayList<>();
        if (null == sknOpevMap) return last60Skns;
        Double total = 0d;
        for (Double one : sknOpevMap.values()) {
            total += one;
        }
        Double threshold = total / sknOpevMap.size() * 0.6;

        for (String key : sknOpevMap.keySet()) {
            if (sknOpevMap.get(key) < threshold) {
                last60Skns.add(key);
            }
        }
        return last60Skns;
    }

    //前三同分类
    public List<String> aliTopThreeSameClass(String querySkn, String queryClass) {
        List<String> res = aliResMap.get(querySkn);
        List<String> topThree = new ArrayList<>();
        if (null == res || 0 >= res.size()) return topThree;
        int count = 0;
        for (String one : res) {
            if (StringUtils.equals(queryClass, skn2ClassMap.getOrDefault(one, ""))) {
                topThree.add(one);
                count++;
                if (count >= 3) break;
            }
        }
        return topThree;
    }

    public void makeSampleFile() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(querySknFile));
        while (iterator.hasNext()) {
            String line = iterator.next();
            String queryOneSkn = line.trim();
            if (StringUtils.isBlank(queryOneSkn)) continue;
            String queryOneClass = skn2ClassMap.get(queryOneSkn);
            if (StringUtils.isBlank(queryOneClass)) continue;

            Set<String> nagSampleSet = new HashSet<>();
            Set<String> posSampleSet = new HashSet<>();
            //不同分类负样本
            nagSampleSet.add(randomDifClassOne(queryOneSkn, queryOneClass));
            //相同分类负样本
            nagSampleSet.add(randomSameClassOne(queryOneSkn, queryOneClass));
            //阿里后3负样本
            nagSampleSet.addAll(aliLastThree(queryOneSkn));
            //visnet结果中非同一类的样本
            nagSampleSet.addAll(vistnetDifClassOnes(queryOneSkn, queryOneClass));
            //直方图过滤小于后百分之三十的
            nagSampleSet.addAll(opencvLast60PercentSkns(queryOneSkn));


            //阿里前三正样本
            posSampleSet.addAll(aliTopThreeSameClass(queryOneSkn, queryOneClass));

            buildSimple(queryOneSkn, posSampleSet, nagSampleSet);
        }
    }

    public void makeYoloSampFile() throws IOException {
        LineIterator iterator = FileUtils.lineIterator(new File(paizhaoResFile));
        List<Simple> simples = new ArrayList<>();
        while (iterator.hasNext()) {
            String line = iterator.next();
            String items[] = StringUtils.split(line.trim(), ",");
            if (null == items || 3 != items.length) continue;
            String queryOneSkn = items[0];
            if (!allPaizhaoPicMap.containsKey(queryOneSkn)) continue;
            if (!allValidPicMap.containsKey(queryOneSkn)) continue;
            String queryOneClass = skn2ClassMap.get(queryOneSkn);
            Set<String> nagSampleSet = new HashSet<>();
            Set<String> posSampleSet = new HashSet<>();
            //不同分类负样本
            nagSampleSet.add(randomDifClassOne(queryOneSkn, queryOneClass));
            //相同分类负样本
            nagSampleSet.add(randomSameClassOne(queryOneSkn, queryOneClass));
            //阿里后3负样本
            nagSampleSet.addAll(aliLastThree(queryOneSkn));
            //visnet结果中非同一类的样本
            nagSampleSet.addAll(vistnetDifClassOnes(queryOneSkn, queryOneClass));
            //直方图过滤小于后百分之三十的
            nagSampleSet.addAll(opencvLast60PercentSkns(queryOneSkn));

            //阿里前三正样本
            posSampleSet.addAll(aliTopThreeSameClass(queryOneSkn, queryOneClass));
            posSampleSet.add(queryOneClass);

            buildPaizhaoSimple(queryOneSkn, posSampleSet, nagSampleSet);
        }
    }

    public void buildSimple(String q_skn, Set<String> p_sknSet, Set<String> n_sknSet) throws IOException {
        List<Simple> simples = new ArrayList<>();
        if (!allValidPicMap.containsKey(q_skn)) return;
        for (String p : p_sknSet) {
            if (!allValidPicMap.containsKey(p)) continue;
            for (String n : n_sknSet) {
                if (!allValidPicMap.containsKey(n)) continue;
                simples.add(new Simple(allValidPicMap.get(q_skn), allValidPicMap.get(p), allValidPicMap.get(n)));
            }
        }

        writeSimple(simples);

    }

    public void buildPaizhaoSimple(String q_skn, Set<String> p_sknSet, Set<String> n_sknSet) throws IOException {
        List<Simple> simples = new ArrayList<>();
        if (!allPaizhaoPicMap.containsKey(q_skn)) return;
        for (String p : p_sknSet) {
            if (!allValidPicMap.containsKey(p)) continue;
            for (String n : n_sknSet) {
                if (!allValidPicMap.containsKey(n)) continue;
                simples.add(new Simple(allPaizhaoPicMap.get(q_skn), allValidPicMap.get(p), allValidPicMap.get(n)));
            }
        }

        writeSimple(simples);

    }

    @Data
    public static class Simple {
        String q;
        String p;
        String n;

        public Simple(String q, String p, String n) {
            this.q = q;
            this.p = p;
            this.n = n;
        }
    }

    public void writeSimple(List<Simple> simpleList) throws IOException {
        for (Simple simple : simpleList) {
            FileUtils.writeStringToFile(new File(qSimpleFile), simple.getQ() + "\n", "utf-8", true);
            FileUtils.writeStringToFile(new File(pSimpleFile), simple.getP() + "\n", "utf-8", true);
            FileUtils.writeStringToFile(new File(nSimpleFile), simple.getN() + "\n", "utf-8", true);
        }
    }

}