/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.model.prediction.methods;

import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.service.model.prediction.methods.LinearRegressionMethodParameters;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethod;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodParameters;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodType;

@Service
public class LinearRegressionMethod
implements PredictionMethod {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(LinearRegressionMethod.class);

    public PredictionMethodType getType() {
        return PredictionMethodType.LINEAR_REGRESSION;
    }

    public String getMethodDefinition(PredictionMethodParameters methodParameters) {
        LinearRegressionMethodParameters parameters = (LinearRegressionMethodParameters)methodParameters;
        return "#####################################################\n# Prediction Method: Linear Regression\n\nfrom sklearn.preprocessing import MinMaxScaler, StandardScaler\nfrom sklearn.linear_model import LinearRegression\nimport pickle\nimport numpy as np\n\nclass CustomModel(IModel):\n\n    def __init__(self, value_transformer=None, timestamp_transformer=None):\n        self.model = None\n        self.timestamp_transformer = timestamp_transformer if timestamp_transformer else StandardScaler()\n        self.value_transformer = value_transformer if value_transformer else MinMaxScaler()\n        self.sum_x = 0\n        self.sum_y = 0\n        self.sum_xy = 0\n        self.sum_xx = 0\n        self.n = 0\n\n    def init_state(self):\n        self.model = LinearRegression()\n\n    def train(self, data, additionalData=None):\n        # Prepare\n        ts = np.array([point[0] for point in data]).reshape(-1, 1)\n        values = np.array([point[1] for point in data]).reshape(-1, 1)\n        self.timestamp_transformer.fit(ts)\n        self.value_transformer.fit(values)\n        ts_scaled = self.timestamp_transformer.transform(ts)\n        values_scaled = self.value_transformer.transform(values)\n\n        # Fit\n        self.sum_x = np.sum(ts_scaled)\n        self.sum_y = np.sum(values_scaled)\n        self.sum_xy = np.sum(ts_scaled * values_scaled)\n        self.sum_xx = np.sum(ts_scaled ** 2)\n        self.n = len(ts_scaled)\n\n        self.model.fit(ts_scaled, values_scaled)\n\n    def partial_fit(self, data, additionalData=None):\n        # Prepare\n        ts = np.array([point[0] for point in data]).reshape(-1, 1)\n        values = np.array([point[1] for point in data]).reshape(-1, 1)\n        # self.timestamp_transformer.partial_fit(ts)\n        # self.value_transformer.partial_fit(values)\n        ts_scaled = self.timestamp_transformer.transform(ts)\n        values_scaled = self.value_transformer.transform(values)\n\n        # Fit\n        self.sum_x += np.sum(ts_scaled)\n        self.sum_y += np.sum(values_scaled)\n        self.sum_xy += np.sum(ts_scaled * values_scaled)\n        self.sum_xx += np.sum(ts_scaled ** 2)\n        self.n += len(ts_scaled)\n\n        if self.n > 0:\n            mean_x = self.sum_x / self.n\n            mean_y = self.sum_y / self.n\n            slope = (self.sum_xy - self.n * mean_x * mean_y) / (self.sum_xx - self.n * mean_x ** 2)\n            intercept = mean_y - slope * mean_x\n            self.model.coef_ = np.array([[slope]])\n            self.model.intercept_ = np.array([intercept])\n\n    def predict(self, timestamps):\n        ts = np.array(timestamps).reshape(-1, 1)\n        ts_scaled = self.timestamp_transformer.transform(ts)\n        predictions_scaled = self.model.predict(ts_scaled)\n        predictions = self.value_transformer.inverse_transform(predictions_scaled)\n        return list(zip(timestamps, predictions.flatten()))\n\n    def save_state(self, file_path):\n        with open(file_path, 'wb') as file:\n            state = {\n                'model': self.model,\n                'value_transformer': self.value_transformer,\n                'timestamp_transformer': self.timestamp_transformer,\n                'sum_x': self.sum_x,\n                'sum_y': self.sum_y,\n                'sum_xy': self.sum_xy,\n                'sum_xx': self.sum_xx,\n                'n': self.n\n            }\n            pickle.dump(state, file)\n\n    def load_state(self, file_path):\n        with open(file_path, 'rb') as file:\n            state = pickle.load(file)\n            self.model = state['model']\n            self.value_transformer = state['value_transformer']\n            self.timestamp_transformer = state['timestamp_transformer']\n            self.sum_x = state['sum_x']\n            self.sum_y = state['sum_y']\n            self.sum_xy = state['sum_xy']\n            self.sum_xx = state['sum_xx']\n            self.n = state['n']\n\n    def name(self):\n        return \"LinearRegressionModel\"\n\n#####################################################\n";
    }
}

