You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

428 lines
15 KiB

import os
import sys
import ctypes
import re
from pathlib import Path
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple
sys.path.append("/media/wang/WORK/wangchongwu_gitea_2023/StitchVideo/Bin")
# 导入模块
from UStitcher import API_UnderStitch, FrameInfo, UPanInfo, UPanConfig
import cv2
@dataclass
class FrameData:
"""帧数据"""
frame_number: int = 0
time_range: str = ""
frame_cnt: int = 0
timestamp: str = ""
focal_len: float = 0.0
dzoom_ratio: float = 0.0
latitude: float = 0.0
longitude: float = 0.0
rel_alt: float = 0.0
abs_alt: float = 0.0
gb_yaw: float = 0.0
gb_pitch: float = 0.0
gb_roll: float = 0.0
real_focal_mm: float = 0.0
pixel_size_um: float = 0.0
def extract_value_from_brackets(line: str, key: str) -> float:
"""从方括号格式中提取值,例如 [key: value] 或 [key: value other]"""
# 匹配模式1单独的方括号 [key: value]
pattern1 = rf'\[{re.escape(key)}:\s*([^\]]+?)\]'
match = re.search(pattern1, line)
if match:
value_str = match.group(1).strip()
# 提取第一个数值(可能后面有其他数据)
num_match = re.search(r'-?\d+\.?\d*', value_str)
if num_match:
try:
return float(num_match.group())
except ValueError:
pass
# 匹配模式2在方括号内的键值对例如 [rel_alt: 300.030 abs_alt: 314.064]
# 匹配 [xxx key: value xxx] 格式
pattern2 = rf'\[[^\]]*?{re.escape(key)}:\s*([^\s\]]+?)(?:\s|])'
match = re.search(pattern2, line)
if match:
value_str = match.group(1).strip()
try:
return float(value_str)
except ValueError:
pass
return None # 使用None表示未找到
def extract_value_after(line: str, key: str) -> float:
"""从字符串中提取指定键后面的值(兼容方括号和非方括号格式)"""
# 先尝试方括号格式
value = extract_value_from_brackets(line, key)
if value is not None:
return value
# 如果没有方括号,尝试直接查找格式
pos = line.find(key)
if pos == -1:
return 0.0
pos += len(key)
# 跳过可能的空格和冒号
while pos < len(line) and line[pos] in ' :':
pos += 1
# 找到值的结束位置(空格、逗号、换行等)
end = pos
while end < len(line) and line[end] not in ' ,\n\r\t]':
end += 1
if end == pos:
return 0.0
try:
return float(line[pos:end])
except ValueError:
return 0.0
def infer_camera_params_h30(frame: FrameData, filename: str) -> None:
"""获取真实焦距和像元尺寸 - H30版本"""
if "_W" in filename:
frame.real_focal_mm = 6.72
frame.pixel_size_um = 2.4
elif "_Z" in filename:
frame.real_focal_mm = 6.72 * 2
frame.pixel_size_um = 2.4
elif "_T" in filename:
frame.real_focal_mm = 24.0
frame.pixel_size_um = 12.0
# else: 保持默认值
def parse_dji_srt(filename: str) -> List[FrameData]:
"""解析DJI SRT文件"""
frames = []
try:
with open(filename, 'r', encoding='utf-8') as file:
lines = []
for line in file:
line_stripped = line.rstrip('\n\r')
lines.append(line_stripped)
# 每5行一组包括空行
if len(lines) >= 5:
frame = FrameData()
# Line 1: Frame number
try:
frame.frame_number = int(lines[0])
except ValueError:
lines = []
continue
# Line 2: Time range
frame.time_range = lines[1]
# Line 3: FrameCnt and timestamp
parts = lines[2].split()
if len(parts) >= 3:
try:
frame.frame_cnt = int(parts[1])
frame.timestamp = parts[2]
except (ValueError, IndexError):
pass
# Line 4: Metadata
meta = lines[3]
# 使用改进的解析函数,支持方括号格式
frame.focal_len = extract_value_after(meta, "focal_len")
frame.dzoom_ratio = extract_value_after(meta, "dzoom_ratio")
frame.latitude = extract_value_after(meta, "latitude")
frame.longitude = extract_value_after(meta, "longitude")
frame.rel_alt = extract_value_after(meta, "rel_alt")
frame.abs_alt = extract_value_after(meta, "abs_alt")
frame.gb_yaw = extract_value_after(meta, "gb_yaw")
frame.gb_pitch = extract_value_after(meta, "gb_pitch")
frame.gb_roll = extract_value_after(meta, "gb_roll")
# 调试打印前几帧的meta信息和解析结果
if frame.frame_number <= 3:
print(f"Frame {frame.frame_number} meta: {meta}")
print(f" 解析结果: lat={frame.latitude}, lon={frame.longitude}, "
f"alt={frame.abs_alt}, yaw={frame.gb_yaw}, pitch={frame.gb_pitch}, roll={frame.gb_roll}")
infer_camera_params_h30(frame, filename)
frames.append(frame)
# 清空lines准备下一组
lines = []
except Exception as e:
print(f"错误: 无法打开文件 {filename}: {e}")
return frames
return frames
def proc_dj_video(video_path_list: List[str], srt_path_list: List[str],
cache_dir: str = "./cache", output_dir: str = "./google_tiles"):
"""处理DJI视频"""
# 创建拼接器
stitcher = API_UnderStitch.Create(cache_dir)
stitcher.SetOutput("DJI", output_dir)
# 打开第一个视频获取属性
cap = cv2.VideoCapture(video_path_list[0])
if not cap.isOpened():
print(f"错误: 无法打开视频 {video_path_list[0]}")
return
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 降采样
n_down_sample = 1
if width > 3000:
n_down_sample = 2
# 解析SRT文件
srt_init = parse_dji_srt(srt_path_list[0])
if not srt_init:
print("错误: SRT文件解析失败")
cap.release()
return
n_start = min(2000, len(srt_init) - 1)
# 初始化FrameInfo
frame_info = FrameInfo()
frame_info.nFrmID = n_start
frame_info.camInfo.nFocus = srt_init[n_start].real_focal_mm
frame_info.camInfo.fPixelSize = srt_init[n_start].pixel_size_um * n_down_sample
frame_info.craft.stAtt.fYaw = srt_init[n_start].gb_yaw
frame_info.craft.stAtt.fPitch = 0.0
frame_info.craft.stAtt.fRoll = srt_init[n_start].gb_roll
frame_info.craft.stPos.B = srt_init[n_start].latitude
frame_info.craft.stPos.L = srt_init[n_start].longitude
frame_info.craft.stPos.H = srt_init[n_start].abs_alt
frame_info.nEvHeight = int(srt_init[n_start].rel_alt)
frame_info.servoInfo.fServoAz = 0.0
frame_info.servoInfo.fServoPt = srt_init[n_start].gb_pitch
frame_info.nWidth = width // n_down_sample
frame_info.nHeight = height // n_down_sample
# 初始化拼接器
print("初始化拼接器...")
pan_info = stitcher.Init(frame_info)
print(f"初始化成功,全景图尺寸: {pan_info.m_pan_width} x {pan_info.m_pan_height}")
# 设置配置
config = UPanConfig()
config.bOutFrameTile = False
config.bOutGoogleTile = True
config.bUseBA = True
stitcher.SetConfig(config)
# 获取全景图
mat_pan = stitcher.ExportPanMat()
cap.release()
# 创建输出视频
output_width = mat_pan.shape[1] // 8
output_height = mat_pan.shape[0] // 8
output_path = "DJ_stitchVL.mp4"
# 尝试多个编码器,按优先级顺序
codecs_to_try = [
('H264', 'H264'),
('mp4v', 'mp4v'),
('XVID', 'XVID'),
('MJPG', 'MJPG'),
]
output = None
used_codec = None
for codec_name, fourcc_str in codecs_to_try:
try:
fourcc = cv2.VideoWriter_fourcc(*fourcc_str)
output = cv2.VideoWriter(output_path, fourcc, 5.0, (output_width, output_height), True)
if output.isOpened():
used_codec = codec_name
print(f"成功创建输出视频,使用编码器: {codec_name}")
break
else:
output.release()
output = None
except Exception as e:
if output:
output.release()
output = None
print(f"尝试编码器 {codec_name} 失败: {e}")
continue
if output is None or not output.isOpened():
print(f"错误: 无法创建输出视频 {output_path}")
print(f"尝试了所有编码器: {[c[0] for c in codecs_to_try]}")
print(f"输出尺寸: {output_width}x{output_height}")
print("提示: 可能需要安装视频编码库(如 ffmpeg或使用其他编码器")
return
# 处理每个视频
for vid_idx, video_path in enumerate(video_path_list):
print(f"处理 {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"错误: 无法打开视频 {video_path}")
continue
srt = parse_dji_srt(srt_path_list[vid_idx])
if not srt:
print(f"错误: 无法解析SRT文件 {srt_path_list[vid_idx]}")
cap.release()
continue
cap.set(cv2.CAP_PROP_POS_FRAMES, n_start)
frm_id = n_start
while True:
ret, mat = cap.read()
if not ret or mat is None or mat.size == 0:
print("视频结束")
cap.release()
break
frm_id += 1
if frm_id < n_start:
continue
if frm_id >= len(srt):
print(f"警告: 帧ID {frm_id} 超出SRT数据范围")
break
# 降采样
mat_ds2 = cv2.resize(mat, (width // n_down_sample, height // n_down_sample))
# 更新FrameInfo
frame_info = FrameInfo()
frame_info.nFrmID = frm_id
frame_info.camInfo.nFocus = srt[frm_id].real_focal_mm
frame_info.camInfo.fPixelSize = srt[frm_id].pixel_size_um * n_down_sample
frame_info.craft.stAtt.fYaw = srt[frm_id].gb_yaw
frame_info.craft.stAtt.fPitch = 0.0
frame_info.craft.stAtt.fRoll = srt[frm_id].gb_roll
frame_info.craft.stPos.B = srt[frm_id].latitude
frame_info.craft.stPos.L = srt[frm_id].longitude
frame_info.craft.stPos.H = srt[frm_id].abs_alt
frame_info.nEvHeight = int(srt[frm_id].abs_alt)
frame_info.servoInfo.fServoAz = 0.0
frame_info.servoInfo.fServoPt = srt[frm_id].gb_pitch
frame_info.nWidth = mat_ds2.shape[1]
frame_info.nHeight = mat_ds2.shape[0]
# 每10帧处理一次
if frm_id % 10 != 0:
continue
progress = float(frm_id) / frame_count * 100
print(f"{progress:.1f}% B={frame_info.craft.stPos.B:.6f} L={frame_info.craft.stPos.L:.6f} "
f"H={frame_info.craft.stPos.H:.2f} Yaw={frame_info.craft.stAtt.fYaw:.2f} "
f"Pitch={frame_info.craft.stAtt.fPitch:.2f} Roll={frame_info.craft.stAtt.fRoll:.2f} "
f"ServoAz={frame_info.servoInfo.fServoAz:.2f} ServoPt={frame_info.servoInfo.fServoPt:.2f}")
# 运行拼接
import time
start_time = time.time()
stitcher.Run(mat_ds2, frame_info)
cost_time = time.time() - start_time
print(f"处理时间: {cost_time:.3f}")
# 获取全景图并写入输出视频
mat_pan = stitcher.ExportPanMat()
if mat_pan is not None and mat_pan.size > 0:
# 转换为BGR
if len(mat_pan.shape) == 3 and mat_pan.shape[2] == 4:
pan_rgb = cv2.cvtColor(mat_pan, cv2.COLOR_BGRA2BGR)
else:
pan_rgb = mat_pan
# 降采样
pan_rgb_ds = cv2.resize(pan_rgb, (output_width, output_height))
# 写入视频
output.write(pan_rgb_ds)
# 显示(可选)
cv2.imshow("pan_rgb", pan_rgb_ds)
if cv2.waitKey(1) & 0xFF == 27: # ESC键退出
break
# 优化并输出当前全景图
print("优化并输出当前全景图...")
import time
start_time = time.time()
stitcher.OptAndOutCurrPan()
stitcher.Stop()
opt_time = time.time() - start_time
print(f"优化时间: {opt_time:.3f}")
# 最终输出
mat_pan = stitcher.ExportPanMat()
if mat_pan is not None and mat_pan.size > 0:
if len(mat_pan.shape) == 3 and mat_pan.shape[2] == 4:
pan_rgb = cv2.cvtColor(mat_pan, cv2.COLOR_BGRA2BGR)
else:
pan_rgb = mat_pan
pan_rgb_ds = cv2.resize(pan_rgb, (output_width, output_height))
output.write(pan_rgb_ds)
cv2.imshow("pan_rgb", pan_rgb_ds)
cv2.waitKey(0)
output.release()
cv2.destroyAllWindows()
print("处理完成")
def main():
"""主函数"""
video_path_list = []
srt_path_list = []
# 修改为你的数据路径
folder = "/media/wang/data/K2D_data/"
video_path_list.append(folder + "DJI_20250418153043_0006_W.MP4")
srt_path_list.append(folder + "DJI_20250418153043_0006_W.srt")
# video_path_list.append(folder + "DJI_20250418153043_0006_W.MP4")
# srt_path_list.append(folder + "DJI_20250418153043_0006_W.srt")
proc_dj_video(video_path_list, srt_path_list)
if __name__ == "__main__":
main()