Bring Drag & Drop to Dash with Dashboard Engine. 💫 Learn how at our next webinar!

Can't create subplots the right way

I have currently achart where I display the price of an asset and add some moving average traces. Now I want to add other indicators to the chart ( OBV, MACD, RSI ). Every time the user calls one of these three I would like my current chart to be fractioned in subplots. Let say I add OBV, my display will have 2 subplots ( the price on row 1and OBV on row 2 ). Now if I want to add MACD, it should display the price on row 1, OBV on row 2 and MACD on row 3 . Now if the user decide to remove OBV it should display price on row 1 and MACD on row 2.

Currently I am only able to display each indicators for a specific row. How can I manage that ?
Here is my code for the moment :

def display(value, timeframes, indicators, start_date, end_date):
st.write(f"""
# {value}
Price chart""")
df = pd.read_csv(f’data/{timeframes}/{value}-{timeframes}-data.csv’, parse_dates=[‘timestamp’], index_col=‘timestamp’)

if start_date < end_date:
    df = df[start_date : end_date]
else:
    st.error('Error: End date must fall after start date.')

fig = make_subplots(rows = 5, cols = 1, shared_xaxes = True, vertical_spacing = 0.01, row_heights = [0.5, 0.1, 0.2, 0.2, 0.2])
                   

trace_price = fig.add_trace(
                    go.Candlestick(
                          x = df.index,
                          open = df['open'],
                          high = df['high'],
                          low = df['low'],
                          close = df['close'],
                          name = 'Candlestick'),
                    
                          #secondary_y = True, 
                          row = 1, col = 1
              )
colors = ['green' if row['open'] - row['close'] >= 0 
      else 'red' for index, row in df.iterrows()]
fig.add_trace(go.Bar(x=df.index, 
                     y=df['volume'],
                     marker_color=colors), row=2, col=1)


for indi in indicators:



    #FIRST SUBPLOT  obj = indicators.indicators(df, range_indi, periods)
    
    if indi.startswith('SMA'):
        
        sma_trace = go.Scatter( 
                          x=df.index, y=indicator.simple_moving_average(df,int(indi[3:])), 
                          line=dict(color= np.random.choice(colors_sma, replace = False), width=1),
                          name = f'SMA{indi[3:]}'
                          )
        
        fig.add_trace(sma_trace, 
                      #secondary_y = True, 
                      row = 1, col = 1)

    if indi.startswith('WMA'): 
        wma_trace = go.Scatter(
                          x=df.index, y=indicator.weighted_moving_average(df,int(indi[3:])), 
                          line=dict(color=np.random.choice(colors_wma, replace = False), width=1),
                          name = f'WMA{indi[3:]}'
                          )

        fig.add_trace(wma_trace,
                     #secondary_y = True,
                     row = 1, col = 1) 

    if indi.startswith('EMA'):
        ema_trace = go.Scatter(
                          x=df.index, y=indicator.exponential_moving_average(df,int(indi[3:])), 
                          line=dict(color=np.random.choice(colors_ema, replace = False), width=1),
                          name = f'EMA{indi[3:]}'
                          )

        fig.add_trace(ema_trace, 
                      #secondary_y = True,
                      row = 1, col = 1)


    # SECOND SUBPLOT

    if indi.startswith('OBV'):
        obv_trace = go.Scatter(
                          x=df.index, y= indicator.on_balance_volume(df, int(indi[3:])), line=dict(color='red', width=1),
                          name = f'OBV{indi[3:]}'
                          )

        fig.add_trace(obv_trace, row = 3, col = 1)


    # THIRD SUBPLOT
        

    if indi.startswith('RSI'):
        
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=indicator.relative_strength_index(df, periods = 14),
                  name='RSI', marker_color= 'blue'
              ), row=4, col=1,
        )
    
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=[70] * len(df.index),
                  name='Overbought', marker_color='#109618',
                  line = dict(dash='dot'), showlegend=False,
              ), row=4, col=1,
        )
    
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=[30] * len(df.index),
                  name='Oversold', marker_color='#109618',
                  line = dict(dash='dot'),showlegend=False,
              ),row=4, col=1,
        )
                          
    # FOURTH SUBPLOT

    if indi.startswith('MACD'):
        macd, macd_s, macd_h = indicator.moving_average_convergence_divergence(df)
        #fast signal
        fig.add_trace(
              go.Scatter(
                    x = df.index, y = macd, 
                    line=dict(color='blue', width=1), 
                    name = 'macd',
              ), row = 5, col = 1,
        )
        #slow signal
        fig.add_trace(
              go.Scatter(
                    x = df.index, y = macd_s, 
                    line=dict(color='#000000', width=1), 
                    name = 'macd',
              ), row = 5, col = 1,
        )
        # Colorize the histogram values
        colors = np.where(macd_h < 0, 'red', 'green')
        # Plot the histogram
        fig.add_trace(
              go.Bar(
                    x = df.index, y =  macd_h, 
                    name='histogram', marker_color=colors,
              ), row = 5, col = 1, 
        )
   



