跳过正文
Background Image
  1. 文章/

LeRobot SO-101 机械臂实践

··8038 字·38 分钟
目录

0.写在开头
#

LeRobot 仓库经常性变动,请务必参考官方教程和仓库代码。

本文记录LeRobot SO-101机械臂实践,包括从硬件连接、环境搭建到基础的校准、采集数据、训练(基于 ACT)、推理(基于 ACT),具体的端口、路径等仅作参考,以实际环境为准。

硬件连接
#

WOWROBO SO-ARM101 双臂使用说明书(成品版本专用)

  1. 连接与安装 请按以下说明连接设备电源:主控需连接 5V/6A 电源,随从臂连接 12V/8A 电源。 相机组件的安装方法请参考以下视频,在 1 小时 6 分 32 秒(1:06:32) 处开始演示: 链接:https://www.bilibili.com/video/BV13bLyzKES8/

  2. 操作指南阅读建议 本产品出厂前已完成舵机 ID、波特率、磁编码器中位设置及遥操作功能测试。建议在完全熟悉产品前不要手动更改舵机参数。 请根据以下文档完成环境配置和端口设置后,从步骤 “Calibrate” 开始操作。 链接:https://huggingface.co/docs/lerobot/so101

  3. 重要事项 错误的标定方向可能导致舵机朝错误的限位方向运行,进而产生过载,造成舵机损坏,尤其是从臂的 3 号舵机。因此,在首次进行遥操作时,请仔细观察每个舵机的运动方向是否与预期一致。如发现方向异常,请立即停止并重新运行标定流程。 这款机械臂的建议最大负载重量为 400 克,请勿超出这个重量,超载可能会损坏舵机。 因操作不当导致的舵机损坏(如过载、烧毁)不在保修范围内,敬请知悉。

  1. 由于直接购买的成品LeRobot SO-101,所以无需自己组装零件和设置舵机。
  2. 机械臂连接电源和电脑
    • 相机尽量直连电脑,机械臂两条信号线可以通过集线器连接电脑。
  3. 安装摄像头

环境搭建
#

安装好miniforge和配置conda设置后,参考官方-安装文档安装lerobot环境。

  1. 创建虚拟环境:conda create -y -n lerobot python=3.10
  2. 激活虚拟环境:conda activate lerobot
  3. 安装ffmpegconda install ffmpeg -c conda-forge
  4. 更新pippython -m pip install -U pip setuptools wheel
  5. 拉取lerobot仓库并进入:git clone https://github.com/huggingface/lerobot.git && cd lerobot
  6. 安装lerobotpip install 'lerobot[all]'
    • 或者基础的:pip install 'lerobot[feetech]'
  1. 暂存本地修改:git stash
  2. 拉取最新代码:git pull
  3. 恢复本地修改:git stash pop
    • 如果有冲突,需要手动解决

校准
#

由于我们购买的成品LeRobot SO-101,所以无需自己进行舵机 ID、波特率、磁编码器中位设置,直接进行校准。

  1. 连接机械臂电源和电脑后,激活lerobot环境,进入lerobot仓库。

  2. 找到两个机械臂的端口:

    • 通过运行lerobot-find-port,并断开
    • 找到两个机械臂的端口:
      • 主臂(LeaderArm):/dev/ttyACM0
      • 从臂(FollowerArm):/dev/ttyACM1
  3. 设置串口设备权限:sudo chmod 666 /dev/ttyACM*

    • 如果不想每次连接设备都设置权限,可以把自己加入对应设备组:
      1. 查看设备组:
        [xiadengma@archlinux ~]$ ls -l /dev/ttyACM*
        crw-rw---- 1 root uucp 166, 0 10月16日 13:15 /dev/ttyACM0
        crw-rw---- 1 root uucp 166, 1 10月16日 13:15 /dev/ttyACM1
        
      2. 把自己加入对应设备组(根据ls -l /dev/ttyACM*输出修改),并验证是否加入成功:
        • 加入设备组 uucp 并刷新环境
          sudo usermod -aG uucp "$USER" && exec su -l "$USER"
          
        • 验证是否加入成功
          id
          
          结果中组有对应设备组即可。
  4. 设置串口设备别名:

    • 如果不想每次连接设备都需要运行lerobot-find-port来查找端口,可以设置别名:

      1. 找到设备的唯一属性(/dev/ttyACM0请自行修改):

        udevadm info --query=property --name=/dev/ttyACM0 | sed -n 's/^ID_SERIAL_SHORT=\(.*\)$/ATTRS{serial}=="\1"/p'
        

        输出示例:

        ATTRS{serial}=="5A7C118736"
        
      2. 创建udev规则文件:

        sudo touch /etc/udev/rules.d/99-my-robot.rules
        

        写入下面内容(serial部分内容请根据实际修改):

        # Rule for the Leader Arm
        SUBSYSTEM=="tty", ATTRS{serial}=="5A7C118736", SYMLINK+="leader_arm", MODE="0666"
        
        # Rule for the Follower Arm
        SUBSYSTEM=="tty", ATTRS{serial}=="5A7C118880", SYMLINK+="follower_arm", MODE="0666"
        

        简单解释一下规则:

        • SUBSYSTEM=="tty": 表示规则应用于串行设备。
        • ATTRS{serial}=="...": 匹配我们找到的唯一设备序列号。
        • SYMLINK+="leader_arm": 创建一个名为 /dev/leader_arm 的符号链接(别名)。
        • MODE="0666": 这一步将设备文件的权限设置为所有人可读可写,可以避免很多 “Permission denied” 的错误。
      3. 重新加载 udev规则并触发

        sudo udevadm control --reload-rules && sudo udevadm trigger
        
      4. 验证是否设置成功 拔插一下你的设备,然后运行 ls -l /dev/leader_arm /dev/follower_arm

        [xiadengma@archlinux ~]$ ls -l /dev/leader_arm /dev/follower_arm
        lrwxrwxrwx 1 root root 7 Oct 26 15:30 /dev/leader_arm -> ttyACM0
        lrwxrwxrwx 1 root root 7 Oct 26 15:30 /dev/follower_arm -> ttyACM1
        
  5. 执行校准

    • 注意:下面命令都指定了校准文件保存路径,默认路径在~/.cache/huggingface/lerobot/calibration/,如果不指定,可以去掉下面所有命令的--robot.calibration_dir--teleop.calibration_dir
    • 校准过程:运行校准程序,然后将机器人移动到中间位,按下Enter后,再将每个关节在其完整的运动范围内移动。
    1. 校准主臂:
      • 主臂名称(可自定义):my_leader_arm
      • 主臂校准文件保存路径(可自定义):./data/calibration
      lerobot-calibrate \
         --robot.type=so101_leader \
         --robot.port=/dev/leader_arm \
         --robot.id=my_leader_arm \
         --robot.calibration_dir=./data/calibration
      
    2. 校准从臂:
      • 从臂名称(可自定义):my_follower_arm
      • 从臂校准文件保存路径(可自定义):./data/calibration
      lerobot-calibrate \
         --robot.type=so101_follower \
         --robot.port=/dev/follower_arm \
         --robot.id=my_follower_arm \
         --robot.calibration_dir=./data/calibration
      
  6. 标定数据解读

    {
      "shoulder_pan": { #肩部旋转关节
        "id": 1, # 电机的唯一标识符,用于在总线通信时精准定位和控制特定电机
        "drive_mode": 0, # 电机的驱动模式,取值为 0 表示特定的驱动模式,不同驱动模式会影响电机的运动特性与控制方式
        "homing_offset": 56, # 归位偏移量,指电机从物理零点位置到校准零点位置的偏移量。此参数能保证电机在每次启动时都能回到准确的零点位置,从而提升运动精度
        "range_min": 829, #电机运动范围的最小值和最大值,以数值形式呈现。这两个参数限定了电机的运动边界,避免因超出范围而导致硬件损坏或者运动异常
        "range_max": 2866
      },
      "shoulder_lift": { #肩部升降关节
        "id": 2,
        "drive_mode": 0,
        "homing_offset": 463,
        "range_min": 836,
        "range_max": 3136
      },
      "elbow_flex": { #肘部弯曲关节
        "id": 3,
        "drive_mode": 0,
        "homing_offset": -100,
        "range_min": 894,
        "range_max": 3100
      },
      "wrist_flex": { #腕部弯曲关节
        "id": 4,
        "drive_mode": 0,
        "homing_offset": -582,
        "range_min": 928,
        "range_max": 3213
      },
      "wrist_roll": { #腕部旋转关节
        "id": 5,
        "drive_mode": 0,
        "homing_offset": -650,
        "range_min": 140,
        "range_max": 3955
      },
      "gripper": { #夹爪关节
        "id": 6,
        "drive_mode": 0,
        "homing_offset": -1229,
        "range_min": 2044,
        "range_max": 3461
      }
    }
    

遥操作/示教
#

无摄像头遥操作
#

  1. 执行遥操作:
    lerobot-teleoperate \
    --robot.type=so101_follower \
    --robot.port=/dev/follower_arm \
    --robot.id=my_follower_arm \
    --robot.calibration_dir=./data/calibration \
    --teleop.type=so101_leader \
    --teleop.port=/dev/leader_arm \
    --teleop.id=my_leader_arm \
    --teleop.calibration_dir=./data/calibration
    
  2. 运行后摆动主臂,从臂会跟随运动,按Ctrl+C退出

安装摄像头
#

如果在硬件连接中安装了摄像头,可以进行摄像头遥操作。

  1. 连接摄像头和电脑(最好有两个,一个摄像头拍操作区域,一个摄像头安装在机械臂腕部)

  2. 查看摄像头索引和输出

    lerobot-find-cameras opencv --output-dir ./data/cameras_images
    

    结果输出如下:

    --- Detected Cameras --- # 检测到的相机列表
    Camera #0: # 相机 0
      Name: OpenCV Camera @ /dev/video0 # 相机名称及设备路径
      Type: OpenCV # 相机类型:OpenCV
      Id: /dev/video0 # 相机 ID,即设备路径
      Backend api: V4L2 # 使用的后端 API:V4L2
      Default stream profile: # 默认流配置
        Format: 0.0 # 格式:0.0(默认设置)
        Width: 640 # 图像宽度:640 像素
        Height: 480 # 图像高度:480 像素
        Fps: 30.0 # 每秒帧数:30 帧
    --------------------
    Camera #1:
      Name: OpenCV Camera @ /dev/video2
      Type: OpenCV
      Id: /dev/video2
      Backend api: V4L2
      Default stream profile:
        Format: 0.0
        Width: 640
        Height: 480
        Fps: 30.0
    --------------------
    Camera #2:
      Name: OpenCV Camera @ /dev/video6
      Type: OpenCV
      Id: /dev/video6
      Backend api: V4L2
      Default stream profile:
        Format: 0.0
        Width: 640
        Height: 480
        Fps: 30.0
    --------------------
    Camera #3:
      Name: OpenCV Camera @ /dev/video8
      Type: OpenCV
      Id: /dev/video8
      Backend api: V4L2
      Default stream profile:
        Format: 0.0
        Width: 640
        Height: 480
        Fps: 30.0
    --------------------
    
    Finalizing image saving...
    Image capture finished. Images saved to data/cameras_images
    

    运行完成后会在.data/cameras_images目录下保存图片

  3. 根据输出记录摄像头:

    1. 手腕左摄像头:
      • 名称:wrist_left
      • 索引:2
      • 宽度:640
      • 高度:480
      • 帧率:30
    2. 机械臂正前方摄像头:
      • 名称:front_rgb
      • 索引:8
      • 宽度:640
      • 高度:480
      • 帧率:30

