Skip to content

Commit

Permalink
Merge pull request #271 from MAYANK12SHARMA/main
Browse files Browse the repository at this point in the history
Add a Frame to change and update the description of ML Models
  • Loading branch information
peterdudfield authored Jan 30, 2025
2 parents d5cdf1b + 2b86555 commit b7c2b8b
Showing 1 changed file with 81 additions and 18 deletions.
99 changes: 81 additions & 18 deletions src/mlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pvsite_datamodel.connection import DatabaseConnection
from pvsite_datamodel.read.model import get_models
from pvsite_datamodel.read.site import get_all_sites
from pvsite_datamodel.sqlmodels import GenerationSQL, SiteSQL
from pvsite_datamodel.sqlmodels import GenerationSQL, SiteSQL, MLModelSQL

def color_survived(val):
now = pd.Timestamp.utcnow()
Expand Down Expand Up @@ -53,17 +53,20 @@ def mlmodel_page():
if site.ml_model is not None:
site_dict["ml_model_name"] = site.ml_model.name

# Get last generation timestamp
last_gen = (
session.query(GenerationSQL)
.filter(GenerationSQL.site_uuid == site.site_uuid)
.order_by(GenerationSQL.created_utc.desc())
.limit(1)
.one()
.one_or_none()
)

if last_gen is not None:
site_dict["last_generation_datetime"] = pd.Timestamp(last_gen.start_utc, tz="UTC")
site_dict["last_generation_datetime"] = pd.Timestamp(
last_gen.start_utc, tz="UTC"
)
else:
site_dict["last_generation_datetime"] = None # Or any placeholder value

all_sites.append(site_dict)

Expand All @@ -72,20 +75,66 @@ def mlmodel_page():
# Order by name
all_sites = all_sites.sort_values(by="client_site_name")

st.table(all_sites.style.applymap(color_survived, subset=["last_generation_datetime"]))
st.table(
all_sites.style.applymap(
color_survived, subset=["last_generation_datetime"]
)
)

# 2. display all models
models = get_models(session)

all_models = pd.DataFrame(
[{"name": m.name, "version": m.version, "description": "todo"} for m in models]
)
try:
models = get_models(session)
if not models:
st.warning("No ML models found in the database.")
else:
all_models_df = pd.DataFrame(
[
{
"name": m.name,
"version": m.version,
"description": m.description,
}
for m in models
]
).sort_values(by="name")

st.subheader("ML Models")
st.dataframe(all_models_df, use_container_width=True)


# Change the Description of ML Models
st.subheader("Change the Description of ML Models")

# Select the model first
model_name = st.selectbox("Select Model", all_models_df["name"].unique().tolist())

# Filter the models dataframe based on the selected model
model_versions_df = all_models_df[all_models_df["name"] == model_name]
model_versions = model_versions_df["version"].tolist()

# Select the version based on the selected model
model_version = st.selectbox("Select Version", model_versions)

# Now retrieve the model object based on the selected model and version
model = session.query(MLModelSQL).filter(
MLModelSQL.name == model_name, MLModelSQL.version == model_version
).one_or_none()

if model is not None:
# Display the current description
new_description = st.text_area("Current Description", model.description)

# Update the description when the button is clicked
if st.button("Update Description"):
model.description = new_description
session.commit()
st.success("Description updated successfully.")
else:
st.error("Model version not found in the database.")
except Exception as e:
st.error(f"Failed to fetch or update ML models: {e}")

# order by name
all_models = all_models.sort_values(by="name")

st.write("ML Models")
st.write(all_models)

# 3. Show site locations on the map
st.subheader("Site Locations on Map")
Expand All @@ -99,9 +148,11 @@ def mlmodel_page():
"longitude": getattr(site, "longitude", None),
"region": site.region,
"capacity_kw": site.capacity_kw,
"asset_type": str(site.asset_type),
"asset_type": str(site.asset_type),
}
if site_dict["latitude"] and site_dict["longitude"]: # Ensure latitude and longitude exist
if (
site_dict["latitude"] and site_dict["longitude"]
): # Ensure latitude and longitude exist
site_details.append(site_dict)

# Convert to DataFrame
Expand Down Expand Up @@ -136,7 +187,10 @@ def mlmodel_page():
"latitude": False,
"longitude": False,
},
color_discrete_map={"SiteAssetType.pv": "orange", "SiteAssetType.wind": "blue"},
color_discrete_map={
"SiteAssetType.pv": "orange",
"SiteAssetType.wind": "blue",
},
zoom=4,
height=600,
)
Expand All @@ -151,7 +205,16 @@ def mlmodel_page():
# Display site details in a table
st.subheader("Site Geographical Details")
st.dataframe(
map_data[["client_site_name", "region", "capacity_kw", "asset_type", "latitude", "longitude"]],
map_data[
[
"client_site_name",
"region",
"capacity_kw",
"asset_type",
"latitude",
"longitude",
]
],
use_container_width=True,
)
else:
Expand Down

0 comments on commit b7c2b8b

Please sign in to comment.