fig.update_layout(width= 800,height = 800 , xaxis_rangeslider_visible=False, showlegend= False)      

fig.update_xaxes(gridcolor ='#7f7f7f')
# Set y-axes titles
fig.update_yaxes(title_text="Price", row = 1, col = 1)
fig.update_yaxes(title_text="Volume", row = 2, col = 1)
fig.update_yaxes(title_text="OBV", row = 3, col = 1)
fig.update_yaxes(title_text="RSI", row = 4, col = 1)
fig.update_yaxes(title_text="MACD", row = 5, col = 1)
fig.update_yaxes(gridcolor ='#7f7f7f')

price = st.plotly_chart(fig, width= 800,height = 800)

return price

Hi there!

How are the indicators selected? Is it a selection coming from Streamlit?

I imagine the easiest way is to group your indicators by “type” (an ordered dictionary with lists can do the trick), then loop over the groups instead and add a subplot counter for each group if the list is not empty. Something like:

indicator_types = ["OBV", "MACD", "RSI"]

# should be an ordereddict
indicator_groups = {k: [i for i in indicators if i.startswith(k)] for k in indicator_types}

You can also benefit from using a function to create the traces by passing the indicator function and column as parameter:

def create_indicator_trace(df, fct, x):
    return go.Scatter(
        x=df.index,
        y=fct(df, x)
        #... and so on
    )

Hey thank you for your reply !

Yes I am using streamlit :slight_smile: . Sorry but I can’t really understand why by using a dict with lists can help me… I am kind of new to python.

‘’'python
import streamlit as st
import pandas as pd
import pandas_ta as ta
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import plotly.express as px
import datetime
from indi import indicator

symbols = [‘BTCUSDT’, ‘ETHUSDT’, ‘ADAUSDT’, ‘BNBUSDT’, ‘DOTUSDT’, ‘XRPUSDT’, ‘UNIUSDT’, ‘LTCUSDT’, ‘LINKUSDT’,
‘BCHUSDT’, ‘XLMUSDT’, ‘LUNAUSDT’, ‘DOGEUSDT’, ‘VETUSDT’,‘ATOMUSDT’, ‘AAVEUSDT’, ‘FILUSDT’, ‘AVAXUSDT’,
‘TRXUSDT’, ‘EOSUSDT’, ‘SOLUSDT’, ‘IOTAUSDT’, ‘XTZUSDT’, ‘NEOUSDT’, ‘CHZUSDT’, ‘DAIUSDT’, ‘SNXUSDT’,
‘SUSHIUSDT’, ‘EGLDUSDT’, ‘ENJUSDT’, ‘ZILUSDT’, ‘MATICUSDT’, ‘MKRUSDT’, ‘COMPUSDT’, ‘BATUSDT’, ‘ZRXUSDT’,
‘RSRUSDT’
]

timef = [ ‘12h’, ‘1d’,‘30m’, ‘1h’,‘4h’]

CREATE indicators

def concatenate_2_lists_diff_sizes(list1, list2):
return [x + str(y) for x in list1 for y in list2]

indi_ma = [‘SMA’, ‘EMA’, ‘WMA’]
range_indi_ma = [20,30,40,50,100,150,200]
indicators_ma = concatenate_2_lists_diff_sizes(indi_ma, range_indi_ma)
others_indi =[ ‘RSI’,‘MACD’,‘OBV’]
final_lists_of_indicators = indicators_ma + others_indi

