| import io |
| import numpy as np |
| import gradio as gr |
| import yfinance as yf |
| import pandas as pd |
|
|
|
|
| from PIL import Image |
| from datetime import datetime |
| import plotly.express as px |
| import matplotlib.pyplot as plt |
| import plotly.graph_objects as go |
|
|
|
|
| from prophet import Prophet |
|
|
| import torch |
| import timesfm |
|
|
| |
| try: |
| tfm_model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch") |
| tfm_model.compile(timesfm.ForecastConfig(max_context=1024, max_horizon=256)) |
| except Exception as e: |
| print(f"Model load error: {e}") |
| tfm_model = None |
|
|
| class StockDataFetcher: |
| """Handles fetching and preprocessing stock data""" |
| |
| @staticmethod |
| def fetch_stock_data(ticker, start_date, end_date): |
| """Fetch and preprocess stock data""" |
| stock_data = yf.download(ticker, start=start_date, end=end_date) |
| |
| |
| if isinstance(stock_data.columns, pd.MultiIndex): |
| stock_data.columns = stock_data.columns.droplevel(level=1) |
| |
| |
| stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume'] |
| |
| return stock_data |
|
|
| |
|
|
| def timesfm_forecast(ticker, start_date, end_date): |
| try: |
| |
| stock_data = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) |
| stock_data.reset_index(inplace=True) |
| df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) |
| df['ds'] = pd.to_datetime(df['ds']) |
| |
| |
| if tfm_model is None: |
| return "Error: Model failed to load." |
| |
| input_values = df['y'].values.astype(np.float32) |
| point_forecast, _ = tfm_model.forecast(horizon=30, inputs=[input_values]) |
| |
| |
| last_date = df['ds'].max() |
| forecast_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=30, freq='D') |
| forecast_df = pd.DataFrame({'ds': forecast_dates, 'forecast': point_forecast[0]}) |
|
|
| |
| fig = go.Figure() |
|
|
| fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], |
| mode="lines", name="Actual Prices", |
| line=dict(color="#00FFFF", width=2))) |
|
|
| fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], |
| mode="lines", name="Forecasted Prices", |
| line=dict(color="#FF00FF", width=2, dash="dash"))) |
|
|
| fig.update_layout( |
| title=f"{ticker} Stock Price Forecast (TimesFM)", |
| xaxis_title="Date", |
| yaxis_title="Price", |
| template="plotly_dark", |
| hovermode="x unified", |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), |
| plot_bgcolor="#111111", |
| paper_bgcolor="#111111", |
| font=dict(color="white", size=12), |
| margin=dict(l=40, r=40, t=50, b=40), |
| ) |
|
|
| fig.update_xaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") |
| fig.update_yaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") |
|
|
| return fig |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| def prophet_forecast(ticker, start_date, end_date): |
| try: |
| |
| df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) |
| |
| |
| df_plot = df.reset_index() |
| |
| |
| df1 = df_plot[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) |
| |
| |
| m = Prophet() |
| m.fit(df1) |
| |
| |
| future = m.make_future_dataframe(periods=30, freq='D') |
| forecast = m.predict(future) |
| |
| |
| fig1 = go.Figure() |
| |
| |
| fig1.add_trace(go.Scatter( |
| x=df1['ds'], |
| y=df1['y'], |
| mode='lines', |
| name='Actual Price', |
| line=dict(color='#36D7B7', width=2) |
| )) |
| |
| |
| fig1.add_trace(go.Scatter( |
| x=forecast['ds'], |
| y=forecast['trend'], |
| mode='lines', |
| name='Trend', |
| line=dict(color='#FF6B6B', width=2) |
| )) |
| |
| fig1.update_layout( |
| title=f'{ticker} Price and Trend', |
| plot_bgcolor='#111111', |
| paper_bgcolor='#111111', |
| font=dict(color='white', size=12), |
| margin=dict(l=40, r=40, t=50, b=40), |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) |
| ) |
| |
| |
| forecast_40 = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(40) |
| fig2 = go.Figure() |
| |
| |
| fig2.add_trace(go.Scatter( |
| x=forecast_40['ds'], |
| y=forecast_40['yhat'], |
| mode='lines', |
| name='Forecast', |
| line=dict(color='#FF6B6B', width=2) |
| )) |
| |
| |
| fig2.add_trace(go.Scatter( |
| x=forecast_40["ds"].tolist() + forecast_40["ds"].tolist()[::-1], |
| y=forecast_40["yhat_upper"].tolist() + forecast_40["yhat_lower"].tolist()[::-1], |
| fill="toself", |
| fillcolor="rgba(78, 205, 196, 0.2)", |
| line=dict(color="rgba(255,255,255,0)"), |
| name="Confidence Interval" |
| )) |
| |
| fig2.update_layout( |
| title=f'{ticker} 30 Days Forecast (Prophet)', |
| plot_bgcolor='#111111', |
| paper_bgcolor='#111111', |
| font=dict(color='white', size=12), |
| margin=dict(l=40, r=40, t=50, b=40), |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) |
| ) |
| |
| |
| components_fig = go.Figure() |
| |
| |
| if 'yearly' in forecast.columns: |
| yearly_pattern = forecast.iloc[-365:] if len(forecast) > 365 else forecast |
| components_fig.add_trace(go.Scatter( |
| x=yearly_pattern['ds'], |
| y=yearly_pattern['yearly'], |
| mode='lines', |
| name='Yearly Pattern', |
| line=dict(color='#4ECDC4', width=2) |
| )) |
| |
| |
| components_fig.update_layout( |
| title=f'{ticker} Forecast Components', |
| xaxis_title='Date', |
| yaxis_title='Value', |
| plot_bgcolor='#111111', |
| paper_bgcolor='#111111', |
| font=dict(color='white', size=12), |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), |
| margin=dict(l=40, r=40, t=50, b=40), |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)") |
| ) |
| |
| |
| try: |
| plt.style.use('dark_background') |
| fig, ax = plt.subplots(figsize=(10, 8), facecolor='#111111') |
| |
| plt.rcParams.update({ |
| 'text.color': 'white', |
| 'axes.labelcolor': 'white', |
| 'axes.edgecolor': 'white', |
| 'xtick.color': 'white', |
| 'ytick.color': 'white', |
| 'grid.color': 'gray', |
| 'figure.facecolor': '#111111', |
| 'axes.facecolor': '#111111', |
| 'savefig.facecolor': '#111111', |
| }) |
| |
| m.plot_components(forecast, ax=ax) |
| |
| for ax in plt.gcf().get_axes(): |
| ax.set_facecolor('#111111') |
| for spine in ax.spines.values(): |
| spine.set_color('white') |
| ax.tick_params(colors='white') |
| ax.title.set_color('white') |
| for line in ax.get_lines(): |
| if line.get_color() == 'b': |
| line.set_color('#C678DD') |
| else: |
| line.set_color('#FF6B6B') |
| |
| plt.tight_layout() |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', facecolor='#111111') |
| buf.seek(0) |
| plt.close(fig) |
| |
| img = Image.open(buf) |
| |
| return fig1, fig2, components_fig |
| except Exception as e: |
| print(f"Error with Matplotlib components: {e}") |
| return fig1, fig2, components_fig |
| |
| except Exception as e: |
| return f"Error: {str(e)}", f"Error: {str(e)}", None |
| |
| |
|
|
|
|
| def smooth_moving_average(series: pd.Series, window: int) -> pd.Series: |
| if len(series) < window or window <= 0: |
| return pd.Series(series.mean(), index=series.index) |
| result = pd.Series(index=series.index, dtype=float) |
| result.iloc[:window] = series.iloc[:window].mean() |
| for i in range(window, len(series)): |
| result.iloc[i] = (result.iloc[i-1] * (window - 1) + series.iloc[i]) / window |
| return result.ffill().bfill().fillna(series.mean()) |
|
|
| def calculate_rsi(close: pd.Series, window: int = 14) -> pd.Series: |
| if len(close) <= window: |
| return pd.Series(50.0, index=close.index) |
| delta = close.diff() |
| gain = delta.where(delta > 0, 0.0) |
| loss = -delta.where(delta < 0, 0.0) |
| avg_gain = smooth_moving_average(gain, window) |
| avg_loss = smooth_moving_average(loss, window) |
| rs = np.where(avg_loss != 0, avg_gain / avg_loss, np.inf) |
| rsi = 100.0 - (100.0 / (1.0 + rs)) |
| return pd.Series(rsi, index=close.index).replace([np.inf, -np.inf], np.nan).ffill().bfill().fillna(50.0) |
|
|
| def calculate_stochastic(high: pd.Series, low: pd.Series, close: pd.Series, k_window=14, d_window=3): |
| if len(close) < k_window: |
| return pd.Series(50.0, index=close.index), pd.Series(50.0, index=close.index) |
| lowest = low.rolling(k_window, min_periods=1).min() |
| highest = high.rolling(k_window, min_periods=1).max() |
| k_pct = ((close - lowest) / (highest - lowest + 1e-10)) * 100 |
| k_pct = k_pct.clip(0, 100) |
| d_pct = k_pct.rolling(d_window, min_periods=1).mean() |
| return k_pct.ffill().bfill().fillna(50.0), d_pct.ffill().bfill().fillna(50.0) |
|
|
| def calculate_cci(high: pd.Series, low: pd.Series, close: pd.Series, window=20): |
| if len(close) < window: |
| return pd.Series(0.0, index=close.index) |
| typical_price = (high + low + close) / 3.0 |
| sma = typical_price.rolling(window, min_periods=1).mean() |
| mean_deviation = (typical_price - sma).abs().rolling(window, min_periods=1).mean() |
| cci = (typical_price - sma) / (0.015 * mean_deviation + 1e-10) |
| return cci.ffill().bfill().fillna(0.0) |
|
|
| |
| def calculate_sma_robust(series: pd.Series, window: int) -> pd.Series: |
| if len(series) < window or window <= 0: |
| return pd.Series(series.mean(), index=series.index) |
| return series.rolling(window=window, min_periods=window).mean().ffill().bfill().fillna(series.mean()) |
|
|
| def calculate_ema_robust(series: pd.Series, span: int) -> pd.Series: |
| if len(series) < span or span <= 0: |
| return pd.Series(series.mean(), index=series.index) |
| return series.ewm(span=span, adjust=False, min_periods=span).mean().ffill().bfill().fillna(series.mean()) |
|
|
| def calculate_macd_robust(close: pd.Series): |
| ema12 = calculate_ema_robust(close, 12) |
| ema26 = calculate_ema_robust(close, 26) |
| macd_line = ema12 - ema26 |
| signal_line = calculate_ema_robust(macd_line, 9) |
| return macd_line, signal_line |
|
|
| def calculate_bollinger_bands_robust(close: pd.Series, window=20, num_std=2.0): |
| if len(close) < window: |
| mid = pd.Series(close.mean(), index=close.index) |
| return mid, mid, mid |
| sma = calculate_sma_robust(close, window) |
| std = close.rolling(window=window, min_periods=window).std().fillna(1e-10) |
| upper = sma + num_std * std |
| lower = sma - num_std * std |
| return sma.ffill().bfill(), upper.ffill().bfill(), lower.ffill().bfill() |
|
|
| |
| def generate_trading_signals(df: pd.DataFrame) -> pd.DataFrame: |
| """ |
| Generates trading signals using strict thresholds to minimize false positives. |
| Output columns match the expected names for the plotting functions. |
| """ |
| df = df.copy() |
| close = df['Close'] |
| has_hl = all(col in df.columns for col in ['High', 'Low']) |
| has_vol = 'Volume' in df.columns |
|
|
| high = df['High'] if has_hl else close |
| low = df['Low'] if has_hl else close |
| volume = df['Volume'] if has_vol else pd.Series(1.0, index=close.index) |
|
|
| |
| rsi = calculate_rsi(close, window=14) |
| stoch_k, stoch_d = calculate_stochastic(high, low, close, k_window=14, d_window=3) |
| cci = calculate_cci(high, low, close, window=20) |
| sma30 = calculate_sma_robust(close, 30) |
| sma100 = calculate_sma_robust(close, 100) |
| macd_line, macd_signal_line = calculate_macd_robust(close) |
| _, bb_upper, bb_lower = calculate_bollinger_bands_robust(close, window=20, num_std=3.0) |
|
|
| |
| if has_hl and has_vol: |
| mfv = ((close - low) - (high - close)) / (high - low + 1e-10) * volume |
| cmf = mfv.rolling(window=20, min_periods=20).sum() / (volume.rolling(window=20, min_periods=20).sum() + 1e-10) |
| cmf = cmf.ffill().bfill().fillna(0.0) |
| else: |
| cmf = pd.Series(0.0, index=close.index) |
|
|
| |
|
|
| |
| for col in ['RSI_Signal', 'BB_Signal', 'Stochastic_Signal', 'CCI_Signal']: |
| df[col] = 0 |
|
|
| |
| df.loc[rsi < 20, 'RSI_Signal'] = 1 |
| df.loc[rsi > 80, 'RSI_Signal'] = -1 |
|
|
| |
| df.loc[close <= bb_lower, 'BB_Signal'] = 1 |
| df.loc[close >= bb_upper, 'BB_Signal'] = -1 |
|
|
| |
| stoch_buy = (stoch_k < 5) & (stoch_d < 5) |
| stoch_sell = (stoch_k > 95) & (stoch_d > 95) |
| df.loc[stoch_buy, 'Stochastic_Signal'] = 1 |
| df.loc[stoch_sell, 'Stochastic_Signal'] = -1 |
|
|
|
|
| |
| df.loc[cci < -250, 'CCI_Signal'] = 1 |
| df.loc[cci > 250, 'CCI_Signal'] = -1 |
|
|
| |
| df['Combined_Signal'] = df[['RSI_Signal', 'BB_Signal', |
| 'Stochastic_Signal', 'CCI_Signal']].sum(axis=1) |
|
|
| return df |
|
|
|
|
| def plot_combined_signals(df, ticker): |
| """ |
| Creates a focused plot of JUST the combined signal strength. |
| Bars are colored green for positive (buy) signals and red for negative (sell) signals. |
| """ |
| |
| fig = go.Figure() |
|
|
| |
| colors = ['#2ECC71' if val >= 0 else '#E74C3C' for val in df['Combined_Signal']] |
| |
| |
| fig.add_trace(go.Bar( |
| x=df.index, |
| y=df['Combined_Signal'], |
| name='Signal Strength', |
| marker_color=colors, |
| |
| hovertemplate='<b>Date</b>: %{x}<br><b>Signal</b>: %{y}<extra></extra>' |
| )) |
|
|
| |
| fig.update_layout( |
| title=f'{ticker}', |
| template='plotly_dark', |
| xaxis_title='Date', |
| yaxis_title='Signal Strength Score', |
| yaxis=dict(zeroline=True, zerolinewidth=2, zerolinecolor='gray'), |
| showlegend=False |
| ) |
|
|
| return fig |
|
|
| def plot_individual_signals(df, ticker, x_range=None): |
| fig = go.Figure() |
| |
| |
| fig.add_trace(go.Scatter( |
| x=df.index, y=df['Close'], |
| mode='lines', |
| name='Closing Price', |
| line=dict(color='#36A2EB', width=2) |
| )) |
|
|
| signal_colors = { |
| 'RSI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, |
| 'BB_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, |
| 'Stochastic_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, |
| 'CCI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'} |
| } |
|
|
| signal_names = ['RSI_Signal', 'BB_Signal', |
| 'Stochastic_Signal', 'CCI_Signal'] |
| |
| for signal in signal_names: |
| buy_signals = df[df[signal] == 1] |
| sell_signals = df[df[signal] == -1] |
| |
| fig.add_trace(go.Scatter( |
| x=buy_signals.index, y=buy_signals['Close'], |
| mode='markers', |
| marker=dict(symbol='triangle-up', size=12, color=signal_colors[signal]['buy']), |
| name=f'{signal} Buy' |
| )) |
| fig.add_trace(go.Scatter( |
| x=sell_signals.index, y=sell_signals['Close'], |
| mode='markers', |
| marker=dict(symbol='triangle-down', size=12, color=signal_colors[signal]['sell']), |
| name=f'{signal} Sell' |
| )) |
|
|
| fig.update_layout( |
| title=f'{ticker}', |
| xaxis=dict( |
| title='Date', |
| showgrid=True, |
| gridcolor="rgba(255,255,255,0.1)", |
| range=x_range |
| ), |
| yaxis=dict( |
| title='Price', |
| side='left', |
| showgrid=True, |
| gridcolor="rgba(255,255,255,0.1)" |
| ), |
| plot_bgcolor='#111111', |
| paper_bgcolor='#111111', |
| font=dict(color='white', size=12), |
| legend=dict( |
| orientation='h', |
| yanchor='bottom', |
| y=1.02, |
| xanchor='right', |
| x=1, |
| bgcolor="rgba(0,0,0,0.8)", |
| bordercolor="white", |
| borderwidth=1 |
| ), |
| margin=dict(l=40, r=40, t=80, b=40) |
| ) |
|
|
| return fig |
|
|
| def technical_analysis(ticker, start_date, end_date): |
| try: |
| |
| df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) |
|
|
| |
| df = generate_trading_signals(df) |
| |
| |
| df_last_120 = df.tail(120) |
|
|
| |
| fig_signals = plot_combined_signals(df_last_120, ticker) |
|
|
| |
| fig_individual_signals = plot_individual_signals(df_last_120, ticker) |
|
|
| return fig_signals, fig_individual_signals |
| |
| except Exception as e: |
| return f"Error: {str(e)}", f"Error: {str(e)}" |
|
|
|
|
| |
| custom_css = """ |
| .gradio-container { |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| } |
| .container { |
| max-width: 1200px; |
| margin: auto; |
| } |
| button#analyze-btn { |
| background-color: #003366; |
| color: white; |
| border: none; |
| } |
| """ |
|
|
| |
| with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo: |
| gr.Markdown("# Advanced Stock Analysis & Forecasting App") |
| gr.Markdown("Enter a stock ticker, start date, and end date to analyze and forecast stock prices.") |
| |
| with gr.Row(): |
| ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA") |
| start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2025-01-01") |
| end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2027-01-01") |
| |
| |
| with gr.Tabs() as tabs: |
|
|
|
|
| with gr.TabItem("TimesFM Forecast"): |
| timesfm_button = gr.Button("Generate TimesFM Forecast") |
| timesfm_plot = gr.Plot(label="TimesFM Stock Price Forecast") |
| |
| |
| timesfm_button.click( |
| timesfm_forecast, |
| inputs=[ticker_input, start_date_input, end_date_input], |
| outputs=timesfm_plot |
| ) |
| |
| with gr.TabItem("Prophet Forecast"): |
| prophet_button = gr.Button("Generate Prophet Forecast") |
| prophet_recent_plot = gr.Plot(label="Recent Stock Prices") |
| prophet_forecast_plot = gr.Plot(label="Prophet 30-Day Forecast") |
| prophet_components = gr.Plot(label="Forecast Components") |
|
|
| with gr.TabItem("Technical Analysis"): |
| analysis_button = gr.Button("Generate Technical Analysis") |
| |
| individual_signals = gr.Plot(label="Individual Trading Signals") |
| combined_signals = gr.Plot(label="Combined Trading Signals") |
| |
| |
| analysis_button.click( |
| technical_analysis, |
| inputs=[ticker_input, start_date_input, end_date_input], |
| outputs=[combined_signals, individual_signals] |
| ) |
| |
| |
| prophet_button.click( |
| prophet_forecast, |
| inputs=[ticker_input, start_date_input, end_date_input], |
| outputs=[prophet_recent_plot, prophet_forecast_plot, prophet_components] |
| ) |
| |
|
|
|
|
| |
| demo.launch() |