Skip to content

Commit

Permalink
Only stretch image to resize if format is Stretch (#352)
Browse files Browse the repository at this point in the history
* Only stretch image to resize if format is Stretch

* fix(pre_commit): 🎨 auto format pre-commit hooks

* Format

* Remove unused resize var

* reformat w latest version of ruff

* Initialize should_resize var outside of if

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SolomonLake and pre-commit-ci[bot] authored Jan 17, 2025
1 parent 594bff7 commit e7654dd
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 25 deletions.
4 changes: 2 additions & 2 deletions roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_key(api_key, model, notebook, num_retries=0):
num_retries += 1
return check_key(api_key, model, notebook, num_retries)
else:
raise RuntimeError("There was an error validating the api key with Roboflow" " server.")
raise RuntimeError("There was an error validating the api key with Roboflow server.")
else:
r = response.json()
return r
Expand Down Expand Up @@ -71,7 +71,7 @@ def login(workspace=None, force=False):
# default configuration location
conf_location = os.getenv("ROBOFLOW_CONFIG_DIR", default=default_path)
if os.path.isfile(conf_location) and not force:
write_line("You are already logged into Roboflow. To make a different login," "run roboflow.login(force=True).")
write_line("You are already logged into Roboflow. To make a different login,run roboflow.login(force=True).")
return None
# we could eventually return the workspace object here
# return Roboflow().workspace()
Expand Down
2 changes: 1 addition & 1 deletion roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def save_annotation(


def _save_annotation_url(api_key, project_url, name, image_id, job_name, is_prediction, overwrite=False):
url = f"{API_URL}/dataset/{project_url}/annotate/{image_id}?api_key={api_key}" f"&name={name}"
url = f"{API_URL}/dataset/{project_url}/annotate/{image_id}?api_key={api_key}&name={name}"
if job_name:
url += f"&jobName={job_name}"
if is_prediction:
Expand Down
4 changes: 2 additions & 2 deletions roboflow/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def generate_version(self, settings):
)

r = requests.post(
f"{API_URL}/{self.__workspace}/{self.__project_name}/" f"generate?api_key={self.__api_key}",
f"{API_URL}/{self.__workspace}/{self.__project_name}/generate?api_key={self.__api_key}",
json=settings,
)

Expand Down Expand Up @@ -426,7 +426,7 @@ def upload(

if not is_image:
raise RuntimeError(
"The image you provided {} is not a supported file format. We" " currently support: {}.".format(
"The image you provided {} is not a supported file format. We currently support: {}.".format(
image_path, ", ".join(ACCEPTED_IMAGE_FORMATS)
)
)
Expand Down
16 changes: 7 additions & 9 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
]

if not any(supported_model in model_type for supported_model in supported_models):
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))
raise (ValueError(f"Model type {model_type} not supported. Supported models are {supported_models}"))

if model_type.startswith(("paligemma", "paligemma2", "florence-2")):
if any(model in model_type for model in ["paligemma", "paligemma2", "florence-2"]):
Expand Down Expand Up @@ -648,7 +648,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
)
else:
if file in ["model_artifacts.json", "state_dict.pt"]:
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
raise (ValueError(f"File {file} not found. Please make sure to provide a valid model path."))

self.upload_zip(model_type, model_path)

Expand Down Expand Up @@ -761,7 +761,7 @@ def deploy_yolonas(self, model_type: str, model_path: str, filename: str = "weig
)
else:
if file in ["model_artifacts.json", filename]:
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
raise (ValueError(f"File {file} not found. Please make sure to provide a valid model path."))

self.upload_zip(model_type, model_path)

