init commit
This commit is contained in:
262
ultralytics/solutions/streamlit_inference.py
Normal file
262
ultralytics/solutions/streamlit_inference.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||
|
||||
torch.classes.__path__ = [] # Torch module __path__._path issue: https://github.com/datalab-to/marker/issues/442
|
||||
|
||||
|
||||
class Inference:
|
||||
"""
|
||||
A class to perform object detection, image classification, image segmentation and pose estimation inference.
|
||||
|
||||
This class provides functionalities for loading models, configuring settings, uploading video files, and performing
|
||||
real-time inference using Streamlit and Ultralytics YOLO models.
|
||||
|
||||
Attributes:
|
||||
st (module): Streamlit module for UI creation.
|
||||
temp_dict (dict): Temporary dictionary to store the model path and other configuration.
|
||||
model_path (str): Path to the loaded model.
|
||||
model (YOLO): The YOLO model instance.
|
||||
source (str): Selected video source (webcam or video file).
|
||||
enable_trk (bool): Enable tracking option.
|
||||
conf (float): Confidence threshold for detection.
|
||||
iou (float): IoU threshold for non-maximum suppression.
|
||||
org_frame (Any): Container for the original frame to be displayed.
|
||||
ann_frame (Any): Container for the annotated frame to be displayed.
|
||||
vid_file_name (str | int): Name of the uploaded video file or webcam index.
|
||||
selected_ind (list[int]): List of selected class indices for detection.
|
||||
|
||||
Methods:
|
||||
web_ui: Set up the Streamlit web interface with custom HTML elements.
|
||||
sidebar: Configure the Streamlit sidebar for model and inference settings.
|
||||
source_upload: Handle video file uploads through the Streamlit interface.
|
||||
configure: Configure the model and load selected classes for inference.
|
||||
inference: Perform real-time object detection inference.
|
||||
|
||||
Examples:
|
||||
Create an Inference instance with a custom model
|
||||
>>> inf = Inference(model="path/to/model.pt")
|
||||
>>> inf.inference()
|
||||
|
||||
Create an Inference instance with default settings
|
||||
>>> inf = Inference()
|
||||
>>> inf.inference()
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the Inference class, checking Streamlit requirements and setting up the model path.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments for model configuration.
|
||||
"""
|
||||
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
||||
import streamlit as st
|
||||
|
||||
self.st = st # Reference to the Streamlit module
|
||||
self.source = None # Video source selection (webcam or video file)
|
||||
self.img_file_names = [] # List of image file names
|
||||
self.enable_trk = False # Flag to toggle object tracking
|
||||
self.conf = 0.25 # Confidence threshold for detection
|
||||
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
|
||||
self.org_frame = None # Container for the original frame display
|
||||
self.ann_frame = None # Container for the annotated frame display
|
||||
self.vid_file_name = None # Video file name or webcam index
|
||||
self.selected_ind: list[int] = [] # List of selected class indices for detection
|
||||
self.model = None # YOLO model instance
|
||||
|
||||
self.temp_dict = {"model": None, **kwargs}
|
||||
self.model_path = None # Model file path
|
||||
if self.temp_dict["model"] is not None:
|
||||
self.model_path = self.temp_dict["model"]
|
||||
|
||||
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
||||
|
||||
def web_ui(self) -> None:
|
||||
"""Set up the Streamlit web interface with custom HTML elements."""
|
||||
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
||||
|
||||
# Main title of streamlit application
|
||||
main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
|
||||
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
||||
|
||||
# Subtitle of streamlit application
|
||||
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
||||
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
|
||||
with the power of Ultralytics YOLO! 🚀</h5></div>"""
|
||||
|
||||
# Set html page configuration and append custom HTML
|
||||
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
|
||||
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
||||
|
||||
def sidebar(self) -> None:
|
||||
"""Configure the Streamlit sidebar for model and inference settings."""
|
||||
with self.st.sidebar: # Add Ultralytics LOGO
|
||||
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
|
||||
self.st.image(logo, width=250)
|
||||
|
||||
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
|
||||
self.source = self.st.sidebar.selectbox(
|
||||
"Source",
|
||||
("webcam", "video", "image"),
|
||||
) # Add source selection dropdown
|
||||
if self.source in ["webcam", "video"]:
|
||||
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
|
||||
self.conf = float(
|
||||
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
||||
) # Slider for confidence
|
||||
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
|
||||
|
||||
if self.source != "image": # Only create columns for video/webcam
|
||||
col1, col2 = self.st.columns(2) # Create two columns for displaying frames
|
||||
self.org_frame = col1.empty() # Container for original frame
|
||||
self.ann_frame = col2.empty() # Container for annotated frame
|
||||
|
||||
def source_upload(self) -> None:
|
||||
"""Handle video file uploads through the Streamlit interface."""
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
|
||||
|
||||
self.vid_file_name = ""
|
||||
if self.source == "video":
|
||||
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
|
||||
if vid_file is not None:
|
||||
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
||||
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
|
||||
out.write(g.read()) # Read bytes into file
|
||||
self.vid_file_name = "ultralytics.mp4"
|
||||
elif self.source == "webcam":
|
||||
self.vid_file_name = 0 # Use webcam index 0
|
||||
elif self.source == "image":
|
||||
import tempfile # scope import
|
||||
|
||||
if imgfiles := self.st.sidebar.file_uploader(
|
||||
"Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
|
||||
):
|
||||
for imgfile in imgfiles: # Save each uploaded image to a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
|
||||
tf.write(imgfile.read())
|
||||
self.img_file_names.append({"path": tf.name, "name": imgfile.name})
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure the model and load selected classes for inference."""
|
||||
# Add dropdown menu for model selection
|
||||
M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"]
|
||||
available_models = sorted(
|
||||
[
|
||||
x.replace("yolo", "YOLO")
|
||||
for x in GITHUB_ASSETS_STEMS
|
||||
if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x
|
||||
],
|
||||
key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
|
||||
)
|
||||
if self.model_path: # Insert user provided custom model in available_models
|
||||
available_models.insert(0, self.model_path)
|
||||
selected_model = self.st.sidebar.selectbox("Model", available_models)
|
||||
|
||||
with self.st.spinner("Model is downloading..."):
|
||||
if selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine")) or any(
|
||||
fmt in selected_model for fmt in ("openvino_model", "rknn_model")
|
||||
):
|
||||
model_path = selected_model
|
||||
else:
|
||||
model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
|
||||
self.model = YOLO(model_path) # Load the YOLO model
|
||||
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
|
||||
self.st.success("Model loaded successfully!")
|
||||
|
||||
# Multiselect box with class names and get indices of selected classes
|
||||
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
|
||||
self.selected_ind = [class_names.index(option) for option in selected_classes]
|
||||
|
||||
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
|
||||
self.selected_ind = list(self.selected_ind)
|
||||
|
||||
def image_inference(self) -> None:
|
||||
"""Perform inference on uploaded images."""
|
||||
for img_info in self.img_file_names:
|
||||
img_path = img_info["path"]
|
||||
image = cv2.imread(img_path) # Load and display the original image
|
||||
if image is not None:
|
||||
self.st.markdown(f"#### Processed: {img_info['name']}")
|
||||
col1, col2 = self.st.columns(2)
|
||||
with col1:
|
||||
self.st.image(image, channels="BGR", caption="Original Image")
|
||||
results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||
annotated_image = results[0].plot()
|
||||
with col2:
|
||||
self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
|
||||
try: # Clean up temporary file
|
||||
os.unlink(img_path)
|
||||
except FileNotFoundError:
|
||||
pass # File doesn't exist, ignore
|
||||
else:
|
||||
self.st.error("Could not load the uploaded image.")
|
||||
|
||||
def inference(self) -> None:
|
||||
"""Perform real-time object detection inference on video or webcam feed."""
|
||||
self.web_ui() # Initialize the web interface
|
||||
self.sidebar() # Create the sidebar
|
||||
self.source_upload() # Upload the video source
|
||||
self.configure() # Configure the app
|
||||
|
||||
if self.st.sidebar.button("Start"):
|
||||
if self.source == "image":
|
||||
if self.img_file_names:
|
||||
self.image_inference()
|
||||
else:
|
||||
self.st.info("Please upload an image file to perform inference.")
|
||||
return
|
||||
|
||||
stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
|
||||
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
|
||||
if not cap.isOpened():
|
||||
self.st.error("Could not open webcam or video source.")
|
||||
return
|
||||
|
||||
while cap.isOpened():
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
|
||||
break
|
||||
|
||||
# Process frame with model
|
||||
if self.enable_trk:
|
||||
results = self.model.track(
|
||||
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
||||
)
|
||||
else:
|
||||
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||
|
||||
annotated_frame = results[0].plot() # Add annotations on frame
|
||||
|
||||
if stop_button:
|
||||
cap.release() # Release the capture
|
||||
self.st.stop() # Stop streamlit app
|
||||
|
||||
self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
|
||||
self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
|
||||
|
||||
cap.release() # Release the capture
|
||||
cv2.destroyAllWindows() # Destroy all OpenCV windows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys # Import the sys module for accessing command-line arguments
|
||||
|
||||
# Check if a model name is provided as a command-line argument
|
||||
args = len(sys.argv)
|
||||
model = sys.argv[1] if args > 1 else None # Assign first argument as the model name if provided
|
||||
# Create an instance of the Inference class and run inference
|
||||
Inference(model=model).inference()
|
||||
Reference in New Issue
Block a user