使用摄像头进行遥操作
#

  1. 执行遥操作:
    lerobot-teleoperate \
      --robot.type=so101_follower \
      --robot.port=/dev/follower_arm \
      --robot.id=my_follower_arm \
      --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
      --teleop.type=so101_leader \
      --teleop.port=/dev/leader_arm \
      --teleop.id=my_leader_arm \
      --display_data=true
    
  2. 运行后会打开rerun窗口,可以查看机械臂和摄像头画面
  3. 摆动主臂,从臂会跟随运动,按Ctrl+C退出
  1. src/lerobot下创建extra文件夹,并添加web_visualization_utils.pylerobot_teleoperate_web.py两个文件,内容如下:

    • web_visualization_utils.py

      # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
      #
      # Licensed under the Apache License, Version 2.0 (the "License");
      # you may not use this file except in compliance with the License.
      # You may obtain a copy of the License at
      #
      #     http://www.apache.org/licenses/LICENSE-2.0
      #
      # Unless required by applicable law or agreed to in writing, software
      # distributed under the License is distributed on an "AS IS" BASIS,
      # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      # See the License for the specific language governing permissions and
      # limitations under the License.
      
      """Web-based visualization utilities using Rerun web viewer."""
      
      import numbers
      import os
      from typing import Any
      
      import numpy as np
      import rerun as rr
      
      from lerobot.utils.constants import OBS_PREFIX, OBS_STR
      
      
      def init_rerun_web(
          session_name: str = "lerobot_control_loop",
          port: int = 9090,
          open_browser: bool = True,
          memory_limit: str = "10%",
      ) -> None:
          """
          初始化 Rerun SDK 用于 Web 浏览器可视化控制循环。
      
          Initializes the Rerun SDK for visualizing the control loop in a web browser.
      
          参数 Args:
              session_name: Rerun 会话名称 / Name of the Rerun session
              port: Web 服务器端口 / Web server port (default: 9090)
              open_browser: 是否自动打开浏览器 / Whether to automatically open browser (default: True)
              memory_limit: 内存限制 / Memory limit for Rerun (default: "10%")
      
          使用示例 Example:
              ```python
              from lerobot.extra.web_visualization_utils import init_rerun_web
      
              # 在浏览器中启动可视化 / Start visualization in browser
              init_rerun_web(session_name="my_teleoperation", port=9090)
              ```
      
          访问 Access:
              在浏览器中打开 / Open in browser: http://localhost:9090
          """
          batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
          os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
      
          rr.init(session_name)
      
          print("🌐 启动 Rerun Web Viewer / Starting Rerun Web Viewer")
          print(f"📍 访问地址 / Access URL: http://localhost:{port}")
          print(f"💾 内存限制 / Memory limit: {memory_limit}")
      
          rr.serve(
              open_browser=open_browser,
              web_port=port,
              server_memory_limit=memory_limit,
          )
      
      
      def init_rerun_connect(
          addr: str = "127.0.0.1:9876",
          session_name: str = "lerobot_control_loop",
      ) -> None:
          """
          连接到远程 Rerun Viewer (适用于远程服务器)。
      
          Connect to a remote Rerun Viewer (useful for remote servers).
      
          参数 Args:
              addr: 远程 Rerun Viewer 地址 / Remote Rerun Viewer address (default: "127.0.0.1:9876")
              session_name: Rerun 会话名称 / Name of the Rerun session
      
          使用示例 Example:
              ```python
              from lerobot.extra.web_visualization_utils import init_rerun_connect
      
              # 连接到远程 Rerun Viewer / Connect to remote Rerun Viewer
              # 首先在另一个终端运行: rerun --port 9876
              # First run in another terminal: rerun --port 9876
              init_rerun_connect(addr="127.0.0.1:9876")
              ```
          """
          batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
          os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
      
          rr.init(session_name)
      
          print("🔗 连接到远程 Rerun Viewer / Connecting to remote Rerun Viewer")
          print(f"📍 地址 / Address: {addr}")
      
          rr.connect_tcp(addr=addr)
      
      
      def _is_scalar(x):
          return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
              isinstance(x, np.ndarray) and x.ndim == 0
          )
      
      
      def log_rerun_data(
          observation: dict[str, Any] | None = None,
          action: dict[str, Any] | None = None,
      ) -> None:
          """
          将观测和动作数据记录到 Rerun 用于实时可视化。
      
          Logs observation and action data to Rerun for real-time visualization.
      
          This function iterates through the provided observation and action dictionaries and sends their contents
          to the Rerun viewer. It handles different data types appropriately:
          - Scalar values (floats, ints) are logged as `rr.Scalar`.
          - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
            from CHW to HWC format and logged as `rr.Image`.
          - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
          - Other multi-dimensional arrays are flattened and logged as individual scalars.
      
          Keys are automatically namespaced with "observation." or "action." if not already present.
      
          参数 Args:
              observation: 包含观测数据的可选字典 / An optional dictionary containing observation data to log.
              action: 包含动作数据的可选字典 / An optional dictionary containing action data to log.
          """
          if observation:
              for k, v in observation.items():
                  if v is None:
                      continue
                  key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
      
                  if _is_scalar(v):
                      rr.log(key, rr.Scalar(float(v)))
                  elif isinstance(v, np.ndarray):
                      arr = v
                      # Convert CHW -> HWC when needed
                      if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
                          arr = np.transpose(arr, (1, 2, 0))
                      if arr.ndim == 1:
                          for i, vi in enumerate(arr):
                              rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
                      else:
                          rr.log(key, rr.Image(arr), static=True)
      
          if action:
              for k, v in action.items():
                  if v is None:
                      continue
                  key = k if str(k).startswith("action.") else f"action.{k}"
      
                  if _is_scalar(v):
                      rr.log(key, rr.Scalar(float(v)))
                  elif isinstance(v, np.ndarray):
                      if v.ndim == 1:
                          for i, vi in enumerate(v):
                              rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
                      else:
                          # Fall back to flattening higher-dimensional arrays
                          flat = v.flatten()
                          for i, vi in enumerate(flat):
                              rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
      
    • lerobot_teleoperate_web.py

      #!/usr/bin/env python3
      # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
      #
      # Licensed under the Apache License, Version 2.0 (the "License");
      # you may not use this file except in compliance with the License.
      # You may obtain a copy of the License at
      #
      #     http://www.apache.org/licenses/LICENSE-2.0
      #
      # Unless required by applicable law or agreed to in writing, software
      # distributed under the License is distributed on an "AS IS" BASIS,
      # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      # See the License for the specific language governing permissions and
      # limitations under the License.
      
      """
      通过遥操作控制机器人的脚本 (使用 Web 可视化)。
      
      Simple script to control a robot from teleoperation (with Web visualization).
      """
      
      import logging
      import time
      from dataclasses import asdict, dataclass
      from pprint import pformat
      import rerun as rr
      from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
      from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
      from lerobot.configs import parser
      from lerobot.processor import (
          RobotAction,
          RobotObservation,
          RobotProcessorPipeline,
          make_default_processors,
      )
      from lerobot.robots import (
          Robot,
          RobotConfig,
          bi_so100_follower,
          hope_jr,
          koch_follower,
          make_robot_from_config,
          so100_follower,
          so101_follower,
      )
      from lerobot.teleoperators import (
          Teleoperator,
          TeleoperatorConfig,
          bi_so100_leader,
          gamepad,
          homunculus,
          koch_leader,
          make_teleoperator_from_config,
          so100_leader,
          so101_leader,
      )
      from lerobot.utils.robot_utils import busy_wait
      from lerobot.utils.utils import init_logging, move_cursor_up
      from .web_visualization_utils import init_rerun_web, log_rerun_data
      
      
      @dataclass
      class TeleoperateWebConfig:
          teleop: TeleoperatorConfig
          robot: RobotConfig
          # Limit the maximum frames per second.
          # 限制最大帧率
          fps: int = 60
          teleop_time_s: float | None = None
          # Display all cameras on screen
          # 在屏幕上显示所有摄像头
          display_data: bool = False
          # Web viewer port
          # Web 查看器端口
          web_port: int = 9090
          # Auto open browser
          # 自动打开浏览器
          open_browser: bool = True
      
      
      def teleop_loop(
          teleop: Teleoperator,
          robot: Robot,
          fps: int,
          teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
          robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
          robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
          display_data: bool = False,
          duration: float | None = None,
      ):
          """
          This function continuously reads actions from a teleoperation device, processes them through optional
          pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
          specified frequency until a set duration is reached or it is manually interrupted.
      
          Args:
              teleop: The teleoperator device instance providing control actions.
              robot: The robot instance being controlled.
              fps: The target frequency for the control loop in frames per second.
              display_data: If True, fetches robot observations and displays them in the console and Rerun.
              duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
              teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
              robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
              robot_observation_processor: An optional pipeline to process raw observations from the robot.
          """
      
          display_len = max(len(key) for key in robot.action_features)
          start = time.perf_counter()
      
          while True:
              loop_start = time.perf_counter()
      
              # Get robot observation
              # 获取机器人观测
              obs = robot.get_observation()
      
              # Get teleop action
              # 获取遥操作动作
              raw_action = teleop.get_action()
      
              # Process teleop action through pipeline
              # 通过处理流水线处理遥操作动作
              teleop_action = teleop_action_processor((raw_action, obs))
      
              # Process action for robot through pipeline
              # 通过处理流水线生成发送给机器人的动作
              robot_action_to_send = robot_action_processor((teleop_action, obs))
      
              # Send processed action to robot
              # 发送处理后的动作给机器人
              _ = robot.send_action(robot_action_to_send)
      
              if display_data:
                  # Process robot observation through pipeline
                  # 通过处理流水线处理机器人观测
                  obs_transition = robot_observation_processor(obs)
      
                  log_rerun_data(
                      observation=obs_transition,
                      action=teleop_action,
                  )
      
                  print("\n" + "-" * (display_len + 10))  # 分隔线
                  print(f"{'NAME':<{display_len}} | {'NORM':>7}")  # 标题:名称与归一化值
                  # Display the final robot action that was sent
                  # 显示已发送给机器人的最终动作
                  for motor, value in robot_action_to_send.items():
                      print(f"{motor:<{display_len}} | {value:>7.2f}")  # 每个电机的动作值(保留 2 位小数)
                  move_cursor_up(len(robot_action_to_send) + 5)  # 上移光标以覆盖刷新
      
              dt_s = time.perf_counter() - loop_start
              busy_wait(1 / fps - dt_s)  # 忙等待以维持目标帧率(若剩余时间为正)
              loop_s = time.perf_counter() - loop_start
              print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")  # 打印单次循环耗时与频率
      
              if duration is not None and time.perf_counter() - start >= duration:
                  return
      
      
      @parser.wrap()
      def teleoperate_web(cfg: TeleoperateWebConfig):
          init_logging()  # 初始化日志
          logging.info(pformat(asdict(cfg)))  # 记录当前配置
      
          if cfg.display_data:
              # 使用 Web 可视化
              init_rerun_web(
                  session_name="teleoperation_web",
                  port=cfg.web_port,
                  open_browser=cfg.open_browser,
              )
      
          teleop = make_teleoperator_from_config(cfg.teleop)
          robot = make_robot_from_config(cfg.robot)
          teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
      
          teleop.connect()  # 连接遥操作设备
          robot.connect()  # 连接机器人
      
          try:
              teleop_loop(
                  teleop=teleop,
                  robot=robot,
                  fps=cfg.fps,
                  display_data=cfg.display_data,
                  duration=cfg.teleop_time_s,
                  teleop_action_processor=teleop_action_processor,
                  robot_action_processor=robot_action_processor,
                  robot_observation_processor=robot_observation_processor,
              )
          except KeyboardInterrupt:
              pass  # 捕获 Ctrl+C 中断,正常退出
          finally:
              if cfg.display_data:
                  rr.rerun_shutdown()  # 关闭 Rerun 会话
              teleop.disconnect()  # 断开遥操作设备
              robot.disconnect()  # 断开机器人
      
      
      def main():
          teleoperate_web()
      
      
      if __name__ == "__main__":
          main()
      
  2. 执行遥操作:

    python -m lerobot.extra.lerobot_teleoperate_web \
      --robot.type=so101_follower \
      --robot.port=/dev/follower_arm \
      --robot.id=my_follower_arm \
      --robot.calibration_dir=./data/calibration \
      --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
      --teleop.type=so101_leader \
      --teleop.port=/dev/leader_arm \
      --teleop.id=my_leader_arm \
      --teleop.calibration_dir=./data/calibration \
      --display_data=true \
      --web_port=9090
    
  3. 运行后默认情况下会在浏览器打开http://localhost:9090/?url=ws://localhost:9877,可以查看机械臂和摄像头画面

  4. 摆动主臂,从臂会跟随运动,按Ctrl+C退出