##RANDOM COLORS FOR EACH MA INDICATORS
colors_sma = [’#1f77b4’, # muted blue
#ff7f0e’, # safety orange
#2ca02c’, # cooked asparagus green
]
colors_wma = [’#d62728’, # brick red
#9467bd’, # muted purple
#8c564b’, # chestnut brown
]
colors_ema = [’#e377c2’, # raspberry yogurt pink
#7f7f7f’, # middle gray
#bcbd22’, # curry yellow-green
]

STREAMLIT INPUT AND SIDEBAR

symbol = st.sidebar.selectbox(‘Search Pairs’, symbols)
start_date = st.sidebar.date_input(‘Start date’, datetime.date(2020, 1, 1))
end_date = st.sidebar.date_input(‘End date’, datetime.date.today())
timeframes = st.sidebar.selectbox(‘Timeframe’, timef)
indicators = st.sidebar.multiselect(‘Indicators’, final_lists_of_indicators)

#Display charts with indicators from indi.py

def display(symbol, timeframes, indicators, start_date, end_date):
st.write(f"""
# {symbol}
Price chart""")
df = pd.read_csv(f’data/{timeframes}/{symbol}-{timeframes}-data.csv’, parse_dates=[‘timestamp’], index_col=‘timestamp’)

if start_date < end_date:
    df = df[start_date : end_date]
else:
    st.error('Error: End date must fall after start date.')

fig = make_subplots(rows = 5, cols = 1, shared_xaxes = True, vertical_spacing = 0.01, row_heights = [0.5, 0.1, 0.2, 0.2, 0.2])
                    
trace_price = fig.add_trace(
                    go.Candlestick(
                          x = df.index,
                          open = df['open'],
                          high = df['high'],
                          low = df['low'],
                          close = df['close'],
                          name = 'Candlestick'),
                    
                          #secondary_y = True, 
                          row = 1, col = 1
              )
colors = ['green' if row['open'] - row['close'] >= 0 
      else 'red' for index, row in df.iterrows()]
fig.add_trace(go.Bar(x=df.index, 
                     y=df['volume'],
                     marker_color=colors), row=2, col=1)


for indi in indicators:



    #FIRST SUBPLOT  obj = indicators.indicators(df, range_indi, periods)
    
    if indi.startswith('SMA'):
        
        sma_trace = go.Scatter( 
                          x=df.index, y=indicator.simple_moving_average(df,int(indi[3:])), 
                          line=dict(color= np.random.choice(colors_sma, replace = False), width=1),
                          name = f'SMA{indi[3:]}'
                          )
        
        fig.add_trace(sma_trace, 
                      #secondary_y = True, 
                      row = 1, col = 1)

    if indi.startswith('WMA'): 
        wma_trace = go.Scatter(
                          x=df.index, y=indicator.weighted_moving_average(df,int(indi[3:])), 
                          line=dict(color=np.random.choice(colors_wma, replace = False), width=1),
                          name = f'WMA{indi[3:]}'
                          )

        fig.add_trace(wma_trace,
                     #secondary_y = True,
                     row = 1, col = 1) 

    if indi.startswith('EMA'):
        ema_trace = go.Scatter(
                          x=df.index, y=indicator.exponential_moving_average(df,int(indi[3:])), 
                          line=dict(color=np.random.choice(colors_ema, replace = False), width=1),
                          name = f'EMA{indi[3:]}'
                          )

        fig.add_trace(ema_trace, 
                      #secondary_y = True, 
                      row = 1, col = 1)


    # SECOND SUBPLOT

    if indi.startswith('OBV'):
        obv_trace = go.Scatter(
                          x=df.index, y= indicator.on_balance_volume(df), line=dict(color='red', width=1),
                          name = 'OBV'
                          )
        fig.add_trace(obv_trace, row = 3, col = 1)


    # THIRD SUBPLOT
        

    if indi.startswith('RSI'):

        
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=indicator.relative_strength_index(df, periods = 14),
                  name='RSI', marker_color= 'blue'
              ), row=4, col=1,
        )
    
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=[70] * len(df.index),
                  name='Overbought', marker_color='#109618',
                  line = dict(dash='dot'), showlegend=False,
              ), row=4, col=1,
        )
    
        fig.add_trace(
              go.Scatter(
                  x=df.index, y=[30] * len(df.index),
                  name='Oversold', marker_color='#109618',
                  line = dict(dash='dot'),showlegend=False,
              ),row=4, col=1,
        )
                          
    # FOURTH SUBPLOT

    if indi.startswith('MACD'):
        macd, macd_s, macd_h = indicator.moving_average_convergence_divergence(df)
        #fast signal
        fig.add_trace(
              go.Scatter(
                    x = df.index, y = macd, 
                    line=dict(color='blue', width=1), 
                    name = 'macd',
              ), row = 5, col = 1,
        )
        #slow signal
        fig.add_trace(
              go.Scatter(
                    x = df.index, y = macd_s, 
                    line=dict(color='#000000', width=1), 
                    name = 'macd',
              ), row = 5, col = 1,
        )
        # Colorize the histogram values
        colors = np.where(macd_h < 0, 'red', 'green')
        # Plot the histogram
        fig.add_trace(
              go.Bar(
                    x = df.index, y =  macd_h, 
                    name='histogram', marker_color=colors,
              ), row = 5, col = 1, 
        )
   



