Dash Chat Component

Thanks a lot for the component, it looks very good!

I found a bug, where the component stops working when i send the same message twice in a row:

Steps to reproduce:

  1. Start Dash app
  2. Send “h”
  3. Send “h” again
  4. Result: The dash chat loads indefinitely and does not call the new_message callback

Hi everyone!

I have introduced Dash Chat in my Dash App but I found a limitation.

Is it possible to show artifacts (like tables or images) in the chat or only text is supported?
I mean, I have cases where my bot assistant replies with both text and a tabel for example, and the component seems not to work for this use case (attached screenshot shows this).

Did anybody of you solved a similar issue or knows a workaround for it?

Thanks a lot guys :slight_smile:

Yes, I found the same problem but was not able to solve it.

Thanks for catching this. I’ll try to reproduce this. I have created an issue on github also for easy tracking.

Thanks for the observation. Just to clarify, are you looking for a way to style the table? As long as the data expected has a Markdown table format, they would appear that way. In case you want to style the table, you can target the .markdown-content class to style the table. For example, .markdown-content table {...}. I’ll plan to make an update later this month to have a configurable styling for some of these components.

Didn’t spend much time on this but I was able to bring gemini into the project and use it to generate plotly graphs:

import os
import dash
from dash import html, dcc, callback, Input, Output, State, _dash_renderer
import dash_mantine_components as dmc
from dash_chat import ChatComponent
from google import genai
from dotenv import load_dotenv
import json
import re
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import io
import base64
from dash.exceptions import PreventUpdate

_dash_renderer._set_react_version("18.2.0")

# Load environment variables
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Initialize Gemini client
genai_client = genai.Client(api_key=GEMINI_API_KEY)


# Helper function to create Plotly visualizations based on text prompts
def create_plotly_visualization(prompt):
    """
    Create a Plotly visualization based on the text prompt.

    Args:
        prompt (str): The text prompt describing the visualization

    Returns:
        dict: A dictionary with the figure data and layout for Plotly
    """
    try:
        # Remove the #Plotly tag
        cleaned_prompt = re.sub(r'#Plotly\s*', '', prompt, flags=re.IGNORECASE)

        # Generate some sample data
        np.random.seed(42)  # For reproducibility

        # Default data
        x = np.linspace(0, 10, 100)
        y = np.sin(x)
        categories = ['A', 'B', 'C', 'D', 'E', 'F']  # Added 'F' to match the 6 categories request
        values = np.random.randint(10, 100, size=len(categories))

        # Create a pandas DataFrame with sample data
        df = pd.DataFrame({
            'x': np.random.normal(0, 1, 100),
            'y': np.random.normal(0, 1, 100),
            'size': np.random.uniform(5, 15, 100),
            'category': np.random.choice(['A', 'B', 'C', 'D', 'E', 'F'], 100),
            'values': np.random.randint(0, 100, 100),
            'date': pd.date_range(start='2023-01-01', periods=100),
        })

        # Attempt to parse what type of chart is needed based on keywords
        prompt_lower = cleaned_prompt.lower()

        # Initialize figure with a default plot
        fig = None

        # Check for various chart types in the prompt
        if 'scatter' in prompt_lower or 'point' in prompt_lower:
            fig = px.scatter(df, x='x', y='y', color='category', size='size',
                             title="Scatter Plot", hover_data=['values'])

        elif 'line' in prompt_lower:
            fig = px.line(df, x='date', y='values', color='category',
                          title="Line Chart")

        elif 'bar' in prompt_lower:
            fig = px.bar(df, x='category', y='values', color='category',
                         title="Bar Chart")

        elif 'histogram' in prompt_lower:
            fig = px.histogram(df, x='x', nbins=20,
                               title="Histogram")

        elif 'box' in prompt_lower or 'boxplot' in prompt_lower:
            fig = px.box(df, x='category', y='values',
                         title="Box Plot")

        elif 'violin' in prompt_lower:
            fig = px.violin(df, x='category', y='values',
                            title="Violin Plot")

        elif 'pie' in prompt_lower:
            fig_data = pd.DataFrame({
                'category': categories,
                'values': values
            })
            fig = px.pie(fig_data, names='category', values='values',
                         title="Pie Chart")

        elif 'heatmap' in prompt_lower:
            # Create a sample correlation matrix
            corr_matrix = np.corrcoef(np.random.normal(0, 1, (5, 100)))
            fig = px.imshow(corr_matrix,
                            x=categories[:5], y=categories[:5],
                            title="Heatmap")

        elif '3d' in prompt_lower or 'surface' in prompt_lower:
            # Create 3D surface data
            z_data = [[np.sin(np.sqrt(x ** 2 + y ** 2)) for x in range(-5, 5)] for y in range(-5, 5)]
            fig = go.Figure(data=[go.Surface(z=z_data)])
            fig.update_layout(title="3D Surface Plot", autosize=True)

        else:
            # Default to a basic line chart if no specific chart type is detected
            fig = px.line(x=x, y=y, title="Sample Line Chart")

        # Update layout for better appearance
        fig.update_layout(
            template="plotly_white",
            margin=dict(l=20, r=20, t=50, b=20),
        )

        # Return the figure itself instead of a component
        return fig

    except Exception as e:
        print(f"Error creating visualization: {e}")
        # Return an error message
        return None


