/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.model.prediction.handler;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.complex.Complex;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.apache.commons.math3.transform.DftNormalization;
import org.apache.commons.math3.transform.FastFourierTransformer;
import org.apache.commons.math3.transform.TransformType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.domain.base.Point;
import org.thingsboard.trendz.domain.base.TimeSeries;
import org.thingsboard.trendz.security.entity.JwtSecurityUser;
import org.thingsboard.trendz.service.model.prediction.handler.FftPoint;
import org.thingsboard.trendz.service.model.prediction.handler.FftState;
import org.thingsboard.trendz.service.model.prediction.handler.PredictionModelHandlerInterface;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodParameters;
import org.thingsboard.trendz.tools.json.JsonUtils;
import reactor.core.publisher.Mono;

/*
 * Exception performing whole class analysis ignored.
 */
@Service
public class JavaPredictionModelHandler
implements PredictionModelHandlerInterface {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(JavaPredictionModelHandler.class);
    private static final int MAX_STORED_POINTS_NUMBER = 2048;
    private static final double FFT_HARMONIC_REDUCTION_COEFFICIENT = 0.1;

    public Mono<String> createModelState(JwtSecurityUser user, PredictionMethodParameters parameters) {
        return Mono.just((Object)JsonUtils.toJson((Object)FftState.getDefault()));
    }

    public Mono<String> fitModel(JwtSecurityUser user, PredictionMethodParameters parameters, String state, TimeSeries inputTelemetry, Map<UUID, TimeSeries> additionalTelemetryMap) {
        FftState oldFftState = (FftState)JsonUtils.fromJson((String)state, FftState.class);
        List telemetryPoints = inputTelemetry.getPoints();
        double[] dataRaw = telemetryPoints.stream().mapToDouble(Point::getValue).toArray();
        double[] dataTrimmed = this.trimToPowerOfTwo(dataRaw);
        WeightedObservedPoints dataEnveloped = new WeightedObservedPoints();
        for (int i = 0; i < dataTrimmed.length; ++i) {
            dataEnveloped.add((double)i, dataTrimmed[i]);
        }
        PolynomialCurveFitter linearRegressionFitter = PolynomialCurveFitter.create((int)1);
        double[] lrCoefficients = linearRegressionFitter.fit((Collection)dataEnveloped.toList());
        double lrIntercept = lrCoefficients[0];
        double lrSlope = lrCoefficients[1];
        double[] dataWithoutLinearTrend = new double[dataTrimmed.length];
        for (int i = 0; i < dataTrimmed.length; ++i) {
            dataWithoutLinearTrend[i] = dataTrimmed[i] - (lrSlope * (double)i + lrIntercept);
        }
        FastFourierTransformer fastFourierTransformer = new FastFourierTransformer(DftNormalization.STANDARD);
        Complex[] fftCoefficientsComplex = fastFourierTransformer.transform(dataWithoutLinearTrend, TransformType.FORWARD);
        List<FftPoint> fftCoefficients = Arrays.stream(fftCoefficientsComplex).map(c -> new FftPoint(c.getReal(), c.getImaginary())).toList();
        List lastPoints = inputTelemetry.getPoints().subList(Math.max(0, telemetryPoints.size() - 2048), telemetryPoints.size());
        FftState newFftState = new FftState(oldFftState.getIteration() + 1, lrCoefficients, fftCoefficients, lastPoints);
        String result = JsonUtils.toJson((Object)newFftState);
        return Mono.just((Object)result);
    }

    public Mono<String> partialFitModel(JwtSecurityUser user, PredictionMethodParameters parameters, String state, TimeSeries inputTelemetry, Map<UUID, TimeSeries> additionalTelemetryMap) {
        FftState fftState = (FftState)JsonUtils.fromJson((String)state, FftState.class);
        List pointsFromFftState = fftState.getLastPoints();
        TimeSeries combinedTimeseries = JavaPredictionModelHandler.mergePoints((List)pointsFromFftState, (List)inputTelemetry.getPoints());
        return this.fitModel(user, parameters, state, combinedTimeseries, additionalTelemetryMap);
    }

    public Mono<TimeSeries> predict(JwtSecurityUser user, PredictionMethodParameters parameters, String state, List<Long> outputX) {
        FftState fftState = (FftState)JsonUtils.fromJson((String)state, FftState.class);
        double[] lrCoefficients = fftState.getLrCoefficients();
        double lrIntercept = lrCoefficients[0];
        double lrSlope = lrCoefficients[1];
        Complex[] fftCoefficientsComplex = (Complex[])fftState.getFftCoefficients().stream().map(fftPoint -> new Complex(fftPoint.getReal(), fftPoint.getImaginary())).toArray(Complex[]::new);
        int learTsSize = fftState.getFftCoefficients().size();
        int harmonics = (int)((double)learTsSize * 0.1);
        if (harmonics > 100) {
            harmonics = 100;
        }
        int[] indexes = this.getSortedIndexes(fftCoefficientsComplex, harmonics);
        double[] freq = this.fftFreq(learTsSize);
        double[] restoredArray = new double[outputX.size()];
        for (int i : indexes) {
            double frequency = freq[i];
            double amplitude = Math.abs(fftCoefficientsComplex[i].abs()) / (double)learTsSize;
            double phase = Math.atan2(fftCoefficientsComplex[i].getImaginary(), fftCoefficientsComplex[i].getReal());
            for (int j = 0; j < restoredArray.length; ++j) {
                int n = j;
                restoredArray[n] = restoredArray[n] + amplitude * Math.cos(Math.PI * 2 * frequency * (double)j + phase);
            }
        }
        ArrayList<Point> prediction = new ArrayList<Point>();
        for (int i = 0; i < restoredArray.length; ++i) {
            double v = restoredArray[i] + (lrIntercept + lrSlope * (double)i);
            prediction.add(new Point(outputX.get(i).longValue(), v));
        }
        TimeSeries result = new TimeSeries(prediction);
        return Mono.just((Object)result);
    }

    private double[] trimToPowerOfTwo(double[] original) {
        int currLength = original.length;
        int nextLength = JavaPredictionModelHandler.nextPowerOf2((int)currLength);
        if (currLength != nextLength) {
            currLength = nextLength >> 1;
            return Arrays.copyOfRange(original, original.length - currLength, original.length);
        }
        return original;
    }

    private static int nextPowerOf2(int a) {
        int b;
        for (b = 1; b < a; b <<= 1) {
        }
        return b;
    }

    private static TimeSeries mergePoints(List<Point> oldPoints, List<Point> newPoints) {
        long latestOldTs = oldPoints.stream().mapToLong(Point::getTs).max().orElseThrow();
        List filteredNewPoints = newPoints.stream().filter(p -> latestOldTs < p.getTs()).collect(Collectors.toCollection(ArrayList::new));
        List unitedPoints = Stream.concat(oldPoints.stream(), filteredNewPoints.stream()).collect(Collectors.toList());
        return new TimeSeries(unitedPoints);
    }

    private int[] getSortedIndexes(Complex[] fftCoefficients, int harmonics) {
        ArrayList<Pair> pairs = new ArrayList<Pair>();
        for (int i = 0; i < fftCoefficients.length; ++i) {
            double magnitude = fftCoefficients[i].abs();
            pairs.add(Pair.of((Object)i, (Object)magnitude));
        }
        pairs.sort((p1, p2) -> Double.compare((Double)p2.getRight(), (Double)p1.getRight()));
        int[] result = new int[Math.min(harmonics, pairs.size())];
        for (int i = 0; i < result.length; ++i) {
            result[i] = (Integer)((Pair)pairs.get(i)).getLeft();
        }
        return result;
    }

    private double[] fftFreq(int n) {
        double[] frequencies = new double[n];
        double deltaF = 1.0 / (double)n;
        for (int i = 0; i < n; ++i) {
            frequencies[i] = i <= n / 2 ? (double)i * deltaF : (double)(i - n) * deltaF;
        }
        return frequencies;
    }
}