录制数据集
#

  • LeRobot 具身智能机械臂实操入门课程-03:机械臂的数据集录制与模型训练

  • 官方-优质数据集要求

  • 官方-录制数据集教程

  • 每个回合录制流程:等待程序提示录制当前回合,通过主臂遥操作机械臂进行抓取,抓取结束后,将机械臂恢复到休息位再结束当前回合,等待程序记录数据的同时重置抓取环境,等待程序提示录制下一个回合。

  • 键盘控制说明:

    按键何时使用作用
    右箭头 (→)在当前回合采集期间,并且你已成功完成任务成功并提前结束 当前 回合,保存数据,然后进入重置阶段。
    左箭头 (←)在当前回合采集期间,但你犯了个错误作废并重新开始 当前 回合。这次的录制数据会被丢弃。
    ESC 键任何时候完全终止 整个采集会话。程序会保存已完成的数据并退出。

录制测试数据集
#

  • 采集次数:5 次
  • 采集任务(请自定义修改):Put the red pepper toy in the cardboard box
  • 采集数据保存路径:./data/datasets/xiadengma/record-test-so101
lerobot-record \
      --robot.type=so101_follower \
      --robot.port=/dev/follower_arm \
      --robot.id=my_follower_arm \
      --robot.calibration_dir=./data/calibration \
      --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
      --teleop.type=so101_leader \
      --teleop.port=/dev/leader_arm \
      --teleop.id=my_leader_arm \
      --teleop.calibration_dir=./data/calibration \
      --display_data=true \
      --dataset.repo_id=xiadengma/record-test-so101 \
      --dataset.num_episodes=5 \
      --dataset.single_task="Put the red pepper toy in the cardboard box" \
      --dataset.push_to_hub=false \
      --dataset.root=./data/datasets/xiadengma/record-test-so101
  1. src/lerobot下的extra文件夹中添加web_visualization_utils.pylerobot_record_web.py两个文件,内容如下:

    #!/usr/bin/env python3
    # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """
    录制数据集 (使用 Web 可视化)。机器人动作可由遥操作或策略生成。
    
    Records a dataset (with Web visualization). Actions for the robot can be either generated by teleoperation or by a policy.
    
    然后在浏览器中打开 / Then open in browser: http://localhost:9090
    """
    
    import logging
    import time
    from dataclasses import asdict, dataclass, field
    from pathlib import Path
    from pprint import pformat
    from typing import Any
    
    from lerobot.cameras import (
        CameraConfig,
    )
    from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
    from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
    from lerobot.configs import parser
    from lerobot.configs.policies import PreTrainedConfig
    from lerobot.datasets.image_writer import safe_stop_image_writer
    from lerobot.datasets.lerobot_dataset import LeRobotDataset
    from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
    from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
    from lerobot.datasets.video_utils import VideoEncodingManager
    from lerobot.policies.factory import make_policy, make_pre_post_processors
    from lerobot.policies.pretrained import PreTrainedPolicy
    from lerobot.processor import (
        PolicyAction,
        PolicyProcessorPipeline,
        RobotAction,
        RobotObservation,
        RobotProcessorPipeline,
        make_default_processors,
    )
    from lerobot.processor.rename_processor import rename_stats
    from lerobot.robots import (
        Robot,
        RobotConfig,
        bi_so100_follower,
        hope_jr,
        koch_follower,
        make_robot_from_config,
        so100_follower,
        so101_follower,
    )
    from lerobot.teleoperators import (
        Teleoperator,
        TeleoperatorConfig,
        bi_so100_leader,
        homunculus,
        koch_leader,
        make_teleoperator_from_config,
        so100_leader,
        so101_leader,
    )
    from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
    from lerobot.utils.constants import ACTION, OBS_STR
    from lerobot.utils.control_utils import (
        init_keyboard_listener,
        is_headless,
        predict_action,
        sanity_check_dataset_name,
        sanity_check_dataset_robot_compatibility,
    )
    from lerobot.utils.robot_utils import busy_wait
    from lerobot.utils.utils import (
        get_safe_torch_device,
        init_logging,
        log_say,
    )
    
    # 导入 Web 可视化工具 / Import web visualization utilities
    from .web_visualization_utils import init_rerun_web, log_rerun_data
    
    
    @dataclass
    class DatasetRecordConfig:
        # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
        # 数据集标识符,按约定为"用户名/数据集名"
        repo_id: str
        # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
        # 对录制任务的简短准确描述
        single_task: str
        # Root directory where the dataset will be stored (e.g. 'dataset/path').
        # 数据集保存的根目录
        root: str | Path | None = None
        # Limit the frames per second.
        # 限制帧率
        fps: int = 30
        # Number of seconds for data recording for each episode.
        # 每个 episode 的录制时长(秒)
        episode_time_s: int | float = 60
        # Number of seconds for resetting the environment after each episode.
        # 每个 episode 后重置环境的时长(秒)
        reset_time_s: int | float = 60
        # Number of episodes to record.
        # 录制的 episode 数量
        num_episodes: int = 50
        # Encode frames in the dataset into video
        # 是否将帧编码为视频
        video: bool = True
        # Upload dataset to Hugging Face hub.
        # 是否上传到 Hugging Face Hub
        push_to_hub: bool = True
        # Upload on private repository on the Hugging Face hub.
        # 是否上传到私有仓库
        private: bool = False
        # Add tags to your dataset on the hub.
        # 为数据集添加标签
        tags: list[str] | None = None
        # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
        # 负责保存 PNG 帧的子进程数;设为 0 仅用线程
        # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
        # 设为 ≥1 时使用子进程,每个子进程内使用线程
        # and threads depends on your system. We recommend 4 threads per camera with 0 processes.
        # 进程与线程数取决于系统,建议每个相机 4 线程、0 子进程
        # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
        # fps 不稳先调线程数;若仍不稳尝试增加子进程
        num_image_writer_processes: int = 0
        # Number of threads writing the frames as png images on disk, per camera.
        # 每个相机写 PNG 的线程数
        # Too many threads might cause unstable teleoperation fps due to main thread being blocked.
        # 线程过多可能阻塞主线程导致遥操作 fps 不稳
        # Not enough threads might cause low camera fps.
        # 线程过少可能导致相机 fps 低
        num_image_writer_threads_per_camera: int = 4
        # Number of episodes to record before batch encoding videos
        # 批量编码视频前要累计的 episode 数
        # Set to 1 for immediate encoding (default behavior), or higher for batched encoding
        # 1 表示立即编码(默认),更大表示批量
        video_encoding_batch_size: int = 1
        # Rename map for the observation to override the image and state keys
        # 覆盖图像与状态键名的重命名映射
        rename_map: dict[str, str] = field(default_factory=dict)
    
        def __post_init__(self):
            if self.single_task is None:
                raise ValueError("You need to provide a task as argument in `single_task`.")
    
    
    @dataclass
    class RecordWebConfig:
        robot: RobotConfig
        dataset: DatasetRecordConfig
        # Whether to control the robot with a teleoperator
        # 是否使用遥操作控制机器人
        teleop: TeleoperatorConfig | None = None
        # Whether to control the robot with a policy
        # 是否使用策略控制机器人
        policy: PreTrainedConfig | None = None
        # Display all cameras on screen
        # 是否在屏幕显示所有摄像头
        display_data: bool = False
        # Use vocal synthesis to read events.
        # 是否用语音播报事件
        play_sounds: bool = True
        # Resume recording on an existing dataset.
        # 继续录制已有数据集
        resume: bool = False
        # Web viewer port
        # Web 查看器端口
        web_port: int = 9090
        # Auto open browser
        # 自动打开浏览器
        open_browser: bool = True
    
        def __post_init__(self):
            # HACK: We parse again the cli args here to get the pretrained path if there was one.
            # 再次解析 CLI 以获取预训练路径(若存在)
            policy_path = parser.get_path_arg("policy")
            if policy_path:
                cli_overrides = parser.get_cli_overrides("policy")
                self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
                self.policy.pretrained_path = policy_path
    
            if self.teleop is None and self.policy is None:
                raise ValueError("Choose a policy, a teleoperator or both to control the robot")
    
        @classmethod
        def __get_path_fields__(cls) -> list[str]:
            """This enables the parser to load config from the policy using `--policy.path=local/dir`"""  # 允许解析器通过 --policy.path 加载配置
            return ["policy"]
    
    
    """ --------------- record_loop() data flow --------------------------
          [ Robot ]
              V
        [ robot.get_observation() ] ---> raw_obs
              V
        [ robot_observation_processor ] ---> processed_obs
              V
        .-----( ACTION LOGIC )------------------.
        V                                       V
        [ From Teleoperator ]                   [ From Policy ]
        |                                       |
        |  [teleop.get_action] -> raw_action    |   [predict_action]
        |          |                            |          |
        |          V                            |          V
        | [teleop_action_processor]             |          |
        |          |                            |          |
        '---> processed_teleop_action           '---> processed_policy_action
        |                                       |
        '-------------------------.-------------'
                                  V
                      [ robot_action_processor ] --> robot_action_to_send
                                  V
                        [ robot.send_action() ] -- (Robot Executes)
                                  V
                        ( Save to Dataset )
                                  V
                      ( Rerun Log / Loop Wait )
    """
    
    
    @safe_stop_image_writer
    def record_loop(
        robot: Robot,
        events: dict,
        fps: int,
        teleop_action_processor: RobotProcessorPipeline[
            tuple[RobotAction, RobotObservation], RobotAction
        ],  # runs after teleop  # 在遥操作之后运行
        robot_action_processor: RobotProcessorPipeline[
            tuple[RobotAction, RobotObservation], RobotAction
        ],  # runs before robot  # 在发送到机器人之前运行
        robot_observation_processor: RobotProcessorPipeline[
            RobotObservation, RobotObservation
        ],  # runs after robot  # 在机器人观测之后运行
        dataset: LeRobotDataset | None = None,
        teleop: Teleoperator | list[Teleoperator] | None = None,
        policy: PreTrainedPolicy | None = None,
        preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
        postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
        control_time_s: int | None = None,
        single_task: str | None = None,
        display_data: bool = False,
    ):
        if dataset is not None and dataset.fps != fps:
            raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
    
        teleop_arm = teleop_keyboard = None
        if isinstance(teleop, list):
            teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
            teleop_arm = next(
                (
                    t
                    for t in teleop
                    if isinstance(
                        t,
                        (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
                    )
                ),
                None,
            )
    
            if not (teleop_arm and teleop_keyboard and len(teleop) == 2 and robot.name == "lekiwi_client"):
                raise ValueError(
                    "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
                )
    
        # Reset policy and processor if they are provided
        if policy is not None and preprocessor is not None and postprocessor is not None:
            policy.reset()
            preprocessor.reset()
            postprocessor.reset()
    
        timestamp = 0
        start_episode_t = time.perf_counter()
        while timestamp < control_time_s:
            start_loop_t = time.perf_counter()
    
            if events["exit_early"]:
                events["exit_early"] = False
                break
    
            # Get robot observation
            # 获取机器人观测
            obs = robot.get_observation()
    
            # Applies a pipeline to the raw robot observation, default is IdentityProcessor
            # 对观测应用处理流水线(默认恒等)
            obs_processed = robot_observation_processor(obs)
    
            if policy is not None or dataset is not None:
                observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
    
            # Get action from either policy or teleop
            # 从策略或遥操作获取动作
            if policy is not None and preprocessor is not None and postprocessor is not None:
                action_values = predict_action(
                    observation=observation_frame,
                    policy=policy,
                    device=get_safe_torch_device(policy.config.device),
                    preprocessor=preprocessor,
                    postprocessor=postprocessor,
                    use_amp=policy.config.use_amp,
                    task=single_task,
                    robot_type=robot.robot_type,
                )
    
                action_names = dataset.features[ACTION]["names"]
                act_processed_policy: RobotAction = {
                    f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
                }
    
            elif policy is None and isinstance(teleop, Teleoperator):
                act = teleop.get_action()
    
                # Applies a pipeline to the raw teleop action, default is IdentityProcessor
                # 对遥操作动作应用处理流水线(默认恒等)
                act_processed_teleop = teleop_action_processor((act, obs))
    
            elif policy is None and isinstance(teleop, list):
                arm_action = teleop_arm.get_action()
                arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
                keyboard_action = teleop_keyboard.get_action()
                base_action = robot._from_keyboard_to_base_action(keyboard_action)
                act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
                act_processed_teleop = teleop_action_processor((act, obs))
            else:
                logging.info(
                    "No policy or teleoperator provided, skipping action generation."
                    "This is likely to happen when resetting the environment without a teleop device."
                    "The robot won't be at its rest position at the start of the next episode."
                )
                continue
    
            # Applies a pipeline to the action, default is IdentityProcessor
            # 对动作应用处理流水线(默认恒等)
            if policy is not None and act_processed_policy is not None:
                action_values = act_processed_policy
                robot_action_to_send = robot_action_processor((act_processed_policy, obs))
            else:
                action_values = act_processed_teleop
                robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
    
            # Send action to robot
            # 发送动作到机器人
            # Action can eventually be clipped using `max_relative_target`,
            # 动作可能被裁剪(max_relative_target)
            # so action actually sent is saved in the dataset. action = postprocessor.process(action)
            # 实际发送的动作会存入数据集
            # TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
            # 建议在流水线中裁剪,保持一致
            _sent_action = robot.send_action(robot_action_to_send)
    
            # Write to dataset
            # 写入数据集
            if dataset is not None:
                action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
                frame = {**observation_frame, **action_frame, "task": single_task}
                dataset.add_frame(frame)
    
            if display_data:
                log_rerun_data(observation=obs_processed, action=action_values)
    
            dt_s = time.perf_counter() - start_loop_t
            busy_wait(1 / fps - dt_s)
    
            timestamp = time.perf_counter() - start_episode_t
    
    
    @parser.wrap()
    def record_web(cfg: RecordWebConfig) -> LeRobotDataset:
        init_logging()
        logging.info(pformat(asdict(cfg)))
    
        if cfg.display_data:
            # 使用 Web 可视化 / Use web visualization
            init_rerun_web(
                session_name="recording_web",
                port=cfg.web_port,
                open_browser=cfg.open_browser,
            )
    
        robot = make_robot_from_config(cfg.robot)
        teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
    
        teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
    
        dataset_features = combine_feature_dicts(
            aggregate_pipeline_dataset_features(
                pipeline=teleop_action_processor,
                initial_features=create_initial_features(
                    action=robot.action_features
                ),  # TODO(steven, pepijn): in future this should be come from teleop or policy
                use_videos=cfg.dataset.video,
            ),
            aggregate_pipeline_dataset_features(
                pipeline=robot_observation_processor,
                initial_features=create_initial_features(observation=robot.observation_features),
                use_videos=cfg.dataset.video,
            ),
        )
    
        if cfg.resume:
            dataset = LeRobotDataset(
                cfg.dataset.repo_id,
                root=cfg.dataset.root,
                batch_encoding_size=cfg.dataset.video_encoding_batch_size,
            )
    
            if hasattr(robot, "cameras") and len(robot.cameras) > 0:
                dataset.start_image_writer(
                    num_processes=cfg.dataset.num_image_writer_processes,
                    num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
                )
            sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
        else:
            # Create empty dataset or load existing saved episodes
            sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
            dataset = LeRobotDataset.create(
                cfg.dataset.repo_id,
                cfg.dataset.fps,
                root=cfg.dataset.root,
                robot_type=robot.name,
                features=dataset_features,
                use_videos=cfg.dataset.video,
                image_writer_processes=cfg.dataset.num_image_writer_processes,
                image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
                batch_encoding_size=cfg.dataset.video_encoding_batch_size,
            )
    
        # Load pretrained policy
        policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
        preprocessor = None
        postprocessor = None
        if cfg.policy is not None:
            preprocessor, postprocessor = make_pre_post_processors(
                policy_cfg=cfg.policy,
                pretrained_path=cfg.policy.pretrained_path,
                dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
                preprocessor_overrides={
                    "device_processor": {"device": cfg.policy.device},
                    "rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
                },
            )
    
        robot.connect()
        if teleop is not None:
            teleop.connect()
    
        listener, events = init_keyboard_listener()
    
        with VideoEncodingManager(dataset):
            recorded_episodes = 0
            while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
                log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
                record_loop(
                    robot=robot,
                    events=events,
                    fps=cfg.dataset.fps,
                    teleop_action_processor=teleop_action_processor,
                    robot_action_processor=robot_action_processor,
                    robot_observation_processor=robot_observation_processor,
                    teleop=teleop,
                    policy=policy,
                    preprocessor=preprocessor,
                    postprocessor=postprocessor,
                    dataset=dataset,
                    control_time_s=cfg.dataset.episode_time_s,
                    single_task=cfg.dataset.single_task,
                    display_data=cfg.display_data,
                )
    
                # Execute a few seconds without recording to give time to manually reset the environment
                # 暂停录制以便手动重置环境
                # Skip reset for the last episode to be recorded
                # 最后一个 episode 可跳过重置
                if not events["stop_recording"] and (
                    (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
                ):
                    log_say("Reset the environment", cfg.play_sounds)
                    record_loop(
                        robot=robot,
                        events=events,
                        fps=cfg.dataset.fps,
                        teleop_action_processor=teleop_action_processor,
                        robot_action_processor=robot_action_processor,
                        robot_observation_processor=robot_observation_processor,
                        teleop=teleop,
                        control_time_s=cfg.dataset.reset_time_s,
                        single_task=cfg.dataset.single_task,
                        display_data=cfg.display_data,
                    )
    
                if events["rerecord_episode"]:
                    log_say("Re-record episode", cfg.play_sounds)
                    events["rerecord_episode"] = False
                    events["exit_early"] = False
                    dataset.clear_episode_buffer()
                    continue
    
                dataset.save_episode()
                recorded_episodes += 1
    
        log_say("Stop recording", cfg.play_sounds, blocking=True)
    
        robot.disconnect()
        if teleop is not None:
            teleop.disconnect()
    
        if not is_headless() and listener is not None:
            listener.stop()
    
        if cfg.dataset.push_to_hub:
            dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
    
        log_say("Exiting", cfg.play_sounds)
        return dataset
    
    
    def main():
        record_web()
    
    
    if __name__ == "__main__":
        main()
    
  2. 执行遥操作:

    python -m lerobot.extra.lerobot_record_web \
          --robot.type=so101_follower \
          --robot.port=/dev/follower_arm \
          --robot.id=my_follower_arm \
          --robot.calibration_dir=./data/calibration \
          --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
          --teleop.type=so101_leader \
          --teleop.port=/dev/leader_arm \
          --teleop.id=my_leader_arm \
          --teleop.calibration_dir=./data/calibration \
          --display_data=true \
          --dataset.repo_id=xiadengma/record-test-so101 \
          --dataset.num_episodes=3 \
          --dataset.single_task="Put the red pepper toy in the cardboard box" \
          --web_port=9090 \
          --dataset.push_to_hub=false \
          --dataset.root=./data/datasets/xiadengma/record-test-so101
    
  3. 运行后默认情况下会在浏览器打开http://localhost:9090/?url=ws://localhost:9877,可以查看机械臂和摄像头画面

可视化数据集
#

我们可以对录制的数据集进行可视化。

  • 查看指定回合:
    lerobot-dataset-viz \
      --repo-id xiadengma/record-test-so101 \
      --root ./data/datasets/xiadengma/record-test-so101 \
      --episode-index 0
    
  • 查看多个回合:
    lerobot-dataset-viz \
      --repo-id xiadengma/record-test-so101 \
      --root ./data/datasets/xiadengma/record-test-so101 \
      --episodes 0 1 2 3 4
    
  1. src/lerobot下的extra文件夹中添加lerobot_dataset_viz_web.py文件,内容如下:

    • lerobot_dataset_viz_web.py
    #!/usr/bin/env python
    
    # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """
    使用 Web 浏览器可视化 LeRobotDataset 中任意 episode 的所有帧数据。
    Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset in a web browser.
    
    注意 Note:
        - Episode 的最后一帧不一定对应最终状态 / The last frame doesn't always correspond to a final state
        - 图像可能存在压缩伪影 / Images may show compression artifacts from mp4 encoding
    
    访问 Access:
        浏览器打开 / Open in browser: http://localhost:PORT
    """
    
    import argparse
    import gc
    import logging
    from collections.abc import Iterator
    from pathlib import Path
    
    import numpy as np
    import rerun as rr
    import torch
    import torch.utils.data
    import tqdm
    
    from lerobot.datasets.lerobot_dataset import LeRobotDataset
    from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
    
    
    class EpisodeSampler(torch.utils.data.Sampler):
        """用于采样单个 episode 的所有帧 / Sampler for all frames of a single episode."""
    
        def __init__(self, dataset: LeRobotDataset, episode_index: int):
            from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
            to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
            self.frame_ids = range(from_idx, to_idx)
    
        def __iter__(self) -> Iterator:
            return iter(self.frame_ids)
    
        def __len__(self) -> int:
            return len(self.frame_ids)
    
    
    def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
        """
        将 PyTorch CHW float32 图像转换为 NumPy HWC uint8 格式。
        Convert PyTorch CHW float32 image to NumPy HWC uint8 format.
        """
        assert chw_float32_torch.dtype == torch.float32
        assert chw_float32_torch.ndim == 3
        c, h, w = chw_float32_torch.shape
        assert c < h and c < w, (
            f"期望通道优先格式,但得到 / expect channel first images, but got {chw_float32_torch.shape}"
        )
        hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
        return hwc_uint8_numpy
    
    
    def visualize_episode(
        dataset: LeRobotDataset,
        episode_index: int,
        batch_size: int = 32,
        num_workers: int = 0,
    ) -> None:
        """
        在 Rerun 中可视化单个 episode 的所有帧。
        Visualize all frames of a single episode in Rerun.
    
        Args:
            dataset: LeRobot 数据集 / LeRobot dataset
            episode_index: Episode 索引 / Episode index
            batch_size: 批处理大小 / Batch size for dataloader
            num_workers: 数据加载进程数 / Number of worker processes
        """
        logging.info(
            f"📊 加载 Episode {episode_index} 的数据加载器 / Loading dataloader for Episode {episode_index}"
        )
    
        episode_sampler = EpisodeSampler(dataset, episode_index)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            num_workers=num_workers,
            batch_size=batch_size,
            sampler=episode_sampler,
        )
    
        total_frames = len(episode_sampler)
        logging.info(
            f"📈 Episode {episode_index} 共有 {total_frames} 帧 / Episode {episode_index} has {total_frames} frames"
        )
    
        # 记录数据到 Rerun / Log data to Rerun
        for batch in tqdm.tqdm(dataloader, total=len(dataloader), desc=f"Episode {episode_index}"):
            # 遍历批次中的每一帧 / iterate over the batch
            for i in range(len(batch["index"])):
                rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
                rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
    
                # 显示相机图像 / display camera images
                for key in dataset.meta.camera_keys:
                    rr.log(f"cameras/{key}", rr.Image(to_hwc_uint8_numpy(batch[key][i])))
    
                # 显示动作空间的每个维度 / display each dimension of action space
                if ACTION in batch:
                    for dim_idx, val in enumerate(batch[ACTION][i]):
                        rr.log(f"{ACTION}/dim_{dim_idx}", rr.Scalar(val.item()))
    
                # 显示观测状态空间的每个维度 / display each dimension of observed state space
                if OBS_STATE in batch:
                    for dim_idx, val in enumerate(batch[OBS_STATE][i]):
                        rr.log(f"state/dim_{dim_idx}", rr.Scalar(val.item()))
    
                # 显示完成标志 / display done flag
                if DONE in batch:
                    rr.log(DONE, rr.Scalar(batch[DONE][i].item()))
    
                # 显示奖励 / display reward
                if REWARD in batch:
                    rr.log(REWARD, rr.Scalar(batch[REWARD][i].item()))
    
                # 显示成功标志 / display success flag
                if "next.success" in batch:
                    rr.log("success", rr.Scalar(batch["next.success"][i].item()))
    
        logging.info(f"✅ Episode {episode_index} 可视化完成 / Episode {episode_index} visualization complete")
    
    
    def visualize_dataset_web(
        dataset: LeRobotDataset,
        episode_indices: list[int],
        batch_size: int = 32,
        num_workers: int = 0,
        port: int = 9090,
        open_browser: bool = True,
        memory_limit: str = "25%",
    ) -> None:
        """
        使用 Web 界面可视化数据集。
        Visualize dataset using web interface.
    
        Args:
            dataset: LeRobot 数据集 / LeRobot dataset
            episode_indices: 要可视化的 episode 索引列表 / List of episode indices to visualize
            batch_size: 批处理大小 / Batch size for dataloader
            num_workers: 数据加载进程数 / Number of worker processes
            port: Web 服务器端口 / Web server port
            open_browser: 是否自动打开浏览器 / Whether to automatically open browser
            memory_limit: Rerun 内存限制 / Memory limit for Rerun
        """
        repo_id = dataset.repo_id
    
        # 初始化 Rerun Web 界面 / Initialize Rerun web interface
        logging.info("🌐 启动 Rerun Web 界面 / Starting Rerun Web interface")
        logging.info(f"📍 访问地址 / Access URL: http://localhost:{port}")
        logging.info(f"💾 内存限制 / Memory limit: {memory_limit}")
    
        rr.init(f"{repo_id}_web_viz", spawn=False)
    
        # 手动触发垃圾回收,避免阻塞 / Manually call garbage collector to avoid blocking
        gc.collect()
    
        # 启动 Web 服务器 / Start web server
        rr.serve_web(
            open_browser=open_browser,
            web_port=port,
            server_memory_limit=memory_limit,
        )
    
        # 可视化每个 episode / Visualize each episode
        for episode_idx in episode_indices:
            if episode_idx >= len(dataset.meta.episodes):
                logging.warning(
                    f"⚠️  Episode {episode_idx} 不存在,跳过 / Episode {episode_idx} does not exist, skipping"
                )
                continue
    
            # 为每个 episode 创建记录路径 / Create recording path for each episode
            rr.log(f"episode_{episode_idx}/info", rr.TextLog(f"Episode {episode_idx}"), static=True)
    
            # 设置时间序列标记当前 episode / Set time sequence for current episode
            rr.set_time_sequence("episode", episode_idx)
    
            visualize_episode(
                dataset=dataset,
                episode_index=episode_idx,
                batch_size=batch_size,
                num_workers=num_workers,
            )
    
        logging.info("✨ 所有 episode 可视化完成 / All episodes visualization complete")
        logging.info("🌐 Web 服务器持续运行中,按 Ctrl+C 退出 / Web server running, press Ctrl+C to exit")
    
        # 保持服务器运行 / Keep server running
        try:
            import time
    
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            logging.info("👋 收到 Ctrl-C,正在退出 / Ctrl-C received, exiting")
    
    
    def main():
        parser = argparse.ArgumentParser(
            description="使用 Web 浏览器可视化 LeRobot 数据集 / Visualize LeRobot dataset in web browser",
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )
    
        parser.add_argument(
            "--repo-id",
            type=str,
            required=True,
            help="数据集仓库 ID / Dataset repository ID (e.g. `lerobot/pusht` or `xiadengma/record-test-so101`)",
        )
    
        # Episode 选择参数 / Episode selection arguments
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument(
            "--episode-index",
            type=int,
            help="要可视化的单个 episode 索引 / Single episode index to visualize",
        )
        group.add_argument(
            "--episodes",
            type=int,
            nargs="+",
            help="要可视化的多个 episode 索引 / Multiple episode indices to visualize (e.g. 0 1 2 3)",
        )
    
        parser.add_argument(
            "--root",
            type=Path,
            default=None,
            help="本地数据集根目录 / Root directory for local dataset (e.g. `--root ./data/datasets/xiadengma/record-test-so101`)",
        )
        parser.add_argument(
            "--batch-size",
            type=int,
            default=32,
            help="DataLoader 批处理大小 / Batch size for DataLoader (default: 32)",
        )
        parser.add_argument(
            "--num-workers",
            type=int,
            default=4,
            help="DataLoader 进程数 / Number of DataLoader worker processes (default: 4)",
        )
        parser.add_argument(
            "--port",
            type=int,
            default=9090,
            help="Web 服务器端口 / Web server port (default: 9090)",
        )
        parser.add_argument(
            "--open-browser",
            type=lambda x: str(x).lower() in ("true", "1", "yes"),
            default=True,
            help="是否自动打开浏览器 / Whether to automatically open browser (default: True)",
        )
        parser.add_argument(
            "--memory-limit",
            type=str,
            default="25%",
            help="Rerun 内存限制 / Memory limit for Rerun (default: 25%%)",
        )
        parser.add_argument(
            "--tolerance-s",
            type=float,
            default=1e-4,
            help="时间戳容差(秒)/ Tolerance in seconds for timestamps (default: 1e-4)",
        )
    
        args = parser.parse_args()
    
        # 配置日志 / Configure logging
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
    
        # 确定要可视化的 episode 列表 / Determine episode list
        if args.episode_index is not None:
            episode_indices = [args.episode_index]
        else:
            episode_indices = args.episodes
    
        logging.info("=" * 80)
        logging.info("🤖 LeRobot 数据集 Web 可视化工具 / LeRobot Dataset Web Visualizer")
        logging.info("=" * 80)
        logging.info(f"📦 数据集 / Dataset: {args.repo_id}")
        logging.info(f"📂 根目录 / Root: {args.root if args.root else 'HuggingFace Cache'}")
        logging.info(f"📊 Episodes: {episode_indices}")
        logging.info("=" * 80)
    
        # 加载数据集 / Load dataset
        logging.info("🔄 正在加载数据集 / Loading dataset...")
        dataset = LeRobotDataset(
            repo_id=args.repo_id,
            episodes=episode_indices,
            root=args.root,
            tolerance_s=args.tolerance_s,
        )
    
        logging.info("✅ 数据集加载成功 / Dataset loaded successfully")
        logging.info(f"📈 数据集总帧数 / Total frames: {len(dataset)}")
        logging.info(f"📹 相机数量 / Number of cameras: {len(dataset.meta.camera_keys)}")
        logging.info(f"🎥 相机列表 / Camera keys: {dataset.meta.camera_keys}")
    
        # 启动 Web 可视化 / Start web visualization
        visualize_dataset_web(
            dataset=dataset,
            episode_indices=episode_indices,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            port=args.port,
            open_browser=args.open_browser,
            memory_limit=args.memory_limit,
        )
    
    
    if __name__ == "__main__":
        main()
    
  2. 查看数据集:

    • 查看单个回合:
      python -m lerobot.extra.lerobot_dataset_viz_web \
      --repo-id xiadengma/record-test-so101 \
      --root ./data/datasets/xiadengma/record-test-so101 \
      --episode-index 0 \
      --port 9090
      
    • 查看多个回合:
      python -m lerobot.extra.lerobot_dataset_viz_web \
      --repo-id xiadengma/record-test-so101 \
      --root ./data/datasets/xiadengma/record-test-so101 \
      --episodes 0 1 2 3 4 \
      --port 9090
      

回放数据集
#

不稳定,可跳过,可尝试。

你也可以让机械臂根据数据集进行回放:

  1. 回放之前测试录制数据集的第一个回合:
    lerobot-replay \
      --robot.type=so101_follower \
      --robot.port=/dev/follower_arm \
      --robot.id=my_follower_arm \
      --dataset.repo_id=xiadengma/record-test-so101 \
      --dataset.root=./data/datasets/xiadengma/record-test-so101 \
      --dataset.episode=0
    
  2. 你应该可以看到机械臂按照数据集进行回放。

录制完整数据集
#

现在,我们开始录制数据集,并将这个数据集训练后用于推理。

  • 每个回合录制流程:等待程序提示录制当前回合,通过主臂遥操作机械臂进行抓取,抓取结束后,将机械臂恢复到休息位再结束当前回合,等待程序记录数据的同时重置抓取环境,等待程序提示录制下一个回合。
  • 录制要求:
    1. 录制数量:至少 50 组数据,确保数据充分性
    2. 录制频次:每个位置重复录制 10 次,以提高数据的多样性和鲁棒性
    3. 录制位置:至少选择 5 个不同的位置,涵盖更多动作场景
  • 键盘控制说明:
    按键何时使用作用
    右箭头 (→)在当前回合采集期间,并且你已成功完成任务成功并提前结束 当前 回合,保存数据,然后进入重置阶段。
    左箭头 (←)在当前回合采集期间,但你犯了个错误作废并重新开始 当前 回合。这次的录制数据会被丢弃。
    ESC 键任何时候完全终止 整个采集会话。程序会保存已完成的数据并退出。
  1. 录制数据集:

    • 采集次数:50 次
    • 采集任务(请自定义修改):Put the red pepper toy in the cardboard box
    • 采集数据保存路径(请自定义修改):./data/datasets/xiadengma/so101-red-pepper
    python -m lerobot.extra.lerobot_record_web \
          --robot.type=so101_follower \
          --robot.port=/dev/follower_arm \
          --robot.id=my_follower_arm \
          --robot.calibration_dir=./data/calibration \
          --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
          --teleop.type=so101_leader \
          --teleop.port=/dev/leader_arm \
          --teleop.id=my_leader_arm \
          --teleop.calibration_dir=./data/calibration \
          --display_data=true \
          --dataset.repo_id=xiadengma/so101-red-pepper \
          --dataset.num_episodes=50 \
          --dataset.single_task="Put the red pepper toy in the cardboard box" \
          --web_port=9090 \
          --dataset.push_to_hub=false \
          --dataset.root=./data/datasets/xiadengma/so101-red-pepper
    
  2. 恢复录制(仅供参考):

    1. Recording episode 15时程序报错:Waiting for image writer to terminate...TimeoutError: Timed out waiting for frame from camera OpenCVCamera(8) after 200 ms. Read thread alive: True.

      • 问题原因:在机械臂运动过程中,拉扯到机械臂腕部摄像头的连接处,导致摄像头连接不稳定。

      • 解决方法:重新连接摄像头,并预留足够长度的线缆,确保摄像头连接稳定,接着运行下面命令恢复录制。

        python -m lerobot.extra.lerobot_record_web \
              --robot.type=so101_follower \
              --robot.port=/dev/follower_arm \
              --robot.id=my_follower_arm \
              --robot.calibration_dir=./data/calibration \
              --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
              --teleop.type=so101_leader \
              --teleop.port=/dev/leader_arm \
              --teleop.id=my_leader_arm \
              --teleop.calibration_dir=./data/calibration \
              --display_data=true \
              --dataset.repo_id=xiadengma/so101-red-pepper \
              --dataset.num_episodes=35 \
              --dataset.single_task="Put the red pepper toy in the cardboard box" \
              --web_port=9090 \
              --dataset.push_to_hub=false \
              --dataset.root=./data/datasets/xiadengma/so101-red-pepper \
              --resume=true
        
        • 在记录过程中会自动创建检查点。
        • 如果记录过程中断,可以通过重新运行相同的命令并添加 –resume=true 来恢复记录。
        • ⚠️ 重要提示:在恢复时,需将 –dataset.num_episodes 设置为要额外记录的剧集数量(而不是数据集中目标的总剧集数量)

        录制结束后,数据集大小为390M

训练
#

extra文件夹添加swanlab_utils.pytrain_swanlab.pylerobot_train_swanlab.py

  • swanlab_utils.py

    #!/usr/bin/env python
    
    # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    import logging
    import re
    from glob import glob
    from pathlib import Path
    from typing import TYPE_CHECKING
    
    from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
    from termcolor import colored
    
    from lerobot.utils.constants import PRETRAINED_MODEL_DIR
    
    if TYPE_CHECKING:
        from lerobot.extra.train_swanlab import TrainPipelineSwanLabConfig
    
    
    def cfg_to_group(cfg: "TrainPipelineSwanLabConfig", return_list: bool = False) -> list[str] | str:
        """Return a group name for logging. Optionally returns group name as list."""
        lst = [
            f"policy:{cfg.policy.type}",
            f"seed:{cfg.seed}",
        ]
        if cfg.dataset is not None:
            lst.append(f"dataset:{cfg.dataset.repo_id}")
        if cfg.env is not None:
            lst.append(f"env:{cfg.env.type}")
        return lst if return_list else "-".join(lst)
    
    
    def get_swanlab_run_id_from_filesystem(log_dir: Path) -> str:
        # Get the SwanLab run ID.
        paths = glob(str(log_dir / "swanlab/latest-run/run-*"))
        if len(paths) != 1:
            raise RuntimeError("Couldn't get the previous SwanLab run ID for run resumption.")
        match = re.search(r"run-([^\.]+).swanlab", paths[0].split("/")[-1])
        if match is None:
            raise RuntimeError("Couldn't get the previous SwanLab run ID for run resumption.")
        swanlab_run_id = match.groups(0)[0]
        return swanlab_run_id
    
    
    def get_safe_swanlab_artifact_name(name: str):
        """SwanLab artifacts don't accept ":" or "/" in their name."""
        return name.replace(":", "_").replace("/", "_")
    
    
    class SwanLabLogger:
        """A helper class to log object using swanlab."""
    
        def __init__(self, cfg: "TrainPipelineSwanLabConfig"):
            self.cfg = cfg.swanlab
            self.log_dir = cfg.output_dir
            self.job_name = cfg.job_name
            self.env_fps = cfg.env.fps if cfg.env else None
            self._group = cfg_to_group(cfg)
    
            import swanlab
    
            swanlab_run_id = (
                cfg.swanlab.run_id
                if cfg.swanlab.run_id
                else get_swanlab_run_id_from_filesystem(self.log_dir)
                if cfg.resume
                else None
            )
            self._run = swanlab.init(
                project=self.cfg.project,
                experiment_name=swanlab_run_id,
                description=self.cfg.notes,
                tags=cfg_to_group(cfg, return_list=True),
                logdir=str(self.log_dir),
                config=cfg.to_dict(),
                save_code=False,
                resume=cfg.resume,
                mode=self.cfg.mode if self.cfg.mode in ["cloud", "offline", "local", "disabled"] else "cloud",
            )
            run_id = self._run.public.run_id
            # NOTE: We will override the cfg.swanlab.run_id with the swanlab run id.
            # This is because we want to be able to resume the run from the swanlab run id.
            cfg.swanlab.run_id = run_id
            # Handle custom step key for rl asynchronous training.
            self._swanlab_custom_step_key: set[str] | None = None
            print(colored("Logs will be synced with swanlab.", "blue", attrs=["bold"]))
            logging.info(
                f"Track this run --> {colored(self._run.public.cloud.experiment_url, 'yellow', attrs=['bold'])}"
            )
            self._swanlab = swanlab
    
        def log_policy(self, checkpoint_dir: Path):
            """Checkpoints the policy to swanlab."""
            if self.cfg.disable_artifact:
                return
    
            step_id = checkpoint_dir.name
            artifact_name = f"{self._group}-{step_id}"
            artifact_name = get_safe_swanlab_artifact_name(artifact_name)
            # SwanLab doesn't have direct artifact logging like wandb
            # We'll log the model file path as a text log for now
            model_path = str(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
            self._swanlab.log({"model_checkpoint": self._swanlab.Text(model_path)}, step=step_id)
    
        def log_dict(
            self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
        ):
            if mode not in {"train", "eval"}:
                raise ValueError(mode)
            if step is None and custom_step_key is None:
                raise ValueError("Either step or custom_step_key must be provided.")
    
            # NOTE: This is not simple. SwanLab step must always monotonically increase and it
            # increases with each swanlab.log call, but in the case of asynchronous RL for example,
            # multiple time steps is possible. For example, the interaction step with the environment,
            # the training step, the evaluation step, etc. So we need to define a custom step key
            # to log the correct step for each metric.
            if custom_step_key is not None:
                if self._swanlab_custom_step_key is None:
                    self._swanlab_custom_step_key = set()
                new_custom_key = f"{mode}/{custom_step_key}"
                if new_custom_key not in self._swanlab_custom_step_key:
                    self._swanlab_custom_step_key.add(new_custom_key)
    
            for k, v in d.items():
                if not isinstance(v, (int, float, str)):
                    logging.warning(
                        f'SwanLab logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
                    )
                    continue
    
                # Do not log the custom step key itself.
                if self._swanlab_custom_step_key is not None and k in self._swanlab_custom_step_key:
                    continue
    
                if custom_step_key is not None:
                    value_custom_step = d.get(custom_step_key)
                    if value_custom_step is None:
                        logging.warning(
                            f'Custom step key "{custom_step_key}" not found in the dictionary. Skipping logging for this key.'
                        )
                        continue
                    data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
                    self._swanlab.log(data)
                    continue
    
                self._swanlab.log(data={f"{mode}/{k}": v}, step=step)
    
        def log_video(self, video_path: str, step: int, mode: str = "train"):
            if mode not in {"train", "eval"}:
                raise ValueError(mode)
    
            # SwanLab media logging - using Media.Video for video logging
            swanlab_video = self._swanlab.Video(video_path, fps=self.env_fps)
            self._swanlab.log({f"{mode}/video": swanlab_video}, step=step)
    
  • train_swanlab.py

    #!/usr/bin/env python
    
    # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """
    SwanLab 训练配置扩展 - 扩展 TrainPipelineConfig 以支持 SwanLab
    
    这个模块为 LeRobot 训练管道提供 SwanLab 日志记录支持。
    它继承了原始的 TrainPipelineConfig,并添加了 tracker 和 swanlab 配置字段。
    """
    
    from dataclasses import dataclass, field
    
    from lerobot.configs.train import TrainPipelineConfig
    from lerobot.extra.default_swanlab import SwanLabConfig
    
    
    @dataclass
    class TrainPipelineSwanLabConfig(TrainPipelineConfig):
        """支持 SwanLab 的训练管道配置
    
        继承自 TrainPipelineConfig,添加了以下字段:
            tracker: 日志跟踪器选择,可选值: 'wandb', 'swanlab', 'both', 'none'
            swanlab: SwanLab 配置对象
    
        使用示例:
            ```python
            config = TrainPipelineSwanLabConfig(
                dataset=DatasetConfig(repo_id="my/dataset"),
                tracker="swanlab",
                swanlab=SwanLabConfig(project="my-project", mode="cloud"),
            )
            ```
        """
    
        # Tracker selection: 'wandb', 'swanlab', 'both', or 'none'
        # 跟踪器选择: 'wandb', 'swanlab', 'both', 或 'none'
        tracker: str = "wandb"
    
        # SwanLab configuration
        # SwanLab 配置
        swanlab: SwanLabConfig = field(default_factory=SwanLabConfig)
    
  • lerobot_train_swanlab.py

    #!/usr/bin/env python
    
    # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    # 训练脚本:构建数据集/环境/策略与优化器,执行训练循环并定期评估与保存
    #   --tracker=swanlab \
    #   --swanlab.project=my_lerobot \
    #   --swanlab.mode=cloud
    import logging
    import time
    from contextlib import nullcontext
    from pprint import pformat
    from typing import Any
    
    import torch
    from termcolor import colored
    from torch.amp import GradScaler
    from torch.optim import Optimizer
    
    from lerobot.configs import parser
    from lerobot.datasets.factory import make_dataset
    from lerobot.datasets.sampler import EpisodeAwareSampler
    from lerobot.datasets.utils import cycle
    from lerobot.envs.factory import make_env
    from lerobot.envs.utils import close_envs
    from lerobot.extra.swanlab_utils import SwanLabLogger
    from lerobot.extra.train_swanlab import TrainPipelineSwanLabConfig
    from lerobot.optim.factory import make_optimizer_and_scheduler
    from lerobot.policies.factory import make_policy, make_pre_post_processors
    from lerobot.policies.pretrained import PreTrainedPolicy
    from lerobot.policies.utils import get_device_from_parameters
    from lerobot.rl.wandb_utils import WandBLogger
    from lerobot.scripts.lerobot_eval import eval_policy_all
    from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
    from lerobot.utils.random_utils import set_seed
    from lerobot.utils.train_utils import (
        get_step_checkpoint_dir,
        get_step_identifier,
        load_training_state,
        save_checkpoint,
        update_last_checkpoint,
    )
    from lerobot.utils.utils import (
        format_big_number,
        get_safe_torch_device,
        has_method,
        init_logging,
    )
    
    
    def update_policy(
        train_metrics: MetricsTracker,
        policy: PreTrainedPolicy,
        batch: Any,
        optimizer: Optimizer,
        grad_clip_norm: float,
        grad_scaler: GradScaler,
        lr_scheduler=None,
        use_amp: bool = False,
        lock=None,
    ) -> tuple[MetricsTracker, dict]:
        """
        Performs a single training step to update the policy's weights.
    
        This function executes the forward and backward passes, clips gradients, and steps the optimizer and
        learning rate scheduler. It also handles mixed-precision training via a GradScaler.
    
        Args:
            train_metrics: A MetricsTracker instance to record training statistics.
            policy: The policy model to be trained.
            batch: A batch of training data.
            optimizer: The optimizer used to update the policy's parameters.
            grad_clip_norm: The maximum norm for gradient clipping.
            grad_scaler: The GradScaler for automatic mixed-precision training.
            lr_scheduler: An optional learning rate scheduler.
            use_amp: A boolean indicating whether to use automatic mixed precision.
            lock: An optional lock for thread-safe optimizer updates.
    
        Returns:
            A tuple containing:
            - The updated MetricsTracker with new statistics for this step.
            - A dictionary of outputs from the policy's forward pass, for logging purposes.
        """
        start_time = time.perf_counter()
        device = get_device_from_parameters(policy)
        policy.train()
        with torch.autocast(device_type=device.type) if use_amp else nullcontext():
            loss, output_dict = policy.forward(batch)
            # TODO(rcadene): policy.unnormalize_outputs(out_dict)
        grad_scaler.scale(loss).backward()
    
        # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
        # 在裁剪梯度之前反缩放优化器参数的梯度
        grad_scaler.unscale_(optimizer)
    
        grad_norm = torch.nn.utils.clip_grad_norm_(
            policy.parameters(),
            grad_clip_norm,
            error_if_nonfinite=False,
        )
    
        # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # 优化器梯度已反缩放,scaler.step 不再反缩放
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        # 若梯度含 inf/NaN,会跳过 optimizer.step()
        with lock if lock is not None else nullcontext():
            grad_scaler.step(optimizer)
        # Updates the scale for next iteration.
        grad_scaler.update()
    
        optimizer.zero_grad()
    
        # Step through pytorch scheduler at every batch instead of epoch
        # 每个 batch 调度学习率而非每个 epoch
        if lr_scheduler is not None:
            lr_scheduler.step()
    
        if has_method(policy, "update"):
            # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
            policy.update()
    
        train_metrics.loss = loss.item()
        train_metrics.grad_norm = grad_norm.item()
        train_metrics.lr = optimizer.param_groups[0]["lr"]
        train_metrics.update_s = time.perf_counter() - start_time
        return train_metrics, output_dict
    
    
    @parser.wrap()
    def train(cfg: TrainPipelineSwanLabConfig):
        """
        Main function to train a policy with SwanLab support.
    
        This function orchestrates the entire training pipeline, including:
        - Setting up logging, seeding, and device configuration.
        - Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
        - Handling resumption from a checkpoint.
        - Running the main training loop, which involves fetching data batches and calling `update_policy`.
        - Periodically logging metrics, saving model checkpoints, and evaluating the policy.
        - Pushing the final trained model to the Hugging Face Hub if configured.
        - Supporting both WandB and SwanLab for experiment tracking.
    
        Args:
            cfg: A `TrainPipelineSwanLabConfig` object containing all training configurations.
        """
        cfg.validate()
        logging.info(pformat(cfg.to_dict()))
    
        # Initialize loggers based on tracker selection
        wandb_logger = None
        swanlab_logger = None
    
        if cfg.tracker in ["wandb", "both"] and cfg.wandb.project:
            wandb_logger = WandBLogger(cfg)
    
        if cfg.tracker in ["swanlab", "both"] and cfg.swanlab.project:
            swanlab_logger = SwanLabLogger(cfg)
    
        if cfg.tracker == "none" or (not wandb_logger and not swanlab_logger):
            logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
    
        if cfg.seed is not None:
            set_seed(cfg.seed)
    
        # Check device is available
        device = get_safe_torch_device(cfg.policy.device, log=True)
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
    
        logging.info("Creating dataset")  # 正在创建数据集
        dataset = make_dataset(cfg)
    
        # Create environment used for evaluating checkpoints during training on simulation data.
        # 用于在仿真训练中评估检查点
        # On real-world data, no need to create an environment as evaluations are done outside train.py,
        # 真实数据上评估在 train.py 外进行
        # using the eval.py instead, with gym_dora environment and dora-rs.
        # 使用 eval.py 与 gym_dora/dora-rs
        eval_env = None
        if cfg.eval_freq > 0 and cfg.env is not None:
            logging.info("Creating env")  # 正在创建环境
            eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
    
        logging.info("Creating policy")  # 正在创建策略
        policy = make_policy(
            cfg=cfg.policy,
            ds_meta=dataset.meta,
        )
    
        # Create processors - only provide dataset_stats if not resuming from saved processors
        processor_kwargs = {}
        postprocessor_kwargs = {}
        if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
            # Only provide dataset_stats when not resuming from saved processor state
            processor_kwargs["dataset_stats"] = dataset.meta.stats
    
        if cfg.policy.pretrained_path is not None:
            processor_kwargs["preprocessor_overrides"] = {
                "device_processor": {"device": device.type},
                "normalizer_processor": {
                    "stats": dataset.meta.stats,
                    "features": {**policy.config.input_features, **policy.config.output_features},
                    "norm_map": policy.config.normalization_mapping,
                },
            }
            postprocessor_kwargs["postprocessor_overrides"] = {
                "unnormalizer_processor": {
                    "stats": dataset.meta.stats,
                    "features": policy.config.output_features,
                    "norm_map": policy.config.normalization_mapping,
                },
            }
    
        preprocessor, postprocessor = make_pre_post_processors(
            policy_cfg=cfg.policy,
            pretrained_path=cfg.policy.pretrained_path,
            **processor_kwargs,
            **postprocessor_kwargs,
        )
    
        logging.info("Creating optimizer and scheduler")  # 正在创建优化器与调度器
        optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
        grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
    
        step = 0  # number of policy updates (forward + backward + optim)
    
        if cfg.resume:
            step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
    
        num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
        num_total_params = sum(p.numel() for p in policy.parameters())
    
        logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")  # 输出目录
        if cfg.env is not None:
            logging.info(f"{cfg.env.task=}")
        logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
        logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
        logging.info(f"{dataset.num_episodes=}")
        logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
        logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
    
        # create dataloader for offline training
        # 为离线训练创建数据加载器
        if hasattr(cfg.policy, "drop_n_last_frames"):
            shuffle = False
            sampler = EpisodeAwareSampler(
                dataset.meta.episodes["dataset_from_index"],
                dataset.meta.episodes["dataset_to_index"],
                drop_n_last_frames=cfg.policy.drop_n_last_frames,
                shuffle=True,
            )
        else:
            shuffle = True
            sampler = None
    
        dataloader = torch.utils.data.DataLoader(
            dataset,
            num_workers=cfg.num_workers,
            batch_size=cfg.batch_size,
            shuffle=shuffle and not cfg.dataset.streaming,
            sampler=sampler,
            pin_memory=device.type == "cuda",
            drop_last=False,
            prefetch_factor=2,
        )
        dl_iter = cycle(dataloader)
    
        policy.train()
    
        train_metrics = {
            "loss": AverageMeter("loss", ":.3f"),
            "grad_norm": AverageMeter("grdn", ":.3f"),
            "lr": AverageMeter("lr", ":0.1e"),
            "update_s": AverageMeter("updt_s", ":.3f"),
            "dataloading_s": AverageMeter("data_s", ":.3f"),
        }
    
        train_tracker = MetricsTracker(
            cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
        )
    
        logging.info("Start offline training on a fixed dataset")  # 开始在固定数据集上进行离线训练
        for _ in range(step, cfg.steps):
            start_time = time.perf_counter()
            batch = next(dl_iter)
            batch = preprocessor(batch)
            train_tracker.dataloading_s = time.perf_counter() - start_time
    
            train_tracker, output_dict = update_policy(
                train_tracker,
                policy,
                batch,
                optimizer,
                cfg.optimizer.grad_clip_norm,
                grad_scaler=grad_scaler,
                lr_scheduler=lr_scheduler,
                use_amp=cfg.policy.use_amp,
            )
    
            # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
            # 评估与保存发生在完成该步更新之后
            # increment `step` here.
            # 因此此处递增 step
            step += 1
            train_tracker.step()
            is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
            is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
            is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
    
            if is_log_step:
                logging.info(train_tracker)
                if wandb_logger:
                    wandb_log_dict = train_tracker.to_dict()
                    if output_dict:
                        wandb_log_dict.update(output_dict)
                    wandb_logger.log_dict(wandb_log_dict, step)
                if swanlab_logger:
                    swanlab_log_dict = train_tracker.to_dict()
                    if output_dict:
                        swanlab_log_dict.update(output_dict)
                    swanlab_logger.log_dict(swanlab_log_dict, step)
                train_tracker.reset_averages()
    
            if cfg.save_checkpoint and is_saving_step:
                logging.info(f"Checkpoint policy after step {step}")  # 保存检查点
                checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
                save_checkpoint(
                    checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
                )
                update_last_checkpoint(checkpoint_dir)
                if wandb_logger:
                    wandb_logger.log_policy(checkpoint_dir)
                if swanlab_logger:
                    swanlab_logger.log_policy(checkpoint_dir)
    
            if cfg.env and is_eval_step:
                step_id = get_step_identifier(step, cfg.steps)
                logging.info(f"Eval policy at step {step}")  # 评估策略
                with (
                    torch.no_grad(),
                    torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
                ):
                    eval_info = eval_policy_all(
                        envs=eval_env,  # dict[suite][task_id] -> vec_env
                        policy=policy,
                        preprocessor=preprocessor,
                        postprocessor=postprocessor,
                        n_episodes=cfg.eval.n_episodes,
                        videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
                        max_episodes_rendered=4,
                        start_seed=cfg.seed,
                        max_parallel_tasks=cfg.env.max_parallel_tasks,
                    )
                # overall metrics (suite-agnostic)
                aggregated = eval_info["overall"]
    
                # optional: per-suite logging
                for suite, suite_info in eval_info.items():
                    logging.info("Suite %s aggregated: %s", suite, suite_info)
    
                # meters/tracker
                eval_metrics = {
                    "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
                    "pc_success": AverageMeter("success", ":.1f"),
                    "eval_s": AverageMeter("eval_s", ":.3f"),
                }
                eval_tracker = MetricsTracker(
                    cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
                )
                eval_tracker.eval_s = aggregated.pop("eval_s")
                eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
                eval_tracker.pc_success = aggregated.pop("pc_success")
                if wandb_logger:
                    wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
                    wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
                    wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
    
                if swanlab_logger:
                    swanlab_log_dict = {**eval_tracker.to_dict(), **eval_info}
                    swanlab_logger.log_dict(swanlab_log_dict, step, mode="eval")
                    swanlab_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
        if eval_env:
            close_envs(eval_env)
        logging.info("End of training")  # 训练结束
    
        if cfg.policy.push_to_hub:
            policy.push_model_to_hub(cfg)
            preprocessor.push_to_hub(cfg.policy.repo_id)
            postprocessor.push_to_hub(cfg.policy.repo_id)
    
    
    def main():
        init_logging()
        train()
    
    
    if __name__ == "__main__":
        main()
    
  1. 安装 SwanLab:
    pip install swanlab
    
  2. 在官网登录后获得API Key,运行 swanlab login 登录
  3. 创建训练日志文件夹:
    mkdir -p ./logs
    
  4. 开始训练(100000 步):
    stdbuf -oL -eL nohup python -m lerobot.extra.lerobot_train_swanlab \
    --dataset.repo_id=xiadengma/so101-red-pepper \
    --dataset.root=./data/datasets/xiadengma/so101-red-pepper \
    --policy.type=act \
    --output_dir=./data/train/act_so101_red_pepper \
    --job_name=act_so101_red_pepper_$(date +%Y%m%d_%H%M%S) \
    --policy.device=cuda \
    --wandb.enable=false \
    --policy.push_to_hub=false \
    --steps=100000 \
    --tracker=swanlab \
    --swanlab.project=so101-red-pepper \
    --swanlab.mode=cloud \
    > ./logs/train_$(date +"%Y-%m-%d-%H-%M-%S").log 2>&1 & echo $! > ./logs/train.pid
    
  5. 恢复训练:
    stdbuf -oL -eL nohup lerobot-train \
    --config_path=./data/train/act_so101_red_pepper/checkpoints/last/pretrained_model/train_config.json \
    --resume=true \
    > ./logs/resume_$(date +"%Y-%m-%d-%H-%M-%S").log 2>&1 & echo $! > ./logs/train.pid
    
  6. 查看最新日志:
    tail -f $(ls -t ./logs/*.log | head -n 1)
    
  7. 中断训练:
    kill -TERM $(cat ./logs/train.pid) || kill -KILL $(cat ./logs/train.pid)
    

运行推断并评估
#

  1. 单回合完成一个任务:
    python -m lerobot.extra.lerobot_record_web \
    --robot.type=so101_follower \
    --robot.port=/dev/follower_arm \
    --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
    --robot.id=my_follower_arm \
    --robot.calibration_dir=./data/calibration \
    --display_data=false \
    --dataset.repo_id=xiadengma/eval_so101-red-pepper \
    --dataset.single_task="Put the red pepper toy in the cardboard box" \
    --policy.path=./data/train/act_so101_red_pepper/checkpoints/last/pretrained_model \
    --policy.device=cuda \
    --dataset.root=./data/datasets/xiadengma/eval_so101-red-pepper \
    --web_port=9090 \
    --dataset.episode_time_s=30 \
    --dataset.episode=0
    
  2. 单回合完成长任务:
    python -m lerobot.extra.lerobot_record_web \
    --robot.type=so101_follower \
    --robot.port=/dev/follow_arm \
    --robot.cameras="{ wrist_left: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30}, front_rgb: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \
    --robot.id=my_follower_arm \
    --robot.calibration_dir=./data/calibration \
    --display_data=false \
    --dataset.repo_id=xiadengma/eval_so101-red-pepper \
    --dataset.single_task="Put the red pepper toy in the cardboard box" \
    --policy.path=./data/train/act_so101_red_pepper/checkpoints/last/pretrained_model \
    --policy.device=cuda \
    --dataset.root=./data/datasets/xiadengma/eval_so101-red-pepper \
    --web_port=9090 \
    --dataset.episode_time_s=240 \
    

88. 参考资料
#

xiadengma
作者
xiadengma