# Function to generate chat responses
def generate_chat_response(messages, model_name):
    """
    Generate a response using Google's Gemini API based on the conversation history.

    Args:
        messages (list): List of message dictionaries containing 'role' and 'content'
        model_name (str): The Gemini model to use

    Returns:
        str or dict: Either the generated response text or a dictionary with text and visualization
    """
    try:
        # Check if the latest user message contains #Plotly
        latest_message = messages[-1]

        if latest_message["role"] == "user" and "#Plotly" in latest_message["content"]:
            # Generate a visualization based on the prompt
            visualization = create_plotly_visualization(latest_message["content"])

            # Generate a text explanation of what was visualized
            visualization_prompt = f"""
            The user wants to visualize: "{latest_message["content"].replace('#Plotly', '')}". 
            I've created a visualization based on their request.
            Provide a brief explanation of what the visualization shows and how to interpret it.
            Keep your response focused on the visualization and be concise.
            """

            # Generate the text response using Gemini
            response = genai_client.models.generate_content(
                model=model_name,
                contents=visualization_prompt
            )

            # Return both the text and visualization
            return {
                "text": response.text.strip(),
                "visualization": visualization
            }

        # If not a visualization request, proceed with normal response generation
        # Format the conversation history for Gemini
        conversation = []

        for msg in messages:
            if msg["role"] == "user":
                conversation.append({
                    "role": "user",
                    "parts": [{"text": msg["content"]}]
                })
            else:  # assistant messages
                conversation.append({
                    "role": "model",
                    "parts": [{"text": msg["content"]}]
                })

        # If there's no conversation history (first message), use a simpler approach
        if len(messages) == 1:
            response = genai_client.models.generate_content(
                model=model_name,
                contents=messages[0]["content"]
            )
        else:
            # Use the conversation history for context
            response = genai_client.models.generate_content(
                model=model_name,
                contents=conversation
            )

        # Clean up the response - remove any special formatting if needed
        ai_response = response.text.strip()
        return ai_response

    except Exception as e:
        print(f"Error generating response: {e}")
        return f"Error generating response: {str(e)}"


# Available Gemini models
MODELS = [
    {"value": "gemini-2.0-flash", "label": "Gemini 2.0 Flash - Next gen features, speed, and multimodal"},
    {"value": "gemini-2.0-flash-lite", "label": "Gemini 2.0 Flash-Lite - Cost efficiency and low latency"},
    {"value": "gemini-1.5-flash", "label": "Gemini 1.5 Flash - Fast and versatile performance"},
    {"value": "gemini-1.5-flash-8b", "label": "Gemini 1.5 Flash-8B - High volume, lower intelligence tasks"},
    {"value": "gemini-1.5-pro", "label": "Gemini 1.5 Pro - Complex reasoning tasks"}
]

# Initialize Dash app
app = dash.Dash(
    __name__,
    external_stylesheets=[
        "https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap"
    ],
    meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}],
)

app.title = "Gemini Chat"
server = app.server

