From 0ba0d3702d3bd99c86c7d93bca0e521c4c795b59 Mon Sep 17 00:00:00 2001 From: Sik Yoon Date: Thu, 6 Jun 2024 02:19:03 +0900 Subject: [PATCH] Implement linear regression --- src/strategy_team/mod.rs | 1 + .../indicators/linear_regression.rs | 80 +++++++++++++++++++ src/value_estimation_team/indicators/mod.rs | 1 + 3 files changed, 82 insertions(+) create mode 100644 src/value_estimation_team/indicators/linear_regression.rs diff --git a/src/strategy_team/mod.rs b/src/strategy_team/mod.rs index 429c13a..fdd9bb4 100644 --- a/src/strategy_team/mod.rs +++ b/src/strategy_team/mod.rs @@ -37,6 +37,7 @@ use crate::value_estimation_team::indicators::tema::{tema, TemaData}; use crate::value_estimation_team::indicators::wiliams_percent_r::{ wiliams_percent_r, WiliamsPercentR, }; +use crate::value_estimation_team::indicators::linear_regression::{LrData, linear_regression}; use crate::future::Position; use futures::future::try_join_all; use reqwest::{Client, ClientBuilder}; diff --git a/src/value_estimation_team/indicators/linear_regression.rs b/src/value_estimation_team/indicators/linear_regression.rs new file mode 100644 index 0000000..03b70ba --- /dev/null +++ b/src/value_estimation_team/indicators/linear_regression.rs @@ -0,0 +1,80 @@ +#![allow(unused)] +#![allow(warnings)] + +use super::HashMap; +use crate::database_control::*; +use crate::strategy_team::FilteredDataValue; +use crate::value_estimation_team::datapoints::price_data::RealtimePriceData; +use futures::future::try_join_all; +use serde::Deserialize; +use sqlx::FromRow; +use std::sync::Arc; +use tokio::{fs::*, io::AsyncWriteExt, sync::Mutex, time::*}; + +#[derive(Clone, Debug)] +pub struct LrData { + pub lr_value: f64, // linear regression value + pub close_time: i64, +} +impl LrData { + fn new() -> LrData { + let a = LrData { + lr_value: 0.0, + close_time: 0, + }; + a + } +} + +// Binance MA (closeprice) +pub async fn linear_regression( + length: usize, + offset: usize, + input_rt_data: &HashMap>, + filtered_symbols: &HashMap, +) -> Result>, Box> { + if filtered_symbols.is_empty() { + Err("Err")?; + } + + let mut lr_data_wrapper: HashMap> = HashMap::new(); + let mut lr_data_wrapper_arc = Arc::new(Mutex::new(lr_data_wrapper)); + + let mut task_vec = Vec::new(); + for (symbol, filtered_data) in filtered_symbols { + if let Some(vec) = input_rt_data.get(symbol) { + let lr_data_wrapper_arc_c = Arc::clone(&lr_data_wrapper_arc); + let symbol_c = symbol.clone(); + let rt_price_data = vec.clone(); + if rt_price_data.len() >= length { + task_vec.push(tokio::spawn(async move { + // Calculate prediction of linear regression + let mut lr_data_vec: Vec = Vec::new(); + + for window in rt_price_data.windows(length) { + let mut lr_data = LrData::new(); + let x: Vec = (0..length).map(|x| x as f64).collect(); + let y: Vec = window.iter().map(|x| x.close_price).collect(); + + let x_mean: f64 = x.iter().sum::() / x.len() as f64; + let y_mean: f64 = y.iter().sum::() / y.len() as f64; + + let numerator: f64 = x.iter().zip(y.iter()).map(|(x_i, y_i)| (x_i - x_mean) * (y_i - y_mean)).sum(); + let denominator: f64 = x.iter().map(|x_i| (x_i - x_mean).powi(2)).sum(); + + let slope = numerator / denominator; + let intercept = y_mean - slope * x_mean; + + let linreg = intercept + slope * (length as f64 - 1.0 - offset as f64); + lr_data.lr_value = linreg; + lr_data.close_time = window.last().unwrap().close_time; + lr_data_vec.push(lr_data.clone()); + } + })); + } + } + } + try_join_all(task_vec).await?; + let a = lr_data_wrapper_arc.lock().await.to_owned(); + Ok(a) +} diff --git a/src/value_estimation_team/indicators/mod.rs b/src/value_estimation_team/indicators/mod.rs index ba2252e..019e1ff 100644 --- a/src/value_estimation_team/indicators/mod.rs +++ b/src/value_estimation_team/indicators/mod.rs @@ -10,6 +10,7 @@ pub mod stoch_rsi; pub mod supertrend; pub mod tema; pub mod wiliams_percent_r; +pub mod linear_regression; use crate::strategy_team::FilteredDataValue; use crate::value_estimation_team::datapoints::price_data::RealtimePriceData;