diff --git a/internal/core/component_initializers.go b/internal/core/component_initializers.go index 78008c8..ae00392 100644 --- a/internal/core/component_initializers.go +++ b/internal/core/component_initializers.go @@ -147,6 +147,7 @@ type DomainServices struct { pigBatchDomain pig.PigBatchService deviceOperator device.DeviceOperator deviceCommunicator device.DeviceCommunicator + otaService device.OtaService taskFactory plan.TaskFactory planExecutionManager plan.ExecutionManager analysisPlanTaskManager plan.AnalysisPlanTaskManager @@ -191,6 +192,15 @@ func initDomainServices(ctx context.Context, cfg *config.Config, infra *Infrastr infra.repos.pendingCollectionRepo, infra.lora.comm, ) + otaService := device.NewOtaService( + logs.AddCompName(baseCtx, "OtaService"), + device.OtaConfig{ + DefaultRetryCount: uint32(cfg.OTA.DefaultRetryCount), + DefaultRequestTimeoutS: uint32(cfg.OTA.DefaultRequestTimeoutSeconds), + }, + infra.repos.otaRepo, + generalDeviceService, + ) // 告警服务 alarmService := alarm.NewAlarmService( @@ -268,6 +278,7 @@ func initDomainServices(ctx context.Context, cfg *config.Config, infra *Infrastr pigBatchDomain: pigBatchDomain, deviceOperator: generalDeviceService, deviceCommunicator: generalDeviceService, + otaService: otaService, analysisPlanTaskManager: analysisPlanTaskManager, taskFactory: taskFactory, planExecutionManager: planExecutionManager, @@ -344,6 +355,7 @@ func initAppServices(ctx context.Context, infra *Infrastructure, domainServices logs.AddCompName(baseCtx, "AreaControllerService"), infra.repos.areaControllerRepo, thresholdAlarmService, + domainServices.otaService, ) auditService := service.NewAuditService(logs.AddCompName(baseCtx, "AuditService"), infra.repos.userActionLogRepo) diff --git a/internal/domain/device/ota_service.go b/internal/domain/device/ota_service.go index 10b3281..18ec1a2 100644 --- a/internal/domain/device/ota_service.go +++ b/internal/domain/device/ota_service.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/base64" + "encoding/json" "encoding/pem" "fmt" "io/fs" @@ -19,8 +20,8 @@ import ( "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" "git.huangwc.com/pig/pig-farm-controller/internal/infra/models" "git.huangwc.com/pig/pig-farm-controller/internal/infra/repository" + "git.huangwc.com/pig/pig-farm-controller/internal/infra/transport/proto" "git.huangwc.com/pig/pig-farm-controller/internal/infra/utils/file" - "github.com/gibson042/canonicaljson-go" ) @@ -56,25 +57,157 @@ type ManifestFile struct { Size int64 `json:"size"` // 文件的大小(字节) } +// OtaConfig 封装了 OTA 服务所需的可配置参数。 +type OtaConfig struct { + DefaultRetryCount uint32 // 默认的设备端文件下载重试次数 + DefaultRequestTimeoutS uint32 // 默认的设备端文件下载请求超时时间(秒) +} + // otaServiceImpl 是 OtaService 接口的实现。 type otaServiceImpl struct { - ctx context.Context - otaRepo repository.OtaRepository - deviceRepo repository.DeviceRepository + ctx context.Context + config OtaConfig + otaRepo repository.OtaRepository + generalDeviceService *GeneralDeviceService } // NewOtaService 创建一个新的 OtaService 实例。 -func NewOtaService(ctx context.Context, otaRepo repository.OtaRepository, deviceRepo repository.DeviceRepository) OtaService { +func NewOtaService( + ctx context.Context, + config OtaConfig, + otaRepo repository.OtaRepository, + generalDeviceService *GeneralDeviceService, +) OtaService { return &otaServiceImpl{ - ctx: ctx, - otaRepo: otaRepo, - deviceRepo: deviceRepo, + ctx: ctx, + config: config, + otaRepo: otaRepo, + generalDeviceService: generalDeviceService, + } +} + +// upgradeTask 封装了单次升级任务的所有上下文和操作,以提高代码的可读性和模块化。 +type upgradeTask struct { + service *otaServiceImpl + ctx context.Context + logger *logs.Logger + task *models.OTATask + firmwarePath string + tempSubDir string +} + +// run 执行核心的升级准备流程。 +// 此方法内的所有操作都处于一个文件锁的保护下。 +func (t *upgradeTask) run() error { + // 步骤 1: 解压固件 + tempDestPath, err := file.CreateTempDir(t.tempSubDir) + if err != nil { + return fmt.Errorf("创建临时目录失败: %w", err) + } + if err := file.Decompress(t.firmwarePath, tempDestPath); err != nil { + return fmt.Errorf("解压固件失败: %w", err) + } + t.logger.Infof("为任务 %d 成功解压固件到 %s", t.task.ID, tempDestPath) + + // 步骤 2: 生成、签名并写入 manifest 文件 + manifest, err := t.service.generateManifest(t.tempSubDir) + if err != nil { + return fmt.Errorf("生成 manifest 失败: %w", err) + } + if err := t.service.signManifest(manifest); err != nil { + return fmt.Errorf("签名 manifest 失败: %w", err) + } + manifestBytes, err := json.Marshal(manifest) + if err != nil { + return fmt.Errorf("序列化 manifest.json 失败: %w", err) + } + if _, err := file.WriteTempFile(t.tempSubDir, "manifest.json", manifestBytes); err != nil { + return fmt.Errorf("写入 manifest.json 失败: %w", err) + } + t.logger.Infof("为任务 %d 成功生成并签名 manifest.json", t.task.ID) + + // 步骤 3: 发送升级指令 + manifestMD5 := fmt.Sprintf("%x", md5.Sum(manifestBytes)) + prepareReq := &proto.PrepareUpdateReq{ + Version: manifest.Version, + TaskId: t.task.ID, + ManifestMd5: manifestMD5, + RetryCount: t.service.config.DefaultRetryCount, + RequestTimeoutSeconds: t.service.config.DefaultRequestTimeoutS, + } + instructionPayload := &proto.Instruction_PrepareUpdateReq{PrepareUpdateReq: prepareReq} + if err := t.service.generalDeviceService.Send(t.ctx, t.task.AreaControllerID, instructionPayload); err != nil { + return fmt.Errorf("发送升级指令失败: %w", err) + } + t.logger.Infof("为任务 %d 成功发送升级指令", t.task.ID) + + // 步骤 4: 更新任务状态为“进行中” + t.task.Status = models.OTATaskStatusInProgress + t.task.TargetVersion = manifest.Version // 回填目标版本号 + if err := t.service.otaRepo.Update(t.ctx, t.task); err != nil { + return fmt.Errorf("更新任务状态为 '进行中' 失败: %w", err) + } + + return nil +} + +// rollback 在 run 方法失败时执行清理和状态更新操作。 +func (t *upgradeTask) rollback(originalErr error) { + t.logger.Errorf("任务 %d 文件准备阶段失败,执行回滚: %v", t.task.ID, originalErr) + + // 更新数据库状态为“准备文件失败” + t.task.Status = models.OTATaskStatusFailedPreparation + t.task.ErrorMessage = fmt.Sprintf("文件准备阶段失败: %v", originalErr) + now := time.Now() + t.task.CompletedAt = &now + if updateErr := t.service.otaRepo.Update(t.ctx, t.task); updateErr != nil { + t.logger.DPanicf("CRITICAL: 任务 %d 回滚失败后,更新其状态也失败了: %v", t.task.ID, updateErr) + } + + // 清理临时解压目录 + if removeDirErr := file.RemoveTempDir(t.tempSubDir); removeDirErr != nil { + t.logger.Warnf("回滚操作:清理任务 %d 的临时目录 %s 失败: %v", t.task.ID, t.tempSubDir, removeDirErr) + } + // 清理原始固件压缩包 + if removeSrcErr := os.Remove(t.firmwarePath); removeSrcErr != nil { + t.logger.Warnf("回滚操作:清理任务 %d 的源固件 %s 失败: %v", t.task.ID, t.firmwarePath, removeSrcErr) } } func (o *otaServiceImpl) StartUpgrade(ctx context.Context, areaControllerID uint32, firmwarePath string) (uint32, error) { - //TODO implement me - panic("implement me") + serviceCtx, logger := logs.Trace(ctx, o.ctx, "StartUpgrade") + + // 步骤 1: 预创建数据库记录 + task := &models.OTATask{ + AreaControllerID: areaControllerID, + Status: models.OTATaskStatusPending, + CreatedAt: time.Now(), + } + if err := o.otaRepo.Create(serviceCtx, task); err != nil { + logger.Errorf("预创建 OTA 任务记录失败: %v", err) + return 0, fmt.Errorf("预创建 OTA 任务记录失败: %w", err) + } + logger.Infof("成功预创建 OTA 任务记录, ID: %d", task.ID) + + // 步骤 2: 初始化升级任务执行器 + upgrade := &upgradeTask{ + service: o, + ctx: serviceCtx, + logger: logger, + task: task, + firmwarePath: firmwarePath, + tempSubDir: filepath.Join(models.OTADir, fmt.Sprintf("%d", task.ID)), + } + + // 步骤 3: 在文件锁的保护下,原子化地执行升级准备流程 + if err := file.ExecuteWithLock(upgrade.run, upgrade.rollback); err != nil { + // 此处的错误已在 rollback 中处理和记录,这里只向调用方返回失败信号 + logger.Errorf("OTA 任务 %d 未能成功启动: %v", task.ID, err) + return 0, err + } + + logger.Infof("OTA 升级任务 %d 已成功启动", task.ID) + return task.ID, nil } func (o *otaServiceImpl) GetUpgradeProgress(ctx context.Context, taskID uint32) (executed, total uint32, CurrentStage models.OTATaskStatus, err error) { @@ -134,7 +267,7 @@ func (o *otaServiceImpl) generateManifest(packageSubDir string) (*Manifest, erro if err != nil { return err } - if d.Name() == "version" { + if d.Name() == "version" || d.Name() == "manifest.json" { return nil } diff --git a/internal/infra/models/execution.go b/internal/infra/models/execution.go index c5e9ea7..5b71e2b 100644 --- a/internal/infra/models/execution.go +++ b/internal/infra/models/execution.go @@ -166,12 +166,13 @@ const ( OTATaskStatusSuccess OTATaskStatus = "成功" // 设备报告升级成功,新固件已运行 OTATaskStatusAlreadyUpToDate OTATaskStatus = "版本已是最新" // 设备报告版本已是最新,未执行升级 - OTATaskStatusFailedPreCheck OTATaskStatus = "预检失败" // 设备报告升级前检查失败 (如拒绝降级、准备分区失败) - OTATaskStatusFailedDownload OTATaskStatus = "下载或校验失败" // 设备报告文件下载或校验失败 (包括清单文件和固件文件) - OTATaskStatusFailedRollback OTATaskStatus = "固件回滚" // 新固件启动失败,设备自动回滚 - OTATaskStatusTimedOut OTATaskStatus = "超时" // 平台在超时后仍未收到最终报告 - OTATaskStatusPlatformError OTATaskStatus = "平台内部错误" // 平台处理过程中发生的非设备报告错误 - OTATaskStatusStopped OTATaskStatus = "手动停止" // 手动停止 + OTATaskStatusFailedPreparation OTATaskStatus = "准备升级失败" // 平台在解压、生成清单等文件操作阶段发生错误 + OTATaskStatusFailedPreCheck OTATaskStatus = "预检失败" // 设备报告升级前检查失败 (如拒绝降级、准备分区失败) + OTATaskStatusFailedDownload OTATaskStatus = "下载或校验失败" // 设备报告文件下载或校验失败 (包括清单文件和固件文件) + OTATaskStatusFailedRollback OTATaskStatus = "固件回滚" // 新固件启动失败,设备自动回滚 + OTATaskStatusTimedOut OTATaskStatus = "超时" // 平台在超时后仍未收到最终报告 + OTATaskStatusPlatformError OTATaskStatus = "平台内部错误" // 平台处理过程中发生的非设备报告错误 + OTATaskStatusStopped OTATaskStatus = "手动停止" // 手动停止 ) // OTADir 是 OTA 升级相关的临时文件存储目录 diff --git a/internal/infra/repository/ota_repository.go b/internal/infra/repository/ota_repository.go index bca36d5..5d04ae0 100644 --- a/internal/infra/repository/ota_repository.go +++ b/internal/infra/repository/ota_repository.go @@ -14,12 +14,16 @@ import ( type OtaRepository interface { // Create 创建一个新的 OTA 任务。 Create(ctx context.Context, task *models.OTATask) error + // CreateTx 在指定的事务中创建一个新的 OTA 任务。 + CreateTx(ctx context.Context, tx *gorm.DB, task *models.OTATask) error // FindByID 根据任务 ID 查找任务。 FindByID(ctx context.Context, id uint32) (*models.OTATask, error) // FindTasksByStatusesAndCreationTime 根据状态列表和创建时间查找任务。 FindTasksByStatusesAndCreationTime(ctx context.Context, statuses []models.OTATaskStatus, createdBefore time.Time) ([]*models.OTATask, error) // Update 更新单个 OTA 任务。 Update(ctx context.Context, task *models.OTATask) error + // UpdateTx 在指定的事务中更新单个 OTA 任务。 + UpdateTx(ctx context.Context, tx *gorm.DB, task *models.OTATask) error } // gormOtaRepository 是 OtaRepository 的 GORM 实现 @@ -36,10 +40,17 @@ func NewGormOtaRepository(ctx context.Context, db *gorm.DB) OtaRepository { } } -// Create 实现了创建新 OTA 任务的逻辑。 +// Create 实现了创建新 OTA 任务的逻辑,内部调用 CreateTx 以复用代码。 func (r *gormOtaRepository) Create(ctx context.Context, task *models.OTATask) error { repoCtx := logs.AddFuncName(ctx, r.ctx, "Create") - return r.db.WithContext(repoCtx).Create(task).Error + // 使用 r.db 作为事务对象,调用通用的事务方法 + return r.CreateTx(repoCtx, r.db, task) +} + +// CreateTx 实现了在事务中创建新 OTA 任务的核心逻辑。 +func (r *gormOtaRepository) CreateTx(ctx context.Context, tx *gorm.DB, task *models.OTATask) error { + repoCtx := logs.AddFuncName(ctx, r.ctx, "CreateTx") + return tx.WithContext(repoCtx).Create(task).Error } // FindByID 实现了根据 ID 查找任务的逻辑。 @@ -47,7 +58,10 @@ func (r *gormOtaRepository) FindByID(ctx context.Context, id uint32) (*models.OT repoCtx := logs.AddFuncName(ctx, r.ctx, "FindByID") var task models.OTATask err := r.db.WithContext(repoCtx).First(&task, id).Error - return &task, err + if err != nil { + return nil, err + } + return &task, nil } // FindTasksByStatusesAndCreationTime 实现了根据状态和创建时间查找任务的逻辑。 @@ -63,8 +77,15 @@ func (r *gormOtaRepository) FindTasksByStatusesAndCreationTime(ctx context.Conte return tasks, err } -// Update 实现了更新单个 OTA 任务的逻辑。 +// Update 实现了更新单个 OTA 任务的逻辑,内部调用 UpdateTx 以复用代码。 func (r *gormOtaRepository) Update(ctx context.Context, task *models.OTATask) error { repoCtx := logs.AddFuncName(ctx, r.ctx, "Update") - return r.db.WithContext(repoCtx).Save(task).Error + // 使用 r.db 作为事务对象,调用通用的事务方法 + return r.UpdateTx(repoCtx, r.db, task) +} + +// UpdateTx 实现了在事务中更新单个 OTA 任务的核心逻辑。 +func (r *gormOtaRepository) UpdateTx(ctx context.Context, tx *gorm.DB, task *models.OTATask) error { + repoCtx := logs.AddFuncName(ctx, r.ctx, "UpdateTx") + return tx.WithContext(repoCtx).Save(task).Error } diff --git a/internal/infra/utils/file/decompress.go b/internal/infra/utils/file/decompress.go index 3768f33..1b0eb70 100644 --- a/internal/infra/utils/file/decompress.go +++ b/internal/infra/utils/file/decompress.go @@ -45,7 +45,7 @@ func (ct CompressionType) Matches(filename string) bool { // 此函数假定它拥有对 destPath 的完全控制权,并会在失败时将其彻底删除。 func DecompressAtomic(sourcePath, destPath string) error { action := func() error { - return decompress(sourcePath, destPath) + return Decompress(sourcePath, destPath) } onRollback := func(err error) { @@ -55,9 +55,9 @@ func DecompressAtomic(sourcePath, destPath string) error { return ExecuteWithLock(action, onRollback) } -// decompress 是解压操作的核心实现,它本身不是线程安全的,也不提供回滚。 +// Decompress 是解压操作的核心实现,它本身不是线程安全的,也不提供回滚。 // 它应该被 DecompressAtomic 或其他调用方在 ExecuteWithLock 的回调中执行。 -func decompress(sourcePath, destPath string) error { +func Decompress(sourcePath, destPath string) error { // 确保目标目录存在 if err := os.MkdirAll(destPath, 0755); err != nil { return fmt.Errorf("创建目标目录 %s 失败: %w", destPath, err)