# Define app layout
app.layout = dmc.MantineProvider(
    theme={
        "colorScheme": "light",
        "fontFamily": "'Inter', sans-serif",
        "primaryColor": "indigo",
    },
    children=[
        dmc.Container(
            fluid=True,
            px=0,
            style={"height": "100vh", "display": "flex", "flexDirection": "column"},
            children=[
                # Header
                dmc.Paper(
                    h=70,
                    p="md",
                    style={"borderBottom": "1px solid #e9ecef"},
                    children=[
                        dmc.Group(
                            align="apart",
                            children=[
                                dmc.Title("Gemini Chat", order=3, c="indigo"),
                                dmc.Select(
                                    id="model-select",
                                    data=MODELS,
                                    value="gemini-1.5-flash",
                                    placeholder="Select a model",
                                    style={"width": 300},
                                    searchable=True,
                                ),
                            ],
                        ),
                    ],
                ),
                # Chat Component
                dmc.Box(
                    style={
                        "flexGrow": 1,
                        "display": "flex",
                        "flexDirection": "column",
                        "overflow": "auto",
                    },
                    children=[
                        dmc.Loader(
                            id="loading",
                            size="xl",
                            color="indigo",
                            style={
                                "position": "absolute",
                                "top": "50%",
                                "left": "50%",
                                "transform": "translate(-50%, -50%)",
                                "zIndex": 1000,
                                "display": "none",
                            },
                        ),
                        dmc.Box(
                            id="plotly-visualizations-container",
                            style={"width": "100%", "padding": "1rem", "backgroundColor": "#ffffff"},
                            children=[]
                        ),
                        # Chat component first
                        ChatComponent(
                            id="chat-component",
                            messages=[],
                            persistence=True,
                            persistence_type="local",
                            input_placeholder="Ask Gemini a question... (Use #Plotly to create visualizations)",
                            theme="light",
                            fill_height=True,
                            user_bubble_style={
                                "backgroundColor": "#6741d9",
                                "color": "white",
                                "marginLeft": "auto",
                                "textAlign": "right",
                                "borderRadius": "1rem 0 1rem 1rem",
                                "padding": "0.75rem 1rem",
                                "maxWidth": "80%",
                            },
                            assistant_bubble_style={
                                "backgroundColor": "#f1f3f5",
                                "color": "#1a1b1e",
                                "marginRight": "auto",
                                "textAlign": "left",
                                "borderRadius": "0 1rem 1rem 1rem",
                                "padding": "0.75rem 1rem",
                                "maxWidth": "80%",
                            },
                            container_style={
                                "padding": "1rem",
                                "backgroundColor": "#ffffff",
                            },
                        ),
                    ],
                ),
                # Footer
                dmc.Paper(
                    h=40,
                    p="xs",
                    style={"borderTop": "1px solid #e9ecef", "textAlign": "center"},
                    children=[
                        dmc.Text("Powered by Google Gemini API", size="sm", c="dimmed"),
                    ],
                ),
                dcc.Store(id="visualization-data", data=None),
                # Store for loading state
                dcc.Store(id="loading-state", data=False),
            ],
        ),
    ],
)


# Callback to update loading indicator
@callback(
    Output("loading", "style"),
    Input("loading-state", "data"),
    prevent_initial_call=True,
)
def update_loading(loading):
    if loading:
        return {
            "position": "absolute",
            "top": "50%",
            "left": "50%",
            "transform": "translate(-50%, -50%)",
            "zIndex": 1000,
            "display": "block",
        }
    else:
        return {
            "position": "absolute",
            "top": "50%",
            "left": "50%",
            "transform": "translate(-50%, -50%)",
            "zIndex": 1000,
            "display": "none",
        }


# Callback to handle chat messages
@callback(
    [
        Output("chat-component", "messages"),
        Output("loading-state", "data"),
        Output("visualization-data", "data"),
    ],
    Input("chat-component", "new_message"),
    [
        State("chat-component", "messages"),
        State("model-select", "value"),
        State("visualization-data", "data"),
    ],
    prevent_initial_call=True,
)
def handle_chat(new_message, messages, model, viz_data):
    if not new_message:
        return messages, False, viz_data

    updated_messages = messages + [new_message]

    if new_message["role"] == "user":
        # Set loading state to True
        loading = True

        try:
            # Check if this is a Plotly visualization request
            if "#plotly" in new_message["content"].lower():
                print("Plotly visualization requested")

                # Generate the visualization
                fig = create_plotly_visualization(new_message["content"])

                if fig is None:
                    # If visualization creation failed, return an error message
                    bot_response = {"role": "assistant",
                                    "content": "I couldn't create the visualization you requested. Please try again with different parameters."}
                    return updated_messages + [bot_response], False, viz_data

                # Generate a text explanation using Gemini
                visualization_prompt = f"""
                The user wants to visualize: "{new_message["content"].replace('#plotly', '')}". 
                Create a brief explanation of what the visualization shows and how to interpret it.
                Keep your response focused on the visualization and be concise.
                """

                # Generate the text response using Gemini
                text_response = generate_chat_response([{"role": "user", "content": visualization_prompt}], model)

                # Create a unique ID for this visualization
                viz_id = f"viz-{len(messages)}"

                # Update the visualization data store
                if viz_data is None:
                    viz_data = {}

                # Store the figure data in the viz_data
                viz_data[viz_id] = {
                    "type": "plotly",
                    "figure": fig.to_dict()  # Convert the figure to a dict to make it serializable
                }

                print(f"Stored visualization with ID: {viz_id}")

                # Create a bot response with reference to the visualization
                bot_response = {
                    "role": "assistant",
                    "content": f"{text_response}\n\n*Visualization has been created based on your request.*"
                }

                return updated_messages + [bot_response], False, viz_data

            else:
                # For non-visualization requests, use the standard response generation
                response = generate_chat_response(messages + [new_message], model)
                bot_response = {"role": "assistant", "content": response}
                return updated_messages + [bot_response], False, viz_data

        except Exception as e:
            # Handle errors
            error_message = f"Error: {str(e)}"
            print(f"Error in handle_chat: {e}")
            bot_response = {"role": "assistant", "content": error_message}
            return updated_messages + [bot_response], False, viz_data

    return updated_messages, False, viz_data


