/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.model.prediction.accuracy;

import com.google.common.collect.Sets;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.DoubleSummaryStatistics;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.domain.base.Point;
import org.thingsboard.trendz.domain.base.TimeSeries;
import org.thingsboard.trendz.domain.definition.view.FieldAggregation;
import org.thingsboard.trendz.exception.TrendzException;
import org.thingsboard.trendz.service.model.prediction.accuracy.AccuracyData;
import org.thingsboard.trendz.service.model.prediction.accuracy.AccuracyMethod;
import org.thingsboard.trendz.service.model.prediction.accuracy.AccuracyMethodConfidenceBand;
import org.thingsboard.trendz.service.model.prediction.accuracy.AccuracyMethodConfidenceBandData;
import org.thingsboard.trendz.service.model.prediction.accuracy.AccuracyMethodData;
import org.thingsboard.trendz.service.model.prediction.segment.SegmentData;
import org.thingsboard.trendz.tools.DateTimeUtils;

@Service
public class AccuracyMethodConfidenceBand
implements AccuracyMethod {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(AccuracyMethodConfidenceBand.class);

    public AccuracyData calculate(Set<SegmentData> segmentData, AccuracyMethodData d, ChronoUnit timeUnit, ZoneId zoneId, boolean autoDefineSettings) {
        log.info("Accuracy method {}: segment data count = {}, time unit = {}, zone = {}", new Object[]{d.getType(), segmentData.size(), timeUnit, zoneId});
        AccuracyMethodConfidenceBandData methodData = (AccuracyMethodConfidenceBandData)d;
        double valueLevelMin = methodData.getValueLevelMin();
        double valueLevelMax = methodData.getValueLevelMax();
        double percentile = methodData.getPercentile();
        Set accuracySet = segmentData.stream().map(data -> {
            TimeSeries historicalTelemetry = data.getHistoricalTelemetry();
            TimeSeries predictionTelemetry = data.getPredictionTelemetry();
            return this.makeAccuracy(historicalTelemetry, predictionTelemetry, valueLevelMin, valueLevelMax);
        }).collect(Collectors.toSet());
        Map pointMap = this.createPointMap(accuracySet, timeUnit, zoneId);
        Map statisticsMap = this.aggregateStatistic(pointMap, percentile);
        TimeSeries max = this.getAggregation(statisticsMap, FieldAggregation.MAX);
        TimeSeries avg = this.getAggregation(statisticsMap, FieldAggregation.AVG);
        TimeSeries min = this.getAggregation(statisticsMap, FieldAggregation.MIN);
        log.debug("Final accuracy result: MAX = {}", (Object)max.getPoints());
        log.debug("Final accuracy result: AVG = {}", (Object)avg.getPoints());
        log.debug("Final accuracy result: MIN = {}", (Object)min.getPoints());
        return new AccuracyData(Map.of("max", max, "avg", avg, "min", min));
    }

    public AccuracyMethodData autoDefineSettings(AccuracyMethodData d, Set<SegmentData> segmentData) {
        AccuracyMethodConfidenceBandData methodData = new AccuracyMethodConfidenceBandData((AccuracyMethodConfidenceBandData)d);
        DoubleSummaryStatistics statistics = segmentData.stream().map(SegmentData::getHistoricalTelemetry).map(TimeSeries::getPoints).flatMap(Collection::stream).mapToDouble(Point::getValue).summaryStatistics();
        methodData.setPercentile(80.0);
        methodData.setValueLevelMin(statistics.getMin());
        methodData.setValueLevelMax(statistics.getMax());
        return methodData;
    }

    public void validate(AccuracyMethodData d) {
        AccuracyMethodConfidenceBandData methodData = (AccuracyMethodConfidenceBandData)d;
        double valueLevelMin = methodData.getValueLevelMin();
        double valueLevelMax = methodData.getValueLevelMax();
        double percentile = methodData.getPercentile();
        log.debug("Used parameters, min level value = {}, max level value = {}, percentile = {}", new Object[]{valueLevelMin, valueLevelMax, percentile});
        if (valueLevelMax < valueLevelMin) {
            throw new TrendzException("Invalid level values");
        }
        if (percentile < 0.0 || 100.0 < percentile) {
            throw new TrendzException("Invalid percentile value: " + percentile);
        }
    }

    private TimeSeries makeAccuracy(TimeSeries historicalTelemetry, TimeSeries predictionTelemetry, double valueLevelMin, double valueLevelMax) {
        Map<Long, Double> historicalDataMap = historicalTelemetry.getPoints().stream().collect(Collectors.toMap(Point::getTs, Point::getValue));
        Map<Long, Double> predictionDataMap = predictionTelemetry.getPoints().stream().collect(Collectors.toMap(Point::getTs, Point::getValue));
        Sets.SetView allTs = Sets.intersection(historicalDataMap.keySet(), predictionDataMap.keySet());
        ArrayList<Point> result = new ArrayList<Point>();
        Iterator iterator = allTs.iterator();
        while (iterator.hasNext()) {
            long ts = (Long)iterator.next();
            double historicalValue = historicalDataMap.get(ts);
            double predictedValue = predictionDataMap.get(ts);
            double accuracyValue = this.basicAccuracyFunction(historicalValue, predictedValue, valueLevelMin, valueLevelMax);
            Point point = new Point(ts, accuracyValue);
            result.add(point);
        }
        Collections.sort(result);
        TimeSeries accuracy = new TimeSeries(result);
        log.debug("Accuracy result: historical = {}", (Object)historicalTelemetry.getPoints());
        log.debug("Accuracy result: prediction = {}", (Object)predictionTelemetry.getPoints());
        log.debug("Accuracy result:   accuracy = {}", result);
        return accuracy;
    }

    private double basicAccuracyFunction(double original, double predicted, double valueLevelMin, double valueLevelMax) {
        double distance = Math.abs(original - predicted);
        if (distance == 0.0) {
            return 100.0;
        }
        double relativeRangeSize = valueLevelMax - valueLevelMin;
        if (relativeRangeSize == 0.0) {
            throw new TrendzException("relativeRangeSize is zero!");
        }
        if (relativeRangeSize < distance) {
            return 0.0;
        }
        return 100.0 - 100.0 * distance / relativeRangeSize;
    }

    private Map<Long, List<Double>> createPointMap(Set<TimeSeries> timeSeriesSet, ChronoUnit timeUnit, ZoneId zoneId) {
        HashMap<Long, List<Double>> pointMap = new HashMap<Long, List<Double>>();
        for (TimeSeries timeseries : timeSeriesSet) {
            if (timeseries.isEmpty()) continue;
            long minTs = timeseries.getPoints().stream().mapToLong(Point::getTs).min().orElseThrow();
            ZonedDateTime minDate = DateTimeUtils.fromTs((long)minTs, (ZoneId)zoneId);
            for (Point currentPoint : timeseries.getPoints()) {
                long ts = currentPoint.getTs();
                ZonedDateTime date = DateTimeUtils.fromTs((long)ts, (ZoneId)zoneId);
                long distance = timeUnit.between(minDate, date);
                double value = currentPoint.getValue();
                pointMap.computeIfAbsent(distance, key -> new ArrayList()).add(value);
            }
        }
        return pointMap;
    }

    private Map<Long, DoubleSummaryStatistics> aggregateStatistic(Map<Long, List<Double>> pointMap, double percentile) {
        HashMap<Long, DoubleSummaryStatistics> statisticsMap = new HashMap<Long, DoubleSummaryStatistics>();
        for (long distance : pointMap.keySet()) {
            DoubleSummaryStatistics statistics;
            List<Double> values = pointMap.get(distance);
            values.sort(Double::compareTo);
            if (values.size() == 1) {
                double singleValue = values.get(0);
                statistics = new DoubleSummaryStatistics();
                statistics.accept(singleValue);
            } else {
                int percentileIndex = (int)Math.max(0.0, Math.ceil((100.0 - percentile) / 100.0 * (double)values.size()) - 1.0);
                percentileIndex = Math.min(percentileIndex, values.size() - 1);
                List<Double> percentileValues = values.subList(percentileIndex, values.size());
                statistics = percentileValues.stream().collect(Collectors.summarizingDouble(Double::doubleValue));
            }
            statisticsMap.put(distance, statistics);
            log.debug("Accuracy aggregation result: values = {}, statistic = {}", values, (Object)statistics);
        }
        return statisticsMap;
    }

    private TimeSeries getAggregation(Map<Long, DoubleSummaryStatistics> statisticsMap, FieldAggregation aggregation) {
        ArrayList<Point> resultPoints = new ArrayList<Point>();
        for (long distance : statisticsMap.keySet()) {
            DoubleSummaryStatistics statistics = statisticsMap.get(distance);
            double value = switch (1.$SwitchMap$org$thingsboard$trendz$domain$definition$view$FieldAggregation[aggregation.ordinal()]) {
                case 1 -> statistics.getMin();
                case 2 -> statistics.getMax();
                case 3 -> statistics.getAverage();
                default -> throw new RuntimeException("Unsupported aggregation: " + String.valueOf(aggregation));
            };
            resultPoints.add(new Point(distance, value));
        }
        resultPoints.sort(Comparator.comparingLong(Point::getTs));
        return new TimeSeries(resultPoints);
    }
}