fig.update_layout(width= 800,height = 800 , xaxis_rangeslider_visible=False, showlegend= False)      

fig.update_xaxes(gridcolor ='#7f7f7f')
# Set y-axes titles
fig.update_yaxes(title_text="Price", row = 1, col = 1)
fig.update_yaxes(title_text="Volume", row = 2, col = 1)
fig.update_yaxes(title_text="OBV", row = 3, col = 1)
fig.update_yaxes(title_text="RSI", row = 4, col = 1)
fig.update_yaxes(title_text="MACD", row = 5, col = 1)
fig.update_yaxes(gridcolor ='#7f7f7f')

price = st.plotly_chart(fig, width= 800,height = 800)




return price

display(symbol, timeframes, indicators, start_date, end_date)

‘’’

the code looks like this, indicators functions coming from another file. Can you explain to me how should I restructured my code to be able to plot in the desired way ?

Let me rephrase a little bit to explain the approach better…

The problem in your code is that you are hardcoding the number of rows in your figure and hardcoding the rows in which each one of your indicators will be plotted. What you want is to create a dynamic number of rows depending on how many of your “indicator groups” have a non-empty selection from Streamlit and you want to adjust the row in which each group is plotted to skip empty selections.

Can you explain to me how should I restructured my code to be able to plot in the desired way ?

It is probably easier for me to give you a better example:

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import streamlit as st

df = px.data.iris()

indicator_types = ["petal", "sepal"]

# Numeric cols only
valid_indicators = [col for col in df.columns if col not in ["species", "species_id"]]

indicators = st.sidebar.multiselect('Indicators', valid_indicators, default=valid_indicators)

indicator_groups = {k: [i for i in indicators if i.startswith(k)] for k in indicator_types}

# how many groups have a non-empty list of indicators?
# this will be the number of rows in the subplots...
n_rows = len([i for i in indicator_groups.values() if len(i) > 0])

fig = make_subplots(rows=n_rows, cols=1)

row_idx = 1
for gk, gv in indicator_groups.items():

    # if the group has no indicators, skip it...
    if len(gv) > 0:
        
        # add one trace for each element in the group
        for col in gv:
            fig.add_trace(
                go.Scatter(
                    x=df.index,
                    y=df[col],
                    name=col,
                ),
                row=row_idx,
                col=1,
            )

        # increase row index just if there are indicators in this group
        row_idx += 1


st.plotly_chart(fig, width= 800, height = 400 * n_rows)

Give this app a try and remove all items starting with “petal” in the multiselect component… You will see that the figure will have a single plot (with “sepal*” columns) positioned in the right place. This is basically what you want to do in terms of subplot.

Please let me know if this helps!

Thank you for all these explanations, my code works fine now !

1 Like