@callback(
    Output("plotly-visualizations-container", "children"),
    Input("visualization-data", "data"),
    prevent_initial_call=True,
)
def update_visualizations(viz_data):
    print(f"update_visualizations called with data: {viz_data}")

    if viz_data is None or not viz_data:
        print("No visualization data, preventing update")
        raise PreventUpdate

    # Create visualization components from the data
    viz_components = []

    # Sort by visualization ID to maintain order
    for viz_id in sorted(viz_data.keys()):
        viz_info = viz_data[viz_id]
        print(f"Processing visualization {viz_id}: {viz_info['type']}")

        if viz_info["type"] == "plotly":
            try:
                # Create a new figure from the stored dictionary
                fig_dict = viz_info["figure"]

                # Create a Graph component with the figure
                graph = dcc.Graph(
                    id=f'graph-{viz_id}',
                    figure=fig_dict,
                    style={'height': '500px', 'width': '100%'},
                    config={'displayModeBar': True, 'responsive': True}
                )

                # Create a container for the visualization
                viz_container = html.Div(
                    children=[
                        html.Hr(),
                        html.H4(f"Visualization {viz_id.split('-')[1]}", style={'textAlign': 'center'}),
                        graph
                    ],
                    style={
                        'margin': '20px 0',
                        'padding': '15px',
                        'border': '1px solid #e0e0e0',
                        'borderRadius': '8px',
                        'backgroundColor': '#f9f9f9'
                    },
                    id=f"viz-container-{viz_id}"
                )
                viz_components.append(viz_container)
                print(f"Successfully rendered visualization {viz_id}")
            except Exception as e:
                print(f"Error rendering visualization {viz_id}: {e}")
                # Create an error display if rendering fails
                viz_components.append(html.Div(
                    [
                        html.Hr(),
                        html.H4(f"Visualization {viz_id.split('-')[1]} (Error)"),
                        html.P(f"Error rendering visualization: {str(e)}"),
                    ],
                    style={
                        'margin': '20px 0',
                        'padding': '15px',
                        'border': '1px solid #e0e0e0',
                        'borderRadius': '8px',
                        'backgroundColor': '#ffeeee'  # Light red for errors
                    }
                ))

    print(f"Created {len(viz_components)} visualization components")
    return viz_components


if __name__ == "__main__":
    app.run_server(debug=True, port=3350)

Wasn’t able to get the graph to visualize within Output("chat-component", "messages") but if I setup a separate Output("visualization-data", "data") I was able to get it to work, @gbolly check out my callback handle_chat think adding another prompt within ChatComponent like render or content could be useful, so that when you ask a question the chat updates with the message followed by the render or content thats being created from the callback.

Hi @PipInstallPython,

I really like your solution, and I think I’ll use it as inspiration for a quick implementation in my project.

That said, what I truly need is what you mentioned at the end: the ability to handle artifacts (such as tables or graphs) directly within ChatComponent.

It would be particularly useful if the handle_chat callback could return something like:

bot_response = {
“role”: “assistant”,
“content”: “Hello John Doe.”,
“artifact”: dash_table.DataTable(…)
}
return updated_messages + [bot_response]

This way, we could work with native Dash components (tables, graphs, etc.) instead of relying on Markdown formatting.

@gbolly, do you think something like this is currently possible, or is it planned for future updates?

1 Like

It would also be great if it was possible to send multiple lines of text in the input field, just like with other LLMs. Therefore you would probably need to use a textarea element instead of only input.
Thanks again for the great work!