Didn’t spend much time on this but I was able to bring gemini into the project and use it to generate plotly graphs:
import os
import dash
from dash import html, dcc, callback, Input, Output, State, _dash_renderer
import dash_mantine_components as dmc
from dash_chat import ChatComponent
from google import genai
from dotenv import load_dotenv
import json
import re
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import io
import base64
from dash.exceptions import PreventUpdate
_dash_renderer._set_react_version("18.2.0")
# Load environment variables
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Initialize Gemini client
genai_client = genai.Client(api_key=GEMINI_API_KEY)
# Helper function to create Plotly visualizations based on text prompts
def create_plotly_visualization(prompt):
"""
Create a Plotly visualization based on the text prompt.
Args:
prompt (str): The text prompt describing the visualization
Returns:
dict: A dictionary with the figure data and layout for Plotly
"""
try:
# Remove the #Plotly tag
cleaned_prompt = re.sub(r'#Plotly\s*', '', prompt, flags=re.IGNORECASE)
# Generate some sample data
np.random.seed(42) # For reproducibility
# Default data
x = np.linspace(0, 10, 100)
y = np.sin(x)
categories = ['A', 'B', 'C', 'D', 'E', 'F'] # Added 'F' to match the 6 categories request
values = np.random.randint(10, 100, size=len(categories))
# Create a pandas DataFrame with sample data
df = pd.DataFrame({
'x': np.random.normal(0, 1, 100),
'y': np.random.normal(0, 1, 100),
'size': np.random.uniform(5, 15, 100),
'category': np.random.choice(['A', 'B', 'C', 'D', 'E', 'F'], 100),
'values': np.random.randint(0, 100, 100),
'date': pd.date_range(start='2023-01-01', periods=100),
})
# Attempt to parse what type of chart is needed based on keywords
prompt_lower = cleaned_prompt.lower()
# Initialize figure with a default plot
fig = None
# Check for various chart types in the prompt
if 'scatter' in prompt_lower or 'point' in prompt_lower:
fig = px.scatter(df, x='x', y='y', color='category', size='size',
title="Scatter Plot", hover_data=['values'])
elif 'line' in prompt_lower:
fig = px.line(df, x='date', y='values', color='category',
title="Line Chart")
elif 'bar' in prompt_lower:
fig = px.bar(df, x='category', y='values', color='category',
title="Bar Chart")
elif 'histogram' in prompt_lower:
fig = px.histogram(df, x='x', nbins=20,
title="Histogram")
elif 'box' in prompt_lower or 'boxplot' in prompt_lower:
fig = px.box(df, x='category', y='values',
title="Box Plot")
elif 'violin' in prompt_lower:
fig = px.violin(df, x='category', y='values',
title="Violin Plot")
elif 'pie' in prompt_lower:
fig_data = pd.DataFrame({
'category': categories,
'values': values
})
fig = px.pie(fig_data, names='category', values='values',
title="Pie Chart")
elif 'heatmap' in prompt_lower:
# Create a sample correlation matrix
corr_matrix = np.corrcoef(np.random.normal(0, 1, (5, 100)))
fig = px.imshow(corr_matrix,
x=categories[:5], y=categories[:5],
title="Heatmap")
elif '3d' in prompt_lower or 'surface' in prompt_lower:
# Create 3D surface data
z_data = [[np.sin(np.sqrt(x ** 2 + y ** 2)) for x in range(-5, 5)] for y in range(-5, 5)]
fig = go.Figure(data=[go.Surface(z=z_data)])
fig.update_layout(title="3D Surface Plot", autosize=True)
else:
# Default to a basic line chart if no specific chart type is detected
fig = px.line(x=x, y=y, title="Sample Line Chart")
# Update layout for better appearance
fig.update_layout(
template="plotly_white",
margin=dict(l=20, r=20, t=50, b=20),
)
# Return the figure itself instead of a component
return fig
except Exception as e:
print(f"Error creating visualization: {e}")
# Return an error message
return None
# Function to generate chat responses
def generate_chat_response(messages, model_name):
"""
Generate a response using Google's Gemini API based on the conversation history.
Args:
messages (list): List of message dictionaries containing 'role' and 'content'
model_name (str): The Gemini model to use
Returns:
str or dict: Either the generated response text or a dictionary with text and visualization
"""
try:
# Check if the latest user message contains #Plotly
latest_message = messages[-1]
if latest_message["role"] == "user" and "#Plotly" in latest_message["content"]:
# Generate a visualization based on the prompt
visualization = create_plotly_visualization(latest_message["content"])
# Generate a text explanation of what was visualized
visualization_prompt = f"""
The user wants to visualize: "{latest_message["content"].replace('#Plotly', '')}".
I've created a visualization based on their request.
Provide a brief explanation of what the visualization shows and how to interpret it.
Keep your response focused on the visualization and be concise.
"""
# Generate the text response using Gemini
response = genai_client.models.generate_content(
model=model_name,
contents=visualization_prompt
)
# Return both the text and visualization
return {
"text": response.text.strip(),
"visualization": visualization
}
# If not a visualization request, proceed with normal response generation
# Format the conversation history for Gemini
conversation = []
for msg in messages:
if msg["role"] == "user":
conversation.append({
"role": "user",
"parts": [{"text": msg["content"]}]
})
else: # assistant messages
conversation.append({
"role": "model",
"parts": [{"text": msg["content"]}]
})
# If there's no conversation history (first message), use a simpler approach
if len(messages) == 1:
response = genai_client.models.generate_content(
model=model_name,
contents=messages[0]["content"]
)
else:
# Use the conversation history for context
response = genai_client.models.generate_content(
model=model_name,
contents=conversation
)
# Clean up the response - remove any special formatting if needed
ai_response = response.text.strip()
return ai_response
except Exception as e:
print(f"Error generating response: {e}")
return f"Error generating response: {str(e)}"
# Available Gemini models
MODELS = [
{"value": "gemini-2.0-flash", "label": "Gemini 2.0 Flash - Next gen features, speed, and multimodal"},
{"value": "gemini-2.0-flash-lite", "label": "Gemini 2.0 Flash-Lite - Cost efficiency and low latency"},
{"value": "gemini-1.5-flash", "label": "Gemini 1.5 Flash - Fast and versatile performance"},
{"value": "gemini-1.5-flash-8b", "label": "Gemini 1.5 Flash-8B - High volume, lower intelligence tasks"},
{"value": "gemini-1.5-pro", "label": "Gemini 1.5 Pro - Complex reasoning tasks"}
]
# Initialize Dash app
app = dash.Dash(
__name__,
external_stylesheets=[
"https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap"
],
meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}],
)
app.title = "Gemini Chat"
server = app.server
# Define app layout
app.layout = dmc.MantineProvider(
theme={
"colorScheme": "light",
"fontFamily": "'Inter', sans-serif",
"primaryColor": "indigo",
},
children=[
dmc.Container(
fluid=True,
px=0,
style={"height": "100vh", "display": "flex", "flexDirection": "column"},
children=[
# Header
dmc.Paper(
h=70,
p="md",
style={"borderBottom": "1px solid #e9ecef"},
children=[
dmc.Group(
align="apart",
children=[
dmc.Title("Gemini Chat", order=3, c="indigo"),
dmc.Select(
id="model-select",
data=MODELS,
value="gemini-1.5-flash",
placeholder="Select a model",
style={"width": 300},
searchable=True,
),
],
),
],
),
# Chat Component
dmc.Box(
style={
"flexGrow": 1,
"display": "flex",
"flexDirection": "column",
"overflow": "auto",
},
children=[
dmc.Loader(
id="loading",
size="xl",
color="indigo",
style={
"position": "absolute",
"top": "50%",
"left": "50%",
"transform": "translate(-50%, -50%)",
"zIndex": 1000,
"display": "none",
},
),
dmc.Box(
id="plotly-visualizations-container",
style={"width": "100%", "padding": "1rem", "backgroundColor": "#ffffff"},
children=[]
),
# Chat component first
ChatComponent(
id="chat-component",
messages=[],
persistence=True,
persistence_type="local",
input_placeholder="Ask Gemini a question... (Use #Plotly to create visualizations)",
theme="light",
fill_height=True,
user_bubble_style={
"backgroundColor": "#6741d9",
"color": "white",
"marginLeft": "auto",
"textAlign": "right",
"borderRadius": "1rem 0 1rem 1rem",
"padding": "0.75rem 1rem",
"maxWidth": "80%",
},
assistant_bubble_style={
"backgroundColor": "#f1f3f5",
"color": "#1a1b1e",
"marginRight": "auto",
"textAlign": "left",
"borderRadius": "0 1rem 1rem 1rem",
"padding": "0.75rem 1rem",
"maxWidth": "80%",
},
container_style={
"padding": "1rem",
"backgroundColor": "#ffffff",
},
),
],
),
# Footer
dmc.Paper(
h=40,
p="xs",
style={"borderTop": "1px solid #e9ecef", "textAlign": "center"},
children=[
dmc.Text("Powered by Google Gemini API", size="sm", c="dimmed"),
],
),
dcc.Store(id="visualization-data", data=None),
# Store for loading state
dcc.Store(id="loading-state", data=False),
],
),
],
)
# Callback to update loading indicator
@callback(
Output("loading", "style"),
Input("loading-state", "data"),
prevent_initial_call=True,
)
def update_loading(loading):
if loading:
return {
"position": "absolute",
"top": "50%",
"left": "50%",
"transform": "translate(-50%, -50%)",
"zIndex": 1000,
"display": "block",
}
else:
return {
"position": "absolute",
"top": "50%",
"left": "50%",
"transform": "translate(-50%, -50%)",
"zIndex": 1000,
"display": "none",
}
# Callback to handle chat messages
@callback(
[
Output("chat-component", "messages"),
Output("loading-state", "data"),
Output("visualization-data", "data"),
],
Input("chat-component", "new_message"),
[
State("chat-component", "messages"),
State("model-select", "value"),
State("visualization-data", "data"),
],
prevent_initial_call=True,
)
def handle_chat(new_message, messages, model, viz_data):
if not new_message:
return messages, False, viz_data
updated_messages = messages + [new_message]
if new_message["role"] == "user":
# Set loading state to True
loading = True
try:
# Check if this is a Plotly visualization request
if "#plotly" in new_message["content"].lower():
print("Plotly visualization requested")
# Generate the visualization
fig = create_plotly_visualization(new_message["content"])
if fig is None:
# If visualization creation failed, return an error message
bot_response = {"role": "assistant",
"content": "I couldn't create the visualization you requested. Please try again with different parameters."}
return updated_messages + [bot_response], False, viz_data
# Generate a text explanation using Gemini
visualization_prompt = f"""
The user wants to visualize: "{new_message["content"].replace('#plotly', '')}".
Create a brief explanation of what the visualization shows and how to interpret it.
Keep your response focused on the visualization and be concise.
"""
# Generate the text response using Gemini
text_response = generate_chat_response([{"role": "user", "content": visualization_prompt}], model)
# Create a unique ID for this visualization
viz_id = f"viz-{len(messages)}"
# Update the visualization data store
if viz_data is None:
viz_data = {}
# Store the figure data in the viz_data
viz_data[viz_id] = {
"type": "plotly",
"figure": fig.to_dict() # Convert the figure to a dict to make it serializable
}
print(f"Stored visualization with ID: {viz_id}")
# Create a bot response with reference to the visualization
bot_response = {
"role": "assistant",
"content": f"{text_response}\n\n*Visualization has been created based on your request.*"
}
return updated_messages + [bot_response], False, viz_data
else:
# For non-visualization requests, use the standard response generation
response = generate_chat_response(messages + [new_message], model)
bot_response = {"role": "assistant", "content": response}
return updated_messages + [bot_response], False, viz_data
except Exception as e:
# Handle errors
error_message = f"Error: {str(e)}"
print(f"Error in handle_chat: {e}")
bot_response = {"role": "assistant", "content": error_message}
return updated_messages + [bot_response], False, viz_data
return updated_messages, False, viz_data
@callback(
Output("plotly-visualizations-container", "children"),
Input("visualization-data", "data"),
prevent_initial_call=True,
)
def update_visualizations(viz_data):
print(f"update_visualizations called with data: {viz_data}")
if viz_data is None or not viz_data:
print("No visualization data, preventing update")
raise PreventUpdate
# Create visualization components from the data
viz_components = []
# Sort by visualization ID to maintain order
for viz_id in sorted(viz_data.keys()):
viz_info = viz_data[viz_id]
print(f"Processing visualization {viz_id}: {viz_info['type']}")
if viz_info["type"] == "plotly":
try:
# Create a new figure from the stored dictionary
fig_dict = viz_info["figure"]
# Create a Graph component with the figure
graph = dcc.Graph(
id=f'graph-{viz_id}',
figure=fig_dict,
style={'height': '500px', 'width': '100%'},
config={'displayModeBar': True, 'responsive': True}
)
# Create a container for the visualization
viz_container = html.Div(
children=[
html.Hr(),
html.H4(f"Visualization {viz_id.split('-')[1]}", style={'textAlign': 'center'}),
graph
],
style={
'margin': '20px 0',
'padding': '15px',
'border': '1px solid #e0e0e0',
'borderRadius': '8px',
'backgroundColor': '#f9f9f9'
},
id=f"viz-container-{viz_id}"
)
viz_components.append(viz_container)
print(f"Successfully rendered visualization {viz_id}")
except Exception as e:
print(f"Error rendering visualization {viz_id}: {e}")
# Create an error display if rendering fails
viz_components.append(html.Div(
[
html.Hr(),
html.H4(f"Visualization {viz_id.split('-')[1]} (Error)"),
html.P(f"Error rendering visualization: {str(e)}"),
],
style={
'margin': '20px 0',
'padding': '15px',
'border': '1px solid #e0e0e0',
'borderRadius': '8px',
'backgroundColor': '#ffeeee' # Light red for errors
}
))
print(f"Created {len(viz_components)} visualization components")
return viz_components
if __name__ == "__main__":
app.run_server(debug=True, port=3350)
Wasn’t able to get the graph to visualize within Output("chat-component", "messages")
but if I setup a separate Output("visualization-data", "data")
I was able to get it to work, @gbolly check out my callback handle_chat
think adding another prompt within ChatComponent
like render
or content
could be useful, so that when you ask a question the chat updates with the message followed by the render
or content
thats being created from the callback.