/*
 * Decompiled with CFR 0.152.
 */
package com.android.tradefed.invoker.shard;

import com.android.tradefed.config.IConfiguration;
import com.android.tradefed.error.HarnessRuntimeException;
import com.android.tradefed.invoker.IRescheduler;
import com.android.tradefed.invoker.TestInformation;
import com.android.tradefed.invoker.shard.DynamicShardHelper;
import com.android.tradefed.invoker.shard.ShardHelper;
import com.android.tradefed.log.ITestLogger;
import com.android.tradefed.log.LogUtil;
import com.android.tradefed.result.ITestLoggerReceiver;
import com.android.tradefed.result.error.InfraErrorIdentifier;
import com.android.tradefed.testtype.IBuildReceiver;
import com.android.tradefed.testtype.IDeviceTest;
import com.android.tradefed.testtype.IInvocationContextReceiver;
import com.android.tradefed.testtype.IRemoteTest;
import com.android.tradefed.testtype.IRuntimeHintProvider;
import com.android.tradefed.testtype.IShardableTest;
import com.android.tradefed.testtype.suite.ITestSuite;
import com.android.tradefed.testtype.suite.ModuleMerger;
import com.android.tradefed.util.TimeUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class StrictShardHelper
extends ShardHelper {
    @Override
    public boolean shardConfig(IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger) {
        if (config.getCommandOptions().shouldRemoteDynamicShard()) {
            DynamicShardHelper helper = new DynamicShardHelper();
            return helper.shardConfig(config, testInfo, rescheduler, logger);
        }
        Integer shardCount = config.getCommandOptions().getShardCount();
        Integer shardIndex = config.getCommandOptions().getShardIndex();
        boolean optimizeMainline = config.getCommandOptions().getOptimizeMainlineTest();
        if (shardIndex == null) {
            return super.shardConfig(config, testInfo, rescheduler, logger);
        }
        if (shardCount == null) {
            throw new RuntimeException("shard-count is null while shard-index is " + shardIndex);
        }
        if (shardCount == 1) {
            return false;
        }
        List<IRemoteTest> listAllTests = this.getAllTests(config, shardCount, testInfo, logger);
        this.normalizeDistribution(listAllTests, shardCount);
        List<IRemoteTest> splitList = shardCount == 1 ? listAllTests : this.splitTests(listAllTests, shardCount, config.getCommandOptions().shouldUseEvenModuleSharding()).get(shardIndex);
        this.aggregateSuiteModules(splitList);
        if (optimizeMainline) {
            LogUtil.CLog.i("Reordering the test modules list for index: %s", shardIndex);
            this.reorderTestModules(splitList);
        }
        config.setTests(splitList);
        return false;
    }

    private void reorderTestModules(List<IRemoteTest> tests) {
        Collections.sort(tests, new Comparator<IRemoteTest>(){

            @Override
            public int compare(IRemoteTest o1, IRemoteTest o2) {
                String moduleId1 = ((ITestSuite)o1).getDirectModule().getId();
                String moduleId2 = ((ITestSuite)o2).getDirectModule().getId();
                return StrictShardHelper.this.getMainlineId(moduleId1).compareTo(StrictShardHelper.this.getMainlineId(moduleId2));
            }
        });
    }

    private String getMainlineId(String id) {
        Pattern parameterizedMainlineRegex = Pattern.compile("\\[(.*(\\.apk|.apex|.apks))\\]$");
        Matcher m = parameterizedMainlineRegex.matcher(id);
        if (m.find()) {
            return m.group(1);
        }
        throw new HarnessRuntimeException(String.format("Module: %s doesn't match the pattern for mainline modules. The pattern should end with apk/apex/apks.", id), InfraErrorIdentifier.OPTION_CONFIGURATION_ERROR);
    }

    private List<IRemoteTest> getAllTests(IConfiguration config, Integer shardCount, TestInformation testInfo, ITestLogger logger) {
        ArrayList<IRemoteTest> allTests = new ArrayList<IRemoteTest>();
        for (IRemoteTest test : config.getTests()) {
            if (test instanceof IShardableTest) {
                Collection<IRemoteTest> subTests;
                if (test instanceof IBuildReceiver) {
                    ((IBuildReceiver)((Object)test)).setBuild(testInfo.getBuildInfo());
                }
                if (test instanceof IDeviceTest) {
                    ((IDeviceTest)((Object)test)).setDevice(testInfo.getDevice());
                }
                if (test instanceof IInvocationContextReceiver) {
                    ((IInvocationContextReceiver)((Object)test)).setInvocationContext(testInfo.getContext());
                }
                if (test instanceof ITestLoggerReceiver) {
                    ((ITestLoggerReceiver)((Object)test)).setTestLogger(logger);
                }
                if (test instanceof ITestSuite) {
                    ((ITestSuite)test).setShouldMakeDynamicModule(false);
                }
                if ((subTests = ((IShardableTest)test).split(shardCount, testInfo)) == null) {
                    allTests.add(test);
                    continue;
                }
                allTests.addAll(subTests);
                continue;
            }
            allTests.add(test);
        }
        return allTests;
    }

    protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount, boolean useEvenModuleSharding) {
        List<Object> shards;
        if (useEvenModuleSharding) {
            LogUtil.CLog.d("Using the sharding strategy to distribute number of modules more evenly.");
            shards = StrictShardHelper.shardList(fullList, shardCount);
        } else {
            shards = new ArrayList();
            int numPerShard = (int)Math.ceil((float)fullList.size() / (float)shardCount);
            boolean needsCorrection = false;
            float correctionRatio = 0.0f;
            if (fullList.size() > shardCount) {
                needsCorrection = numPerShard * (shardCount - 1) > fullList.size();
                correctionRatio = (float)numPerShard - (float)fullList.size() / (float)shardCount;
            }
            numPerShard = (int)Math.floor((float)numPerShard - correctionRatio);
            shards = this.balancedDistrib(fullList, shardCount, numPerShard, needsCorrection);
        }
        this.topBottom(shards, shardCount);
        return shards;
    }

    private List<List<IRemoteTest>> balancedDistrib(List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) {
        int i;
        ArrayList<List<IRemoteTest>> shards = new ArrayList<List<IRemoteTest>>();
        List<Object> correctionList = new ArrayList();
        int correctionSize = 0;
        for (i = 0; i < shardCount; ++i) {
            List<Object> shardList;
            if (i >= fullList.size()) {
                shardList = new ArrayList();
                shards.add(shardList);
                continue;
            }
            if (i == shardCount - 1) {
                if (needsCorrection) {
                    correctionSize = fullList.size() - (numPerShard + i * numPerShard);
                    correctionList = fullList.subList(fullList.size() - correctionSize, fullList.size());
                }
                shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize);
                shards.add(new ArrayList(shardList));
                continue;
            }
            shardList = fullList.subList(i * numPerShard, numPerShard + i * numPerShard);
            shards.add(new ArrayList(shardList));
        }
        for (i = 0; i < shardCount && i < correctionList.size(); ++i) {
            ((List)shards.get(i)).add((IRemoteTest)correctionList.get(i));
        }
        return shards;
    }

    static <T> List<List<T>> shardList(List<T> fullList, int shardCount) {
        int i;
        int totalSize = fullList.size();
        int smallShardSize = totalSize / shardCount;
        int bigShardSize = smallShardSize + 1;
        int bigShardCount = totalSize % shardCount;
        ArrayList<List<T>> shards = new ArrayList<List<T>>();
        for (i = 0; i < bigShardCount * bigShardSize; i += bigShardSize) {
            shards.add(fullList.subList(i, i + bigShardSize));
        }
        while (i < totalSize) {
            shards.add(fullList.subList(i, i + smallShardSize));
            i += smallShardSize;
        }
        while (shards.size() < shardCount) {
            shards.add(new ArrayList());
        }
        return shards;
    }

    private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) {
        int numRound = shardCount;
        int distance = shardCount - 1;
        for (int i = 0; i < numRound; ++i) {
            for (int j = 0; j < listAllTests.size(); j += distance) {
                IRemoteTest push = listAllTests.remove(j);
                listAllTests.add(push);
            }
        }
    }

    private void aggregateSuiteModules(List<IRemoteTest> tests) {
        ArrayList<IRemoteTest> dupList = new ArrayList<IRemoteTest>(tests);
        for (int i = 0; i < dupList.size(); ++i) {
            if (!(dupList.get(i) instanceof ITestSuite)) continue;
            for (int j = i + 1; j < dupList.size(); ++j) {
                if (!tests.contains(dupList.get(j)) || !(dupList.get(j) instanceof ITestSuite) || !ModuleMerger.arePartOfSameSuite((ITestSuite)dupList.get(i), (ITestSuite)dupList.get(j))) continue;
                ModuleMerger.mergeSplittedITestSuite((ITestSuite)dupList.get(i), (ITestSuite)dupList.get(j));
                tests.remove(dupList.get(j));
            }
        }
    }

    private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) {
        if (shardCount < 4) {
            return;
        }
        int index = 0;
        ArrayList<SortShardObj> shardTimes = new ArrayList<SortShardObj>();
        for (List<IRemoteTest> shard : allShards) {
            long aggTime = 0L;
            LogUtil.CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index);
            for (IRemoteTest test : shard) {
                if (!(test instanceof IRuntimeHintProvider)) continue;
                aggTime += ((IRuntimeHintProvider)((Object)test)).getRuntimeHint();
            }
            LogUtil.CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime));
            shardTimes.add(new SortShardObj(index, aggTime));
            ++index;
            LogUtil.CLog.d("+++++++++++++++++++++++++++++++++++++++++++");
        }
        Collections.sort(shardTimes);
        if (((SortShardObj)shardTimes.get((int)0)).mAggTime - ((SortShardObj)shardTimes.get((int)(shardTimes.size() - 1))).mAggTime < 3600000L) {
            return;
        }
        int i = 0;
        while ((double)i < (double)shardCount * 0.3) {
            LogUtil.CLog.d("Top shard %s is index %s with %s", i, ((SortShardObj)shardTimes.get((int)i)).mIndex, TimeUtil.formatElapsedTime(((SortShardObj)shardTimes.get((int)i)).mAggTime));
            int give = ((SortShardObj)shardTimes.get((int)i)).mIndex;
            int receive = ((SortShardObj)shardTimes.get((int)(shardTimes.size() - 1 - i))).mIndex;
            LogUtil.CLog.d("Giving from shard %s to shard %s", give, receive);
            int j = 0;
            while ((float)j < (float)allShards.get(give).size() * (0.2f / (float)(i + 1))) {
                IRemoteTest givetest = allShards.get(give).remove(0);
                allShards.get(receive).add(givetest);
                ++j;
            }
            ++i;
        }
    }

    private class SortShardObj
    implements Comparable<SortShardObj> {
        public final int mIndex;
        public final Long mAggTime;

        public SortShardObj(int index, long aggTime) {
            this.mIndex = index;
            this.mAggTime = aggTime;
        }

        @Override
        public int compareTo(SortShardObj obj) {
            return obj.mAggTime.compareTo(this.mAggTime);
        }
    }
}

