init commit
This commit is contained in:
169
ultralytics/hub/__init__.py
Normal file
169
ultralytics/hub/__init__.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
||||
from ultralytics.utils import LOGGER, SETTINGS, checks
|
||||
|
||||
__all__ = (
|
||||
"PREFIX",
|
||||
"HUB_WEB_ROOT",
|
||||
"HUBTrainingSession",
|
||||
"login",
|
||||
"logout",
|
||||
"reset_model",
|
||||
"export_fmts_hub",
|
||||
"export_model",
|
||||
"get_export",
|
||||
"check_dataset",
|
||||
)
|
||||
|
||||
|
||||
def login(api_key: str = None, save: bool = True) -> bool:
|
||||
"""
|
||||
Log in to the Ultralytics HUB API using the provided API key.
|
||||
|
||||
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
||||
environment variable if successfully authenticated.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
|
||||
SETTINGS or HUB_API_KEY environment variable.
|
||||
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
||||
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
checks.check_requirements("hub-sdk>=0.0.12")
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
||||
saved_key = SETTINGS.get("api_key")
|
||||
active_key = api_key or saved_key
|
||||
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
|
||||
|
||||
client = HUBClient(credentials) # initialize HUBClient
|
||||
|
||||
if client.authenticated:
|
||||
# Successfully authenticated with HUB
|
||||
|
||||
if save and client.api_key != saved_key:
|
||||
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
|
||||
|
||||
# Set message based on whether key was provided or retrieved from settings
|
||||
log_message = (
|
||||
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
|
||||
)
|
||||
LOGGER.info(f"{PREFIX}{log_message}")
|
||||
|
||||
return True
|
||||
else:
|
||||
# Failed to authenticate with HUB
|
||||
LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'")
|
||||
return False
|
||||
|
||||
|
||||
def logout():
|
||||
"""Log out of Ultralytics HUB by removing the API key from the settings file."""
|
||||
SETTINGS["api_key"] = ""
|
||||
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
|
||||
|
||||
|
||||
def reset_model(model_id: str = ""):
|
||||
"""Reset a trained model to an untrained state."""
|
||||
import requests # scoped as slow import
|
||||
|
||||
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
||||
if r.status_code == 200:
|
||||
LOGGER.info(f"{PREFIX}Model reset successfully")
|
||||
return
|
||||
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
|
||||
|
||||
|
||||
def export_fmts_hub():
|
||||
"""Return a list of HUB-supported export formats."""
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
|
||||
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
|
||||
|
||||
def export_model(model_id: str = "", format: str = "torchscript"):
|
||||
"""
|
||||
Export a model to a specified format for deployment via the Ultralytics HUB API.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the model to export. An empty string will use the default model.
|
||||
format (str): The format to export the model to. Must be one of the supported formats returned by
|
||||
export_fmts_hub().
|
||||
|
||||
Raises:
|
||||
AssertionError: If the specified format is not supported or if the export request fails.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import hub
|
||||
>>> hub.export_model(model_id="your_model_id", format="torchscript")
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||
r = requests.post(
|
||||
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
||||
)
|
||||
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
||||
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
||||
|
||||
|
||||
def get_export(model_id: str = "", format: str = "torchscript"):
|
||||
"""
|
||||
Retrieve an exported model in the specified format from ultralytics HUB using the model ID.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the model to retrieve from ultralytics HUB.
|
||||
format (str): The export format to retrieve. Must be one of the supported formats returned by
|
||||
export_fmts_hub().
|
||||
|
||||
Returns:
|
||||
(dict): JSON response containing the exported model information.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the specified format is not supported or if the API request fails.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import hub
|
||||
>>> result = hub.get_export(model_id="your_model_id", format="torchscript")
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||
r = requests.post(
|
||||
f"{HUB_API_ROOT}/get-export",
|
||||
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
|
||||
headers={"x-api-key": Auth().api_key},
|
||||
)
|
||||
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
||||
return r.json()
|
||||
|
||||
|
||||
def check_dataset(path: str, task: str) -> None:
|
||||
"""
|
||||
Check HUB dataset Zip file for errors before upload.
|
||||
|
||||
Args:
|
||||
path (str): Path to data.zip (with data.yaml inside data.zip).
|
||||
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.hub import check_dataset
|
||||
>>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset
|
||||
>>> check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset
|
||||
>>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
|
||||
>>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
||||
>>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
||||
|
||||
Notes:
|
||||
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
||||
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
||||
"""
|
||||
HUBDatasetStats(path=path, task=task).get_json()
|
||||
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|
||||
BIN
ultralytics/hub/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/hub/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/hub/__pycache__/auth.cpython-310.pyc
Normal file
BIN
ultralytics/hub/__pycache__/auth.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/hub/__pycache__/session.cpython-310.pyc
Normal file
BIN
ultralytics/hub/__pycache__/session.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/hub/__pycache__/utils.cpython-310.pyc
Normal file
BIN
ultralytics/hub/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
157
ultralytics/hub/auth.py
Normal file
157
ultralytics/hub/auth.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
||||
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis
|
||||
|
||||
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
||||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
||||
|
||||
The class supports different methods of authentication:
|
||||
1. Directly using an API key.
|
||||
2. Authenticating using browser cookies (specifically in Google Colab).
|
||||
3. Prompting the user to enter an API key.
|
||||
|
||||
Attributes:
|
||||
id_token (str | bool): Token used for identity verification, initialized as False.
|
||||
api_key (str | bool): API key for authentication, initialized as False.
|
||||
model_key (bool): Placeholder for model key, initialized as False.
|
||||
|
||||
Methods:
|
||||
authenticate: Attempt to authenticate with the server using either id_token or API key.
|
||||
auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.
|
||||
get_auth_header: Get the authentication header for making API requests.
|
||||
request_api_key: Prompt the user to input their API key.
|
||||
|
||||
Examples:
|
||||
Initialize Auth with an API key
|
||||
>>> auth = Auth(api_key="your_api_key_here")
|
||||
|
||||
Initialize Auth without API key (will prompt for input)
|
||||
>>> auth = Auth()
|
||||
"""
|
||||
|
||||
id_token = api_key = model_key = False
|
||||
|
||||
def __init__(self, api_key: str = "", verbose: bool = False):
|
||||
"""
|
||||
Initialize Auth class and authenticate user.
|
||||
|
||||
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
||||
authentication.
|
||||
|
||||
Args:
|
||||
api_key (str): API key or combined key_id format.
|
||||
verbose (bool): Enable verbose logging.
|
||||
"""
|
||||
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
||||
api_key = api_key.split("_", 1)[0]
|
||||
|
||||
# Set API key attribute as value passed or SETTINGS API key if none passed
|
||||
self.api_key = api_key or SETTINGS.get("api_key", "")
|
||||
|
||||
# If an API key is provided
|
||||
if self.api_key:
|
||||
# If the provided API key matches the API key in the SETTINGS
|
||||
if self.api_key == SETTINGS.get("api_key"):
|
||||
# Log that the user is already logged in
|
||||
if verbose:
|
||||
LOGGER.info(f"{PREFIX}Authenticated ✅")
|
||||
return
|
||||
else:
|
||||
# Attempt to authenticate with the provided API key
|
||||
success = self.authenticate()
|
||||
# If the API key is not provided and the environment is a Google Colab notebook
|
||||
elif IS_COLAB:
|
||||
# Attempt to authenticate using browser cookies
|
||||
success = self.auth_with_cookies()
|
||||
else:
|
||||
# Request an API key
|
||||
success = self.request_api_key()
|
||||
|
||||
# Update SETTINGS with the new API key after successful authentication
|
||||
if success:
|
||||
SETTINGS.update({"api_key": self.api_key})
|
||||
# Log that the new login was successful
|
||||
if verbose:
|
||||
LOGGER.info(f"{PREFIX}New authentication successful ✅")
|
||||
elif verbose:
|
||||
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
||||
|
||||
def request_api_key(self, max_attempts: int = 3) -> bool:
|
||||
"""
|
||||
Prompt the user to input their API key.
|
||||
|
||||
Args:
|
||||
max_attempts (int): Maximum number of authentication attempts.
|
||||
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
import getpass
|
||||
|
||||
for attempts in range(max_attempts):
|
||||
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
||||
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
||||
self.api_key = input_key.split("_", 1)[0] # remove model id if present
|
||||
if self.authenticate():
|
||||
return True
|
||||
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
||||
|
||||
def authenticate(self) -> bool:
|
||||
"""
|
||||
Attempt to authenticate with the server using either id_token or API key.
|
||||
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
try:
|
||||
if header := self.get_auth_header():
|
||||
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
||||
if not r.json().get("success", False):
|
||||
raise ConnectionError("Unable to authenticate.")
|
||||
return True
|
||||
raise ConnectionError("User has not authenticated locally.")
|
||||
except ConnectionError:
|
||||
self.id_token = self.api_key = False # reset invalid
|
||||
LOGGER.warning(f"{PREFIX}Invalid API key")
|
||||
return False
|
||||
|
||||
def auth_with_cookies(self) -> bool:
|
||||
"""
|
||||
Attempt to fetch authentication via cookies and set id_token.
|
||||
|
||||
User must be logged in to HUB and running in a supported browser.
|
||||
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
if not IS_COLAB:
|
||||
return False # Currently only works with Colab
|
||||
try:
|
||||
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
|
||||
if authn.get("success", False):
|
||||
self.id_token = authn.get("data", {}).get("idToken", None)
|
||||
self.authenticate()
|
||||
return True
|
||||
raise ConnectionError("Unable to fetch browser authentication details.")
|
||||
except ConnectionError:
|
||||
self.id_token = False # reset invalid
|
||||
return False
|
||||
|
||||
def get_auth_header(self):
|
||||
"""
|
||||
Get the authentication header for making API requests.
|
||||
|
||||
Returns:
|
||||
(dict | None): The authentication header if id_token or API key is set, None otherwise.
|
||||
"""
|
||||
if self.id_token:
|
||||
return {"authorization": f"Bearer {self.id_token}"}
|
||||
elif self.api_key:
|
||||
return {"x-api-key": self.api_key}
|
||||
177
ultralytics/hub/google/__init__.py
Normal file
177
ultralytics/hub/google/__init__.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import statistics
|
||||
import time
|
||||
|
||||
|
||||
class GCPRegions:
|
||||
"""
|
||||
A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
||||
|
||||
This class provides functionality to initialize, categorize, and analyze GCP regions based on their
|
||||
geographical location, tier classification, and network latency.
|
||||
|
||||
Attributes:
|
||||
regions (dict[str, tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
|
||||
|
||||
Methods:
|
||||
tier1: Returns a list of tier 1 GCP regions.
|
||||
tier2: Returns a list of tier 2 GCP regions.
|
||||
lowest_latency: Determines the GCP region(s) with the lowest network latency.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.hub.google import GCPRegions
|
||||
>>> regions = GCPRegions()
|
||||
>>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)
|
||||
>>> print(f"Lowest latency region: {lowest_latency_region[0][0]}")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details."""
|
||||
self.regions = {
|
||||
"asia-east1": (1, "Taiwan", "China"),
|
||||
"asia-east2": (2, "Hong Kong", "China"),
|
||||
"asia-northeast1": (1, "Tokyo", "Japan"),
|
||||
"asia-northeast2": (1, "Osaka", "Japan"),
|
||||
"asia-northeast3": (2, "Seoul", "South Korea"),
|
||||
"asia-south1": (2, "Mumbai", "India"),
|
||||
"asia-south2": (2, "Delhi", "India"),
|
||||
"asia-southeast1": (2, "Jurong West", "Singapore"),
|
||||
"asia-southeast2": (2, "Jakarta", "Indonesia"),
|
||||
"australia-southeast1": (2, "Sydney", "Australia"),
|
||||
"australia-southeast2": (2, "Melbourne", "Australia"),
|
||||
"europe-central2": (2, "Warsaw", "Poland"),
|
||||
"europe-north1": (1, "Hamina", "Finland"),
|
||||
"europe-southwest1": (1, "Madrid", "Spain"),
|
||||
"europe-west1": (1, "St. Ghislain", "Belgium"),
|
||||
"europe-west10": (2, "Berlin", "Germany"),
|
||||
"europe-west12": (2, "Turin", "Italy"),
|
||||
"europe-west2": (2, "London", "United Kingdom"),
|
||||
"europe-west3": (2, "Frankfurt", "Germany"),
|
||||
"europe-west4": (1, "Eemshaven", "Netherlands"),
|
||||
"europe-west6": (2, "Zurich", "Switzerland"),
|
||||
"europe-west8": (1, "Milan", "Italy"),
|
||||
"europe-west9": (1, "Paris", "France"),
|
||||
"me-central1": (2, "Doha", "Qatar"),
|
||||
"me-west1": (1, "Tel Aviv", "Israel"),
|
||||
"northamerica-northeast1": (2, "Montreal", "Canada"),
|
||||
"northamerica-northeast2": (2, "Toronto", "Canada"),
|
||||
"southamerica-east1": (2, "São Paulo", "Brazil"),
|
||||
"southamerica-west1": (2, "Santiago", "Chile"),
|
||||
"us-central1": (1, "Iowa", "United States"),
|
||||
"us-east1": (1, "South Carolina", "United States"),
|
||||
"us-east4": (1, "Northern Virginia", "United States"),
|
||||
"us-east5": (1, "Columbus", "United States"),
|
||||
"us-south1": (1, "Dallas", "United States"),
|
||||
"us-west1": (1, "Oregon", "United States"),
|
||||
"us-west2": (2, "Los Angeles", "United States"),
|
||||
"us-west3": (2, "Salt Lake City", "United States"),
|
||||
"us-west4": (2, "Las Vegas", "United States"),
|
||||
}
|
||||
|
||||
def tier1(self) -> list[str]:
|
||||
"""Return a list of GCP regions classified as tier 1 based on predefined criteria."""
|
||||
return [region for region, info in self.regions.items() if info[0] == 1]
|
||||
|
||||
def tier2(self) -> list[str]:
|
||||
"""Return a list of GCP regions classified as tier 2 based on predefined criteria."""
|
||||
return [region for region, info in self.regions.items() if info[0] == 2]
|
||||
|
||||
@staticmethod
|
||||
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
|
||||
"""
|
||||
Ping a specified GCP region and measure network latency statistics.
|
||||
|
||||
Args:
|
||||
region (str): The GCP region identifier to ping (e.g., 'us-central1').
|
||||
attempts (int, optional): Number of ping attempts to make for calculating statistics.
|
||||
|
||||
Returns:
|
||||
region (str): The GCP region identifier that was pinged.
|
||||
mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
|
||||
std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
|
||||
min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
|
||||
max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
|
||||
|
||||
Examples:
|
||||
>>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
|
||||
>>> print(f"Region {region} has mean latency: {mean:.2f}ms")
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
url = f"https://{region}-docker.pkg.dev"
|
||||
latencies = []
|
||||
for _ in range(attempts):
|
||||
try:
|
||||
start_time = time.time()
|
||||
_ = requests.head(url, timeout=5)
|
||||
latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds
|
||||
if latency != float("inf"):
|
||||
latencies.append(latency)
|
||||
except requests.RequestException:
|
||||
pass
|
||||
if not latencies:
|
||||
return region, float("inf"), float("inf"), float("inf"), float("inf")
|
||||
|
||||
std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0
|
||||
return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)
|
||||
|
||||
def lowest_latency(
|
||||
self,
|
||||
top: int = 1,
|
||||
verbose: bool = False,
|
||||
tier: int | None = None,
|
||||
attempts: int = 1,
|
||||
) -> list[tuple[str, float, float, float, float]]:
|
||||
"""
|
||||
Determine the GCP regions with the lowest latency based on ping tests.
|
||||
|
||||
Args:
|
||||
top (int, optional): Number of top regions to return.
|
||||
verbose (bool, optional): If True, prints detailed latency information for all tested regions.
|
||||
tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.
|
||||
attempts (int, optional): Number of ping attempts per region.
|
||||
|
||||
Returns:
|
||||
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and
|
||||
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
||||
|
||||
Examples:
|
||||
>>> regions = GCPRegions()
|
||||
>>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)
|
||||
>>> print(results[0][0]) # Print the name of the lowest latency region
|
||||
"""
|
||||
if verbose:
|
||||
print(f"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...")
|
||||
|
||||
regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys())
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
|
||||
results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))
|
||||
|
||||
sorted_results = sorted(results, key=lambda x: x[1])
|
||||
|
||||
if verbose:
|
||||
print(f"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)")
|
||||
for region, mean, std, min_, max_ in sorted_results:
|
||||
tier, city, country = self.regions[region]
|
||||
location = f"{city}, {country}"
|
||||
if mean == float("inf"):
|
||||
print(f"{region:<25} {location:<35} {tier:<5} Timeout")
|
||||
else:
|
||||
print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})")
|
||||
print(f"\nLowest latency region{'s' if top > 1 else ''}:")
|
||||
for region, mean, std, min_, max_ in sorted_results[:top]:
|
||||
tier, city, country = self.regions[region]
|
||||
location = f"{city}, {country}"
|
||||
print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))")
|
||||
|
||||
return sorted_results[:top]
|
||||
|
||||
|
||||
# Usage example
|
||||
if __name__ == "__main__":
|
||||
regions = GCPRegions()
|
||||
top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3)
|
||||
432
ultralytics/hub/session.py
Normal file
432
ultralytics/hub/session.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from ultralytics import __version__
|
||||
from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX
|
||||
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, TQDM, checks, emojis
|
||||
from ultralytics.utils.errors import HUBModelError
|
||||
|
||||
AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local"
|
||||
|
||||
|
||||
class HUBTrainingSession:
|
||||
"""
|
||||
HUB training session for Ultralytics HUB YOLO models.
|
||||
|
||||
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
|
||||
model creation, metrics tracking, and checkpoint uploading.
|
||||
|
||||
Attributes:
|
||||
model_id (str): Identifier for the YOLO model being trained.
|
||||
model_url (str): URL for the model in Ultralytics HUB.
|
||||
rate_limits (dict[str, int]): Rate limits for different API calls in seconds.
|
||||
timers (dict[str, Any]): Timers for rate limiting.
|
||||
metrics_queue (dict[str, Any]): Queue for the model's metrics.
|
||||
metrics_upload_failed_queue (dict[str, Any]): Queue for metrics that failed to upload.
|
||||
model (Any): Model data fetched from ultralytics HUB.
|
||||
model_file (str): Path to the model file.
|
||||
train_args (dict[str, Any]): Arguments for training the model.
|
||||
client (Any): Client for interacting with Ultralytics HUB.
|
||||
filename (str): Filename of the model.
|
||||
|
||||
Examples:
|
||||
Create a training session with a model URL
|
||||
>>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model")
|
||||
>>> session.upload_metrics()
|
||||
"""
|
||||
|
||||
def __init__(self, identifier: str):
|
||||
"""
|
||||
Initialize the HUBTrainingSession with the provided model identifier.
|
||||
|
||||
Args:
|
||||
identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
|
||||
or a model key with specific format.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided model identifier is invalid.
|
||||
ConnectionError: If connecting with global API key is not supported.
|
||||
ModuleNotFoundError: If hub-sdk package is not installed.
|
||||
"""
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds)
|
||||
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
||||
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||
self.model = None
|
||||
self.model_url = None
|
||||
self.model_file = None
|
||||
self.train_args = None
|
||||
|
||||
# Parse input
|
||||
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||
|
||||
# Get credentials
|
||||
active_key = api_key or SETTINGS.get("api_key")
|
||||
credentials = {"api_key": active_key} if active_key else None # set credentials
|
||||
|
||||
# Initialize client
|
||||
self.client = HUBClient(credentials)
|
||||
|
||||
# Load models
|
||||
try:
|
||||
if model_id:
|
||||
self.load_model(model_id) # load existing model
|
||||
else:
|
||||
self.model = self.client.model() # load empty model
|
||||
except Exception:
|
||||
if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated:
|
||||
LOGGER.warning(
|
||||
f"{PREFIX}Please log in using 'yolo login API_KEY'. "
|
||||
"You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, identifier: str, args: dict[str, Any] | None = None):
|
||||
"""
|
||||
Create an authenticated HUBTrainingSession or return None.
|
||||
|
||||
Args:
|
||||
identifier (str): Model identifier used to initialize the HUB training session.
|
||||
args (dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.
|
||||
|
||||
Returns:
|
||||
session (HUBTrainingSession | None): An authenticated session or None if creation fails.
|
||||
"""
|
||||
try:
|
||||
session = cls(identifier)
|
||||
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
||||
session.create_model(args)
|
||||
assert session.model.id, "HUB model not loaded correctly"
|
||||
return session
|
||||
# PermissionError and ModuleNotFoundError indicate hub-sdk not installed
|
||||
except (PermissionError, ModuleNotFoundError, AssertionError):
|
||||
return None
|
||||
|
||||
def load_model(self, model_id: str):
|
||||
"""
|
||||
Load an existing model from ultralytics HUB using the provided model identifier.
|
||||
|
||||
Args:
|
||||
model_id (str): The identifier of the model to load.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified HUB model does not exist.
|
||||
"""
|
||||
self.model = self.client.model(model_id)
|
||||
if not self.model.data: # then model does not exist
|
||||
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
||||
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
if self.model.is_trained():
|
||||
LOGGER.info(f"Loading trained HUB model {self.model_url} 🚀")
|
||||
url = self.model.get_weights_url("best") # download URL with auth
|
||||
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
|
||||
return
|
||||
|
||||
# Set training args and start heartbeats for HUB to monitor agent
|
||||
self._set_train_args()
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
def create_model(self, model_args: dict[str, Any]):
|
||||
"""
|
||||
Initialize a HUB training session with the specified model arguments.
|
||||
|
||||
Args:
|
||||
model_args (dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
|
||||
etc.
|
||||
|
||||
Returns:
|
||||
(None): If the model could not be created.
|
||||
"""
|
||||
payload = {
|
||||
"config": {
|
||||
"batchSize": model_args.get("batch", -1),
|
||||
"epochs": model_args.get("epochs", 300),
|
||||
"imageSize": model_args.get("imgsz", 640),
|
||||
"patience": model_args.get("patience", 100),
|
||||
"device": str(model_args.get("device", "")), # convert None to string
|
||||
"cache": str(model_args.get("cache", "ram")), # convert True, False, None to string
|
||||
},
|
||||
"dataset": {"name": model_args.get("data")},
|
||||
"lineage": {
|
||||
"architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
|
||||
"parent": {},
|
||||
},
|
||||
"meta": {"name": self.filename},
|
||||
}
|
||||
|
||||
if self.filename.endswith(".pt"):
|
||||
payload["lineage"]["parent"]["name"] = self.filename
|
||||
|
||||
self.model.create_model(payload)
|
||||
|
||||
# Model could not be created
|
||||
# TODO: improve error handling
|
||||
if not self.model.id:
|
||||
return None
|
||||
|
||||
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
||||
|
||||
# Start heartbeats for HUB to monitor agent
|
||||
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
||||
|
||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||
|
||||
@staticmethod
|
||||
def _parse_identifier(identifier: str):
|
||||
"""
|
||||
Parse the given identifier to determine the type and extract relevant components.
|
||||
|
||||
The method supports different identifier formats:
|
||||
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
||||
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
|
||||
- A local filename that ends with '.pt' or '.yaml'
|
||||
|
||||
Args:
|
||||
identifier (str): The identifier string to be parsed.
|
||||
|
||||
Returns:
|
||||
api_key (str | None): Extracted API key if present.
|
||||
model_id (str | None): Extracted model ID if present.
|
||||
filename (str | None): Extracted filename if present.
|
||||
|
||||
Raises:
|
||||
HUBModelError: If the identifier format is not recognized.
|
||||
"""
|
||||
api_key, model_id, filename = None, None, None
|
||||
if identifier.endswith((".pt", ".yaml")):
|
||||
filename = identifier
|
||||
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
||||
parsed_url = urlparse(identifier)
|
||||
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
|
||||
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
|
||||
api_key = query_params.get("api_key", [None])[0]
|
||||
else:
|
||||
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self):
|
||||
"""
|
||||
Initialize training arguments and create a model entry on the Ultralytics HUB.
|
||||
|
||||
This method sets up training arguments based on the model's state and updates them with any additional
|
||||
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
|
||||
or requires specific file setup.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
||||
issues with the provided training arguments.
|
||||
"""
|
||||
if self.model.is_resumable():
|
||||
# Model has saved weights
|
||||
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
||||
self.model_file = self.model.get_weights_url("last")
|
||||
else:
|
||||
# Model has no saved weights
|
||||
self.train_args = self.model.data.get("train_args") # new response
|
||||
|
||||
# Set the model file as either a *.pt or *.yaml file
|
||||
self.model_file = (
|
||||
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
||||
)
|
||||
|
||||
if "data" not in self.train_args:
|
||||
# RF bug - datasets are sometimes not exported
|
||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
|
||||
|
||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||
self.model_id = self.model.id
|
||||
|
||||
def request_queue(
|
||||
self,
|
||||
request_func,
|
||||
retry: int = 3,
|
||||
timeout: int = 30,
|
||||
thread: bool = True,
|
||||
verbose: bool = True,
|
||||
progress_total: int | None = None,
|
||||
stream_response: bool | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute request_func with retries, timeout handling, optional threading, and progress tracking.
|
||||
|
||||
Args:
|
||||
request_func (callable): The function to execute.
|
||||
retry (int): Number of retry attempts.
|
||||
timeout (int): Maximum time to wait for the request to complete.
|
||||
thread (bool): Whether to run the request in a separate thread.
|
||||
verbose (bool): Whether to log detailed messages.
|
||||
progress_total (int, optional): Total size for progress tracking.
|
||||
stream_response (bool, optional): Whether to stream the response.
|
||||
*args (Any): Additional positional arguments for request_func.
|
||||
**kwargs (Any): Additional keyword arguments for request_func.
|
||||
|
||||
Returns:
|
||||
(requests.Response | None): The response object if thread=False, otherwise None.
|
||||
"""
|
||||
|
||||
def retry_request():
|
||||
"""Attempt to call request_func with retries, timeout, and optional threading."""
|
||||
t0 = time.time() # Record the start time for the timeout
|
||||
response = None
|
||||
for i in range(retry + 1):
|
||||
if (time.time() - t0) > timeout:
|
||||
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
||||
break # Timeout reached, exit loop
|
||||
|
||||
response = request_func(*args, **kwargs)
|
||||
if response is None:
|
||||
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
||||
time.sleep(2**i) # Exponential backoff before retrying
|
||||
continue # Skip further processing and retry
|
||||
|
||||
if progress_total:
|
||||
self._show_upload_progress(progress_total, response)
|
||||
elif stream_response:
|
||||
self._iterate_content(response)
|
||||
|
||||
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
# if request related to metrics upload
|
||||
if kwargs.get("metrics"):
|
||||
self.metrics_upload_failed_queue = {}
|
||||
return response # Success, no need to retry
|
||||
|
||||
if i == 0:
|
||||
# Initial attempt, check status code and provide messages
|
||||
message = self._get_failure_message(response, retry, timeout)
|
||||
|
||||
if verbose:
|
||||
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
|
||||
|
||||
if not self._should_retry(response.status_code):
|
||||
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
|
||||
break # Not an error that should be retried, exit loop
|
||||
|
||||
time.sleep(2**i) # Exponential backoff for retries
|
||||
|
||||
# if request related to metrics upload and exceed retries
|
||||
if response is None and kwargs.get("metrics"):
|
||||
self.metrics_upload_failed_queue.update(kwargs.get("metrics"))
|
||||
|
||||
return response
|
||||
|
||||
if thread:
|
||||
# Start a new thread to run the retry_request function
|
||||
threading.Thread(target=retry_request, daemon=True).start()
|
||||
else:
|
||||
# If running in the main thread, call retry_request directly
|
||||
return retry_request()
|
||||
|
||||
@staticmethod
|
||||
def _should_retry(status_code: int) -> bool:
|
||||
"""Determine if a request should be retried based on the HTTP status code."""
|
||||
retry_codes = {
|
||||
HTTPStatus.REQUEST_TIMEOUT,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
HTTPStatus.GATEWAY_TIMEOUT,
|
||||
}
|
||||
return status_code in retry_codes
|
||||
|
||||
def _get_failure_message(self, response, retry: int, timeout: int) -> str:
|
||||
"""
|
||||
Generate a retry message based on the response status code.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The HTTP response object.
|
||||
retry (int): The number of retry attempts allowed.
|
||||
timeout (int): The maximum timeout duration.
|
||||
|
||||
Returns:
|
||||
(str): The retry message.
|
||||
"""
|
||||
if self._should_retry(response.status_code):
|
||||
return f"Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
||||
headers = response.headers
|
||||
return (
|
||||
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {headers['Retry-After']}s."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return response.json().get("message", "No JSON message.")
|
||||
except AttributeError:
|
||||
return "Unable to read JSON."
|
||||
|
||||
def upload_metrics(self):
|
||||
"""Upload model metrics to Ultralytics HUB."""
|
||||
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
|
||||
|
||||
def upload_model(
|
||||
self,
|
||||
epoch: int,
|
||||
weights: str,
|
||||
is_best: bool = False,
|
||||
map: float = 0.0,
|
||||
final: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a model checkpoint to Ultralytics HUB.
|
||||
|
||||
Args:
|
||||
epoch (int): The current training epoch.
|
||||
weights (str): Path to the model weights file.
|
||||
is_best (bool): Indicates if the current model is the best one so far.
|
||||
map (float): Mean average precision of the model.
|
||||
final (bool): Indicates if the model is the final model after training.
|
||||
"""
|
||||
weights = Path(weights)
|
||||
if not weights.is_file():
|
||||
last = weights.with_name(f"last{weights.suffix}")
|
||||
if final and last.is_file():
|
||||
LOGGER.warning(
|
||||
f"{PREFIX} Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
|
||||
"This often happens when resuming training in transient environments like Google Colab. "
|
||||
"For more reliable training, consider using Ultralytics HUB Cloud. "
|
||||
"Learn more at https://docs.ultralytics.com/hub/cloud-training."
|
||||
)
|
||||
shutil.copy(last, weights) # copy last.pt to best.pt
|
||||
else:
|
||||
LOGGER.warning(f"{PREFIX} Model upload issue. Missing model {weights}.")
|
||||
return
|
||||
|
||||
self.request_queue(
|
||||
self.model.upload_model,
|
||||
epoch=epoch,
|
||||
weights=str(weights),
|
||||
is_best=is_best,
|
||||
map=map,
|
||||
final=final,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
thread=not final,
|
||||
progress_total=weights.stat().st_size if final else None, # only show progress if final
|
||||
stream_response=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _show_upload_progress(content_length: int, response) -> None:
|
||||
"""Display a progress bar to track the upload progress of a file download."""
|
||||
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
|
||||
@staticmethod
|
||||
def _iterate_content(response) -> None:
|
||||
"""Process the streamed HTTP response data."""
|
||||
for _ in response.iter_content(chunk_size=1024):
|
||||
pass # Do nothing with data chunks
|
||||
165
ultralytics/hub/utils.py
Normal file
165
ultralytics/hub/utils.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.utils import (
|
||||
IS_COLAB,
|
||||
LOGGER,
|
||||
TQDM,
|
||||
TryExcept,
|
||||
colorstr,
|
||||
)
|
||||
|
||||
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
||||
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
|
||||
|
||||
PREFIX = colorstr("Ultralytics HUB: ")
|
||||
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
|
||||
|
||||
|
||||
def request_with_credentials(url: str) -> Any:
|
||||
"""
|
||||
Make an AJAX request with cookies attached in a Google Colab environment.
|
||||
|
||||
Args:
|
||||
url (str): The URL to make the request to.
|
||||
|
||||
Returns:
|
||||
(Any): The response data from the AJAX request.
|
||||
|
||||
Raises:
|
||||
OSError: If the function is not run in a Google Colab environment.
|
||||
"""
|
||||
if not IS_COLAB:
|
||||
raise OSError("request_with_credentials() must run in a Colab environment")
|
||||
from google.colab import output # noqa
|
||||
from IPython import display # noqa
|
||||
|
||||
display.display(
|
||||
display.Javascript(
|
||||
f"""
|
||||
window._hub_tmp = new Promise((resolve, reject) => {{
|
||||
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
||||
fetch("{url}", {{
|
||||
method: 'POST',
|
||||
credentials: 'include'
|
||||
}})
|
||||
.then((response) => resolve(response.json()))
|
||||
.then((json) => {{
|
||||
clearTimeout(timeout);
|
||||
}}).catch((err) => {{
|
||||
clearTimeout(timeout);
|
||||
reject(err);
|
||||
}});
|
||||
}});
|
||||
"""
|
||||
)
|
||||
)
|
||||
return output.eval_js("_hub_tmp")
|
||||
|
||||
|
||||
def requests_with_progress(method: str, url: str, **kwargs):
|
||||
"""
|
||||
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
||||
url (str): The URL to send the request to.
|
||||
**kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function.
|
||||
|
||||
Returns:
|
||||
(requests.Response): The response object from the HTTP request.
|
||||
|
||||
Notes:
|
||||
- If 'progress' is set to True, the progress bar will display the download progress for responses with a known
|
||||
content length.
|
||||
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
progress = kwargs.pop("progress", False)
|
||||
if not progress:
|
||||
return requests.request(method, url, **kwargs)
|
||||
response = requests.request(method, url, stream=True, **kwargs)
|
||||
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
|
||||
try:
|
||||
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
pbar.update(len(data))
|
||||
pbar.close()
|
||||
except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings
|
||||
response.close()
|
||||
return response
|
||||
|
||||
|
||||
def smart_request(
|
||||
method: str,
|
||||
url: str,
|
||||
retry: int = 3,
|
||||
timeout: int = 30,
|
||||
thread: bool = True,
|
||||
code: int = -1,
|
||||
verbose: bool = True,
|
||||
progress: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
||||
url (str): The URL to make the request to.
|
||||
retry (int, optional): Number of retries to attempt before giving up.
|
||||
timeout (int, optional): Timeout in seconds after which the function will give up retrying.
|
||||
thread (bool, optional): Whether to execute the request in a separate daemon thread.
|
||||
code (int, optional): An identifier for the request, used for logging purposes.
|
||||
verbose (bool, optional): A flag to determine whether to print out to console or not.
|
||||
progress (bool, optional): Whether to show a progress bar during the request.
|
||||
**kwargs (Any): Keyword arguments to be passed to the requests function specified in method.
|
||||
|
||||
Returns:
|
||||
(requests.Response | None): The HTTP response object. If the request is executed in a separate thread, returns
|
||||
None.
|
||||
"""
|
||||
retry_codes = (408, 500) # retry only these codes
|
||||
|
||||
@TryExcept(verbose=verbose)
|
||||
def func(func_method, func_url, **func_kwargs):
|
||||
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
||||
r = None # response
|
||||
t0 = time.time() # initial time for timer
|
||||
for i in range(retry + 1):
|
||||
if (time.time() - t0) > timeout:
|
||||
break
|
||||
r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
|
||||
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
||||
break
|
||||
try:
|
||||
m = r.json().get("message", "No JSON message.")
|
||||
except AttributeError:
|
||||
m = "Unable to read JSON."
|
||||
if i == 0:
|
||||
if r.status_code in retry_codes:
|
||||
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif r.status_code == 429: # rate limit
|
||||
h = r.headers # response headers
|
||||
m = (
|
||||
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
|
||||
f"Please retry after {h['Retry-After']}s."
|
||||
)
|
||||
if verbose:
|
||||
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
|
||||
if r.status_code not in retry_codes:
|
||||
return r
|
||||
time.sleep(2**i) # exponential standoff
|
||||
return r
|
||||
|
||||
args = method, url
|
||||
kwargs["progress"] = progress
|
||||
if thread:
|
||||
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
Reference in New Issue
Block a user