362 lines
13 KiB
Go
362 lines
13 KiB
Go
package device
|
||
|
||
import (
|
||
"context"
|
||
"crypto/ecdsa"
|
||
"crypto/md5"
|
||
"crypto/rand"
|
||
"crypto/sha256"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"fmt"
|
||
"io/fs"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/logs"
|
||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/models"
|
||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/repository"
|
||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/transport/proto"
|
||
"git.huangwc.com/pig/pig-farm-controller/internal/infra/utils/file"
|
||
"github.com/gibson042/canonicaljson-go"
|
||
)
|
||
|
||
// Manifest 代表 OTA 升级的清单文件 (manifest.json) 的结构。
|
||
// 它包含了固件的元数据和所有待更新文件的详细信息。
|
||
type Manifest struct {
|
||
Version string `json:"version"` // 新固件的版本号
|
||
Files []ManifestFile `json:"files"` // 待更新的文件列表
|
||
|
||
// Signature 是对 Manifest 内容(不含 Signature 字段本身)的数字签名。
|
||
//
|
||
// **签名生成流程 (平台侧)**:
|
||
// 1. 将此结构体的 Signature 字段设置为空字符串 ""。
|
||
// 2. 使用确定性 JSON 库 (如 canonicaljson) 将结构体序列化。
|
||
// 3. 对序列化后的字节流计算 SHA-256 哈希。
|
||
// 4. 使用平台私钥对哈希进行签名。
|
||
// 5. 将签名结果进行 Base64 编码后,填充到此字段。
|
||
//
|
||
// **签名校验流程 (设备侧)**:
|
||
// 1. 从接收到的 manifest.json 中解析出 Manifest 结构。
|
||
// 2. 暂存 Signature 字段的值。
|
||
// 3. **关键:将结构体中的 Signature 字段置为空字符串 "" (而不是移除该字段)。**
|
||
// 4. 使用确定性 JSON 规则(如 Python 的 json.dumps(sort_keys=True))将修改后的结构体序列化。
|
||
// 5. 对序列化后的字节流计算 SHA-256 哈希。
|
||
// 6. 使用平台公钥、暂存的签名和计算出的哈希进行验签。
|
||
Signature string `json:"signature"`
|
||
}
|
||
|
||
// ManifestFile 定义了清单文件中单个文件的元数据。
|
||
type ManifestFile struct {
|
||
Path string `json:"path"` // 文件在设备上的目标绝对路径
|
||
MD5 string `json:"md5"` // 文件的 MD5 校验和
|
||
Size int64 `json:"size"` // 文件的大小(字节)
|
||
}
|
||
|
||
// OtaConfig 封装了 OTA 服务所需的可配置参数。
|
||
type OtaConfig struct {
|
||
DefaultRetryCount uint32 // 默认的设备端文件下载重试次数
|
||
DefaultRequestTimeoutS uint32 // 默认的设备端文件下载请求超时时间(秒)
|
||
}
|
||
|
||
// otaServiceImpl 是 OtaService 接口的实现。
|
||
type otaServiceImpl struct {
|
||
ctx context.Context
|
||
config OtaConfig
|
||
otaRepo repository.OtaRepository
|
||
generalDeviceService *GeneralDeviceService
|
||
}
|
||
|
||
// NewOtaService 创建一个新的 OtaService 实例。
|
||
func NewOtaService(
|
||
ctx context.Context,
|
||
config OtaConfig,
|
||
otaRepo repository.OtaRepository,
|
||
generalDeviceService *GeneralDeviceService,
|
||
) OtaService {
|
||
return &otaServiceImpl{
|
||
ctx: ctx,
|
||
config: config,
|
||
otaRepo: otaRepo,
|
||
generalDeviceService: generalDeviceService,
|
||
}
|
||
}
|
||
|
||
// upgradeTask 封装了单次升级任务的所有上下文和操作,以提高代码的可读性和模块化。
|
||
type upgradeTask struct {
|
||
service *otaServiceImpl
|
||
ctx context.Context
|
||
logger *logs.Logger
|
||
task *models.OTATask
|
||
firmwarePath string
|
||
tempSubDir string
|
||
}
|
||
|
||
// run 执行核心的升级准备流程。
|
||
// 此方法内的所有操作都处于一个文件锁的保护下。
|
||
func (t *upgradeTask) run() error {
|
||
// 步骤 1: 解压固件
|
||
tempDestPath, err := file.CreateTempDir(t.tempSubDir)
|
||
if err != nil {
|
||
return fmt.Errorf("创建临时目录失败: %w", err)
|
||
}
|
||
if err := file.Decompress(t.firmwarePath, tempDestPath); err != nil {
|
||
return fmt.Errorf("解压固件失败: %w", err)
|
||
}
|
||
t.logger.Infof("为任务 %d 成功解压固件到 %s", t.task.ID, tempDestPath)
|
||
|
||
// 步骤 2: 生成、签名并写入 manifest 文件
|
||
manifest, err := t.service.generateManifest(t.tempSubDir)
|
||
if err != nil {
|
||
return fmt.Errorf("生成 manifest 失败: %w", err)
|
||
}
|
||
if err := t.service.signManifest(manifest); err != nil {
|
||
return fmt.Errorf("签名 manifest 失败: %w", err)
|
||
}
|
||
manifestBytes, err := json.Marshal(manifest)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化 manifest.json 失败: %w", err)
|
||
}
|
||
if _, err := file.WriteTempFile(t.tempSubDir, "manifest.json", manifestBytes); err != nil {
|
||
return fmt.Errorf("写入 manifest.json 失败: %w", err)
|
||
}
|
||
t.logger.Infof("为任务 %d 成功生成并签名 manifest.json", t.task.ID)
|
||
|
||
// 步骤 3: 发送升级指令
|
||
manifestMD5 := fmt.Sprintf("%x", md5.Sum(manifestBytes))
|
||
prepareReq := &proto.PrepareUpdateReq{
|
||
Version: manifest.Version,
|
||
TaskId: t.task.ID,
|
||
ManifestMd5: manifestMD5,
|
||
RetryCount: t.service.config.DefaultRetryCount,
|
||
RequestTimeoutSeconds: t.service.config.DefaultRequestTimeoutS,
|
||
}
|
||
instructionPayload := &proto.Instruction_PrepareUpdateReq{PrepareUpdateReq: prepareReq}
|
||
if err := t.service.generalDeviceService.Send(t.ctx, t.task.AreaControllerID, instructionPayload, WithoutTracking()); err != nil {
|
||
return fmt.Errorf("发送升级指令失败: %w", err)
|
||
}
|
||
t.logger.Infof("为任务 %d 成功发送升级指令", t.task.ID)
|
||
|
||
// 步骤 4: 更新任务状态为“进行中”
|
||
t.task.Status = models.OTATaskStatusInProgress
|
||
t.task.TargetVersion = manifest.Version // 回填目标版本号
|
||
if err := t.service.otaRepo.Update(t.ctx, t.task); err != nil {
|
||
return fmt.Errorf("更新任务状态为 '进行中' 失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// rollback 在 run 方法失败时执行清理和状态更新操作。
|
||
func (t *upgradeTask) rollback(originalErr error) {
|
||
t.logger.Errorf("任务 %d 文件准备阶段失败,执行回滚: %v", t.task.ID, originalErr)
|
||
|
||
// 更新数据库状态为“准备文件失败”
|
||
t.task.Status = models.OTATaskStatusFailedPreparation
|
||
t.task.ErrorMessage = fmt.Sprintf("文件准备阶段失败: %v", originalErr)
|
||
now := time.Now()
|
||
t.task.CompletedAt = &now
|
||
if updateErr := t.service.otaRepo.Update(t.ctx, t.task); updateErr != nil {
|
||
t.logger.DPanicf("CRITICAL: 任务 %d 回滚失败后,更新其状态也失败了: %v", t.task.ID, updateErr)
|
||
}
|
||
|
||
// 清理临时解压目录
|
||
if removeDirErr := file.RemoveTempDir(t.tempSubDir); removeDirErr != nil {
|
||
t.logger.Warnf("回滚操作:清理任务 %d 的临时目录 %s 失败: %v", t.task.ID, t.tempSubDir, removeDirErr)
|
||
}
|
||
// 清理原始固件压缩包
|
||
if removeSrcErr := os.Remove(t.firmwarePath); removeSrcErr != nil {
|
||
t.logger.Warnf("回滚操作:清理任务 %d 的源固件 %s 失败: %v", t.task.ID, t.firmwarePath, removeSrcErr)
|
||
}
|
||
}
|
||
|
||
func (o *otaServiceImpl) StartUpgrade(ctx context.Context, areaControllerID uint32, firmwarePath string) (uint32, error) {
|
||
serviceCtx, logger := logs.Trace(ctx, o.ctx, "StartUpgrade")
|
||
|
||
// 步骤 1: 预创建数据库记录
|
||
task := &models.OTATask{
|
||
AreaControllerID: areaControllerID,
|
||
Status: models.OTATaskStatusPending,
|
||
CreatedAt: time.Now(),
|
||
}
|
||
if err := o.otaRepo.Create(serviceCtx, task); err != nil {
|
||
logger.Errorf("预创建 OTA 任务记录失败: %v", err)
|
||
return 0, fmt.Errorf("预创建 OTA 任务记录失败: %w", err)
|
||
}
|
||
logger.Infof("成功预创建 OTA 任务记录, ID: %d", task.ID)
|
||
|
||
// 步骤 2: 初始化升级任务执行器
|
||
upgrade := &upgradeTask{
|
||
service: o,
|
||
ctx: serviceCtx,
|
||
logger: logger,
|
||
task: task,
|
||
firmwarePath: firmwarePath,
|
||
tempSubDir: filepath.Join(models.OTADir, fmt.Sprintf("%d", task.ID)),
|
||
}
|
||
|
||
// 步骤 3: 在文件锁的保护下,原子化地执行升级准备流程
|
||
if err := file.ExecuteWithLock(upgrade.run, upgrade.rollback); err != nil {
|
||
// 此处的错误已在 rollback 中处理和记录,这里只向调用方返回失败信号
|
||
logger.Errorf("OTA 任务 %d 未能成功启动: %v", task.ID, err)
|
||
return 0, err
|
||
}
|
||
|
||
logger.Infof("OTA 升级任务 %d 已成功启动", task.ID)
|
||
return task.ID, nil
|
||
}
|
||
|
||
func (o *otaServiceImpl) StopUpgrade(ctx context.Context, taskID uint32) error {
|
||
serviceCtx, logger := logs.Trace(ctx, o.ctx, "StopUpgrade")
|
||
|
||
task, err := o.otaRepo.FindByID(serviceCtx, taskID)
|
||
if err != nil {
|
||
logger.Errorf("查找 OTA 任务失败: %v, 任务ID: %d", err, taskID)
|
||
return fmt.Errorf("查找 OTA 任务失败: %w", err)
|
||
}
|
||
|
||
// 幂等性检查:如果任务已处于终态,则直接返回成功
|
||
if task.IsOver() {
|
||
logger.Infof("OTA 任务 %d 已处于终态 %s,无需停止", taskID, task.Status)
|
||
return nil
|
||
}
|
||
|
||
now := time.Now()
|
||
task.Status = models.OTATaskStatusStopped
|
||
task.CompletedAt = &now
|
||
task.ErrorMessage = "任务被用户手动停止"
|
||
|
||
if err := o.otaRepo.Update(serviceCtx, task); err != nil {
|
||
logger.Errorf("更新 OTA 任务状态失败: %v, 任务ID: %d", err, taskID)
|
||
return fmt.Errorf("更新 OTA 任务状态失败: %w", err)
|
||
}
|
||
|
||
// 清理相关文件目录
|
||
dirToRemove := filepath.Join(models.OTADir, fmt.Sprintf("%d", taskID))
|
||
if err := file.RemoveTempDir(dirToRemove); err != nil {
|
||
// 文件清理失败不应阻塞主流程,但需要记录日志
|
||
logger.Warnf("清理 OTA 任务 %d 的文件目录 %s 失败: %v", taskID, dirToRemove, err)
|
||
}
|
||
|
||
logger.Infof("OTA 任务 %d 已被成功标记为手动停止", taskID)
|
||
return nil
|
||
}
|
||
|
||
// generateManifest 遍历指定的固件包子目录,生成一个完整的 Manifest 对象。
|
||
func (o *otaServiceImpl) generateManifest(packageSubDir string) (*Manifest, error) {
|
||
// 1. 读取版本文件
|
||
versionBytes, err := file.ReadTempFile(packageSubDir, "version")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取 version 文件失败: %w", err)
|
||
}
|
||
version := strings.TrimSpace(string(versionBytes))
|
||
|
||
var files []ManifestFile
|
||
|
||
// 2. 使用 WalkTempDir 遍历
|
||
err = file.WalkTempDir(packageSubDir, func(path string, d fs.DirEntry, err error) error {
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if d.Name() == "version" || d.Name() == "manifest.json" {
|
||
return nil
|
||
}
|
||
|
||
// 3. 获取逻辑相对路径
|
||
relPath, err := file.GetRelativePathInTemp(path, packageSubDir)
|
||
if err != nil {
|
||
return fmt.Errorf("无法计算相对路径 '%s': %w", path, err)
|
||
}
|
||
|
||
// 业务转换: 转换为设备端路径
|
||
devicePath := filepath.ToSlash(relPath)
|
||
|
||
// 跳过目录和忽略config目录下的所有文件
|
||
if d.IsDir() {
|
||
if devicePath == "config" {
|
||
return fs.SkipDir
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 4. 读取文件内容用于计算 (直接使用绝对路径,最高效)
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return fmt.Errorf("读取文件失败 '%s': %w", path, err)
|
||
}
|
||
|
||
// 计算元数据
|
||
md5Sum := fmt.Sprintf("%x", md5.Sum(data))
|
||
|
||
// 5. 添加到列表
|
||
files = append(files, ManifestFile{
|
||
Path: devicePath,
|
||
MD5: md5Sum,
|
||
Size: int64(len(data)),
|
||
})
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成清单时遍历目录失败: %w", err)
|
||
}
|
||
|
||
// 6. 创建 Manifest 对象
|
||
manifest := &Manifest{
|
||
Version: version,
|
||
Files: files,
|
||
}
|
||
return manifest, nil
|
||
}
|
||
|
||
// --- 数字签名常量 ---
|
||
// TODO:在生产环境中,强烈建议使用更安全的方式(如环境变量或密钥管理服务)来管理私钥。
|
||
const pemEncodedPrivateKey = `-----BEGIN EC PRIVATE KEY-----
|
||
MHcCAQEEIFDRC/3W22Fw1M/v36w8kO/n8a9A8sUnY2zD1bCgR6eBoAoGCCqGSM49
|
||
AwEHoUQDQgAEWbV3aG6g6Fv5a3p4Y5N5a2b3aG6g6Fv5a3p4Y5N5a2b3aG6g6Fv5
|
||
a3p4Y5N5a2b3aG6g6Fv5a3p4Y5N5a2Y=
|
||
-----END EC PRIVATE KEY-----`
|
||
|
||
// signManifest 使用硬编码的 ECDSA 私钥对 manifest 进行签名。
|
||
// 它遵循确定性 JSON 规范,以确保平台和设备之间可以生成完全一致的待签名数据。
|
||
func (o *otaServiceImpl) signManifest(manifest *Manifest) error {
|
||
// 1. 加载私钥
|
||
block, _ := pem.Decode([]byte(pemEncodedPrivateKey))
|
||
if block == nil {
|
||
return fmt.Errorf("无法解码 PEM 格式的私钥")
|
||
}
|
||
privateKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解析 ECDSA 私钥: %w", err)
|
||
}
|
||
|
||
// 2. 关键:将 Signature 字段置为空字符串,以准备用于签名的“纯净”数据。
|
||
manifest.Signature = ""
|
||
|
||
// 3. 使用 canonicaljson 库将“纯净”数据序列化为确定性的字节流。
|
||
// 这确保了无论执行多少次,只要内容不变,生成的字节流就完全一样。
|
||
signableData, err := canonicaljson.Marshal(manifest)
|
||
if err != nil {
|
||
return fmt.Errorf("无法将 manifest 序列化为确定性 JSON: %w", err)
|
||
}
|
||
|
||
// 4. 对确定性的字节流进行哈希
|
||
hash := sha256.Sum256(signableData)
|
||
|
||
// 5. 使用私钥对哈希进行签名
|
||
signatureBytes, err := ecdsa.SignASN1(rand.Reader, privateKey, hash[:])
|
||
if err != nil {
|
||
return fmt.Errorf("无法对哈希进行签名: %w", err)
|
||
}
|
||
|
||
// 6. 将签名结果进行 Base64 编码后,填充回 manifest 对象。
|
||
// 此时 manifest 对象已包含所有信息,可以被序列化并写入最终的 manifest.json 文件。
|
||
manifest.Signature = base64.StdEncoding.EncodeToString(signatureBytes)
|
||
|
||
return nil
|
||
}
|