diff --git a/internal/sub/service_test.go b/internal/sub/service_test.go index aeea02c40..8fe3a150f 100644 --- a/internal/sub/service_test.go +++ b/internal/sub/service_test.go @@ -95,6 +95,7 @@ func TestListenIsInternalOnly(t *testing.T) { } func TestResolveInboundAddress(t *testing.T) { + initSubDB(t) const reqHost = "sub.example.com" // A routable bind Listen (a real IP or hostname the operator set as the diff --git a/internal/web/service/client_bulk.go b/internal/web/service/client_bulk.go index 3617390f0..288964517 100644 --- a/internal/web/service/client_bulk.go +++ b/internal/web/service/client_bulk.go @@ -598,25 +598,19 @@ func (s *ClientService) bulkAdjustInboundClients( res.needRestart = true } - markDirty := false if oldInbound.NodeID != nil { - rt, push, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { for email := range foundEmails { res.perEmailSkipped[email] = perr.Error() delete(foundEmails, email) } } else { - if dirty { - markDirty = true - } if flowChanged { - markDirty = true push = false } // Large batches collapse into one reconcile push rather than M updates. if push && len(foundEmails) > nodeBulkPushThreshold { - markDirty = true push = false } if push { @@ -632,7 +626,6 @@ func (s *ClientService) bulkAdjustInboundClients( updated.UpdatedAt = nowMs if err1 := rt.UpdateUser(context.Background(), oldInbound, email, updated); err1 != nil { logger.Warning("Error in updating client on", rt.Name(), ":", err1) - markDirty = true } } } @@ -649,7 +642,13 @@ func (s *ClientService) bulkAdjustInboundClients( if gcErr != nil { return gcErr } - return s.SyncInbound(tx, inboundId, finalClients) + if err := s.SyncInbound(tx, inboundId, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }) if txErr != nil { for email := range foundEmails { @@ -657,10 +656,6 @@ func (s *ClientService) bulkAdjustInboundClients( res.perEmailSkipped[email] = txErr.Error() } } - } else if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } } return res @@ -973,7 +968,6 @@ func (s *ClientService) bulkDelInboundClients( } } - markDirty := false if oldInbound.NodeID == nil { rt, rterr := inboundSvc.runtimeFor(oldInbound) if rterr != nil { @@ -995,26 +989,21 @@ func (s *ClientService) bulkDelInboundClients( } } } else { - rt, push, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { for email := range foundEmails { res.perEmailSkipped[email] = perr.Error() delete(foundEmails, email) } } else { - if dirty { - markDirty = true - } // Large batches collapse into one reconcile push rather than M deletes. if push && len(foundEmails) > nodeBulkPushThreshold { - markDirty = true push = false } if push { for email := range foundEmails { if err1 := rt.DeleteUser(context.Background(), oldInbound, email); err1 != nil { logger.Warning("Error in deleting client on", rt.Name(), ":", err1) - markDirty = true } } } @@ -1031,7 +1020,13 @@ func (s *ClientService) bulkDelInboundClients( if err != nil { return err } - return s.SyncInbound(tx, inboundId, finalClients) + if err := s.SyncInbound(tx, inboundId, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }) if txErr != nil { for email := range foundEmails { @@ -1039,10 +1034,6 @@ func (s *ClientService) bulkDelInboundClients( res.perEmailSkipped[email] = txErr.Error() } } - } else if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } } return res @@ -1512,16 +1503,14 @@ func (s *ClientService) bulkSetEnableInboundClients(inboundSvc *InboundService, } oldInbound.Settings = string(newSettings) - rt, push, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { for _, ch := range changed { res.perEmailSkipped[ch.email] = perr.Error() } return res } - markDirty := dirty if oldInbound.NodeID != nil && push && len(changed) > nodeBulkPushThreshold { - markDirty = true push = false } @@ -1533,7 +1522,13 @@ func (s *ClientService) bulkSetEnableInboundClients(inboundSvc *InboundService, if gcErr != nil { return gcErr } - return s.SyncInbound(tx, inboundId, finalClients) + if err := s.SyncInbound(tx, inboundId, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }) if txErr != nil { for _, ch := range changed { @@ -1576,16 +1571,9 @@ func (s *ClientService) bulkSetEnableInboundClients(inboundSvc *InboundService, updated.UpdatedAt = nowMs if err1 := rt.UpdateUser(context.Background(), oldInbound, ch.email, updated); err1 != nil { logger.Warning("Error in updating client on", rt.Name(), ":", err1) - markDirty = true } } } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } - return res } diff --git a/internal/web/service/client_inbound_apply.go b/internal/web/service/client_inbound_apply.go index d9292e4c9..6170dff2a 100644 --- a/internal/web/service/client_inbound_apply.go +++ b/internal/web/service/client_inbound_apply.go @@ -107,7 +107,6 @@ func (s *ClientService) delInboundClients(inboundSvc *InboundService, inboundId } needRestart := false - markDirty := false // Read each client's live state before the DB write (DelClientStat would // erase the enable flag we need to decide on a runtime removal). @@ -158,7 +157,13 @@ func (s *ClientService) delInboundClients(inboundSvc *InboundService, inboundId if gcErr != nil { return gcErr } - return s.SyncInbound(tx, inboundId, finalClients) + if err := s.SyncInbound(tx, inboundId, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }); txErr != nil { return needRestart, txErr } @@ -167,17 +172,13 @@ func (s *ClientService) delInboundClients(inboundSvc *InboundService, inboundId var nodeRt runtime.Runtime nodePush := false if oldInbound.NodeID != nil { - rt, push, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { return needRestart, perr } - if dirty { - markDirty = true - } nodeRt, nodePush = rt, push // Large batches collapse into one reconcile push rather than M deletes. if nodePush && len(targets) > nodeBulkPushThreshold { - markDirty = true nodePush = false } } @@ -202,16 +203,10 @@ func (s *ClientService) delInboundClients(inboundSvc *InboundService, inboundId } else if nodePush { if err1 := nodeRt.DeleteUser(context.Background(), oldInbound, t.email); err1 != nil { logger.Warning("Error in deleting client on", nodeRt.Name(), ":", err1) - markDirty = true } } } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } return needRestart, nil } @@ -357,15 +352,11 @@ func (s *ClientService) addInboundClient(inboundSvc *InboundService, data *model oldInbound.Settings = string(newSettings) needRestart := false - markDirty := false - rt, push, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { return false, perr } - if dirty { - markDirty = true - } // Persist client stats + inbound atomically, serialized against the traffic // poll to avoid the cross-transaction lock-order deadlock (runSerializedTx). @@ -385,7 +376,13 @@ func (s *ClientService) addInboundClient(inboundSvc *InboundService, data *model if gcErr != nil { return gcErr } - return s.SyncInbound(tx, oldInbound.Id, finalClients) + if err := s.SyncInbound(tx, oldInbound.Id, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }); txErr != nil { return false, txErr } @@ -434,25 +431,18 @@ func (s *ClientService) addInboundClient(inboundSvc *InboundService, data *model // settings already hold the final set, so mark dirty and let one reconcile // push converge the node instead. if push && len(clients) > nodeBulkPushThreshold { - markDirty = true push = false } for _, client := range clients { if push { if err1 := rt.AddClient(context.Background(), oldInbound, client); err1 != nil { logger.Warning("Error in adding client on", rt.Name(), ":", err1) - markDirty = true push = false } } } } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } return needRestart, nil } @@ -623,7 +613,6 @@ func (s *ClientService) UpdateInboundClient(inboundSvc *InboundService, data *mo oldInbound.Settings = string(newSettings) needRestart := false - markDirty := false // Resolve the push plan before the DB write so a node-state lookup failure // still aborts the whole update without committing anything (it used to roll @@ -631,15 +620,11 @@ func (s *ClientService) UpdateInboundClient(inboundSvc *InboundService, data *mo var rt runtime.Runtime var push bool if len(oldEmail) > 0 { - var dirty bool var perr error - rt, push, dirty, perr = inboundSvc.nodePushPlan(oldInbound) + rt, push, _, perr = inboundSvc.nodePushPlan(oldInbound) if perr != nil { return false, perr } - if dirty { - markDirty = true - } } // Persist client stats + inbound atomically, serialized against the traffic @@ -705,7 +690,13 @@ func (s *ClientService) UpdateInboundClient(inboundSvc *InboundService, data *mo if gcErr != nil { return gcErr } - return s.SyncInbound(tx, oldInbound.Id, finalClients) + if err := s.SyncInbound(tx, oldInbound.Id, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }); txErr != nil { return false, txErr } @@ -757,7 +748,6 @@ func (s *ClientService) UpdateInboundClient(inboundSvc *InboundService, data *mo } else if push { if err1 := rt.UpdateUser(context.Background(), oldInbound, oldEmail, clients[0]); err1 != nil { logger.Warning("Error in updating client on", rt.Name(), ":", err1) - markDirty = true } } } else { @@ -765,11 +755,6 @@ func (s *ClientService) UpdateInboundClient(inboundSvc *InboundService, data *mo needRestart = true } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } return needRestart, nil } @@ -831,7 +816,6 @@ func (s *ClientService) DelInboundClientByEmail(inboundSvc *InboundService, inbo } needRestart := false - markDirty := false // Decide what to delete and the push plan before the serialized DB write — // these are reads, and nodePushPlan failing should abort before committing. @@ -850,14 +834,11 @@ func (s *ClientService) DelInboundClientByEmail(inboundSvc *InboundService, inbo var rt runtime.Runtime var push bool if len(email) > 0 && (oldInbound.NodeID != nil || needApiDel) { - r, p, dirty, perr := inboundSvc.nodePushPlan(oldInbound) + r, p, _, perr := inboundSvc.nodePushPlan(oldInbound) if perr != nil { return false, perr } rt, push = r, p - if dirty { - markDirty = true - } } // Persist the deletion atomically, serialized against the traffic poll to @@ -882,7 +863,13 @@ func (s *ClientService) DelInboundClientByEmail(inboundSvc *InboundService, inbo if gcErr != nil { return gcErr } - return s.SyncInbound(tx, inboundId, finalClients) + if err := s.SyncInbound(tx, inboundId, finalClients); err != nil { + return err + } + if oldInbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID) + } + return nil }); txErr != nil { return false, txErr } @@ -915,17 +902,11 @@ func (s *ClientService) DelInboundClientByEmail(inboundSvc *InboundService, inbo if push { if err1 := rt.DeleteUser(context.Background(), oldInbound, email); err1 != nil { logger.Warning("Error in deleting client on", rt.Name(), ":", err1) - markDirty = true } } } } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } return needRestart, nil } diff --git a/internal/web/service/inbound.go b/internal/web/service/inbound.go index 96e92198e..2c6d86b8b 100644 --- a/internal/web/service/inbound.go +++ b/internal/web/service/inbound.go @@ -659,12 +659,14 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo tx.Rollback() return } - tx.Commit() if markDirty && inbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*inbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) + if dErr := (&NodeService{}).MarkNodeDirtyTx(tx, *inbound.NodeID); dErr != nil { + err = dErr + tx.Rollback() + return } } + tx.Commit() }() // Omit the ClientStats has-many association: GORM's cascade would INSERT @@ -809,17 +811,20 @@ func (s *InboundService) DelInbound(id int) (bool, error) { } } - if err := db.Delete(model.Inbound{}, id).Error; err != nil { - return needRestart, err - } - // Hosts have no hard FK; drop the inbound's hosts alongside it. - if err := db.Where("inbound_id = ?", id).Delete(&model.Host{}).Error; err != nil { - return needRestart, err - } - if markDirty && ib.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*ib.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Delete(model.Inbound{}, id).Error; err != nil { + return err } + // Hosts have no hard FK; drop the inbound's hosts alongside it. + if err := tx.Where("inbound_id = ?", id).Delete(&model.Host{}).Error; err != nil { + return err + } + if markDirty && ib.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *ib.NodeID) + } + return nil + }); err != nil { + return needRestart, err } if !database.IsPostgres() { var count int64 @@ -902,14 +907,22 @@ func (s *InboundService) SetInboundEnable(id int, enable bool) (bool, error) { } db := database.GetDB() - if err := db.Model(model.Inbound{}).Where("id = ?", id). - Update("enable", enable).Error; err != nil { + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(model.Inbound{}).Where("id = ?", id). + Update("enable", enable).Error; err != nil { + return err + } + if inbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *inbound.NodeID) + } + return nil + }); err != nil { return false, err } inbound.Enable = enable needRestart := false - rt, push, dirty, perr := s.nodePushPlan(inbound) + rt, push, _, perr := s.nodePushPlan(inbound) if perr != nil { return false, perr } @@ -923,12 +936,6 @@ func (s *InboundService) SetInboundEnable(id int, enable bool) (bool, error) { if push { if err := rt.UpdateInbound(context.Background(), inbound, inbound); err != nil { logger.Warning("SetInboundEnable: remote UpdateInbound on", rt.Name(), "failed:", err) - dirty = true - } - } - if dirty { - if dErr := (&NodeService{}).MarkNodeDirty(*inbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) } } return false, nil @@ -991,7 +998,6 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, oldTagWasAuto := isAutoGeneratedTag(tag, oldInbound.Port, oldInbound.NodeID, oldBits) needRestart := false - markDirty := false // Persist the client-stat sync, settings munging, runtime push and inbound // save as one transaction routed through the serial traffic writer, so it @@ -1117,13 +1123,10 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, oldInbound.Tag = resolvedTag inbound.Tag = oldInbound.Tag - rt, push, dirty, perr := s.nodePushPlan(oldInbound) + rt, push, _, perr := s.nodePushPlan(oldInbound) if perr != nil { return perr } - if dirty { - markDirty = true - } if oldInbound.NodeID == nil { if !push { needRestart = true @@ -1152,11 +1155,9 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, if !inbound.Enable { if err2 := rt.DelInbound(context.Background(), &oldSnapshot); err2 != nil { logger.Warning("Unable to disable inbound on", rt.Name(), ":", err2) - markDirty = true } } else if err2 := rt.UpdateInbound(context.Background(), &oldSnapshot, oldInbound); err2 != nil { logger.Warning("Unable to update inbound on", rt.Name(), ":", err2) - markDirty = true } } @@ -1179,6 +1180,11 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, if err := s.clientService.SyncInbound(tx, oldInbound.Id, newClients); err != nil { return err } + if oldInbound.NodeID != nil { + if err := (&NodeService{}).MarkNodeDirtyTx(tx, *oldInbound.NodeID); err != nil { + return err + } + } // (Re)generate the Xray config whenever routing was or is now enabled, so // the egress SOCKS bridge is added, moved, or dropped to match the new // settings. @@ -1201,11 +1207,6 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, needRestart = true } } - if markDirty && oldInbound.NodeID != nil { - if dErr := (&NodeService{}).MarkNodeDirty(*oldInbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } return inbound, needRestart, nil } diff --git a/internal/web/service/inbound_traffic.go b/internal/web/service/inbound_traffic.go index d77a11e5b..faf73fd80 100644 --- a/internal/web/service/inbound_traffic.go +++ b/internal/web/service/inbound_traffic.go @@ -544,18 +544,12 @@ func (s *InboundService) resetClientTrafficLocked(id int, clientEmail string) (b } for _, client := range clients { if client.Email == clientEmail && client.Enable { - rt, push, dirty, perr := s.nodePushPlan(inbound) + rt, push, _, perr := s.nodePushPlan(inbound) if perr != nil { return false, perr } if !push { - if inbound.NodeID != nil { - if dirty { - if dErr := (&NodeService{}).MarkNodeDirty(*inbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } - } - } else { + if inbound.NodeID == nil { needRestart = true } break @@ -582,9 +576,6 @@ func (s *InboundService) resetClientTrafficLocked(id int, clientEmail string) (b logger.Debug("Client enabled on", rt.Name(), "due to reset traffic:", clientEmail) } else if inbound.NodeID != nil { logger.Warning("Error in enabling client on", rt.Name(), ":", err1) - if dErr := (&NodeService{}).MarkNodeDirty(*inbound.NodeID); dErr != nil { - logger.Warning("mark node dirty failed:", dErr) - } } else { logger.Debug("Error in enabling client on", rt.Name(), ":", err1) needRestart = true @@ -599,24 +590,35 @@ func (s *InboundService) resetClientTrafficLocked(id int, clientEmail string) (b traffic.Enable = true db := database.GetDB() - err = db.Save(traffic).Error + now := time.Now().UnixMilli() + inbound, err := s.GetInbound(id) if err != nil { return false, err } - if err := clearGlobalTraffic(db, clientEmail); err != nil { - return false, err - } - if err := db.Where("email = ?", clientEmail).Delete(&model.NodeClientTraffic{}).Error; err != nil { + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Save(traffic).Error; err != nil { + return err + } + if err := clearGlobalTraffic(tx, clientEmail); err != nil { + return err + } + if err := tx.Where("email = ?", clientEmail).Delete(&model.NodeClientTraffic{}).Error; err != nil { + return err + } + if err := tx.Model(model.Inbound{}). + Where("id = ?", id). + Update("last_traffic_reset_time", now).Error; err != nil { + return err + } + if inbound != nil && inbound.NodeID != nil { + return (&NodeService{}).MarkNodeDirtyTx(tx, *inbound.NodeID) + } + return nil + }); err != nil { return false, err } - now := time.Now().UnixMilli() - _ = db.Model(model.Inbound{}). - Where("id = ?", id). - Update("last_traffic_reset_time", now).Error - - inbound, err := s.GetInbound(id) - if err == nil && inbound != nil && inbound.NodeID != nil { + if inbound != nil && inbound.NodeID != nil { if rt, rterr := s.runtimeFor(inbound); rterr == nil { if e := rt.ResetClientTraffic(context.Background(), inbound, clientEmail); e != nil { logger.Warning("ResetClientTraffic: remote propagation to", rt.Name(), "failed:", e) diff --git a/internal/web/service/node.go b/internal/web/service/node.go index 630eb5993..206b0939e 100644 --- a/internal/web/service/node.go +++ b/internal/web/service/node.go @@ -453,12 +453,14 @@ func (s *NodeService) Update(id int, in *model.Node) error { "inbound_tags": string(inboundTagsJSON), "outbound_tag": in.OutboundTag, } - if err := db.Model(model.Node{}).Where("id = ?", id).Updates(updates).Error; err != nil { + if err := db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(model.Node{}).Where("id = ?", id).Updates(updates).Error; err != nil { + return err + } + return s.MarkNodeDirtyTx(tx, id) + }); err != nil { return err } - if dErr := s.MarkNodeDirty(id); dErr != nil { - logger.Warning("mark node dirty after update failed:", dErr) - } if mgr := runtime.GetManager(); mgr != nil { mgr.InvalidateNode(id) } @@ -736,10 +738,17 @@ func (s *NodeService) warnOnDuplicateGuid(id int, guid string) { } func (s *NodeService) MarkNodeDirty(id int) error { + return s.MarkNodeDirtyTx(database.GetDB(), id) +} + +func (s *NodeService) MarkNodeDirtyTx(tx *gorm.DB, id int) error { if id <= 0 { return nil } - return database.GetDB().Model(model.Node{}). + if tx == nil { + return errors.New("nil db transaction") + } + return tx.Model(model.Node{}). Where("id = ?", id). Updates(map[string]any{ "config_dirty": true, diff --git a/internal/web/service/node_dirty_test.go b/internal/web/service/node_dirty_test.go index 92ec711f6..2d299a359 100644 --- a/internal/web/service/node_dirty_test.go +++ b/internal/web/service/node_dirty_test.go @@ -1,8 +1,11 @@ package service import ( + "errors" "testing" + "gorm.io/gorm" + "github.com/mhsanaei/3x-ui/v3/internal/database" "github.com/mhsanaei/3x-ui/v3/internal/database/model" "github.com/mhsanaei/3x-ui/v3/internal/web/runtime" @@ -145,6 +148,43 @@ func TestNodeDirty_ClearIsCASOnDirtyAt(t *testing.T) { } } +func TestMarkNodeDirtyTxRollsBackWithTransaction(t *testing.T) { + setupConflictDB(t) + db := database.GetDB() + + node := &model.Node{Name: "n3", Address: "127.0.0.1", Port: 2096, ApiToken: "tok", Enable: true, Status: "online"} + if err := db.Create(node).Error; err != nil { + t.Fatalf("create node: %v", err) + } + + nodeSvc := NodeService{} + rollbackErr := errors.New("force rollback") + if err := db.Transaction(func(tx *gorm.DB) error { + if err := nodeSvc.MarkNodeDirtyTx(tx, node.Id); err != nil { + return err + } + return rollbackErr + }); !errors.Is(err, rollbackErr) { + t.Fatalf("rollback tx: got %v want %v", err, rollbackErr) + } + if _, _, dirty, _, err := nodeSvc.NodeSyncState(node.Id); err != nil { + t.Fatalf("NodeSyncState after rollback: %v", err) + } else if dirty { + t.Fatal("dirty flag escaped a rolled-back transaction") + } + + if err := db.Transaction(func(tx *gorm.DB) error { + return nodeSvc.MarkNodeDirtyTx(tx, node.Id) + }); err != nil { + t.Fatalf("commit tx: %v", err) + } + if _, _, dirty, _, err := nodeSvc.NodeSyncState(node.Id); err != nil { + t.Fatalf("NodeSyncState after commit: %v", err) + } else if !dirty { + t.Fatal("dirty flag should commit with its transaction") + } +} + // Editing a node must mark it config-dirty so the next traffic-sync tick // reconciles (pushes the panel's inbounds to the remote) before pulling a // snapshot. Without the dirty flag, re-pointing a node to a fresh server