Model/model_data/plotting/plotting_functions.py
2023-06-27 09:56:14 +01:00

40 lines
1.4 KiB
Python

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def create_heatmap_plots(data, response_var, pivot_var1, pivot_var2, order1=None, order2=None):
"""
Create a heatmap plot based on a list of data and given variables.
:param data: List of dictionaries, input data.
:param response_var: String, response variable to be plotted.
:param pivot_var1: String, first pivot variable to be used in the plot.
:param pivot_var2: String, second pivot variable to be used in the plot.
:param order1: List, the order of categories for pivot_var1. Optional.
:param order2: List, the order of categories for pivot_var2. Optional.
Returns:
None. Displays the generated plot.
"""
# Create a DataFrame from your list of dictionaries
df = pd.DataFrame(data)
# Convert the response variable column to float type if it's not already
df[response_var] = df[response_var].astype(float)
# Create a pivot table
pivot = df.pivot_table(index=pivot_var1, columns=pivot_var2, values=response_var)
# If an order is provided, reorder the pivot table
if order1 is not None:
pivot = pivot.reindex(order1)
if order2 is not None:
pivot = pivot[order2]
# Plot the heatmap
plt.figure(figsize=(10, 6))
sns.heatmap(pivot, annot=True, fmt=".2f", cmap='coolwarm')
plt.title(f"Heatmap of {response_var} by {pivot_var1} and {pivot_var2}")
plt.show()