From 33d8f213f9e03b5b445d4dea04e19e06d8a8671e Mon Sep 17 00:00:00 2001 From: wangchongwu <759291707@qq.com> Date: Sat, 29 Nov 2025 17:54:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B7=A5=E7=A8=8B=E6=9E=B6=E6=9E=84=E5=B1=82?= =?UTF-8?q?=E9=9D=A2=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + CMakeLists.txt | 90 ++-- stitch/CMakeLists.txt | 18 +- stitch/py/API_UnderStitch_binding.cpp | 53 ++ stitch/py/UStitcher.pyi | 158 ++++++ stitch/py/cvbind.hpp | 232 +++++++++ stitch/py/example.cpp | 118 ----- stitch/py/public_binding.cpp | 31 ++ stitch/py/simple.cpp | 38 -- stitch/src/API_GeoCorrect.h | 31 ++ stitch/src/API_UnderStitch.h | 2 + stitch/src/Arith_BATask.cpp | 92 +--- stitch/src/Arith_GeoCorrect.cpp | 88 ++++ stitch/src/Arith_GeoCorrect.h | 21 + stitch/src/Arith_GeoSolver.cpp | 12 +- stitch/src/Arith_GeoSolver.h | 4 +- stitch/src/Arith_UnderStitch.cpp | 47 +- stitch/src/Arith_UnderStitch.h | 4 +- stitch/src/Version.h | 2 +- tests/1.py | 56 -- tests/{ => cpp}/DJ/ProcDJ.cpp | 2 +- tests/{ => cpp}/NeoArithStandardDll.h | 0 tests/{ => cpp}/S7215/Arith_zhryp.cpp | 0 tests/{ => cpp}/S7215/Arith_zhryp.h | 0 tests/{ => cpp}/S7215/TsDecoder.hpp | 0 tests/{ => cpp}/S7215/TsPacker.hpp | 0 tests/{ => cpp}/S7215/commondefine.h | 0 tests/{ => cpp}/S7215/stitch_S7215.cpp | 0 tests/{ => cpp}/S729.h | 0 tests/{ => cpp}/S732/H264_SEI_typedef.h | 0 tests/cpp/S732/NeoArithStandardDll.h | 655 ++++++++++++++++++++++++ tests/{ => cpp}/S732/S732.h | 0 tests/{ => cpp}/S732/decodedata.cpp | 0 tests/{ => cpp}/S732/decodedata.h | 0 tests/{ => cpp}/S732/hi_type.h | 0 tests/{ => cpp}/S732/stitch_S732.cpp | 0 tests/{ => cpp}/S732/stitch_udp.cpp | 0 tests/cpp/Test_GeoCorrect.cpp | 47 ++ tests/{ => cpp}/feaStitchTest.cpp | 0 tests/{ => cpp}/main.cpp | 0 tests/{ => cpp}/stitch_Genaral.cpp | 0 tests/{ => cpp}/utils.cpp | 0 tests/{ => cpp}/utils.h | 0 tests/python/1.py | 78 +++ tests/python/ProcDJ.py | 427 +++++++++++++++ tests/python/UStitcher.pyi | 158 ++++++ tests/python/sim_scan.py | 565 ++++++++++++++++++++ 47 files changed, 2678 insertions(+), 354 deletions(-) create mode 100644 stitch/py/API_UnderStitch_binding.cpp create mode 100644 stitch/py/UStitcher.pyi create mode 100644 stitch/py/cvbind.hpp delete mode 100644 stitch/py/example.cpp delete mode 100644 stitch/py/simple.cpp create mode 100755 stitch/src/API_GeoCorrect.h create mode 100644 stitch/src/Arith_GeoCorrect.cpp create mode 100644 stitch/src/Arith_GeoCorrect.h delete mode 100644 tests/1.py rename tests/{ => cpp}/DJ/ProcDJ.cpp (99%) rename tests/{ => cpp}/NeoArithStandardDll.h (100%) rename tests/{ => cpp}/S7215/Arith_zhryp.cpp (100%) rename tests/{ => cpp}/S7215/Arith_zhryp.h (100%) rename tests/{ => cpp}/S7215/TsDecoder.hpp (100%) rename tests/{ => cpp}/S7215/TsPacker.hpp (100%) rename tests/{ => cpp}/S7215/commondefine.h (100%) rename tests/{ => cpp}/S7215/stitch_S7215.cpp (100%) rename tests/{ => cpp}/S729.h (100%) rename tests/{ => cpp}/S732/H264_SEI_typedef.h (100%) create mode 100644 tests/cpp/S732/NeoArithStandardDll.h rename tests/{ => cpp}/S732/S732.h (100%) rename tests/{ => cpp}/S732/decodedata.cpp (100%) rename tests/{ => cpp}/S732/decodedata.h (100%) rename tests/{ => cpp}/S732/hi_type.h (100%) rename tests/{ => cpp}/S732/stitch_S732.cpp (100%) rename tests/{ => cpp}/S732/stitch_udp.cpp (100%) create mode 100644 tests/cpp/Test_GeoCorrect.cpp rename tests/{ => cpp}/feaStitchTest.cpp (100%) rename tests/{ => cpp}/main.cpp (100%) rename tests/{ => cpp}/stitch_Genaral.cpp (100%) rename tests/{ => cpp}/utils.cpp (100%) rename tests/{ => cpp}/utils.h (100%) create mode 100644 tests/python/1.py create mode 100644 tests/python/ProcDJ.py create mode 100644 tests/python/UStitcher.pyi create mode 100644 tests/python/sim_scan.py diff --git a/.gitignore b/.gitignore index 4febf8db..c9ea6cc5 100644 --- a/.gitignore +++ b/.gitignore @@ -61,6 +61,9 @@ install *.png *.kml +cache +StitchLog + # 保留 3rdparty 整个文件夹内容 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c26be1e..c50b9872 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,24 +64,24 @@ include_directories(tests) # 测试用例1 -add_executable(stitch tests/main.cpp "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") +add_executable(stitch tests/cpp/main.cpp "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_libraries(stitch ${LIB_STITCH}) # 测试用例2 -add_executable(stitch_DJ "tests/DJ/ProcDJ.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") +add_executable(stitch_DJ "tests/cpp/DJ/ProcDJ.cpp" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_libraries(stitch_DJ ${LIB_STITCH}) # 测试用例3 -add_executable(stitch_Fea "tests/feaStitchTest.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") +add_executable(stitch_Fea "tests/cpp/feaStitchTest.cpp" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_libraries(stitch_Fea ${LIB_STITCH}) # 测试用例4 -add_executable(stitch_Genaral "tests/stitch_Genaral.cpp" "tests/utils.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") +add_executable(stitch_Genaral "tests/cpp/stitch_Genaral.cpp" "tests/cpp/utils.cpp" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_libraries(stitch_Genaral ${LIB_STITCH}) #S732 -add_executable(stitch_S732_VL "tests/S732/stitch_S732.cpp" "tests/S732/S732.h" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") +add_executable(stitch_S732_VL "tests/cpp/S732/stitch_S732.cpp" "tests/cpp/S732/S732.h" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_libraries(stitch_S732_VL ${LIB_STITCH} ${OpenCV_LIBS}) @@ -98,7 +98,7 @@ IF(WIN32) set(FFMPEG_LIBRARIES "avcodec" "avformat" "avutil" "swscale" "swresample") - add_executable(stitch_udp "tests/S732/stitch_udp.cpp" "tests/S732/decodedata.h" "tests/S732/decodedata.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") + add_executable(stitch_udp "tests/cpp/S732/stitch_udp.cpp" "tests/cpp/S732/decodedata.h" "tests/cpp/S732/decodedata.cpp" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_directories(stitch_udp PUBLIC ${FFMPEG_LIBS_DIR}) target_include_directories(stitch_udp PUBLIC ${FFMPEG_INCLUDE_DIRS}) target_link_libraries(stitch_udp ${LIB_STITCH} ${FFMPEG_LIBRARIES} ${OpenCV_LIBS}) @@ -130,39 +130,71 @@ IF(WIN32) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTORCC ON) - add_executable(stitch_S7215ts "tests/S7215/TsDecoder.hpp" "tests/S7215/TsPacker.hpp" "tests/S7215/stitch_S7215.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") + add_executable(stitch_S7215ts "tests/cpp/S7215/TsDecoder.hpp" "tests/cpp/S7215/TsPacker.hpp" "tests/cpp/S7215/stitch_S7215.cpp" "tests/cpp/S7215/Arith_zhryp.cpp" "tests/cpp/S7215/Arith_zhryp.h") target_link_directories(stitch_S7215ts PUBLIC ${FFMPEG_LIBS_DIR}) target_include_directories(stitch_S7215ts PUBLIC ${FFMPEG_INCLUDE_DIRS}) target_link_libraries(stitch_S7215ts ${LIB_STITCH} ${FFMPEG_LIBRARIES} ${OpenCV_LIBS} Qt${QT_VERSION_MAJOR}::Core) ELSE() - # ======== FFMPEG配置 ======== - set(FFMPEG_INCLUDE_DIRS "/usr/include/x86_64-linux-gnu") - set(FFMPEG_LIBS_DIR "/usr/lib/x86_64-linux-gnu") - include_directories(${FFMPEG_INCLUDE_DIRS}) - set(FFMPEG_LIBRARIES "avdevice" "avcodec" "avformat" "avutil" "swscale" "swresample") + # # ======== Qt5 配置 ======== + # find_package(QT NAMES Qt5 REQUIRED COMPONENTS Core Widgets) + # find_package(Qt${QT_VERSION_MAJOR} REQUIRED COMPONENTS Core Widgets) - # == Qt - find_package(QT NAMES Qt5 REQUIRED COMPONENTS Widgets ) - find_package(Qt${QT_VERSION_MAJOR} REQUIRED COMPONENTS Widgets ) - find_package(Qt5 REQUIRED COMPONENTS PrintSupport) - set(TS_FILES QGuideArith_zh_CN.ts) - - set(CMAKE_AUTOUIC ON) - set(CMAKE_AUTOMOC ON) - set(CMAKE_AUTORCC ON) - - add_executable(stitch_S7215ts "tests/S7215/TsDecoder.hpp" "tests/S7215/TsPacker.hpp" "tests/S7215/stitch_S7215.cpp" "tests/S7215/Arith_zhryp.cpp" "tests/S7215/Arith_zhryp.h") - target_link_directories(stitch_S7215ts PUBLIC ${FFMPEG_LIBS_DIR}) - target_include_directories(stitch_S7215ts PUBLIC ${FFMPEG_INCLUDE_DIRS}) - target_link_libraries(stitch_S7215ts ${LIB_STITCH} ${FFMPEG_LIBRARIES} ${OpenCV_LIBS} Qt${QT_VERSION_MAJOR}::Core) - + # message(STATUS "QT_VERSION_MAJOR: ${QT_VERSION_MAJOR}") + # message(STATUS "Qt5_FOUND: ${Qt5_FOUND}") + # message(STATUS "Qt5_VERSION: ${Qt5_VERSION}") + + # # 启用Qt的MOC、UIC、RCC自动处理 + # set(CMAKE_AUTOUIC ON) + # set(CMAKE_AUTOMOC ON) + # set(CMAKE_AUTORCC ON) + + # # ======== FFMPEG配置 ======== + # # 尝试多个可能的FFmpeg安装路径 + # set(FFMPEG_POSSIBLE_PATHS + # "/usr" + # ) + + # set(FFMPEG_INCLUDE_DIRS ${FFMPEG_DIR}/include) + # set(FFMPEG_LIBS_DIR ${FFMPEG_DIR}/lib) + # set(FFMPEG_LIBRARIES "avdevice" "avcodec" "avformat" "avutil" "swscale" "swresample") + # include_directories(${FFMPEG_INCLUDE_DIRS}) + # message(STATUS "FFmpeg found at: ${FFMPEG_DIR}") + + + + # # ======== 构建 stitch_S7215ts 可执行文件 ======== + # add_executable(stitch_S7215ts + # "tests/cpp/S7215/TsDecoder.hpp" + # "tests/cpp/S7215/TsPacker.hpp" + # "tests/cpp/S7215/stitch_S7215.cpp" + # "tests/cpp/S7215/Arith_zhryp.cpp" + # "tests/cpp/S7215/Arith_zhryp.h" + # ) + + # # 设置链接目录 + # if(NOT FFMPEG_LIBS_DIR STREQUAL "") + # target_link_directories(stitch_S7215ts PUBLIC ${FFMPEG_LIBS_DIR}) + # endif() + + # # 设置包含目录 + # if(FFMPEG_INCLUDE_DIRS) + # target_include_directories(stitch_S7215ts PUBLIC ${FFMPEG_INCLUDE_DIRS}) + # endif() + + # # 链接所有库 + # target_link_libraries(stitch_S7215ts + # ${LIB_STITCH} + # ${FFMPEG_LIBRARIES} + # ${OpenCV_LIBS} + # Qt${QT_VERSION_MAJOR}::Core + # ) ENDIF() - - +add_executable(stitch_GeoCorrect "tests/cpp/Test_GeoCorrect.cpp") +target_link_libraries(stitch_GeoCorrect ${LIB_STITCH}) # 可执行文件输出路径 set(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/Bin) diff --git a/stitch/CMakeLists.txt b/stitch/CMakeLists.txt index 94fc1a8a..c0fae554 100644 --- a/stitch/CMakeLists.txt +++ b/stitch/CMakeLists.txt @@ -98,20 +98,20 @@ add_subdirectory(py/pybind11) pybind11_add_module( UStitcher # 模块名 - py/example.cpp # 您的 C++ 源文件 - py/public_binding.cpp - # 也可以指定更多源文件... + py/API_UnderStitch_binding.cpp # 主绑定文件 + py/public_binding.cpp # 公共结构绑定 + # cvbind.h 是头文件,不需要单独编译 ) -target_include_directories(UStitcher PUBLIC src public_include) +target_include_directories(UStitcher PUBLIC src public_include py) target_link_libraries(UStitcher PUBLIC ${LIB_STITCH} ) -set_target_properties(UStitcher PROPERTIES - OUTPUT_NAME "UStitcher" # 强制输出文件名 - PREFIX "" # Windows 默认无前缀,这里显式写更安全 - SUFFIX ".pyd" # 禁止生成带 ABI tag 的名字 -) +# set_target_properties(UStitcher PROPERTIES +# OUTPUT_NAME "UStitcher" # 强制输出文件名 +# PREFIX "" # Windows 默认无前缀,这里显式写更安全 +# SUFFIX ".pyd" # 禁止生成带 ABI tag 的名字 +# ) diff --git a/stitch/py/API_UnderStitch_binding.cpp b/stitch/py/API_UnderStitch_binding.cpp new file mode 100644 index 00000000..d30081e2 --- /dev/null +++ b/stitch/py/API_UnderStitch_binding.cpp @@ -0,0 +1,53 @@ +#include +#include +#include "cvbind.hpp" // 包含 cv::Mat 类型转换支持 +#include "API_UnderStitch.h" +#include "StitchStruct.h" + +namespace py = pybind11; + +// 前向声明 +void bind_public_structures(py::module_ &m); + +// 绑定 API_UnderStitch 类 +void bind_API_UnderStitch(py::module_ &m) { + py::class_(m, "API_UnderStitch", "视频帧下视地理拼接") + .def("Init", &API_UnderStitch::Init, + py::arg("info"), + "初始化拼接") + .def("SetOutput", &API_UnderStitch::SetOutput, + py::arg("name"), py::arg("outdir"), + "设置输出标识和路径") + .def("Run", py::overload_cast(&API_UnderStitch::Run), + py::arg("img"), py::arg("para"), + "运行拼接流程 (cv::Mat版本)") + .def("Stop", &API_UnderStitch::Stop, + "中止拼接流程") + .def("SetConfig", &API_UnderStitch::SetConfig, + py::arg("config"), + "运行参数配置") + .def("OptAndOutCurrPan", &API_UnderStitch::OptAndOutCurrPan, + "立即优化并输出当前全景图") + .def("ExportPanMat", &API_UnderStitch::ExportPanMat, + "获取全景图mat") + .def("getHomography", &API_UnderStitch::getHomography, + py::arg("info"), + "获取单应性矩阵") + .def_static("Create", &API_UnderStitch::Create, + py::arg("cachedir") = "./cache", + "创建 API_UnderStitch 实例") + .def_static("Destroy", &API_UnderStitch::Destroy, + py::arg("obj"), + "销毁 API_UnderStitch 实例"); +} + +PYBIND11_MODULE(UStitcher, m) { + m.doc() = "下视拼接算法 Python 绑定"; + + // 先绑定公共结构体 + bind_public_structures(m); + + // 再绑定 API_UnderStitch 类 + bind_API_UnderStitch(m); +} + diff --git a/stitch/py/UStitcher.pyi b/stitch/py/UStitcher.pyi new file mode 100644 index 00000000..c9258651 --- /dev/null +++ b/stitch/py/UStitcher.pyi @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import Any, TypeAlias + +from numpy.typing import NDArray + +ImageLike: TypeAlias = NDArray[Any] + + +class PointBLH: + """地理坐标系(单位:度).""" + + B: float + L: float + H: float + + def __init__(self) -> None: ... + + +class EulerRPY: + """RPY 姿态角(单位:度).""" + + fRoll: float + fPitch: float + fYaw: float + + def __init__(self) -> None: ... + + +class AirCraftInfo: + """载体信息.""" + + nPlaneID: int + stPos: PointBLH + stAtt: EulerRPY + + def __init__(self) -> None: ... + + +class CamInfo: + """相机信息.""" + + nFocus: int + fPixelSize: float + unVideoType: int + dCamx: float + dCamy: float + fAglReso: float + + def __init__(self) -> None: ... + + +class ServoInfo: + """伺服状态.""" + + fServoAz: float + fServoPt: float + fServoAzSpeed: float + fServoPtSpeed: float + + def __init__(self) -> None: ... + + +class FrameInfo: + """帧内外方位元素.""" + + nFrmID: int + craft: AirCraftInfo + camInfo: CamInfo + servoInfo: ServoInfo + nEvHeight: int + nWidth: int + nHeight: int + + def __init__(self) -> None: ... + + +class UPanInfo: + """下视全景图配置.""" + + m_pan_width: int + m_pan_height: int + scale: float + map_shiftX: float + map_shiftY: float + + def __init__(self) -> None: ... + + +class UPanConfig: + """下视拼接参数控制.""" + + bUseBA: bool + bOutFrameTile: bool + bOutGoogleTile: bool + + def __init__(self) -> None: ... + + +class API_UnderStitch: + """视频帧下视地理拼接.""" + + def Init(self, info: FrameInfo) -> UPanInfo: + """初始化拼接,返回全景图配置.""" + ... + + def SetOutput(self, name: str, outdir: str) -> None: + """配置输出标识和目录.""" + ... + + def Run(self, img: ImageLike, para: FrameInfo) -> int: + """运行拼接流程(cv::Mat/numpy.ndarray).""" + ... + + def Stop(self) -> None: + """中止拼接流程.""" + ... + + def SetConfig(self, config: UPanConfig) -> None: + """更新运行参数.""" + ... + + def OptAndOutCurrPan(self) -> int: + """立即优化并输出当前全景图.""" + ... + + def ExportPanMat(self) -> ImageLike: + """获取当前全景图像.""" + ... + + def getHomography(self, info: FrameInfo) -> ImageLike: + """根据帧信息返回单应性矩阵.""" + ... + + @staticmethod + def Create(cachedir: str = "./cache") -> API_UnderStitch: + """创建 API_UnderStitch 实例.""" + ... + + @staticmethod + def Destroy(obj: API_UnderStitch) -> None: + """销毁 API_UnderStitch 实例.""" + ... + + +__all__ = [ + "API_UnderStitch", + "AirCraftInfo", + "CamInfo", + "EulerRPY", + "FrameInfo", + "ImageLike", + "PointBLH", + "ServoInfo", + "UPanConfig", + "UPanInfo", +] + diff --git a/stitch/py/cvbind.hpp b/stitch/py/cvbind.hpp new file mode 100644 index 00000000..83dcbf0d --- /dev/null +++ b/stitch/py/cvbind.hpp @@ -0,0 +1,232 @@ +// Created by ausk @ 2019.11.23 +// Based on +// https://github.com/ausk/keras-unet-deploy/tree/master/cpp/libunet/cvbind.h + +#pragma once + +#include +#include +#include +#include +#include +#include + +// Convert cv::Point, cv::Rest, cv::Mat +namespace pybind11 +{ +namespace detail +{ +//! cv::Point <=> tuple(x,y) +template <> +struct type_caster +{ + PYBIND11_TYPE_CASTER(cv::Point, _("tuple_xy")); + + // Convert from Python to C++. + // Convert the Python tuple object to C++ cv::Point type, and return false + // if the conversion fails. + // The second argument indicates whether implicit conversions should be + // applied. + bool load(handle obj, bool) + { + // Ensure that the passed parameter is of tuple type + if (!pybind11::isinstance(obj)) + { + std::logic_error("Point(x,y) should be a tuple!"); + return false; + } + + // Extract the tuple object from the handle and ensure its length is 2. + pybind11::tuple pt = reinterpret_borrow(obj); + if (pt.size() != 2) + { + std::logic_error("Point(x,y) tuple should be size of 2"); + return false; + } + + // Convert a tuple of length 2 to cv::Point. + value = cv::Point(pt[0].cast(), pt[1].cast()); + return true; + } + + // Convert from C++ to Python. Convert C++ cv::Mat object to tuple, + // parameter 2 and parameter 3 are ignored + static handle cast(const cv::Point& pt, return_value_policy, handle) + { + return pybind11::make_tuple(pt.x, pt.y).release(); + } +}; + +// cv::Rect <=> tuple(x,y,w,h) +template <> +struct type_caster +{ + PYBIND11_TYPE_CASTER(cv::Rect, _("tuple_xywh")); + + bool load(handle obj, bool) + { + if (!pybind11::isinstance(obj)) + { + std::logic_error("Rect should be a tuple!"); + return false; + } + + pybind11::tuple rect = reinterpret_borrow(obj); + if (rect.size() != 4) + { + std::logic_error("Rect (x,y,w,h) tuple should be size of 4"); + return false; + } + + value = cv::Rect(rect[0].cast(), + rect[1].cast(), + rect[2].cast(), + rect[3].cast()); + return true; + } + + static handle cast(const cv::Rect& rect, return_value_policy, handle) + { + return pybind11::make_tuple(rect.x, rect.y, rect.width, rect.height) + .release(); + } +}; + +// Convert between cv::Mat and numpy.ndarray. +// +// Python supports a general buffer protocol for data exchange between plugins. +// Let the type expose a buffer view, this allows direct access to the original +// internal data, often used in matrix types. +// +// Pybind11 provides the pybind11::buffer_info type to map the Python buffer +// protocol (buffer protocol). +// +// struct buffer_info { +// void* ptr; /* Pointer to buffer */ +// ssize_t itemsize; /* Size of one scalar */ +// std::string format; /* Python struct-style format descriptor +// */ ssize_t ndim; /* Number of dimensions */ +// std::vector shape; /* Buffer dimensions */ +// std::vector strides; /* Strides (in bytes) for each index */ +//}; + +template <> +struct type_caster +{ +public: + PYBIND11_TYPE_CASTER(cv::Mat, _("numpy.ndarray")); + + //! 1. cast numpy.ndarray to cv::Mat + bool load(handle obj, bool) + { + array b = reinterpret_borrow(obj); + buffer_info info = b.request(); + + int nh = 1; + int nw = 1; + int nc = 1; + int ndims = info.ndim; + if (ndims == 2) + { + nh = info.shape[0]; + nw = info.shape[1]; + } + else if (ndims == 3) + { + nh = info.shape[0]; + nw = info.shape[1]; + nc = info.shape[2]; + } + else + { + throw std::logic_error("Only support 2d, 2d matrix"); + return false; + } + + int dtype; + if (info.format == format_descriptor::format()) + { + dtype = CV_8UC(nc); + } + else if (info.format == format_descriptor::format()) + { + dtype = CV_32SC(nc); + } + else if (info.format == format_descriptor::format()) + { + dtype = CV_32FC(nc); + } + else if (info.format == format_descriptor::format()) + { + dtype = CV_64FC(nc); + } + else + { + throw std::logic_error( + "Unsupported type, only support uchar, int32, float, double"); + return false; + } + value = cv::Mat(nh, nw, dtype, info.ptr); + return true; + } + + //! Cast cv::Mat to numpy.ndarray + static handle cast(const cv::Mat& mat, + return_value_policy, + handle /*defval*/) + { + std::string format = format_descriptor::format(); + size_t elemsize = sizeof(unsigned char); + int nw = mat.cols; + int nh = mat.rows; + int nc = mat.channels(); + int depth = mat.depth(); + int type = mat.type(); + int dim = (depth == type) ? 2 : 3; + if (depth == CV_8U) + { + format = format_descriptor::format(); + elemsize = sizeof(unsigned char); + } + else if (depth == CV_32S) + { + format = format_descriptor::format(); + elemsize = sizeof(int); + } + else if (depth == CV_32F) + { + format = format_descriptor::format(); + elemsize = sizeof(float); + } + else if (depth == CV_64F) + { + format = format_descriptor::format(); + elemsize = sizeof(double); + } + else + { + throw std::logic_error( + "Unsupport type, only support uchar, int32, float, double"); + } + + std::vector bufferdim; + std::vector strides; + if (dim == 2) + { + bufferdim = {(size_t)nh, (size_t)nw}; + strides = {elemsize * (size_t)nw, elemsize}; + } + else if (dim == 3) + { + bufferdim = {(size_t)nh, (size_t)nw, (size_t)nc}; + strides = {(size_t)elemsize * nw * nc, + (size_t)elemsize * nc, + (size_t)elemsize}; + } + return array(buffer_info( + mat.data, elemsize, format, dim, bufferdim, strides)) + .release(); + } +}; +} // namespace detail +} // namespace pybind11 diff --git a/stitch/py/example.cpp b/stitch/py/example.cpp deleted file mode 100644 index f5719dcf..00000000 --- a/stitch/py/example.cpp +++ /dev/null @@ -1,118 +0,0 @@ -#include -#include // 用于绑定 std::string 等 -#include -#include -#include "StitchStruct.h" // 包含 API_UnderStitch 的定义 -#include "Arith_UnderStitch.h" -#include "Arith_GeoSolver.h" -// 可能还需要包含其他依赖项,如 OpenCV 的绑定头文件 - -namespace py = pybind11; - -// 声明外部的辅助绑定函数 -void bind_public_structures(py::module_ &m); -void bind_api_understitch(py::module_ &m); - - -PYBIND11_MODULE(UStitcher, m) { - - m.doc() = "pybind11 bindings for UnderStitcher."; - - // 绑定通用数据结构 - bind_public_structures(m); - - // 绑定下视模块接口 - bind_api_understitch(m); - -} - - - - - -void bind_api_understitch(py::module_& m) { - - // UPanConfig - py::class_(m, "UPanConfig", "下视拼接参数控制") - .def(py::init<>()) - .def_readwrite("bUseBA", &UPanConfig::bUseBA, "开启BA") - .def_readwrite("bOutFrameTile", &UPanConfig::bOutFrameTile, "输出单帧正射图") - .def_readwrite("bOutGoogleTile", &UPanConfig::bOutGoogleTile, "输出谷歌瓦片"); - - // FrameInfo - py::class_(m, "FrameInfo", "帧内外方位元素") - .def(py::init<>()) - .def_readwrite("nFrmID", &FrameInfo::nFrmID, "帧编号,唯一ID") - .def_readwrite("craft", &FrameInfo::craft, "载体信息") - .def_readwrite("camInfo", &FrameInfo::camInfo, "相机信息") - .def_readwrite("servoInfo", &FrameInfo::servoInfo, "伺服状态") - .def_readwrite("nEvHeight", &FrameInfo::nEvHeight, "相对高差") - .def_readwrite("nWidth", &FrameInfo::nWidth) - .def_readwrite("nHeight", &FrameInfo::nHeight); - - // UPanInfo - py::class_(m, "UPanInfo", "下视全景图配置") - .def(py::init<>()) - .def_readwrite("m_pan_width", &UPanInfo::m_pan_width) - .def_readwrite("m_pan_height", &UPanInfo::m_pan_height) - .def_readwrite("scale", &UPanInfo::scale, "比例尺") - .def_readwrite("map_shiftX", &UPanInfo::map_shiftX, "平移X") - .def_readwrite("map_shiftY", &UPanInfo::map_shiftY, "平移Y"); - // ---------------------------------------------------- - // 2. 绑定 API_UnderStitch 类 - // ---------------------------------------------------- - py::class_(m, "API_UnderStitch") - // py::init() 是不必要的,因为它是抽象类 - - - // 绑定纯虚函数 (这些函数需要在 C++ 派生类中实现) - .def("Init", &API_UnderStitch::Init, - py::arg("info"), "初始化拼接") - - .def("SetOutput", &API_UnderStitch::SetOutput, - py::arg("name"), py::arg("outdir"), "设置输出标识和路径") - - .def("Run", &API_UnderStitch::Run, - py::arg("img"), py::arg("para"), "运行拼接流程") - - .def("Stop", &API_UnderStitch::Stop, - "中止拼接流程") - - .def("SetConfig", &API_UnderStitch::SetConfig, - py::arg("config"), "运行参数配置") - - - .def("ExportPanAddr", &API_UnderStitch::ExportPanAddr, - py::return_value_policy::reference_internal, // 返回引用,确保 C++ 对象不被析构 - "获取全景图") - - .def("ExportPanMat", &API_UnderStitch::ExportPanMat, - py::return_value_policy::copy, // 通常 cv::Mat 返回时是深拷贝 - "获取全景图cv::Mat") - - //getHomography - .def("getHomography", &API_UnderStitch::getHomography, - py::return_value_policy::copy, // 通常 cv::Mat 返回时是深拷贝 - "获取H映射矩阵:从像方到物方") - - //---------------------------------------------------- - //3. 绑定静态工厂方法 (Factory Methods) - //---------------------------------------------------- - - // 绑定 Create 方法 - .def_static("Create", &API_UnderStitch::Create, - py::arg("cachedir") = "./cache", - py::return_value_policy::take_ownership, // 返回指针,需要 Python 管理其生命周期 - "创建 API_UnderStitch 实例") - - // 绑定 Destroy 方法 - // 注意:这里需要 pybind11::arg::none() 来明确指出该函数没有返回值 - .def_static("Destroy", &API_UnderStitch::Destroy, - py::arg("obj"), - "销毁 API_UnderStitch 实例") - - ; - - -} - diff --git a/stitch/py/public_binding.cpp b/stitch/py/public_binding.cpp index bd7fe966..3bd8eefa 100644 --- a/stitch/py/public_binding.cpp +++ b/stitch/py/public_binding.cpp @@ -53,6 +53,37 @@ void bind_public_structures(py::module_ &m) { .def_readwrite("fServoPt", &ServoInfo::fServoPt, "当前帧伺服俯仰角") .def_readwrite("fServoAzSpeed", &ServoInfo::fServoAzSpeed, "当前帧伺服方位角速度") .def_readwrite("fServoPtSpeed", &ServoInfo::fServoPtSpeed, "当前帧伺服俯仰角速度"); + + // ---------------------------------------------------- + // C. FrameInfo 和 UPanInfo + // ---------------------------------------------------- + + // FrameInfo + py::class_(m, "FrameInfo", "帧内外方位元素") + .def(py::init<>()) + .def_readwrite("nFrmID", &FrameInfo::nFrmID, "帧编号,唯一ID") + .def_readwrite("craft", &FrameInfo::craft, "载体信息 (AirCraftInfo)") + .def_readwrite("camInfo", &FrameInfo::camInfo, "相机信息 (CamInfo)") + .def_readwrite("servoInfo", &FrameInfo::servoInfo, "伺服状态 (ServoInfo)") + .def_readwrite("nEvHeight", &FrameInfo::nEvHeight, "相对高差") + .def_readwrite("nWidth", &FrameInfo::nWidth, "图像宽度") + .def_readwrite("nHeight", &FrameInfo::nHeight, "图像高度"); + + // UPanInfo + py::class_(m, "UPanInfo", "下视全景图配置") + .def(py::init<>()) + .def_readwrite("m_pan_width", &UPanInfo::m_pan_width, "全景图宽度") + .def_readwrite("m_pan_height", &UPanInfo::m_pan_height, "全景图高度") + .def_readwrite("scale", &UPanInfo::scale, "比例尺") + .def_readwrite("map_shiftX", &UPanInfo::map_shiftX, "平移X") + .def_readwrite("map_shiftY", &UPanInfo::map_shiftY, "平移Y"); + + // UPanConfig + py::class_(m, "UPanConfig", "下视拼接参数控制") + .def(py::init<>()) + .def_readwrite("bUseBA", &UPanConfig::bUseBA, "开启BA") + .def_readwrite("bOutFrameTile", &UPanConfig::bOutFrameTile, "输出单帧正射图") + .def_readwrite("bOutGoogleTile", &UPanConfig::bOutGoogleTile, "输出谷歌瓦片"); } diff --git a/stitch/py/simple.cpp b/stitch/py/simple.cpp deleted file mode 100644 index 384ea90f..00000000 --- a/stitch/py/simple.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include -#include -#include "StitchStruct.h" // 包含 API_UnderStitch 的定义 -#include "API_UnderStitch.h" -using namespace cv; - - -namespace py = pybind11; - -// 导出的 C++ 函数 -std::string greet(const std::string &name) { - return "Hello, " + name + "! pybind11 module loaded successfully!"; -} - -cv::Mat getMat() -{ - return cv::Mat::zeros(3,3,CV_8UC1); -} - -cv::Mat getH(FrameInfo info) -{ - printf("run get H"); - auto module = API_UnderStitch::Create(); - return module->getHomography(info); -} - - -// PYBIND11 模块入口点 -PYBIND11_MODULE(simple_test, m) { - m.doc() = "***************A minimal pybind11 test module.*****************"; - - // 绑定 greet 函数 - m.def("greet", &greet, "A function that returns a greeting string."); - - m.def("getMat", &getMat, "A function that returns a greeting string."); - m.def("getH",&getH,"getH"); -} \ No newline at end of file diff --git a/stitch/src/API_GeoCorrect.h b/stitch/src/API_GeoCorrect.h new file mode 100755 index 00000000..ceed4218 --- /dev/null +++ b/stitch/src/API_GeoCorrect.h @@ -0,0 +1,31 @@ +#ifdef _WIN32 +#define STD_STITCH_API __declspec(dllexport) +#else +#define STD_STITCH_API __attribute__ ((visibility("default"))) +#endif + +#include "StitchStruct.h" + +// 地理校正模块 +class API_GeoCorrect +{ +public: + virtual ~API_GeoCorrect() = default; + + virtual void Correct(cv::Mat img, FrameInfo info) = 0; + + virtual void Init(FrameInfo info) = 0; + + +public: + static API_GeoCorrect* Create(); + static void Destroy(API_GeoCorrect* obj); +}; + + + + + + + + diff --git a/stitch/src/API_UnderStitch.h b/stitch/src/API_UnderStitch.h index b17cefef..e0d9c722 100644 --- a/stitch/src/API_UnderStitch.h +++ b/stitch/src/API_UnderStitch.h @@ -32,6 +32,8 @@ public: // 运行拼接流程 virtual SINT32 Run(GD_VIDEO_FRAME_S img, FrameInfo para) = 0; + virtual SINT32 Run(cv::Mat img, FrameInfo para) = 0; + // 中止拼接流程 virtual void Stop() = 0; diff --git a/stitch/src/Arith_BATask.cpp b/stitch/src/Arith_BATask.cpp index e09340f6..44de418f 100644 --- a/stitch/src/Arith_BATask.cpp +++ b/stitch/src/Arith_BATask.cpp @@ -18,60 +18,6 @@ using namespace ceres; #define STABLE_X3 1920 #define STABLE_Y3 1080 -struct HomographyResidual_1 -{ - - HomographyResidual_1(const cv::KeyPoint& keypoint_i, const cv::KeyPoint& keypoint_j, const cv::Mat H1, const cv::Mat H2) - : keypoint_i_(keypoint_i), keypoint_j_(keypoint_j),Hi0_(H1), Hj0_(H2) - { - } - - template - bool operator()(const T* const h_i, const T* const h_j, T* residual) const - { - typedef Eigen::Matrix EPoint; - typedef Eigen::Matrix EMat; - - EMat Mat_H_i; - Mat_H_i << T(h_i[0]), T(h_i[1]), T(h_i[2]), - T(h_i[3]), T(h_i[4]), T(h_i[5]), - T(h_i[6]), T(h_i[7]), T(1.0); - - EMat Mat_H_j; - Mat_H_j << T(h_j[0]), T(h_j[1]), T(h_j[2]), - T(h_j[3]), T(h_j[4]), T(h_j[5]), - T(h_j[6]), T(h_j[7]), T(1.0); - - EPoint p_i(T(keypoint_i_.pt.x), T(keypoint_i_.pt.y), T(1.0)); - EPoint p_j(T(keypoint_j_.pt.x), T(keypoint_j_.pt.y), T(1.0)); - - - EPoint warp_i = Mat_H_i * p_i; - EPoint img_j = Mat_H_j.inverse() * warp_i; - EPoint img_j_1 = img_j / img_j[2]; - - // 计算残差向量 - residual[0] = (img_j_1 - p_j).squaredNorm(); - - - - return true; - } - - -private: - const cv::KeyPoint keypoint_i_; // 第 i 帧图像中的特征点 - const cv::KeyPoint keypoint_j_; // 第 j 帧图像中的特征点 - - const cv::Mat Hi0_; - const cv::Mat Hj0_; - -}; - - - -//// H投影残差:将左图投影到右图,最小化同名点全景图上的投影误差。 -//// 问题:缺少尺度约束,H矩阵优化后失去正交性 优点:形式简单 struct HomographyResidual { @@ -108,34 +54,6 @@ struct HomographyResidual residual[0] = (img_j_1 - p_j).squaredNorm(); - // 2. 不动点位置约束 - cv::Point2f pS1 = warpPointWithH(Hi0_, cv::Point2f(STABLE_X1, STABLE_Y1)); - cv::Point2f pS2 = warpPointWithH(Hi0_, cv::Point2f(STABLE_X2, STABLE_Y2)); - cv::Point2f pS3 = warpPointWithH(Hi0_, cv::Point2f(STABLE_X3, STABLE_Y3)); - - EPoint pS1_E{ (T)pS1.x,(T)pS1.y ,T(1.0) }; - EPoint pS2_E{ (T)pS2.x,(T)pS2.y,T(1.0) }; - EPoint pS3_E{ (T)pS3.x,(T)pS3.y ,T(1.0) }; - - EPoint p_1{ T(STABLE_X1), T(STABLE_Y1), T(1.0) }; - EPoint p_2{ T(STABLE_X2), T(STABLE_Y2), T(1.0) }; - EPoint p_3{ T(STABLE_X3), T(STABLE_Y3), T(1.0) }; - - - EPoint P_r1 = Mat_H_i * p_1; - EPoint P_r2 = Mat_H_i * p_2; - EPoint P_r3 = Mat_H_i * p_3; - - P_r1 /= P_r1[2]; - P_r2 /= P_r2[2]; - P_r3 /= P_r3[2]; - - - // 约束投影位置不变 - residual[1] = (P_r1 - pS1_E).squaredNorm(); - residual[2] = (P_r2 - pS2_E).squaredNorm(); - residual[3] = (P_r3 - pS3_E).squaredNorm(); - return true; } @@ -322,7 +240,7 @@ void BA_Task::OptFrame(vector frameInd,cv::Mat H_map, std::unordered_ma // 创建 Ceres 问题 ceres::Problem problemH; ceres::Problem problemSE3; - ceres::Problem problemH2; + // 添加残差块 int nParaCnt = 0;//参数组数 @@ -387,9 +305,9 @@ void BA_Task::OptFrame(vector frameInd,cv::Mat H_map, std::unordered_ma // problemH2.AddResidualBlock(cost_function, nullptr, h_list[i], h_list[j]); ceres::CostFunction* cost_function = - new ceres::AutoDiffCostFunction( - new HomographyResidual_1(keypoint_i, keypoint_j, Hi0, Hj0)); - problemH2.AddResidualBlock(cost_function, nullptr, h_list[i], h_list[j]); + new ceres::AutoDiffCostFunction( + new HomographyResidual(keypoint_i, keypoint_j, Hi0, Hj0)); + problemH.AddResidualBlock(cost_function, nullptr, h_list[i], h_list[j]); #endif #ifdef OPT_SE3 @@ -421,7 +339,7 @@ void BA_Task::OptFrame(vector frameInd,cv::Mat H_map, std::unordered_ma // 求解 #ifdef OPT_H - ceres::Solve(options, &problemH2, &summary); + ceres::Solve(options, &problemH, &summary); #endif #ifdef OPT_SE3 diff --git a/stitch/src/Arith_GeoCorrect.cpp b/stitch/src/Arith_GeoCorrect.cpp new file mode 100644 index 00000000..2f614d87 --- /dev/null +++ b/stitch/src/Arith_GeoCorrect.cpp @@ -0,0 +1,88 @@ +#include "Arith_GeoCorrect.h" +#include "Arith_GeoSolver.h" +#include "Arith_CoordModule.h" + +API_GeoCorrect* API_GeoCorrect::Create() +{ + return new GeoCorrect(); +} + + +void API_GeoCorrect::Destroy(API_GeoCorrect* obj) +{ + delete obj; +} + + +GeoCorrect::GeoCorrect() +{ + _GeoSolver = new GeoSolver(); +} + +GeoCorrect::~GeoCorrect() +{ +} + +void GeoCorrect::Correct(cv::Mat src, FrameInfo info) +{ + info.servoInfo.fServoPt += 90; + + _GeoSolver->SetOriginPoint(info); + +// // GT +// //_GeoSolver->SetOriginPoint(info); +// cv::Mat R = _GeoSolver->Mat_TransENG2uv(info); + +// cv::Mat R_1 = R.inv(); + +// // 计算图像中心点的地面坐标 +// cv::Point2f centerPtInGeo = warpPointWithH(R_1, cv::Point2f(info.nWidth / 2.0f, info.nHeight / 2.0f)); + +// PointXYZ ptCurr = { centerPtInGeo.y, -info.nEvHeight, centerPtInGeo.x}; +// PointXYZ originCGCSPoint = getXYZFromBLH(info.craft.stPos); +// //PointXYZ diff = getNUEXYZFromCGCSXYZ(ptCurr, originCGCSPoint); + +// PointBLH centerPtInBLH = getBLHFromXYZ(getCGCSXYZFromNUEXYZ(ptCurr, originCGCSPoint)); +// std::cout << "centerPtInBLH: " << centerPtInBLH.B << ", " << centerPtInBLH.L << ", " << centerPtInBLH.H << std::endl; + +// // +// cv::Mat H = _GeoSolver->findHomography(info); + +// cv::Point2f pt_Geo = warpPointWithH(H, cv::Point2f(info.nWidth / 2.0f, info.nHeight / 2.0f)); + +// PointXYZ ptNUE = { 0 }; +// // 本模块使用的地理系是东(x)-北(y)-地(z),需要转换一下 +// ptNUE.X = pt_Geo.y; +// ptNUE.Y = -info.nEvHeight; +// ptNUE.Z = pt_Geo.x; +// PointBLH centerPtInBLH_ = getBLHFromXYZ(getCGCSXYZFromNUEXYZ(ptNUE, _GeoSolver->originCGCSPoint)); +// std::cout << "centerPtInBLH_: " << centerPtInBLH_.B << ", " << centerPtInBLH_.L << ", " << centerPtInBLH_.H << std::endl; + + cv::Mat H = _GeoSolver->findHomography(info); + + cv::Rect2f roi = warpRectWithH_2Rect(H, src.size()); + + float scale = src.cols / MAX(roi.width, roi.height); + float shiftX = -roi.x * scale; + float shiftY = (roi.y + roi.height)* scale; + + Mat H_proj = (Mat_(3, 3) << scale, 0, shiftX, + 0, -scale, shiftY, + 0, 0, 1 + ); + cv::Rect2f roi2 = warpRectWithH_2Rect(H_proj * H, src.size()); + + cv::Mat dst; + warpPerspective(src, dst, H_proj * H, cv::Size(roi2.width, roi2.height)); + + + // 计算图像四至的经纬高 + + + return; +} + +void GeoCorrect::Init(FrameInfo info) +{ + _GeoSolver->SetOriginPoint(info); +} \ No newline at end of file diff --git a/stitch/src/Arith_GeoCorrect.h b/stitch/src/Arith_GeoCorrect.h new file mode 100644 index 00000000..11a3b098 --- /dev/null +++ b/stitch/src/Arith_GeoCorrect.h @@ -0,0 +1,21 @@ +#pragma once +#include "API_GeoCorrect.h" +#include "opencv2/opencv.hpp" +#include "Arith_GeoSolver.h" +#include "StitchStruct.h" +#include "FileCache.h" +#include +#include "Logger.h" + +class GeoCorrect : public API_GeoCorrect +{ +public: + GeoCorrect(); + ~GeoCorrect(); + + void Correct(cv::Mat img, FrameInfo info); + void Init(FrameInfo info); + +private: + GeoSolver* _GeoSolver; +}; \ No newline at end of file diff --git a/stitch/src/Arith_GeoSolver.cpp b/stitch/src/Arith_GeoSolver.cpp index 8c5c4a67..440fbe4f 100644 --- a/stitch/src/Arith_GeoSolver.cpp +++ b/stitch/src/Arith_GeoSolver.cpp @@ -139,20 +139,12 @@ Mat GeoSolver::Mat_TransENG2uv(FrameInfo info) 0, 0, info.nEvHeight ); - - // 内参 - FLOAT32 fd = info.camInfo.nFocus / info.camInfo.fPixelSize * 1000; - - Mat M_cam = (Mat_(3, 3) << fd, 0, info.nWidth / 2, - 0, -fd, info.nHeight / 2, - 0, 0, 1 - ); - - Mat M = M_cam * Mat_TransENG2Cam(info) * M_het; + Mat M = Mat_GetCamK(info) * Mat_TransENG2Cam(info) * M_het; return M; } +using namespace std; cv::Mat GeoSolver::Mat_TransENG2Cam(FrameInfo info) { diff --git a/stitch/src/Arith_GeoSolver.h b/stitch/src/Arith_GeoSolver.h index e0b112ce..b29b1d0f 100644 --- a/stitch/src/Arith_GeoSolver.h +++ b/stitch/src/Arith_GeoSolver.h @@ -62,7 +62,7 @@ public: // 经纬度转换为局部地理系坐标 cv::Point2f getGeoFromBLH(PointBLH ptPos); -private: +public: // 计算当前帧李群(SE3)和投影K RtK AnlayseRtK(FrameInfo info); @@ -76,7 +76,7 @@ private: // 机体ENG(东北地)到像方的 旋转矩阵 cv::Mat Mat_TransENG2uv(FrameInfo info); -private: +public: // 平移矩阵,以初始化点为基准,计算当前位置在初始点的地理坐标,那么当前帧所有点的坐标做此平移 cv::Mat Mat_TransENGMove(FrameInfo info); diff --git a/stitch/src/Arith_UnderStitch.cpp b/stitch/src/Arith_UnderStitch.cpp index fa4153c6..b982f4fa 100644 --- a/stitch/src/Arith_UnderStitch.cpp +++ b/stitch/src/Arith_UnderStitch.cpp @@ -8,7 +8,7 @@ #include #include "utils/Arith_ThreadPool.h" #include "utils/Arith_timer.h" - + #ifdef _WIN32 #include #include @@ -132,7 +132,7 @@ UPanInfo UnderStitch::InitMap(FrameInfo para) auto geo_pro_rect = warpRectWithH_2Rect(H0,cv::Size(para.nWidth, para.nHeight)); // 计算目标比例(单帧投影后占全景图的0.25) - float target_ratio = 0.25; + float target_ratio = 0.08; float current_ratio = MAX(geo_pro_rect.width / (target_ratio*panPara.m_pan_width), geo_pro_rect.height / (target_ratio*panPara.m_pan_height)); // 调整scale参数 @@ -468,6 +468,49 @@ cv::Mat gammaCorrect(const cv::Mat& src, double gamma) } +GD_VIDEO_FRAME_S mat_to_gd_frame(const cv::Mat& mat) +{ + GD_VIDEO_FRAME_S frame = {0}; + + if (mat.empty()) { + return frame; + } + + frame.u32Width = mat.cols; + frame.u32Height = mat.rows; + + // 根据 Mat 的类型和通道数确定像素格式 + int channels = mat.channels(); + int depth = mat.depth(); + + if (channels == 1 && depth == CV_8U) { + frame.enPixelFormat = GD_PIXEL_FORMAT_GRAY_Y8; + } else if (channels == 3 && depth == CV_8U) { + // 假设是 BGR(OpenCV 默认) + frame.enPixelFormat = GD_PIXEL_FORMAT_BGR_PACKED; + } else if (channels == 4 && depth == CV_8U) { + // BGRA + frame.enPixelFormat = GD_PIXEL_FORMAT_BGR_PACKED; // 使用 BGR,忽略 alpha + } else { + throw std::runtime_error("Unsupported cv::Mat format for GD_VIDEO_FRAME_S"); + } + + // 设置 stride 和虚拟地址 + frame.u32Stride[0] = static_cast(mat.step[0]); + frame.u64VirAddr[0] = mat.data; + + return frame; +} + + +SINT32 UnderStitch::Run(cv::Mat img, FrameInfo para) +{ + // 将cv::Mat转换为GD_VIDEO_FRAME_S + GD_VIDEO_FRAME_S frame = mat_to_gd_frame(img); + return Run(frame, para); +} + + SINT32 UnderStitch::Run(GD_VIDEO_FRAME_S frame, FrameInfo para) { stopWatch sw; diff --git a/stitch/src/Arith_UnderStitch.h b/stitch/src/Arith_UnderStitch.h index 7fc9c7d8..23caa8dc 100644 --- a/stitch/src/Arith_UnderStitch.h +++ b/stitch/src/Arith_UnderStitch.h @@ -11,7 +11,7 @@ #include "utils/Arith_ThreadPool.h" #include #include "Logger.h" - + // 定义扫描模式,使用扫描专用的拼接地图策略 #define SCAN_MODE @@ -39,6 +39,8 @@ public: SINT32 Run(GD_VIDEO_FRAME_S img, FrameInfo para); + SINT32 Run(cv::Mat img, FrameInfo para); + void Stop(); SINT32 OptAndOutCurrPan(); diff --git a/stitch/src/Version.h b/stitch/src/Version.h index ef6c9c6f..8cd6c620 100644 --- a/stitch/src/Version.h +++ b/stitch/src/Version.h @@ -2,5 +2,5 @@ #pragma once #include -std::string BUILD_TIME = "BUILD_TIME 2025_11_27-19.34.59"; +std::string BUILD_TIME = "BUILD_TIME 2025_11_29-17.52.06"; std::string VERSION = "BUILD_VERSION 1.0.1"; diff --git a/tests/1.py b/tests/1.py deleted file mode 100644 index ec27a054..00000000 --- a/tests/1.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -import sys - -# 1. 定义您的 DLL 目录 -DLL_DIR = r"D:\wangchongwu_gitea_2023\StitchVideo\Bin" -CV_DIR = r"C:\Lib\opencv455\build\x64\vc14\bin" -CUDA_DIR = r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" - -# path -os.add_dll_directory(DLL_DIR) -os.add_dll_directory(CV_DIR) -os.add_dll_directory(CUDA_DIR) - -sys.path.append(DLL_DIR) -sys.path.append(CV_DIR) -sys.path.append(CUDA_DIR) - -# 4. 导入 -from UStitcher import API_UnderStitch,FrameInfo - -input("sssss") - - -frame_info = FrameInfo() - - - -frame_info.nFrmID = 1 -frame_info.craft.stPos.B = 39 -frame_info.craft.stPos.L = 120 -frame_info.craft.stPos.H = 1000 -frame_info.camInfo.nFocus = 40 -frame_info.camInfo.fPixelSize = 12 - -frame_info.servoInfo.fServoAz = 90 -frame_info.servoInfo.fServoPt = -45 - -frame_info.nEvHeight = 1200 - -frame_info.nWidth = 1280 -frame_info.nHeight = 1024 - - - -stitcher = API_UnderStitch.Create() -#H = stitcher.getHomography(frame_info) - - -print("done") - - - - - - - diff --git a/tests/DJ/ProcDJ.cpp b/tests/cpp/DJ/ProcDJ.cpp similarity index 99% rename from tests/DJ/ProcDJ.cpp rename to tests/cpp/DJ/ProcDJ.cpp index 645c6e0b..5fd39986 100644 --- a/tests/DJ/ProcDJ.cpp +++ b/tests/cpp/DJ/ProcDJ.cpp @@ -371,7 +371,7 @@ int main() vector videoPathList; vector srtPathList; - string folder = "F:/K2D_data/"; + string folder = "/media/wang/data/K2D_data/"; videoPathList.push_back(folder + "DJI_20251024144932_0001_Z.MP4"); diff --git a/tests/NeoArithStandardDll.h b/tests/cpp/NeoArithStandardDll.h similarity index 100% rename from tests/NeoArithStandardDll.h rename to tests/cpp/NeoArithStandardDll.h diff --git a/tests/S7215/Arith_zhryp.cpp b/tests/cpp/S7215/Arith_zhryp.cpp similarity index 100% rename from tests/S7215/Arith_zhryp.cpp rename to tests/cpp/S7215/Arith_zhryp.cpp diff --git a/tests/S7215/Arith_zhryp.h b/tests/cpp/S7215/Arith_zhryp.h similarity index 100% rename from tests/S7215/Arith_zhryp.h rename to tests/cpp/S7215/Arith_zhryp.h diff --git a/tests/S7215/TsDecoder.hpp b/tests/cpp/S7215/TsDecoder.hpp similarity index 100% rename from tests/S7215/TsDecoder.hpp rename to tests/cpp/S7215/TsDecoder.hpp diff --git a/tests/S7215/TsPacker.hpp b/tests/cpp/S7215/TsPacker.hpp similarity index 100% rename from tests/S7215/TsPacker.hpp rename to tests/cpp/S7215/TsPacker.hpp diff --git a/tests/S7215/commondefine.h b/tests/cpp/S7215/commondefine.h similarity index 100% rename from tests/S7215/commondefine.h rename to tests/cpp/S7215/commondefine.h diff --git a/tests/S7215/stitch_S7215.cpp b/tests/cpp/S7215/stitch_S7215.cpp similarity index 100% rename from tests/S7215/stitch_S7215.cpp rename to tests/cpp/S7215/stitch_S7215.cpp diff --git a/tests/S729.h b/tests/cpp/S729.h similarity index 100% rename from tests/S729.h rename to tests/cpp/S729.h diff --git a/tests/S732/H264_SEI_typedef.h b/tests/cpp/S732/H264_SEI_typedef.h similarity index 100% rename from tests/S732/H264_SEI_typedef.h rename to tests/cpp/S732/H264_SEI_typedef.h diff --git a/tests/cpp/S732/NeoArithStandardDll.h b/tests/cpp/S732/NeoArithStandardDll.h new file mode 100644 index 00000000..894d7726 --- /dev/null +++ b/tests/cpp/S732/NeoArithStandardDll.h @@ -0,0 +1,655 @@ +#pragma once +/*********版权所有(C)2024, 武汉高德红外股份有限公司*************** +* 文件名称:ArithStandardDll.h +* 文件标识:高德光电搜索跟踪算法SDK +* 内容摘要: +* 其它说明:算法动态链接库(Arith DLL)的函数、全局变量、宏定义,统一前缀为简写"ARIDLL" +* 当前版本:V2.0 +* 创建作者:04046wcw +* 创建日期:2023-11-01 +****************************************************************/ +#ifndef __NEO_ARTTHSTANDARDDLL_H__ +#define __NEO_ARTTHSTANDARDDLL_H__ + +#include "PlatformDefine.h" +#include "Arith_CommonDef.h" + + + +#ifdef _WIN32 +#define STD_TRACKER_API extern "C" __declspec(dllexport) +#else +#define STD_TRACKER_API __attribute__ ((visibility("default"))) +#endif + + + +#ifdef __cplusplus +extern "C" { +#endif + + + +//+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +//单个目标的结构体[兼容输入/输出,检测/跟踪] +typedef struct tagARIDLL_OBJINFO +{ + //*****1.目标状态信息***** + int nFrameId; //目标当前信息所对应的帧编号 + unsigned char unObjStatus; //目标搜索状态信息,更新/新增/删除 + unsigned char bMainTracked; //目标是否为主跟踪目标 + TrackingStatus unTrackingStatus;//目标跟踪状态 + + + //*****2.目标管道信息***** + int nOutputID; //输出告警目标 + int nInPipesID; //目标在管道数组中的编号 + int nPipeLostCnt; //目标当前管道连续丢失计数 + int nTotalCnt; //目标当前管道总帧数 + unsigned char bInCurrFov; //目标是否在当前视场 + int nAntiJamming; //抗干扰状态 + + //*****3.目标核心信息***** + float nX; //目标中心点图像坐标x + float nY; //目标中心点图像坐标y + float nObjW; //目标宽度 + float nObjH; //目标高度 + float fAz; //目标当前方位角 + float fPt; //目标当前俯仰角 + + // 目标预测位置 + float fPredAz; + float fPredPt; + + //*****4.其他属性信息***** + int nObjGray; //目标灰度 + int nObjMaxGray; //目标极值灰度 + int nMaxPosX; //目标极大值点X + int nMaxPosY; //目标极大值点Y + int nPixCnts; //目标像素个数 + unsigned char ubSizeType; //目标尺寸类型: + float fProb; //目标识别置信度 + float fSNR; //目标信噪比值 + float fTgEntropy; //目标信息熵值 + float fBgEntropy; //目标背景信息熵 + float fSaliency; //目标显著性值 + + // + bool nJammingSucess; //目标成功干扰 + + int unClsType; //目标类别 + float fReIDSim; //当前目标与主目标的ReID相似度 + + // 如果处于跟踪状态,则输出下列值 + RECT32S SA_SrBox; //小面目标跟踪波门 + SizeType SA_SizeType; //尺度信息 + RECT32S KCF_SrBox; //KCF波门 + RECT32S TLD_SrBox; //TLD波门 + FLOAT32 fConf; //跟踪置信度 + ObjSrc ArithSrc; //跟踪算法来源,决策后 + + unsigned char byte[20]; //预留 + +}ARIDLL_OBJINFO; + + +//输入【系统参数】结构体 +typedef struct tagARIDLL_INPUTPARA +{ + int nTimeStamp; //当前帧采集时刻时间戳,单位毫秒 + int unFrmId; //当前帧图像帧编号 + short unFreq; //输入图像帧频 + ServoInfo stServoInfo; //传感器伺服信息 + CamInfo stCameraInfo; //相机信息 + AirCraftInfo stAirCraftInfo; //载体信息 + GuideInfo stGuideInfo; //外部引导信息 + + AIT_OUTPUT stAITrackerInfo; //AI跟踪器结果 + + int nServoDelatCnt; //伺服角度延迟帧数 + // 其他输入 + bool bImageRataSys; //像旋系统标记,S731实物样机1,数字样机未模拟像旋 -0 + int nElevationDiff; //机载设备挂飞高程差 +}ARIDLL_INPUTPARA; + +//调试信息 +typedef struct tagARIDLL_DEBUG_OUTPUT +{ + unsigned short nDetectObjsNum; + // 管道资源 + unsigned short nMaxPipeNum; //当前系统管道资源数量 + unsigned short nUsePipeNum; //当前非空管道数量 + float Arith_time; //算法运行耗时 + unsigned int unFrmID; //算法执行的帧编号 + + //ARM发指令信息 + unsigned char nSysMode; //外部系统状态 + unsigned char nScenMode; //场景模式 + unsigned char nStatus; //待命/检测/跟踪/丢失状态信息等 + unsigned char nPixelType; //图像数据类型 + unsigned short nWidth; //图像宽 + unsigned short nHeight; //图像高 + unsigned char nLockType; //1-拉框吸附 2-点选吸附 3-ID锁定 4-修正攻击点 5-解锁 + unsigned char nLockID; //锁定id号 + unsigned short nLockX; //锁定波门或者修改攻击点 + unsigned short nLockY; + unsigned short nLockW; + unsigned short nLockH; + unsigned short nPredictX; //轨迹预测点 + unsigned short nPredictY; + unsigned short nForceMemFrm; //强制记忆帧数 + unsigned char unFreq; //帧频 + float fServoAz; //伺服方位角 + float fServoPt; //伺服俯仰角 + float nFocus; //焦距 + float fPixelSize; //像元尺寸 + unsigned char unVideoType; //视频类型 + + //算法参数(公用) + unsigned short nX; //决策输出中心点X + unsigned short nY; //决策输出中心点Y + unsigned short nW; //决策输出宽度 + unsigned short nH; //决策输出高度 + + unsigned short nRecapX; //重捕区域中心X + unsigned short nRecapY; //重捕区域中心Y + unsigned short nRecapW; //重捕区域宽度 + unsigned short nRecapH; //重捕区域高度 + + //对地跟踪信息调试 + //对地参数 + unsigned char nDecisionStatus; //决策状态 + unsigned char nKcfStatus; //kcf状态 + unsigned char bAIDStatus; //AI识别状态 + float fKCFRes; //KCF响应 + float fLargeResTH; //KCF重捕阈值 + float fArrKCFRes; + unsigned char nOccKCFStatus; + unsigned char nArrestKCFStatus; + short nAIDBestId; + short nAIDLostCnt; + short unContiTrackedCnt; + short nAIDJamCnt; + unsigned char nOccAIDStatus; + unsigned char nArrestAIDStatus; + unsigned short nTLDNum; //TLD聚类检测个数 + unsigned short nLearnCnt; //TLD学习计数 + float fMaxNNConf; //TLD检测最大响应 + + //对空参数 + unsigned char sky_bComplexEnv; // 复杂背景标志位 + unsigned char sky_bInterferenceMem; // 干扰近记忆标志位 + unsigned char sky_bResetByAIDet; // 跟踪被AI重置标记 + unsigned char sky_nClassSource; // 目标类别输出来源 + unsigned char sky_nDecisionStatus; // 决策状态,输出来源 + unsigned char sky_TrkDownRatio; // 对空跟踪降采样倍数 + unsigned short sky_TrkMemFrm; // 跟踪目标进记忆帧数 + unsigned short sky_nTrkCX; // 决策目标信息中心点X + unsigned short sky_nTrkCY; // 决策目标信息中心点Y + unsigned short sky_nTrkW; // 决策目标信息宽度 + unsigned short sky_nTrkH; // 决策目标信息高度 + unsigned short sky_nTrkPxlsCnt; // 跟踪目标像素数 + unsigned short sky_fTrkConf; // 跟踪目标置信度 + unsigned char sky_bGuideInFov; // 导引目标是否在视场 + unsigned short sky_nGuideCX; // 导引区域中心X + unsigned short sky_nGuideCY; // 导引区域中心Y + unsigned short sky_nGuideW; // 导引区域宽度 + unsigned short sky_nGuideH; // 导引区域高度 + + unsigned char resv[48]; //预留 +}ARIDLL_DEBUG_OUTPUT; + + + +//跟踪目标输出结构体 +typedef struct tagARIDLL_OUTPUT +{ + int nTimeStamp;//当前帧时间戳(透传),单位:毫秒 + + // 系统工作模式(透传)// by wcw04046 @ 2021/12/06 + GLB_SYS_MODE nSysMode; + int nFrmNum;//处理帧计数 + + // 场景模式 + GLB_SCEN_MODE nScenMode; + + //*****工作状态***** + GLB_STATUS nStatus; //待命/检测/跟踪/丢失状态信息等 + + //*****目标检测*****(短时航迹点,用于用户指示) + int nAlarmObjCnts; //当前帧告警目标总个数 + ARIDLL_OBJINFO stAlarmObjs[ST_OBJ_NUM]; //检测目标信息数组 + + //*****目标跟踪*****(长时航迹点,第0个为主目标送伺服跟踪) + int nTrackObjCnts; //跟踪目标个数 + ARIDLL_OBJINFO stTrackers[LT_OBJ_NUM]; //跟踪器输出数组 + + // AI跟踪器协同控制指令输出,用于控制端侧NPU程序 + AIT_Command stAI_TkCmd; + + //搜索区域(对空场景为导引区域) + RECT32S rsRecaptureRion; + + //指示轨迹预测模块是否相信当前跟踪点 + BBOOL bPredictJam; + + //调试信息 + ARIDLL_DEBUG_OUTPUT stDebugInfo; + +}ARIDLL_OUTPUT; + +/************************************* +* 函数名称: STD_CreatEOArithHandle() +* 功能描述: 算法句柄创建 +* 创建日期: 2025/6/24 +* 输出参数: 无 +* 返回值: STD_TRACKER_API ArithHandle:算法句柄 +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API ArithHandle STD_CreatEOArithHandle(); + + +/************************************* +* 函数名称: STD_CreatEOArithNamedHandle() +* 功能描述: 指定算法句柄名称的句柄创建,每个句柄创建都有自己独立的配置文件路径,避免设置重复参数,比如句柄名称 + 实际上也可以允许句柄名称重复,不建议这样做,因为不能起到区分作用,参数设置请参考文件 +* 创建日期: 2025/8/8 +* 输入参数: const char * configPath:算法创建相关配置json名称 +* 输出参数: +* 返回值: STD_TRACKER_API ArithHandle:算法句柄 +* 调用关系: +* 其它说明: +*************************************/ +STD_TRACKER_API ArithHandle STD_CreatEOArithNamedHandle(const char* configPath); + + +/************************************* +* 函数名称: STD_DeleteEOArithHandle() +* 功能描述: 释放算法句柄 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArith:算法句柄 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void STD_DeleteEOArithHandle(ArithHandle hArith); + + +/************************************* +* 函数名称: ARIDLL_EOArithInit() +* 功能描述: 执行算法模块初始化,默认为对空凝视场景 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArith:算法句柄 +* 输入参数: int nWidth:输入图像宽度 +* 输入参数: int nHeight:输入图像高度 +* 输入参数: GD_PIXEL_FORMAT_E nPixelType:图像的像素类型 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: GLB_PT_TYPE nPixelType 暂未使用的参数 +*************************************/ +STD_TRACKER_API void ARIDLL_EOArithInit(ArithHandle hArith, int nWidth, int nHeight, GD_PIXEL_FORMAT_E nPixelType); + + +/************************************* +* 函数名称: ARIDLL_EOArithInitWithMode() +* 功能描述: 执行算法模块初始化2 - 带模式的初始化,可指定系统模式和场景模式 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArith:算法句柄 +* 输入参数: int nWidth:输入图像宽度 +* 输入参数: int nHeight:输入图像高度 +* 输入参数: GD_PIXEL_FORMAT_E nPixelType:图像的像素类型 +* 输入参数: GLB_SYS_MODE nSysMode:系统模式 +* 输入参数: GLB_SCEN_MODE nScenMode:场景模式 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: GLB_PT_TYPE nPixelType 暂未使用的参数 +*************************************/ +STD_TRACKER_API void ARIDLL_EOArithInitWithMode(ArithHandle hArith, int nWidth, int nHeight, GD_PIXEL_FORMAT_E nPixelType, + GLB_SYS_MODE nSysMode, GLB_SCEN_MODE nScenMode); + + +/************************************* +* 函数名称: ARIDLL_CreateAITracker() +* 功能描述: 创建AI跟踪器模块 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArith:算法句柄 +* 输入参数: int nWidth:输入图像宽度 +* 输入参数: int nHeight:输入图像高度 +* 输入参数: const char * configPath:AI跟踪参数配置文件名称 +* 输出参数: 无 +* 返回值: STD_TRACKER_API bool +* 调用关系: 无 +* 其它说明: 创建AI跟踪器后,NeoTracker仅在切换为AI跟踪场景模式时才会使用AI跟踪 +*************************************/ +STD_TRACKER_API bool ARIDLL_CreateAITracker(ArithHandle hArith, int nWidth, int nHeight, const char* configPath); + + +/************************************* +* 函数名称: ARIDLL_RunController() +* 功能描述: 目标搜跟流程,算法主处理逻辑 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: GD_VIDEO_FRAME_S img :输入图像 +* 输入参数: ARIDLL_INPUTPARA stInputPara :输入参数 +* 输入参数: ARIDLL_OUTPUT * pstOutput :输出结果 +* 输出参数: 无 +* 返回值: STD_TRACKER_API int :主处理逻辑状态码 +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API int ARIDLL_RunController(ArithHandle hArithSrc, GD_VIDEO_FRAME_S img, ARIDLL_INPUTPARA stInputPara, ARIDLL_OUTPUT* pstOutput); + + +/************************************* +* 函数名称: ARIDLL_SearchFrameTargets() +* 功能描述: 执行单帧小面目标检测算法 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: GD_VIDEO_FRAME_S img :输入图像 +* 输出参数: 无 +* 返回值: STD_TRACKER_API int :单帧小面检测算法个数 +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API int ARIDLL_SearchFrameTargets(ArithHandle hArithSrc, GD_VIDEO_FRAME_S img); + + +/************************************* +* 函数名称: ARIDLL_MergeAITargets() +* 功能描述: 接收外部AI识别目标,在主循环调用前调用以传入外部目标 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: AI_Target * aiDetectArray:输入AI目标检测数组 +* 输入参数: int aiNum:输入AI目标检测个数 +* 输出参数: 无 +* 返回值: STD_TRACKER_API int +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API int ARIDLL_MergeAITargets(ArithHandle hArithSrc, AI_Target* aiDetectArray,int aiNum); + + +/************************************* +* 函数名称: ARIDLL_SendReIDToTargets() +* 功能描述: 接收外部目标reID特征,将特征传入到算法内部 +* 创建日期: 2025/6/25 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: ReIDFeature * pReidFeatures :ReID特征数组 +* 输入参数: int ReIDNums :ReID特征个数 +* 输出参数: 无 +* 返回值: STD_TRACKER_API int +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +//STD_TRACKER_API int ARIDLL_SendReIDToTargets(ArithHandle hArithSrc, ReIDFeature* pReidFeatures, int ReIDNums); + + +/************************************* +* 函数名称: ARIDLL_LockCommand() +* 功能描述: 视场内自适应锁定,智能锁定,有目标锁目标,无目标根据项目需求执行锁定方式,中心点XY一定要传,宽高可选择 +* + 点选吸附 +* + 框选距离管道目标近距离吸附,远距离强锁框选区域 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int nLockX:锁定点目标图像坐标中心点X +* 输入参数: int nLockY:锁定点目标图像坐标中心点Y +* 输入参数: int nLockW:锁定波门宽度, 目标宽度 +* 输入参数: int nLockH:锁定波门高度, 目标高度 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 下发指令在下一帧转入锁定,宽高送0则算法根据识别结果自动选择波门 +*************************************/ +STD_TRACKER_API void ARIDLL_LockCommand(ArithHandle hArithSrc, int nLockX, int nLockY, int nLockW, int nLockH); + + +/************************************* +* 函数名称: ARIDLL_LockCommand_DefaultSize() +* 功能描述: 框选吸附 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int nLockX:锁定点目标图像坐标中心点X +* 输入参数: int nLockY:锁定点目标图像坐标中心点Y +* 输入参数: int nLockW:锁定波门宽度, 目标宽度 +* 输入参数: int nLockH:锁定波门高度, 目标高度 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_LockCommand_DefaultSize(ArithHandle hArithSrc, int nLockX, int nLockY, int nLockW, int nLockH); + + +/************************************* +* 函数名称: ARIDLL_LockTargetByID() +* 功能描述: 视场内根据批号锁定,当前帧处理 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int nBatchID:锁定管道批号 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_LockTargetByID(ArithHandle hArithSrc, int nBatchID); + + +/************************************* +* 函数名称: ARIDLL_GuideLockMultiCommand() +* 功能描述: 视场外引导锁定-支持批量锁定 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: TargetGuide * guideList +* SINT32 ID; //目标批号,传递给锁定后目标(必填,0为无效值,外部指定跟踪目标输出批号) + UBYTE8 bIsCoLocate; //协同定位标记,直接透传到目标(非必须,默认0) + Pole stTargetPole; //目标极坐标(伺服系目标方位、俯仰、目标测距) + PointBLH stTargetPos; //目标GPS坐标(大地系纬经高) + FLOAT32 fGuideAzSpeed; //实际锁定点方位角速度(非必须,默认0) + FLOAT32 fGuidePtSpeed; //实际锁定点俯仰角速度(非必须,默认0) + SINT32 nGuideFocus; //引导时的焦距值(必填) + SINT32 nMaxFocus; //最大焦距值(必填) + SINT32 nLockX; //锁定点当前图像坐标X(图像系X) + SINT32 nLockY; //锁定点当前图像坐标Y(图像系Y) + BBOOL bInFOV; //在视场判断(图像系目标在视场判断) + stTargetPole、stTargetPos、nLockX(nLockY,bInFOV),三选一,当前建议选择stTargetPole +* 输入参数: int num:引导目标个数 +* 输入参数: int nGuideAge:引导生命周期,帧数衡量,如50HZ平台,1s填50,外部设定 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 对空引导在焦距没有拉满的情况下,不要出现是小目标的情况 +*************************************/ +STD_TRACKER_API void ARIDLL_GuideLockMultiCommand(ArithHandle hArithSrc, TargetGuide* guideList, int num, int nGuideAge); + + +/************************************* +* 函数名称: ARIDLL_unLockCommand() +* 功能描述: 解锁 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_unLockCommand(ArithHandle hArithSrc); + + +/************************************* +* 函数名称: ARIDLL_AdjustTrackRect() +* 功能描述: 微调对地主跟踪器跟踪框 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int dx:X 方向的偏移量,向左为负;向右为正 +* 输入参数: int dy:Y 方向的偏移量,向上为负;向下为正 +* 输入参数: int dw :目标宽度扩大,缩减大小 +* 输入参数: int dh :目标高度扩大,缩减大小 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_AdjustTrackRect(ArithHandle hArithSrc,int dx,int dy,int dw,int dh); + + +/************************************* +* 函数名称: ARIDLL_ReadSetParamFile() +* 功能描述: 读取序列化参数:从文件 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: const char * configFilePath:保存了算法参数的文件名 +* 输出参数: 无 +* 返回值: STD_TRACKER_API bool +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API bool ARIDLL_ReadSetParamFile(ArithHandle hArithSrc, const char* configFilePath); + + +/************************************* +* 函数名称: ARIDLL_ReadSetParamStream() +* 功能描述: 读取序列化参数:从buffer +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: const char * configsstream:保存了算法参数的文件流 +* 输出参数: 无 +* 返回值: STD_TRACKER_API bool : 设置成功标记 +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API bool ARIDLL_ReadSetParamStream(ArithHandle hArithSrc,const char* configsstream); + + +/************************************* +* 函数名称: ARIDLL_SetSysMode() +* 功能描述: 设置外部系统工作模式 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: GLB_SYS_MODE nSysMode:系统模式 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetSysMode(ArithHandle hArithSrc, GLB_SYS_MODE nSysMode); + + +/************************************* +* 函数名称: ARIDLL_SetScenMode() +* 功能描述: 设置工作场景,对空,对地,对海等 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: GLB_SCEN_MODE nScenMode:工作场景 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetScenMode(ArithHandle hArithSrc, GLB_SCEN_MODE nScenMode); + + +/************************************* +* 函数名称: ARIDLL_SetForceMemTrack() +* 功能描述: 设置强制记忆跟踪 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int nMemCnt:强制进记忆帧数,超出强制记忆帧数则退出记忆 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetForceMemTrack(ArithHandle hArithSrc, int nMemCnt); + + +/************************************* +* 函数名称: ARIDLL_SetSkyLineCaliPoints() +* 功能描述: 接收外部传入的天地线标定点,最多支持360个点 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: ANGLE32F * SkyLinePoints:天地线标定点数组首地址 +* 输入参数: int N:天地线标定点个数,最多360个 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetSkyLineCaliPoints(ArithHandle hArithSrc, ANGLE32F* SkyLinePoints, int N); + + +/************************************* +* 函数名称: ARIDLL_ExportParamFile() +* 功能描述: 输出算法配置文件 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: const char * configFilePath:输出算法配置的文件名 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_ExportParamFile(ArithHandle hArithSrc, const char* configFilePath); + + + +/************************************* +* 函数名称: ARIDLL_ExportOSDJson() +* 功能描述: 输出算法调试json字符串流,算法OSD叠加功能 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: char * Buffer:接收OSD流的缓冲区 +* 输入参数: int bufferSize:接收OSD流的缓冲区的大小 +* 输出参数: 无 +* 返回值: STD_TRACKER_API int:成功返回0 +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API int ARIDLL_ExportOSDJson(ArithHandle hArithSrc, char* Buffer, int bufferSize); + + +/************************************* +* 函数名称: ARIDLL_SetSOTRect() +* 功能描述: 按照固定大小调整波门 +* 创建日期: 2025/6/24 +* 输入参数: ArithHandle hArithSrc +* 输入参数: int objX:新的锁定X中心 +* 输入参数: int objY:新的锁定Y中心 +* 输入参数: int objW:新的锁定目标宽度 +* 输入参数: int objH:新的锁定目标高度 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetSOTRect(ArithHandle hArithSrc, int objX, int objY, int objW, int objH); + + +/************************************* +* 函数名称: ARIDLL_SetPredictInfo() +* 功能描述: 外部预测的主跟踪位置点 +* 创建日期: 2025/7/7 +* 输入参数: ArithHandle hArithSrc:算法句柄 +* 输入参数: int X1:当前帧预测x位置 +* 输入参数: int Y1:当前帧预测y位置 +* 输入参数: int X2:8s后预测的x位置 +* 输入参数: int Y2:8s后预测的y位置 +* 输出参数: 无 +* 返回值: STD_TRACKER_API void +* 调用关系: 无 +* 其它说明: 无 +*************************************/ +STD_TRACKER_API void ARIDLL_SetPredictInfo(ArithHandle hArithSrc, int X1, int Y1, int X2, int Y2); + + +#ifdef __cplusplus +} +#endif + + + + +#endif diff --git a/tests/S732/S732.h b/tests/cpp/S732/S732.h similarity index 100% rename from tests/S732/S732.h rename to tests/cpp/S732/S732.h diff --git a/tests/S732/decodedata.cpp b/tests/cpp/S732/decodedata.cpp similarity index 100% rename from tests/S732/decodedata.cpp rename to tests/cpp/S732/decodedata.cpp diff --git a/tests/S732/decodedata.h b/tests/cpp/S732/decodedata.h similarity index 100% rename from tests/S732/decodedata.h rename to tests/cpp/S732/decodedata.h diff --git a/tests/S732/hi_type.h b/tests/cpp/S732/hi_type.h similarity index 100% rename from tests/S732/hi_type.h rename to tests/cpp/S732/hi_type.h diff --git a/tests/S732/stitch_S732.cpp b/tests/cpp/S732/stitch_S732.cpp similarity index 100% rename from tests/S732/stitch_S732.cpp rename to tests/cpp/S732/stitch_S732.cpp diff --git a/tests/S732/stitch_udp.cpp b/tests/cpp/S732/stitch_udp.cpp similarity index 100% rename from tests/S732/stitch_udp.cpp rename to tests/cpp/S732/stitch_udp.cpp diff --git a/tests/cpp/Test_GeoCorrect.cpp b/tests/cpp/Test_GeoCorrect.cpp new file mode 100644 index 00000000..e972b4ac --- /dev/null +++ b/tests/cpp/Test_GeoCorrect.cpp @@ -0,0 +1,47 @@ +#include "API_GeoCorrect.h" +#include "StitchStruct.h" +#include "opencv2/opencv.hpp" + + +int main(int argc, char** argv) +{ + printf("Test_GeoCorrect\n"); + API_GeoCorrect* geoCorrect = API_GeoCorrect::Create(); + + FrameInfo info; + info.camInfo.fPixelSize = 4; + info.camInfo.nFocus = 48; + + + info.servoInfo.fServoAz = 1; + info.servoInfo.fServoPt = -78; + + info.nWidth = 1920; + info.nHeight = 1080; + + info.craft.stPos.B = 30.0; + info.craft.stPos.L = 114.0; + info.craft.stPos.H = 1000.0; + + info.craft.stAtt.fRoll = 1.0; + info.craft.stAtt.fPitch = -2.0; + info.craft.stAtt.fYaw = 134.0; + + info.nEvHeight = 1000; + + + + geoCorrect->Init(info); + + info.craft.stPos.B += 0.1; + info.craft.stPos.L += 0.1; + info.craft.stPos.H += 100; + + + + + + + geoCorrect->Correct(cv::Mat(), info); + return 0; +} \ No newline at end of file diff --git a/tests/feaStitchTest.cpp b/tests/cpp/feaStitchTest.cpp similarity index 100% rename from tests/feaStitchTest.cpp rename to tests/cpp/feaStitchTest.cpp diff --git a/tests/main.cpp b/tests/cpp/main.cpp similarity index 100% rename from tests/main.cpp rename to tests/cpp/main.cpp diff --git a/tests/stitch_Genaral.cpp b/tests/cpp/stitch_Genaral.cpp similarity index 100% rename from tests/stitch_Genaral.cpp rename to tests/cpp/stitch_Genaral.cpp diff --git a/tests/utils.cpp b/tests/cpp/utils.cpp similarity index 100% rename from tests/utils.cpp rename to tests/cpp/utils.cpp diff --git a/tests/utils.h b/tests/cpp/utils.h similarity index 100% rename from tests/utils.h rename to tests/cpp/utils.h diff --git a/tests/python/1.py b/tests/python/1.py new file mode 100644 index 00000000..ef387e90 --- /dev/null +++ b/tests/python/1.py @@ -0,0 +1,78 @@ +import os +import sys +import ctypes +from pathlib import Path +import numpy as np + +# 获取项目根目录(当前文件所在目录的父目录) +project_root = Path(__file__).parent.parent.parent +bin_dir = project_root / "Bin" + +# 添加Bin目录到Python模块搜索路径 +if str(bin_dir) not in sys.path: + sys.path.insert(0, str(bin_dir)) + +# 预加载依赖库,确保动态链接器能找到它们 +lib_guide_stitch = bin_dir / "libGuideStitch.so" +if lib_guide_stitch.exists(): + try: + ctypes.CDLL(str(lib_guide_stitch), mode=ctypes.RTLD_GLOBAL) + except Exception as e: + print(f"警告: 预加载libGuideStitch.so失败: {e}") + +# 导入模块 +from UStitcher import API_UnderStitch, FrameInfo + +import cv2 + + +frame_info = FrameInfo() + + + +frame_info.nFrmID = 1 +frame_info.craft.stPos.B = 39 +frame_info.craft.stPos.L = 120 +frame_info.craft.stPos.H = 1000 +frame_info.camInfo.nFocus = 40 +frame_info.camInfo.fPixelSize = 12 + +frame_info.servoInfo.fServoAz = 90 +frame_info.servoInfo.fServoPt = -45 + +frame_info.nEvHeight = 1200 + +frame_info.nWidth = 1280 +frame_info.nHeight = 1024 + + + +stitcher = API_UnderStitch.Create() + + +# 先初始化(设置原点) +pan_info = stitcher.Init(frame_info) +print(f"初始化成功,全景图尺寸: {pan_info.m_pan_width} x {pan_info.m_pan_height}") + +def warpPointWithH(H, pt): + wp = H @ np.array([pt[0],pt[1],1]).T + return wp / wp[2] + +for i in range(100): + frame_info.nFrmID = i + frame_info.craft.stPos.B = 39 + frame_info.craft.stPos.L = 120 + frame_info.craft.stPos.H = 1000 + frame_info.camInfo.nFocus = 40 + frame_info.camInfo.fPixelSize = 12 + + frame_info.servoInfo.fServoAz = 90 + frame_info.servoInfo.fServoPt = -45 + + H = stitcher.getHomography(frame_info) + print(f"单应性矩阵 H:\n{H}") + + wp = warpPointWithH(H, np.array([100,111])) + print(f"物方坐标:\n{wp}") + +print("done") diff --git a/tests/python/ProcDJ.py b/tests/python/ProcDJ.py new file mode 100644 index 00000000..d25a732a --- /dev/null +++ b/tests/python/ProcDJ.py @@ -0,0 +1,427 @@ +import os +import sys +import ctypes +import re +from pathlib import Path + +import numpy as np +from dataclasses import dataclass +from typing import List, Tuple + + +sys.path.append("/media/wang/WORK/wangchongwu_gitea_2023/StitchVideo/Bin") + + +# 导入模块 +from UStitcher import API_UnderStitch, FrameInfo, UPanInfo, UPanConfig + + +import cv2 + +@dataclass +class FrameData: + """帧数据""" + frame_number: int = 0 + time_range: str = "" + frame_cnt: int = 0 + timestamp: str = "" + focal_len: float = 0.0 + dzoom_ratio: float = 0.0 + latitude: float = 0.0 + longitude: float = 0.0 + rel_alt: float = 0.0 + abs_alt: float = 0.0 + gb_yaw: float = 0.0 + gb_pitch: float = 0.0 + gb_roll: float = 0.0 + real_focal_mm: float = 0.0 + pixel_size_um: float = 0.0 + + +def extract_value_from_brackets(line: str, key: str) -> float: + """从方括号格式中提取值,例如 [key: value] 或 [key: value other]""" + # 匹配模式1:单独的方括号 [key: value] + pattern1 = rf'\[{re.escape(key)}:\s*([^\]]+?)\]' + match = re.search(pattern1, line) + if match: + value_str = match.group(1).strip() + # 提取第一个数值(可能后面有其他数据) + num_match = re.search(r'-?\d+\.?\d*', value_str) + if num_match: + try: + return float(num_match.group()) + except ValueError: + pass + + # 匹配模式2:在方括号内的键值对,例如 [rel_alt: 300.030 abs_alt: 314.064] + # 匹配 [xxx key: value xxx] 格式 + pattern2 = rf'\[[^\]]*?{re.escape(key)}:\s*([^\s\]]+?)(?:\s|])' + match = re.search(pattern2, line) + if match: + value_str = match.group(1).strip() + try: + return float(value_str) + except ValueError: + pass + + return None # 使用None表示未找到 + +def extract_value_after(line: str, key: str) -> float: + """从字符串中提取指定键后面的值(兼容方括号和非方括号格式)""" + # 先尝试方括号格式 + value = extract_value_from_brackets(line, key) + if value is not None: + return value + + # 如果没有方括号,尝试直接查找格式 + pos = line.find(key) + if pos == -1: + return 0.0 + pos += len(key) + # 跳过可能的空格和冒号 + while pos < len(line) and line[pos] in ' :': + pos += 1 + # 找到值的结束位置(空格、逗号、换行等) + end = pos + while end < len(line) and line[end] not in ' ,\n\r\t]': + end += 1 + if end == pos: + return 0.0 + try: + return float(line[pos:end]) + except ValueError: + return 0.0 + + +def infer_camera_params_h30(frame: FrameData, filename: str) -> None: + """获取真实焦距和像元尺寸 - H30版本""" + if "_W" in filename: + frame.real_focal_mm = 6.72 + frame.pixel_size_um = 2.4 + elif "_Z" in filename: + frame.real_focal_mm = 6.72 * 2 + frame.pixel_size_um = 2.4 + elif "_T" in filename: + frame.real_focal_mm = 24.0 + frame.pixel_size_um = 12.0 + # else: 保持默认值 + + +def parse_dji_srt(filename: str) -> List[FrameData]: + """解析DJI SRT文件""" + frames = [] + + try: + with open(filename, 'r', encoding='utf-8') as file: + lines = [] + for line in file: + line_stripped = line.rstrip('\n\r') + lines.append(line_stripped) + + # 每5行一组(包括空行) + if len(lines) >= 5: + frame = FrameData() + # Line 1: Frame number + try: + frame.frame_number = int(lines[0]) + except ValueError: + lines = [] + continue + + # Line 2: Time range + frame.time_range = lines[1] + + # Line 3: FrameCnt and timestamp + parts = lines[2].split() + if len(parts) >= 3: + try: + frame.frame_cnt = int(parts[1]) + frame.timestamp = parts[2] + except (ValueError, IndexError): + pass + + # Line 4: Metadata + meta = lines[3] + # 使用改进的解析函数,支持方括号格式 + frame.focal_len = extract_value_after(meta, "focal_len") + frame.dzoom_ratio = extract_value_after(meta, "dzoom_ratio") + frame.latitude = extract_value_after(meta, "latitude") + frame.longitude = extract_value_after(meta, "longitude") + frame.rel_alt = extract_value_after(meta, "rel_alt") + frame.abs_alt = extract_value_after(meta, "abs_alt") + frame.gb_yaw = extract_value_after(meta, "gb_yaw") + frame.gb_pitch = extract_value_after(meta, "gb_pitch") + frame.gb_roll = extract_value_after(meta, "gb_roll") + + # 调试:打印前几帧的meta信息和解析结果 + if frame.frame_number <= 3: + print(f"Frame {frame.frame_number} meta: {meta}") + print(f" 解析结果: lat={frame.latitude}, lon={frame.longitude}, " + f"alt={frame.abs_alt}, yaw={frame.gb_yaw}, pitch={frame.gb_pitch}, roll={frame.gb_roll}") + + infer_camera_params_h30(frame, filename) + frames.append(frame) + + # 清空lines,准备下一组 + lines = [] + + except Exception as e: + print(f"错误: 无法打开文件 {filename}: {e}") + return frames + + return frames + + +def proc_dj_video(video_path_list: List[str], srt_path_list: List[str], + cache_dir: str = "./cache", output_dir: str = "./google_tiles"): + """处理DJI视频""" + # 创建拼接器 + stitcher = API_UnderStitch.Create(cache_dir) + stitcher.SetOutput("DJI", output_dir) + + # 打开第一个视频获取属性 + cap = cv2.VideoCapture(video_path_list[0]) + if not cap.isOpened(): + print(f"错误: 无法打开视频 {video_path_list[0]}") + return + + fps = cap.get(cv2.CAP_PROP_FPS) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # 降采样 + n_down_sample = 1 + if width > 3000: + n_down_sample = 2 + + # 解析SRT文件 + srt_init = parse_dji_srt(srt_path_list[0]) + if not srt_init: + print("错误: SRT文件解析失败") + cap.release() + return + + n_start = min(2000, len(srt_init) - 1) + + # 初始化FrameInfo + frame_info = FrameInfo() + frame_info.nFrmID = n_start + frame_info.camInfo.nFocus = srt_init[n_start].real_focal_mm + frame_info.camInfo.fPixelSize = srt_init[n_start].pixel_size_um * n_down_sample + + frame_info.craft.stAtt.fYaw = srt_init[n_start].gb_yaw + frame_info.craft.stAtt.fPitch = 0.0 + frame_info.craft.stAtt.fRoll = srt_init[n_start].gb_roll + + frame_info.craft.stPos.B = srt_init[n_start].latitude + frame_info.craft.stPos.L = srt_init[n_start].longitude + frame_info.craft.stPos.H = srt_init[n_start].abs_alt + + frame_info.nEvHeight = int(srt_init[n_start].rel_alt) + + frame_info.servoInfo.fServoAz = 0.0 + frame_info.servoInfo.fServoPt = srt_init[n_start].gb_pitch + + frame_info.nWidth = width // n_down_sample + frame_info.nHeight = height // n_down_sample + + # 初始化拼接器 + print("初始化拼接器...") + pan_info = stitcher.Init(frame_info) + print(f"初始化成功,全景图尺寸: {pan_info.m_pan_width} x {pan_info.m_pan_height}") + + # 设置配置 + config = UPanConfig() + config.bOutFrameTile = False + config.bOutGoogleTile = True + config.bUseBA = True + stitcher.SetConfig(config) + + # 获取全景图 + mat_pan = stitcher.ExportPanMat() + + cap.release() + + # 创建输出视频 + output_width = mat_pan.shape[1] // 8 + output_height = mat_pan.shape[0] // 8 + output_path = "DJ_stitchVL.mp4" + + # 尝试多个编码器,按优先级顺序 + codecs_to_try = [ + ('H264', 'H264'), + ('mp4v', 'mp4v'), + ('XVID', 'XVID'), + ('MJPG', 'MJPG'), + ] + + output = None + used_codec = None + for codec_name, fourcc_str in codecs_to_try: + try: + fourcc = cv2.VideoWriter_fourcc(*fourcc_str) + output = cv2.VideoWriter(output_path, fourcc, 5.0, (output_width, output_height), True) + if output.isOpened(): + used_codec = codec_name + print(f"成功创建输出视频,使用编码器: {codec_name}") + break + else: + output.release() + output = None + except Exception as e: + if output: + output.release() + output = None + print(f"尝试编码器 {codec_name} 失败: {e}") + continue + + if output is None or not output.isOpened(): + print(f"错误: 无法创建输出视频 {output_path}") + print(f"尝试了所有编码器: {[c[0] for c in codecs_to_try]}") + print(f"输出尺寸: {output_width}x{output_height}") + print("提示: 可能需要安装视频编码库(如 ffmpeg)或使用其他编码器") + return + + # 处理每个视频 + for vid_idx, video_path in enumerate(video_path_list): + print(f"处理 {video_path}") + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print(f"错误: 无法打开视频 {video_path}") + continue + + srt = parse_dji_srt(srt_path_list[vid_idx]) + if not srt: + print(f"错误: 无法解析SRT文件 {srt_path_list[vid_idx]}") + cap.release() + continue + + cap.set(cv2.CAP_PROP_POS_FRAMES, n_start) + frm_id = n_start + + while True: + ret, mat = cap.read() + if not ret or mat is None or mat.size == 0: + print("视频结束") + cap.release() + break + + frm_id += 1 + + if frm_id < n_start: + continue + + if frm_id >= len(srt): + print(f"警告: 帧ID {frm_id} 超出SRT数据范围") + break + + # 降采样 + mat_ds2 = cv2.resize(mat, (width // n_down_sample, height // n_down_sample)) + + # 更新FrameInfo + frame_info = FrameInfo() + frame_info.nFrmID = frm_id + frame_info.camInfo.nFocus = srt[frm_id].real_focal_mm + frame_info.camInfo.fPixelSize = srt[frm_id].pixel_size_um * n_down_sample + + frame_info.craft.stAtt.fYaw = srt[frm_id].gb_yaw + frame_info.craft.stAtt.fPitch = 0.0 + frame_info.craft.stAtt.fRoll = srt[frm_id].gb_roll + + frame_info.craft.stPos.B = srt[frm_id].latitude + frame_info.craft.stPos.L = srt[frm_id].longitude + frame_info.craft.stPos.H = srt[frm_id].abs_alt + + frame_info.nEvHeight = int(srt[frm_id].abs_alt) + + frame_info.servoInfo.fServoAz = 0.0 + frame_info.servoInfo.fServoPt = srt[frm_id].gb_pitch + + frame_info.nWidth = mat_ds2.shape[1] + frame_info.nHeight = mat_ds2.shape[0] + + # 每10帧处理一次 + if frm_id % 10 != 0: + continue + + progress = float(frm_id) / frame_count * 100 + + print(f"{progress:.1f}% B={frame_info.craft.stPos.B:.6f} L={frame_info.craft.stPos.L:.6f} " + f"H={frame_info.craft.stPos.H:.2f} Yaw={frame_info.craft.stAtt.fYaw:.2f} " + f"Pitch={frame_info.craft.stAtt.fPitch:.2f} Roll={frame_info.craft.stAtt.fRoll:.2f} " + f"ServoAz={frame_info.servoInfo.fServoAz:.2f} ServoPt={frame_info.servoInfo.fServoPt:.2f}") + + # 运行拼接 + import time + start_time = time.time() + stitcher.Run(mat_ds2, frame_info) + cost_time = time.time() - start_time + print(f"处理时间: {cost_time:.3f}秒") + + # 获取全景图并写入输出视频 + mat_pan = stitcher.ExportPanMat() + if mat_pan is not None and mat_pan.size > 0: + # 转换为BGR + if len(mat_pan.shape) == 3 and mat_pan.shape[2] == 4: + pan_rgb = cv2.cvtColor(mat_pan, cv2.COLOR_BGRA2BGR) + else: + pan_rgb = mat_pan + + # 降采样 + pan_rgb_ds = cv2.resize(pan_rgb, (output_width, output_height)) + + # 写入视频 + output.write(pan_rgb_ds) + + # 显示(可选) + cv2.imshow("pan_rgb", pan_rgb_ds) + if cv2.waitKey(1) & 0xFF == 27: # ESC键退出 + break + + # 优化并输出当前全景图 + print("优化并输出当前全景图...") + import time + start_time = time.time() + stitcher.OptAndOutCurrPan() + stitcher.Stop() + opt_time = time.time() - start_time + print(f"优化时间: {opt_time:.3f}秒") + + # 最终输出 + mat_pan = stitcher.ExportPanMat() + if mat_pan is not None and mat_pan.size > 0: + if len(mat_pan.shape) == 3 and mat_pan.shape[2] == 4: + pan_rgb = cv2.cvtColor(mat_pan, cv2.COLOR_BGRA2BGR) + else: + pan_rgb = mat_pan + pan_rgb_ds = cv2.resize(pan_rgb, (output_width, output_height)) + output.write(pan_rgb_ds) + cv2.imshow("pan_rgb", pan_rgb_ds) + cv2.waitKey(0) + + output.release() + cv2.destroyAllWindows() + print("处理完成") + + +def main(): + """主函数""" + video_path_list = [] + srt_path_list = [] + + # 修改为你的数据路径 + folder = "/media/wang/data/K2D_data/" + + video_path_list.append(folder + "DJI_20250418153043_0006_W.MP4") + srt_path_list.append(folder + "DJI_20250418153043_0006_W.srt") + + # video_path_list.append(folder + "DJI_20250418153043_0006_W.MP4") + # srt_path_list.append(folder + "DJI_20250418153043_0006_W.srt") + + proc_dj_video(video_path_list, srt_path_list) + + +if __name__ == "__main__": + main() + diff --git a/tests/python/UStitcher.pyi b/tests/python/UStitcher.pyi new file mode 100644 index 00000000..c9258651 --- /dev/null +++ b/tests/python/UStitcher.pyi @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import Any, TypeAlias + +from numpy.typing import NDArray + +ImageLike: TypeAlias = NDArray[Any] + + +class PointBLH: + """地理坐标系(单位:度).""" + + B: float + L: float + H: float + + def __init__(self) -> None: ... + + +class EulerRPY: + """RPY 姿态角(单位:度).""" + + fRoll: float + fPitch: float + fYaw: float + + def __init__(self) -> None: ... + + +class AirCraftInfo: + """载体信息.""" + + nPlaneID: int + stPos: PointBLH + stAtt: EulerRPY + + def __init__(self) -> None: ... + + +class CamInfo: + """相机信息.""" + + nFocus: int + fPixelSize: float + unVideoType: int + dCamx: float + dCamy: float + fAglReso: float + + def __init__(self) -> None: ... + + +class ServoInfo: + """伺服状态.""" + + fServoAz: float + fServoPt: float + fServoAzSpeed: float + fServoPtSpeed: float + + def __init__(self) -> None: ... + + +class FrameInfo: + """帧内外方位元素.""" + + nFrmID: int + craft: AirCraftInfo + camInfo: CamInfo + servoInfo: ServoInfo + nEvHeight: int + nWidth: int + nHeight: int + + def __init__(self) -> None: ... + + +class UPanInfo: + """下视全景图配置.""" + + m_pan_width: int + m_pan_height: int + scale: float + map_shiftX: float + map_shiftY: float + + def __init__(self) -> None: ... + + +class UPanConfig: + """下视拼接参数控制.""" + + bUseBA: bool + bOutFrameTile: bool + bOutGoogleTile: bool + + def __init__(self) -> None: ... + + +class API_UnderStitch: + """视频帧下视地理拼接.""" + + def Init(self, info: FrameInfo) -> UPanInfo: + """初始化拼接,返回全景图配置.""" + ... + + def SetOutput(self, name: str, outdir: str) -> None: + """配置输出标识和目录.""" + ... + + def Run(self, img: ImageLike, para: FrameInfo) -> int: + """运行拼接流程(cv::Mat/numpy.ndarray).""" + ... + + def Stop(self) -> None: + """中止拼接流程.""" + ... + + def SetConfig(self, config: UPanConfig) -> None: + """更新运行参数.""" + ... + + def OptAndOutCurrPan(self) -> int: + """立即优化并输出当前全景图.""" + ... + + def ExportPanMat(self) -> ImageLike: + """获取当前全景图像.""" + ... + + def getHomography(self, info: FrameInfo) -> ImageLike: + """根据帧信息返回单应性矩阵.""" + ... + + @staticmethod + def Create(cachedir: str = "./cache") -> API_UnderStitch: + """创建 API_UnderStitch 实例.""" + ... + + @staticmethod + def Destroy(obj: API_UnderStitch) -> None: + """销毁 API_UnderStitch 实例.""" + ... + + +__all__ = [ + "API_UnderStitch", + "AirCraftInfo", + "CamInfo", + "EulerRPY", + "FrameInfo", + "ImageLike", + "PointBLH", + "ServoInfo", + "UPanConfig", + "UPanInfo", +] + diff --git a/tests/python/sim_scan.py b/tests/python/sim_scan.py new file mode 100644 index 00000000..8315ad86 --- /dev/null +++ b/tests/python/sim_scan.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python3 +""" +Downward scan simulator that visualises the projected camera footprint +on top of a 2D map background. Camera, flight and scan parameters are +interactive so engineers can inspect coverage overlap in real time. +""" + +from __future__ import annotations + +import argparse +import ctypes +import math +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, List, Optional, Tuple + + +import numpy as np + +# ----------------------------------------------------------------------------- +# Environment bootstrap (reuse logic from tests/python/1.py) +# ----------------------------------------------------------------------------- +PROJECT_ROOT = Path(__file__).resolve().parents[2] +BIN_DIR = PROJECT_ROOT / "Bin" + +if str(BIN_DIR) not in sys.path: + sys.path.insert(0, str(BIN_DIR)) + +LIB_GUIDE = BIN_DIR / "libGuideStitch.so" +if LIB_GUIDE.exists(): + try: + ctypes.CDLL(str(LIB_GUIDE), mode=ctypes.RTLD_GLOBAL) + except OSError as exc: + print(f"[WARN] Failed to preload {LIB_GUIDE.name}: {exc}") + +from UStitcher import API_UnderStitch, FrameInfo, UPanConfig +import cv2 + +# ----------------------------------------------------------------------------- +# Utility dataclasses +# ----------------------------------------------------------------------------- +@dataclass +class CameraProfile: + width: int = 1920 + height: int = 1080 + focus_mm: float = 40.0 + pixel_size_um: float = 4 + + +@dataclass +class FlightProfile: + start_lat: float = 39.0 + start_lon: float = 120.0 + altitude_m: float = 1000.0 + speed_mps: float = 70.0 + heading_deg: float = 90.0 # 0 -> east, 90 -> north + + +@dataclass +class ScannerProfile: + base_az_deg: float = 90.0 + base_pitch_deg: float = -45.0 + az_min_deg: float = 60.0 + az_max_deg: float = 120.0 + pitch_min_deg: float = -80.0 + pitch_max_deg: float = -10.0 + az_max_speed_degps: float = 25.0 + az_max_acc_degps2: float = 60.0 + pitch_max_speed_degps: float = 20.0 + pitch_max_acc_degps2: float = 50.0 + + +@dataclass +class MapView: + background: Optional[Path] + meters_per_pixel: float = 2.0 + trail_length: int = 200 + + +@dataclass +class SimulationOptions: + duration_s: float = 60.0 + frame_rate: float = 25.0 + overlay_alpha: float = 0.25 + display_scale: float = 0.03 # 直接作用在全景尺度,控制画布大小 + + +@dataclass +class AxisState: + angle: float + velocity: float + phase: float = 0.0 + + +@dataclass +class AxisProfile: + low: float + high: float + mid: float + amplitude: float + omega: float + + +# ----------------------------------------------------------------------------- +# Helper utilities +# ----------------------------------------------------------------------------- +def clamp(value: float, low: float, high: float) -> float: + if low > high: + low, high = high, low + return max(low, min(high, value)) + + +def make_aircraft_polygon(center: np.ndarray, heading_deg: float, size: float = 18.0) -> np.ndarray: + """Construct a simple aircraft-shaped polygon oriented by heading.""" + base = np.array( + [ + [size, 0.0], # nose + [-0.4 * size, 0.35 * size], + [-0.1 * size, 0.0], + [-0.4 * size, -0.35 * size], + ], + dtype=np.float32, + ) + rad = math.radians(heading_deg) + rot = np.array( + [ + [math.cos(rad), -math.sin(rad)], + [math.sin(rad), math.cos(rad)], + ], + dtype=np.float32, + ) + pts = base @ rot.T + pts[:, 1] *= -1 # screen Y axis points downward + pts += center.astype(np.float32) + return pts + + +# ----------------------------------------------------------------------------- +# Map drawing helper +# ----------------------------------------------------------------------------- +def _blank_canvas(height: int = 900, width: int = 900) -> np.ndarray: + canvas = np.full((height, width, 3), 75, dtype=np.uint8) + step = 100 + for x in range(0, width, step): + cv2.line(canvas, (x, 0), (x, height - 1), (45, 45, 45), 1) + for y in range(0, height, step): + cv2.line(canvas, (0, y), (width - 1, y), (45, 45, 45), 1) + return canvas + + +class MapCanvas: + def __init__(self, view: MapView): + self.view = view + if view.background and view.background.exists(): + image = cv2.imread(str(view.background), cv2.IMREAD_COLOR) + if image is None: + print(f"[WARN] Failed to load map '{view.background}', using blank canvas.") + image = _blank_canvas() + else: + image = _blank_canvas() + self.base = image + self.origin = np.array([image.shape[1] // 2, image.shape[0] // 2], dtype=np.float32) + self.history: List[np.ndarray] = [] + + def geo_to_px(self, pts_en: np.ndarray) -> np.ndarray: + """Convert local east/north meters to pixel coordinates.""" + scale = 1.0 / self.view.meters_per_pixel + px = np.empty_like(pts_en) + px[:, 0] = self.origin[0] + pts_en[:, 0] * scale + px[:, 1] = self.origin[1] - pts_en[:, 1] * scale + return px + + def push_trail(self, polygon_px: np.ndarray) -> None: + self.history.append(polygon_px.astype(np.int32)) + if len(self.history) > self.view.trail_length: + self.history = self.history[-self.view.trail_length :] + + def draw(self, live_poly: np.ndarray, aircraft_px: Tuple[int, int], info: Iterable[str]) -> np.ndarray: + frame = self.base.copy() + overlay = frame.copy() + for idx, poly in enumerate(self.history): + alpha = max(0.05, 1.0 - (len(self.history) - idx) / len(self.history)) + cv2.fillPoly(overlay, [poly], (30, 90, 180), lineType=cv2.LINE_AA) + cv2.addWeighted(overlay, alpha * 0.2, frame, 1 - alpha * 0.2, 0, frame) + cv2.polylines(frame, [live_poly.astype(np.int32)], True, (0, 255, 0), 2, cv2.LINE_AA) + cv2.circle(frame, aircraft_px, 6, (0, 0, 255), -1, cv2.LINE_AA) + + y = 20 + for line in info: + cv2.putText(frame, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) + y += 20 + return frame + + +# ----------------------------------------------------------------------------- +# Simulation core +# ----------------------------------------------------------------------------- +def warpPointWithH(H: np.ndarray, pt: np.ndarray) -> np.ndarray: + """使用H矩阵投影点坐标(参考1.py的实现)""" + wp = H @ np.array([pt[0], pt[1], 1.0]) + return wp / wp[2] + + +class UnderStitchSimulator: + def __init__( + self, + camera: CameraProfile, + flight: FlightProfile, + scanner: ScannerProfile, + map_view: MapView, + options: SimulationOptions, + ): + self.camera = camera + self.flight = flight + self.scanner = scanner + self.map = MapCanvas(map_view) + self.options = options + + if options.display_scale <= 0: + raise ValueError("display_scale must be positive.") + + self.time_step = 1.0 / options.frame_rate + self.total_frames = int(options.duration_s * options.frame_rate) + self.stitcher = API_UnderStitch.Create(str((PROJECT_ROOT / "cache").resolve())) + self.frame_info = FrameInfo() + self._init_frame_info() + + # 初始化拼接器并获取全景图信息 + self.pan_info = self.stitcher.Init(self.frame_info) + print(f"[INFO] 初始化成功,全景图尺寸: {self.pan_info.m_pan_width} x {self.pan_info.m_pan_height}") + + # 根据 display_scale 缩放全景参数,减小绘制画布 + scale = self.options.display_scale + self.pan_info.scale = self.pan_info.scale * 0.5 + self.scaled_pan_width = max(1, int(self.pan_info.m_pan_width * scale)) + self.scaled_pan_height = max(1, int(self.pan_info.m_pan_height * scale)) + self.scaled_pan_scale = self.pan_info.scale * scale + self.scaled_shift_x = self.pan_info.map_shiftX * scale + self.scaled_shift_y = self.pan_info.map_shiftY * scale + + # 构建从地理坐标到全景图的变换矩阵(参考Arith_UnderStitch.cpp的getAffineFromGeo2Pan) + self.H_geo2pan = np.array([ + [self.scaled_pan_scale, 0, self.scaled_shift_x], + [0, -self.scaled_pan_scale, self.scaled_pan_height + self.scaled_shift_y], + [0, 0, 1] + ], dtype=np.float64) + + # 初始化伺服扫描状态(正弦扇扫:边界速度最小、中心最大) + self.az_profile = self._build_axis_profile( + self.scanner.az_min_deg, + self.scanner.az_max_deg, + self.scanner.az_max_speed_degps, + self.scanner.az_max_acc_degps2, + ) + self.pitch_profile = self._build_axis_profile( + self.scanner.pitch_min_deg, + self.scanner.pitch_max_deg, + self.scanner.pitch_max_speed_degps, + self.scanner.pitch_max_acc_degps2, + ) + self.az_state = self._init_axis_state(self.az_profile, self.scanner.base_az_deg) + self.pitch_state = self._init_axis_state(self.pitch_profile, self.scanner.base_pitch_deg) + + self.frame_info.servoInfo.fServoAz = self.az_state.angle + self.frame_info.servoInfo.fServoPt = self.pitch_state.angle + + def _build_axis_profile( + self, + min_deg: float, + max_deg: float, + max_speed: float, + max_acc: float, + ) -> AxisProfile: + low, high = (min_deg, max_deg) if min_deg <= max_deg else (max_deg, min_deg) + mid = (low + high) * 0.5 + amplitude = (high - low) * 0.5 + if amplitude <= 1e-6: + return AxisProfile(low, high, mid, 0.0, 0.0) + + omega_candidates = [] + if max_speed > 0: + omega_candidates.append(max_speed / amplitude) + if max_acc > 0: + omega_candidates.append(math.sqrt(max_acc / amplitude)) + + omega = min(omega_candidates) if omega_candidates else 0.0 + return AxisProfile(low, high, mid, amplitude, omega) + + def _init_axis_state(self, profile: AxisProfile, base_angle: float) -> AxisState: + angle = clamp(base_angle, profile.low, profile.high) + if profile.amplitude <= 1e-6 or profile.omega <= 0.0: + return AxisState(angle=angle, velocity=0.0, phase=0.0) + + rel = clamp((angle - profile.mid) / profile.amplitude, -1.0, 1.0) + phase = math.asin(rel) + velocity = profile.amplitude * profile.omega * math.cos(phase) + return AxisState(angle=angle, velocity=velocity, phase=phase) + + def _init_frame_info(self) -> None: + fi = self.frame_info + fi.camInfo.nFocus = int(self.camera.focus_mm) + fi.camInfo.fPixelSize = float(self.camera.pixel_size_um) + fi.nWidth = self.camera.width + fi.nHeight = self.camera.height + fi.nEvHeight = int(self.flight.altitude_m) + fi.craft.stPos.B = self.flight.start_lat + fi.craft.stPos.L = self.flight.start_lon + fi.craft.stPos.H = self.flight.altitude_m + fi.servoInfo.fServoAz = self.scanner.base_az_deg + fi.servoInfo.fServoPt = self.scanner.base_pitch_deg + + def _update_pose(self, step_idx: int) -> Tuple[float, float]: + t = step_idx * self.time_step + distance = self.flight.speed_mps * t + heading_rad = math.radians(self.flight.heading_deg) + east = distance * math.cos(heading_rad) + north = distance * math.sin(heading_rad) + + lat = self.flight.start_lat + north / 111320.0 + lon = self.flight.start_lon + east / (111320.0 * math.cos(math.radians(self.flight.start_lat))) + self.frame_info.craft.stPos.B = lat + self.frame_info.craft.stPos.L = lon + self.frame_info.nFrmID = step_idx + + self.frame_info.craft.stAtt.fYaw = self.flight.heading_deg + self.frame_info.craft.stAtt.fPitch = 0.0 + self.frame_info.craft.stAtt.fRoll = 0.0 + return east, north + + def _step_axis(self, state: AxisState, profile: AxisProfile, dt: float) -> None: + if profile.amplitude <= 1e-6 or profile.omega <= 0.0: + state.angle = profile.mid + state.velocity = 0.0 + state.phase = 0.0 + return + + state.phase += profile.omega * dt + # Wrap phase to [-pi, pi] for numerical stability + state.phase = (state.phase + math.pi) % (2 * math.pi) - math.pi + state.angle = profile.mid + profile.amplitude * math.sin(state.phase) + state.velocity = profile.amplitude * profile.omega * math.cos(state.phase) + + def _update_scanner(self, dt: float) -> None: + self._step_axis(self.az_state, self.az_profile, dt) + self._step_axis(self.pitch_state, self.pitch_profile, dt) + + self.frame_info.servoInfo.fServoAz = self.az_state.angle + self.frame_info.servoInfo.fServoPt = self.pitch_state.angle + + def _image_corners_to_geo(self, H_img2geo: np.ndarray) -> Optional[np.ndarray]: + if H_img2geo is None or H_img2geo.size != 9: + return None + + img_corners = np.array( + [ + [0.0, 0.0], + [self.camera.width, 0.0], + [self.camera.width, self.camera.height], + [0.0, self.camera.height], + ], + dtype=np.float64, + ) + + geo_corners = [warpPointWithH(H_img2geo, pt)[:2] for pt in img_corners] + return np.array(geo_corners, dtype=np.float64) + + def _image_point_to_geo(self, H_img2geo: np.ndarray, pt: np.ndarray) -> Optional[np.ndarray]: + if H_img2geo is None or H_img2geo.size != 9: + return None + return warpPointWithH(H_img2geo, pt)[:2] + + def _geo_to_pan(self, geo_pts: np.ndarray) -> np.ndarray: + pan_pts = [warpPointWithH(self.H_geo2pan, pt)[:2] for pt in geo_pts] + return np.array(pan_pts, dtype=np.float64) + + def _compute_aircraft_geo_downward(self) -> Optional[np.ndarray]: + """ + 将俯仰角临时设置为 -90°,用于计算机体正下方在大地坐标中的投影。 + """ + original_pitch = self.frame_info.servoInfo.fServoPt + try: + self.frame_info.servoInfo.fServoPt = -90.0 + H_down = self.stitcher.getHomography(self.frame_info) + finally: + self.frame_info.servoInfo.fServoPt = original_pitch + + if H_down is None or H_down.size != 9: + return None + + center_px = np.array([self.camera.width * 0.5, self.camera.height * 0.5], dtype=np.float64) + return self._image_point_to_geo(H_down, center_px) + + def run(self) -> None: + print(f"[INFO] Running simulator for {self.total_frames} frames.") + t0 = time.time() + + # 创建全景图画布 + pan_canvas = np.zeros((self.scaled_pan_height, self.scaled_pan_width, 3), dtype=np.uint8) + + try: + for idx in range(self.total_frames): + + start = time.perf_counter() + sim_time = idx * self.time_step + east, north = self._update_pose(idx) + self._update_scanner(self.time_step) + + H_img2geo = self.stitcher.getHomography(self.frame_info) + geo_footprint = self._image_corners_to_geo(H_img2geo) + if geo_footprint is None: + print("[WARN] Image->geo projection failed; skipping frame.") + continue + + plane_pan = None + plane_geo_down = self._compute_aircraft_geo_downward() + if plane_geo_down is not None: + plane_pan = self._geo_to_pan(np.array([plane_geo_down]))[0] + + # 将当前覆盖区域投影到全景图 + footprint_pan = self._geo_to_pan(geo_footprint) + + # 累积绘制到原始全景图上,实现自然叠加效果 + cv2.fillPoly(pan_canvas, [footprint_pan.astype(np.int32)], (30, 90, 180), lineType=cv2.LINE_AA) + + # 基于当前累积结果生成可渲染帧 + frame = pan_canvas.copy() + + # 绘制当前覆盖框(绿色边框) + cv2.polylines(frame, [footprint_pan.astype(np.int32)], True, (0, 255, 0), 2, cv2.LINE_AA) + + # 飞机位置画红色三角形 + if plane_pan is not None: + # 根据全景图比例计算三角形大小 + triangle_size = max(15, int(50 / max(self.scaled_pan_scale, 0.01))) + heading_rad = math.radians(self.flight.heading_deg) + + # 定义指向上方的三角形(航向方向) + # 顶点在航向方向,底边垂直于航向 + top = np.array([0, -triangle_size], dtype=np.float32) + left = np.array([-triangle_size * 0.6, triangle_size * 0.4], dtype=np.float32) + right = np.array([triangle_size * 0.6, triangle_size * 0.4], dtype=np.float32) + + # 旋转三角形以匹配航向 + cos_h = math.cos(heading_rad) + sin_h = math.sin(heading_rad) + rot_mat = np.array([[cos_h, -sin_h], [sin_h, cos_h]], dtype=np.float32) + + top_rot = top @ rot_mat.T + left_rot = left @ rot_mat.T + right_rot = right @ rot_mat.T + + # 转换到像素坐标 + center = plane_pan.astype(np.float32) + triangle_pts = np.array([ + (center + top_rot).astype(np.int32), + (center + left_rot).astype(np.int32), + (center + right_rot).astype(np.int32) + ]) + + cv2.fillPoly(frame, [triangle_pts], (0, 0, 255), lineType=cv2.LINE_AA) + cv2.polylines(frame, [triangle_pts], True, (255, 255, 255), 1, cv2.LINE_AA) + + + # 添加HUD信息 + hud = [ + f"Frame: {idx}/{self.total_frames}", + f"Pan Size (orig): {self.pan_info.m_pan_width} x {self.pan_info.m_pan_height}", + f"Pan Size (scaled): {self.scaled_pan_width} x {self.scaled_pan_height}", + f"Pan Scale (orig/scaled): {self.pan_info.scale:.4f}/{self.scaled_pan_scale:.4f}", + f"Lat/Lon: {self.frame_info.craft.stPos.B:.6f}, {self.frame_info.craft.stPos.L:.6f}", + f"Altitude: {self.flight.altitude_m:.1f} m", + f"Speed: {self.flight.speed_mps:.1f} m/s", + f"Az/Pt: {self.frame_info.servoInfo.fServoAz:.1f}/{self.frame_info.servoInfo.fServoPt:.1f}", + ] + + y = 20 + for line in hud: + cv2.putText(frame, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) + y += 20 + + display_frame = frame + if not math.isclose(self.options.display_scale, 1.0): + display_frame = cv2.resize( + frame, + (0, 0), + fx=self.options.display_scale, + fy=self.options.display_scale, + interpolation=cv2.INTER_AREA, + ) + + cv2.imshow("UnderStitch Scan Simulator - Panorama View", display_frame) + key = cv2.waitKey(1) + if key == 27 or key == ord("q"): + break + end = time.perf_counter() + print(f"frame {idx} took {(end-start)*1000:.1f} ms") + # remaining = max(0.0, t0 + sim_time + self.time_step - time.time()) + # time.sleep(remaining) + finally: + API_UnderStitch.Destroy(self.stitcher) + cv2.destroyAllWindows() + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="UnderStitch downward scan footprint simulator.") + parser.add_argument("--duration", type=float, default=60.0, help="Simulation duration in seconds.") + parser.add_argument("--fps", type=float, default=50.0, help="Output/solver frame rate.") + parser.add_argument("--speed", type=float, default=50.0, help="Aircraft speed (m/s).") + parser.add_argument("--heading", type=float, default=36.0, help="Aircraft heading (deg, 0=E).") + parser.add_argument("--alt", type=float, default=1200.0, help="Altitude above ground (m).") + parser.add_argument("--map", type=Path, default=None, help="Optional background map image.") + parser.add_argument("--scale", type=float, default=2.0, help="Meters per pixel for map rendering.") + parser.add_argument("--trail", type=int, default=200, help="Stored footprint polygons.") + parser.add_argument("--base_az", type=float, default=90.0, help="Center azimuth (deg).") + parser.add_argument("--base_pitch", type=float, default=-90.0, help="Base pitch (deg).") + parser.add_argument("--az_min", type=float, default=90, help="Minimum azimuth limit (deg).") + parser.add_argument("--az_max", type=float, default=90, help="Maximum azimuth limit (deg).") + parser.add_argument("--az_speed", type=float, default=0, help="Azimuth max speed (deg/s).") + parser.add_argument("--az_acc", type=float, default=0, help="Azimuth max acceleration (deg/s^2).") + parser.add_argument("--pitch_min", type=float, default=-45.0, help="Minimum pitch limit (deg).") + parser.add_argument("--pitch_max", type=float, default=45.0, help="Maximum pitch limit (deg).") + parser.add_argument("--pitch_speed", type=float, default=45.0, help="Pitch max speed (deg/s).") + parser.add_argument("--pitch_acc", type=float, default=20.0, help="Pitch max acceleration (deg/s^2).") + parser.add_argument("--display_scale", type=float, default=0.3, help="Panorama display scale factor.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + camera = CameraProfile() + flight = FlightProfile( + altitude_m=args.alt, + speed_mps=args.speed, + heading_deg=args.heading, + ) + scanner = ScannerProfile( + base_az_deg=args.base_az, + base_pitch_deg=args.base_pitch, + az_min_deg=args.az_min, + az_max_deg=args.az_max, + pitch_min_deg=args.pitch_min, + pitch_max_deg=args.pitch_max, + az_max_speed_degps=args.az_speed, + az_max_acc_degps2=args.az_acc, + pitch_max_speed_degps=args.pitch_speed, + pitch_max_acc_degps2=args.pitch_acc, + ) + map_view = MapView(background=args.map, meters_per_pixel=args.scale, trail_length=args.trail) + options = SimulationOptions( + duration_s=args.duration, + frame_rate=args.fps, + display_scale=args.display_scale, + ) + simulator = UnderStitchSimulator(camera, flight, scanner, map_view, options) + simulator.run() + + +if __name__ == "__main__": + main() +