Expand Down Expand Up @@ -791,8 +791,7 @@ def upload_zip(self, model_type: str, model_path: str, model_file_name: str = "r

if self.public:
print(
"View the status of your deployment at:"
f" {APP_URL}/{self.workspace}/{self.project}/{self.version}"
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/{self.version}"
)
print(
"Share your model with the world at:"
Expand All @@ -801,8 +800,7 @@ def upload_zip(self, model_type: str, model_path: str, model_file_name: str = "r
)
else:
print(
"View the status of your deployment at:"
f" {APP_URL}/{self.workspace}/{self.project}/{self.version}"
f"View the status of your deployment at: {APP_URL}/{self.workspace}/{self.project}/{self.version}"
)

except Exception as e:
Expand All @@ -824,7 +822,7 @@ def bar_progress(current, total, width=80):
progress_message = (
"Downloading Dataset Version Zip in "
f"{location} to {format}: "
f"{current/total*100:.0f}% [{current} / {total}] bytes"
f"{current / total * 100:.0f}% [{current} / {total}] bytes"
)
sys.stdout.write("\r" + progress_message)
sys.stdout.flush()
Expand Down Expand Up @@ -923,7 +921,7 @@ def __get_format_identifier(self, format):

if not format:
raise RuntimeError(
"You must pass a format argument to version.download() or define a" " model in your Roboflow object"
"You must pass a format argument to version.download() or define a model in your Roboflow object"
)

friendly_formats = {"yolov5": "yolov5pytorch", "yolov7": "yolov7pytorch"}
Expand Down
2 changes: 1 addition & 1 deletion roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def two_stage_ocr(
# capture OCR results from cropped image
results.append(ocr_infer(croppedImg)["results"])
else:
print("please use an object detection model--can only perform two stage with" " bounding box results")
print("please use an object detection model--can only perform two stage with bounding box results")

return results

Expand Down
4 changes: 2 additions & 2 deletions roboflow/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_deployment(args):
print(json.dumps(msg, indent=2))
break

print(f'{datetime.now().strftime("%H:%M:%S")} Waiting for deployment {args.deployment_name} to be ready...\n')
print(f"{datetime.now().strftime('%H:%M:%S')} Waiting for deployment {args.deployment_name} to be ready...\n")
time.sleep(30)


Expand Down Expand Up @@ -278,7 +278,7 @@ def get_deployment_log(args):
continue
log_ids.add(log["insert_id"])
last_log_timestamp = log_timestamp
print(f'[{log_timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")}] {log["payload"]}')
print(f"[{log_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}] {log['payload']}")

if not args.follow:
break
Expand Down
11 changes: 7 additions & 4 deletions roboflow/models/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,25 @@ def predict( # type: ignore[override]
else:
self.__exception_check(image_path_check=image_path)

resize = False
original_dimensions = None
should_resize = False
# If image is local image
if not hosted:
import cv2
import numpy as np

should_resize = (
"resize" in self.preprocessing.keys() and "Stretch" in self.preprocessing["resize"]["format"]
)

if isinstance(image_path, str):
image = Image.open(image_path).convert("RGB")
dimensions = image.size
original_dimensions = copy.deepcopy(dimensions)

# Here we resize the image to the preprocessing settings
# before sending it over the wire
if "resize" in self.preprocessing.keys():
if should_resize:
if dimensions[0] > int(self.preprocessing["resize"]["width"]) or dimensions[1] > int(
self.preprocessing["resize"]["height"]
):
Expand All @@ -197,7 +201,6 @@ def predict( # type: ignore[override]
)
)
dimensions = image.size
resize = True

# Create buffer
buffered = io.BytesIO()
Expand Down Expand Up @@ -245,7 +248,7 @@ def predict( # type: ignore[override]
if self.format == "json":
resp_json = resp.json()

if resize and original_dimensions is not None:
if should_resize and original_dimensions is not None:
new_preds = []
for p in resp_json["predictions"]:
p["x"] = int(p["x"] * (int(original_dimensions[0]) / int(self.preprocessing["resize"]["width"])))
Expand Down
4 changes: 2 additions & 2 deletions roboflow/roboflowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _add_upload_parser(subparsers):
upload_parser.add_argument(
"-w",
dest="workspace",
help="specify a workspace url or id " "(will use default workspace if not specified)",
help="specify a workspace url or id (will use default workspace if not specified)",
)
upload_parser.add_argument(
"-p",
Expand Down Expand Up @@ -307,7 +307,7 @@ def _add_import_parser(subparsers):
import_parser.add_argument(
"-w",
dest="workspace",
help="specify a workspace url or id " "(will use default workspace if not specified)",
help="specify a workspace url or id (will use default workspace if not specified)",
)
import_parser.add_argument(
"-p",
Expand Down
4 changes: 2 additions & 2 deletions roboflow/util/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_wrong_dependencies_versions(
module = import_module(dependency)
module_version = module.__version__
if order not in order_funcs:
raise ValueError(f"order={order} not supported, please use" f" `{', '.join(order_funcs.keys())}`")
raise ValueError(f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`")

is_okay = order_funcs[order](Version(module_version), Version(version))
if not is_okay:
Expand All @@ -53,7 +53,7 @@ def print_warn_for_wrong_dependencies_versions(
f" {dependency}{order}{version}`"
)
if ask_to_continue:
answer = input(f"Would you like to continue with the wrong version of {dependency}?" " y/n: ")
answer = input(f"Would you like to continue with the wrong version of {dependency}? y/n: ")
if answer.lower() != "y":
sys.exit(1)

Expand Down

0 comments on commit e7654dd

Please sign in to comment.