Skip to content

Commit

Permalink
Merge commit '1039f03ca728b3b74cc75fdadaf6f11796176433' into nwp-update
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Jan 27, 2025
2 parents bd02210 + 1039f03 commit d28b529
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 96 deletions.
53 changes: 1 addition & 52 deletions src/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,59 +108,8 @@ def get_metric_value(
return metric_values


# get all users
def get_all_users(session: Session) -> List[UserSQL]:
"""Get all users from the database.
:param session: database session
"""
query = session.query(UserSQL)

query = query.order_by(UserSQL.email.asc())

users = query.all()

return users


# get all site groups
def get_all_site_groups(session: Session) -> List[SiteGroupSQL]:
"""Get all users from the database.
:param session: database session
"""
query = session.query(SiteGroupSQL)

query = query.order_by(SiteGroupSQL.site_group_name.asc())

site_groups = query.all()

return site_groups


# update user site group; users only belong to one site group
def update_user_site_group(
session: Session, email: str, site_group_name: str
) -> UserSQL:
"""Change site group for user.
:param session: database session
:param email: email of user
:param site_group_name: name of site group
"""
site_group = (
session.query(SiteGroupSQL)
.filter(SiteGroupSQL.site_group_name == site_group_name)
.first()
)

user = session.query(UserSQL).filter(UserSQL.email == email)

user = user.update({"site_group_uuid": site_group.site_group_uuid})

session.commit()

return user


# get site group by name
def get_site_by_client_site_id(session: Session, client_site_id: str) -> List[SiteSQL]:
"""Get site by client site id.
:param session: database session
Expand All @@ -172,4 +121,4 @@ def get_site_by_client_site_id(session: Session, client_site_id: str) -> List[Si

site = query.first()

return site
return site
89 changes: 60 additions & 29 deletions src/pvsite_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import plotly.graph_objects as go
import pytz
from zoneinfo import ZoneInfo

# Penalty Calculator
def calculate_penalty(df, region, asset_type, capacity_kw):
Expand All @@ -25,16 +24,24 @@ def calculate_penalty(df, region, asset_type, capacity_kw):
# Define penalty bands for combinations of region and asset type
penalty_bands = {
("Rajasthan", "solar"): [
(10, 15, 0.1), # Band (lowest bound of the band range, highest bound of the band range, penalty that particular band carries)
(15, None, 1.0), # Band (lowest bound of the band range, no highest bound of the band range, penalty that particular band carries)
(
10,
15,
0.1,
), # Band (lowest bound of the band range, highest bound of the band range, penalty that particular band carries)
(
15,
None,
1.0,
), # Band (lowest bound of the band range, no highest bound of the band range, penalty that particular band carries)
],
("Madhya Pradesh", "wind"): [
(10, 20, 0.25),
(10, 20, 0.25),
(20, 30, 0.5),
(30, None, 0.75),
],
("Gujarat", "solar"): [
(7, 15, 0.25),
(7, 15, 0.25),
(15, 23, 0.5),
(23, None, 0.75),
],
Expand Down Expand Up @@ -71,7 +78,7 @@ def calculate_penalty(df, region, asset_type, capacity_kw):
for lower, upper, rate in bands:
mask = (deviation_percentage >= lower) if lower is not None else True
if upper is not None:
mask &= (deviation_percentage < upper)
mask &= deviation_percentage < upper
penalty[mask] += abs(deviation[mask]) * rate

# Calculate total penalty
Expand All @@ -80,7 +87,6 @@ def calculate_penalty(df, region, asset_type, capacity_kw):
return penalty, total_penalty



# Internal Dashboard
def pvsite_forecast_page():
"""Main page for pvsite forecast"""
Expand All @@ -97,7 +103,9 @@ def pvsite_forecast_page():
site_uuids = [sites.site_uuid for sites in sites if sites.site_uuid is not None]

# streamlit toggle between site_uuid and client_site_name
query_method = st.sidebar.radio("Select site by", ("site_uuid", "client_site_name"))
query_method = st.sidebar.radio(
"Select site by", ("site_uuid", "client_site_name")
)

if query_method == "site_uuid":
site_selection_uuid = st.sidebar.selectbox(
Expand All @@ -111,15 +119,21 @@ def pvsite_forecast_page():
sorted([sites.client_site_name for sites in sites]),
)
site_selection_uuid = [
sites.site_uuid for sites in sites if sites.client_site_name == client_site_name
sites.site_uuid
for sites in sites
if sites.client_site_name == client_site_name
][0]

timezone_selected = st.sidebar.selectbox("Select timezone", ["UTC", "Asia/Calcutta"])
timezone_selected = ZoneInfo(timezone_selected)
timezone_selected = st.sidebar.selectbox(
"Select timezone", ["UTC", "Asia/Calcutta"]
)
timezone_selected = pytz.timezone(timezone_selected)

day_after_tomorrow = datetime.today() + timedelta(days=3)
starttime = st.sidebar.date_input(
"Start Date", min_value=datetime.today() - timedelta(days=365), max_value=datetime.today()
"Start Date",
min_value=datetime.today() - timedelta(days=365),
max_value=datetime.today(),
)
endtime = st.sidebar.date_input("End Date", day_after_tomorrow)

Expand All @@ -139,7 +153,6 @@ def pvsite_forecast_page():
asset_type = site.asset_type # Assume site object has an 'asset_type' attribute
capacity_kw = site.capacity_kw # Extract capacity dynamically


if forecast_type == "Latest":
created = pd.Timestamp.utcnow().ceil("15min")
created = created.astimezone(timezone.utc)
Expand All @@ -158,7 +171,9 @@ def pvsite_forecast_page():
created = None

if forecast_type == "Forecast_horizon":
forecast_horizon = st.sidebar.selectbox("Select Forecast Horizon", range(0, 2880, 15), 6)
forecast_horizon = st.sidebar.selectbox(
"Select Forecast Horizon", range(0, 2880, 15), 6
)
else:
forecast_horizon = None

Expand Down Expand Up @@ -206,16 +221,18 @@ def pvsite_forecast_page():
endtime = datetime.combine(endtime, time.min)

# change to the correct timezone
starttime = starttime.replace(tzinfo=timezone_selected)
endtime = endtime.replace(tzinfo=timezone_selected)
# starttime = starttime.replace(tzinfo=timezone_selected)
# endtime = endtime.replace(tzinfo=timezone_selected)
starttime = timezone_selected.localize(starttime)
endtime = timezone_selected.localize(endtime)

# change to utc
starttime = starttime.astimezone(ZoneInfo("UTC"))
endtime = endtime.astimezone(ZoneInfo("UTC"))
starttime = starttime.astimezone(pytz.utc)
endtime = endtime.astimezone(pytz.utc)

if created is not None:
created = created.replace(tzinfo=timezone_selected) # Add timezone information to created
created = created.astimezone(ZoneInfo("UTC"))
created = timezone_selected.localize(created)
created = created.astimezone(pytz.utc)

# great ml model names for this site

Expand All @@ -231,8 +248,10 @@ def pvsite_forecast_page():
)

if len(ml_models) == 0:

class Models:
name = None

ml_models = [Models()]

ys = {}
Expand All @@ -257,7 +276,7 @@ class Models:
y = [i.forecast_power_kw for i in forecast]

# convert to timezone
x = [i.replace(tzinfo=ZoneInfo("UTC")) for i in x]
x = [i.replace(tzinfo=pytz.utc) for i in x]
x = [i.astimezone(timezone_selected) for i in x]

ys[model.name] = y
Expand All @@ -273,12 +292,16 @@ class Models:
)

yy = [
generation.generation_power_kw for generation in generations if generation is not None
generation.generation_power_kw
for generation in generations
if generation is not None
]
xx = [
generation.start_utc for generation in generations if generation is not None
]
xx = [generation.start_utc for generation in generations if generation is not None]

# convert to timezone
xx = [i.replace(tzinfo=ZoneInfo("UTC")) for i in xx]
xx = [i.replace(tzinfo=pytz.utc) for i in xx]
xx = [i.astimezone(timezone_selected) for i in xx]

df_forecast = []
Expand All @@ -304,7 +327,9 @@ class Models:
df_generation = df_generation.resample(resample).mean()

# merge together
df_all = df_forecast.merge(df_generation, left_index=True, right_index=True, how="outer")
df_all = df_forecast.merge(
df_generation, left_index=True, right_index=True, how="outer"
)

# select variables
xx = df_all.index
Expand All @@ -313,7 +338,9 @@ class Models:
fig = go.Figure(
layout=go.Layout(
title=go.layout.Title(text="Latest Forecast for Selected Site"),
xaxis=go.layout.XAxis(title=go.layout.xaxis.Title(text=f"Time [{timezone_selected}]")),
xaxis=go.layout.XAxis(
title=go.layout.xaxis.Title(text=f"Time [{timezone_selected}]")
),
yaxis=go.layout.YAxis(title=go.layout.yaxis.Title(text="KW")),
legend=go.layout.Legend(title=go.layout.legend.Title(text="Chart Legend")),
)
Expand Down Expand Up @@ -366,7 +393,9 @@ def convert_df(df: pd.DataFrame):

# MAE and NMAE Calculator
mae_kw = (df["generation_power_kw"] - df[forecast_column]).abs().mean()
mae_mw = (df["generation_power_kw"] - df[forecast_column]).abs().mean() / 1000
mae_mw = (
df["generation_power_kw"] - df[forecast_column]
).abs().mean() / 1000
me_kw = (df["generation_power_kw"] - df[forecast_column]).mean()
mean_generation = df["generation_power_kw"].mean()
nmae = mae_kw / mean_generation * 100
Expand All @@ -389,10 +418,12 @@ def convert_df(df: pd.DataFrame):
"capacity": capacity,
"pearson_corr": pearson_corr,
}

if country == "india":
df["forecast_power_kw"] = df[forecast_column]
penalties, total_penalty = calculate_penalty(df, str(region), str(asset_type), capacity_kw)
penalties, total_penalty = calculate_penalty(
df, str(region), str(asset_type), capacity_kw
)
one_metric_data["total_penalty [INR]"] = total_penalty

metrics.append(one_metric_data)
Expand Down
3 changes: 2 additions & 1 deletion src/site_toolbox/site_group_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

from get_data import get_site_by_client_site_id

from pvsite_datamodel.write.user_and_site import (
add_site_to_site_group,
update_user_site_group,
Expand All @@ -28,7 +29,7 @@ def select_site_id(dbsession, query_method: str):
]
client_site_id = st.selectbox("Sites by client_site_id", client_site_ids)
site = get_site_by_client_site_id(
session=dbsession, client_site_id=client_site_id
session=dbsession, client_site_id=client_site_id,
)
selected_uuid = str(site.site_uuid)
elif query_method not in ["site_uuid", "client_site_id"]:
Expand Down
4 changes: 2 additions & 2 deletions src/sites_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import streamlit as st
from sqlalchemy import func
from pvsite_datamodel.connection import DatabaseConnection
# from pvsite_datamodel.write.user_and_site import create_site_group
from pvsite_datamodel.read import (
get_all_sites,
)
from pvsite_datamodel.read.model import get_models
from pvsite_datamodel.sqlmodels import SiteAssetType

from get_data import get_all_users, get_all_site_groups
from pvsite_datamodel.write.user_and_site import (
assign_model_name_to_site,
create_site,
Expand All @@ -21,6 +19,8 @@
create_site_group
)

from pvsite_datamodel.read.user import get_all_users, get_all_site_groups

from site_toolbox.get_details import (
get_user_details,
get_site_details,
Expand Down
16 changes: 8 additions & 8 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import streamlit as st
import json
import json
import requests
from datetime import datetime
from zoneinfo import ZoneInfo # Replacing pytz with zoneinfo
import pytz


def load_css(css_file):
"""Load CSS from a file."""
Expand All @@ -13,7 +14,6 @@ def load_css(css_file):
st.error(f"CSS file not found: {css_file}")



def parse_timestamp(status):
"""Parse the timestamp from the status object and return local time"""
timestamp = str(status.created_utc)
Expand All @@ -22,17 +22,17 @@ def parse_timestamp(status):
parsed_time = datetime.fromisoformat(timestamp)
except ValueError as e:
raise ValueError(f"Invalid timestamp format: {e}")

if parsed_time.tzinfo is not None:
utc_time = parsed_time.astimezone(ZoneInfo("UTC"))
utc_time = parsed_time.astimezone(pytz.utc)
else:
# If no timezone is specified, assume it's UTC
utc_time = parsed_time.replace(tzinfo=ZoneInfo("UTC"))
utc_time = parsed_time.replace(tzinfo=pytz.utc)

# Convert to specific timezone (Asia/Kolkata)
local_timezone = ZoneInfo("Asia/Kolkata")
local_timezone = pytz.timezone("Asia/Kolkata")
local_time = parsed_time.astimezone(local_timezone)

return local_time


Expand Down
6 changes: 2 additions & 4 deletions tests/test_get_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""tests for get_data.py"""
from get_data import (
get_all_users,
get_all_site_groups,
)

from pvsite_datamodel.read.user import get_all_users, get_all_site_groups
from pvsite_datamodel.read import get_all_sites


Expand Down

0 comments on commit d28b529

Please sign in to comment.