mirror of
https://github.com/certd/certd.git
synced 2026-04-21 02:20:52 +08:00
perf: 支持同步域名过期时间
This commit is contained in:
@@ -34,3 +34,29 @@ export class Pager {
|
||||
this.pageNo = Math.ceil(offset / (this.pageSize ?? 50)) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
export async function doPageTurn<T>(req: { pager: Pager; getPage: (pager: Pager) => Promise<PageRes<T>>; itemHandle?: (item: T) => Promise<void>; batchHandle?: (pageRes: PageRes<T>) => Promise<void> }) {
|
||||
let count = 0;
|
||||
const { pager, getPage, itemHandle, batchHandle } = req;
|
||||
while (true) {
|
||||
const pageRes = await getPage(pager);
|
||||
if (!pageRes || !pageRes.list || pageRes.list.length === 0) {
|
||||
break;
|
||||
}
|
||||
count += pageRes.list.length;
|
||||
if (batchHandle) {
|
||||
await batchHandle(pageRes);
|
||||
}
|
||||
if (itemHandle) {
|
||||
for (const item of pageRes.list) {
|
||||
await itemHandle(item);
|
||||
}
|
||||
}
|
||||
if (pageRes.total && pageRes.total >= 0 && count >= pageRes.total) {
|
||||
//遍历完成
|
||||
break;
|
||||
}
|
||||
pager.pageNo++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,14 @@ import psl from "psl";
|
||||
import { ILogger, utils, logger as globalLogger } from "@certd/basic";
|
||||
import { resolveDomainBySoaRecord } from "@certd/acme-client";
|
||||
|
||||
export function parseDomainByPsl(fullDomain: string) {
|
||||
const parsed = psl.parse(fullDomain) as psl.ParsedDomain;
|
||||
if (parsed.error) {
|
||||
throw new Error(`解析${fullDomain}域名失败:` + JSON.stringify(parsed.error));
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
export class DomainParser implements IDomainParser {
|
||||
subDomainsGetter: ISubDomainsGetter;
|
||||
logger: ILogger;
|
||||
@@ -13,11 +21,7 @@ export class DomainParser implements IDomainParser {
|
||||
}
|
||||
|
||||
parseDomainByPsl(fullDomain: string) {
|
||||
const parsed = psl.parse(fullDomain) as psl.ParsedDomain;
|
||||
if (parsed.error) {
|
||||
throw new Error(`解析${fullDomain}域名失败:` + JSON.stringify(parsed.error));
|
||||
}
|
||||
return parsed.domain as string;
|
||||
return parseDomainByPsl(fullDomain).domain as string;
|
||||
}
|
||||
|
||||
async parse(fullDomain: string) {
|
||||
|
||||
@@ -65,3 +65,10 @@ export async function SyncSubmit(body: any) {
|
||||
data: body,
|
||||
});
|
||||
}
|
||||
|
||||
export async function SyncDomainsExpiration() {
|
||||
return await request({
|
||||
url: apiPrefix + "/sync/expiration",
|
||||
method: "post",
|
||||
});
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { useUserStore } from "/@/store/user";
|
||||
import { useSettingStore } from "/@/store/settings";
|
||||
import { Dicts } from "/@/components/plugins/lib/dicts";
|
||||
import { createAccessApi } from "/@/views/certd/access/api";
|
||||
import { Modal } from "ant-design-vue";
|
||||
import { Modal, notification } from "ant-design-vue";
|
||||
import { useDomainImport } from "./use";
|
||||
|
||||
export default function ({ crudExpose, context }: CreateCrudOptionsProps): CreateCrudOptionsRet {
|
||||
@@ -98,7 +98,27 @@ export default function ({ crudExpose, context }: CreateCrudOptionsProps): Creat
|
||||
type: "primary",
|
||||
text: "从域名提供商导入",
|
||||
click: () => {
|
||||
openDomainImportDialog();
|
||||
openDomainImportDialog({
|
||||
afterSubmit: () => {
|
||||
setTimeout(() => {
|
||||
crudExpose.doRefresh();
|
||||
}, 2000);
|
||||
},
|
||||
});
|
||||
},
|
||||
},
|
||||
syncExpirationDate: {
|
||||
title: "同步域名过期时间",
|
||||
type: "primary",
|
||||
text: "同步域名过期时间",
|
||||
click: async () => {
|
||||
await api.SyncDomainsExpiration();
|
||||
notification.success({
|
||||
message: "同步任务已提交",
|
||||
});
|
||||
setTimeout(() => {
|
||||
crudExpose.doRefresh();
|
||||
}, 2000);
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -48,7 +48,7 @@ export function useDomainImport() {
|
||||
},
|
||||
};
|
||||
|
||||
return function openDomainImportDialog() {
|
||||
return function openDomainImportDialog(req: { afterSubmit?: () => void }) {
|
||||
openFormDialog({
|
||||
title: "从域名提供商导入域名",
|
||||
columns: columns,
|
||||
@@ -58,6 +58,9 @@ export function useDomainImport() {
|
||||
dnsProviderAccessId: form.dnsProviderAccessId,
|
||||
});
|
||||
message.success("导入任务已提交");
|
||||
if (req.afterSubmit) {
|
||||
req.afterSubmit();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -80,12 +80,20 @@ export class DomainController extends CrudController<DomainService> {
|
||||
|
||||
|
||||
@Post('/sync/submit', { summary: Constants.per.authOnly })
|
||||
async sync(@Body(ALL) body: any) {
|
||||
async syncSubmit(@Body(ALL) body: any) {
|
||||
const { dnsProviderType, dnsProviderAccessId } = body;
|
||||
const req = {
|
||||
dnsProviderType, dnsProviderAccessId, userId: this.getUserId(),
|
||||
}
|
||||
await this.service.syncFromProvider(req);
|
||||
await this.service.doSyncFromProvider(req);
|
||||
return this.ok();
|
||||
}
|
||||
|
||||
@Post('/sync/expiration', { summary: Constants.per.authOnly })
|
||||
async syncExpiration(@Body(ALL) body: any) {
|
||||
await this.service.doSyncDomainsExpirationDate({
|
||||
userId: this.getUserId(),
|
||||
})
|
||||
return this.ok();
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ import {SiteInfoService} from '../monitor/index.js';
|
||||
import {Cron} from '../cron/cron.js';
|
||||
import {UserSettingsService} from "../mine/service/user-settings-service.js";
|
||||
import {UserSiteMonitorSetting} from "../mine/service/models.js";
|
||||
import {getPlusInfo} from "@certd/plus-core";
|
||||
import {getPlusInfo, isPlus} from "@certd/plus-core";
|
||||
import dayjs from "dayjs";
|
||||
import {NotificationService} from "../pipeline/service/notification-service.js";
|
||||
import {UserService} from "../sys/authority/service/user-service.js";
|
||||
import {Between} from "typeorm";
|
||||
import { DomainService } from '../cert/service/domain-service.js';
|
||||
|
||||
@Autoload()
|
||||
@Scope(ScopeEnum.Request, { allowDowngrade: true })
|
||||
@@ -44,6 +45,9 @@ export class AutoCRegisterCron {
|
||||
@Inject()
|
||||
userService: UserService;
|
||||
|
||||
@Inject()
|
||||
domainService: DomainService;
|
||||
|
||||
|
||||
@Init()
|
||||
async init() {
|
||||
@@ -60,7 +64,9 @@ export class AutoCRegisterCron {
|
||||
|
||||
await this.registerPlusExpireCheckCron();
|
||||
|
||||
await this.registerUserExpireCheckCron()
|
||||
await this.registerUserExpireCheckCron();
|
||||
|
||||
await this.registerDomainExpireCheckCron();
|
||||
}
|
||||
|
||||
async registerSiteMonitorCron() {
|
||||
@@ -199,4 +205,23 @@ export class AutoCRegisterCron {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
registerDomainExpireCheckCron(){
|
||||
if (!isPlus()){
|
||||
return
|
||||
}
|
||||
// 添加域名即将到期检查任务
|
||||
const randomWeek = Math.floor(Math.random() * 7) + 1
|
||||
const randomHour = Math.floor(Math.random() * 24)
|
||||
const randomMinute = Math.floor(Math.random() * 60)
|
||||
logger.info(`注册域名注册过期时间检查任务,每周${randomWeek} ${randomHour}:${randomMinute}检查一次`)
|
||||
this.cron.register({
|
||||
name: 'domain-expire-check',
|
||||
cron: `0 ${randomMinute} ${randomHour} ? * ${randomWeek}`, // 每周随机一天检查一次
|
||||
job: async () => {
|
||||
await this.domainService.doSyncDomainsExpirationDate({})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,33 @@
|
||||
import {Inject, Provide, Scope, ScopeEnum} from '@midwayjs/core';
|
||||
import {InjectEntityModel} from '@midwayjs/typeorm';
|
||||
import {In, Not, Repository} from 'typeorm';
|
||||
import {AccessService, BaseService} from '@certd/lib-server';
|
||||
import {DomainEntity} from '../entity/domain.js';
|
||||
import {SubDomainService} from "../../pipeline/service/sub-domain-service.js";
|
||||
import {createDnsProvider, DomainParser} from "@certd/plugin-lib";
|
||||
import {DomainVerifiers} from "@certd/plugin-cert";
|
||||
import { SubDomainsGetter } from '../../pipeline/service/getter/sub-domain-getter.js';
|
||||
import { CnameRecordService } from '../../cname/service/cname-record-service.js';
|
||||
import { CnameRecordEntity } from "../../cname/entity/cname-record.js";
|
||||
import { http, logger, utils } from '@certd/basic';
|
||||
import { AccessService, BaseService } from '@certd/lib-server';
|
||||
import { doPageTurn, Pager, PageRes } from '@certd/pipeline';
|
||||
import { DomainVerifiers } from "@certd/plugin-cert";
|
||||
import { createDnsProvider, DomainParser, parseDomainByPsl } from "@certd/plugin-lib";
|
||||
import { Inject, Provide, Scope, ScopeEnum } from '@midwayjs/core';
|
||||
import { InjectEntityModel } from '@midwayjs/typeorm';
|
||||
import dayjs from 'dayjs';
|
||||
import { In, Not, Repository } from 'typeorm';
|
||||
import { CnameRecordEntity } from "../../cname/entity/cname-record.js";
|
||||
import { CnameRecordService } from '../../cname/service/cname-record-service.js';
|
||||
import { SubDomainsGetter } from '../../pipeline/service/getter/sub-domain-getter.js';
|
||||
import { TaskServiceBuilder } from '../../pipeline/service/getter/task-service-getter.js';
|
||||
import { Pager } from '@certd/pipeline';
|
||||
import { SubDomainService } from "../../pipeline/service/sub-domain-service.js";
|
||||
import { DomainEntity } from '../entity/domain.js';
|
||||
import { BackTask, taskExecutor } from './task-executor.js';
|
||||
|
||||
export interface SyncFromProviderReq {
|
||||
export interface SyncFromProviderReq {
|
||||
userId: number;
|
||||
dnsProviderType: string;
|
||||
dnsProviderAccessId: string;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Provide()
|
||||
@Scope(ScopeEnum.Request, {allowDowngrade: true})
|
||||
@Scope(ScopeEnum.Request, { allowDowngrade: true })
|
||||
export class DomainService extends BaseService<DomainEntity> {
|
||||
@InjectEntityModel(DomainEntity)
|
||||
repository: Repository<DomainEntity>;
|
||||
@@ -45,7 +49,7 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
}
|
||||
|
||||
async add(param) {
|
||||
if (param.userId == null ){
|
||||
if (param.userId == null) {
|
||||
throw new Error('userId 不能为空');
|
||||
}
|
||||
if (!param.domain) {
|
||||
@@ -97,9 +101,9 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
* @param userId
|
||||
* @param domains //去除* 且去重之后的域名列表
|
||||
*/
|
||||
async getDomainVerifiers(userId: number, domains: string[]):Promise<DomainVerifiers> {
|
||||
async getDomainVerifiers(userId: number, domains: string[]): Promise<DomainVerifiers> {
|
||||
|
||||
const mainDomainMap:Record<string, string> = {}
|
||||
const mainDomainMap: Record<string, string> = {}
|
||||
const subDomainGetter = new SubDomainsGetter(userId, this.subDomainService)
|
||||
const domainParser = new DomainParser(subDomainGetter)
|
||||
|
||||
@@ -111,7 +115,7 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
}
|
||||
|
||||
//匹配DNS记录
|
||||
let allDomains = [...domains,...mainDomains]
|
||||
let allDomains = [...domains, ...mainDomains]
|
||||
//去重
|
||||
allDomains = [...new Set(allDomains)]
|
||||
|
||||
@@ -120,16 +124,16 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
where: {
|
||||
domain: In(allDomains),
|
||||
userId,
|
||||
disabled:false,
|
||||
disabled: false,
|
||||
}
|
||||
})
|
||||
|
||||
const dnsMap = domainRecords.filter(item=>item.challengeType === 'dns').reduce((pre, item) => {
|
||||
const dnsMap = domainRecords.filter(item => item.challengeType === 'dns').reduce((pre, item) => {
|
||||
pre[item.domain] = item
|
||||
return pre
|
||||
}, {})
|
||||
|
||||
const httpMap = domainRecords.filter(item=>item.challengeType === 'http').reduce((pre, item) => {
|
||||
const httpMap = domainRecords.filter(item => item.challengeType === 'http').reduce((pre, item) => {
|
||||
pre[item.domain] = item
|
||||
return pre
|
||||
}, {})
|
||||
@@ -150,7 +154,7 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
}, {})
|
||||
|
||||
//构建域名验证计划
|
||||
const domainVerifiers:DomainVerifiers = {}
|
||||
const domainVerifiers: DomainVerifiers = {}
|
||||
|
||||
for (const domain of domains) {
|
||||
const mainDomain = mainDomainMap[domain]
|
||||
@@ -168,7 +172,7 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
}
|
||||
continue
|
||||
}
|
||||
const cnameRecord:CnameRecordEntity = cnameMap[domain]
|
||||
const cnameRecord: CnameRecordEntity = cnameMap[domain]
|
||||
if (cnameRecord) {
|
||||
domainVerifiers[domain] = {
|
||||
domain,
|
||||
@@ -194,7 +198,7 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
httpUploadRootDir: httpRecord.httpUploadRootDir
|
||||
}
|
||||
}
|
||||
continue
|
||||
continue
|
||||
}
|
||||
domainVerifiers[domain] = null
|
||||
}
|
||||
@@ -202,9 +206,18 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
return domainVerifiers;
|
||||
}
|
||||
|
||||
|
||||
|
||||
async syncFromProvider(req: SyncFromProviderReq) {
|
||||
async doSyncFromProvider(req: SyncFromProviderReq) {
|
||||
taskExecutor.start('syncFromProviderTask', new BackTask({
|
||||
key: `user_${req.userId}`,
|
||||
title: `同步用户${req.userId}从域名提供商导入域名`,
|
||||
run: async (task: BackTask) => {
|
||||
await this._syncFromProvider(req, task)
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
private async _syncFromProvider(req: SyncFromProviderReq, task: BackTask) {
|
||||
const { userId, dnsProviderType, dnsProviderAccessId } = req;
|
||||
const subDomainGetter = new SubDomainsGetter(userId, this.subDomainService)
|
||||
const domainParser = new DomainParser(subDomainGetter)
|
||||
@@ -212,20 +225,17 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
const access = await this.accessService.getById(dnsProviderAccessId, userId);
|
||||
const context = { access, logger, http, utils, domainParser, serviceGetter };
|
||||
// 翻页查询dns的记录
|
||||
const dnsProvider = await createDnsProvider({dnsProviderType,context})
|
||||
|
||||
const dnsProvider = await createDnsProvider({ dnsProviderType, context })
|
||||
|
||||
const pager = new Pager({
|
||||
pageNo: 1,
|
||||
pageSize: 100,
|
||||
})
|
||||
const challengeType = "dns"
|
||||
|
||||
const importDomain = async(domainRecord: any) =>{
|
||||
const importDomain = async (domainRecord: any) => {
|
||||
task.incrementCurrent()
|
||||
const domain = domainRecord.domain
|
||||
const certProps :any={
|
||||
registrationDate: domainRecord.registrationDate,
|
||||
expirationDate: domainRecord.expirationDate,
|
||||
}
|
||||
|
||||
const old = await this.findOne({
|
||||
where: {
|
||||
@@ -234,15 +244,15 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
}
|
||||
})
|
||||
if (old) {
|
||||
const updateObj :any={
|
||||
id: old.id,
|
||||
...certProps
|
||||
if (old.fromType !== 'auto') {
|
||||
//如果是手动的,跳过更新校验配置
|
||||
return
|
||||
}
|
||||
if (old.fromType !== 'manual'){
|
||||
//如果不是手动的,更新校验配置
|
||||
updateObj.dnsProviderType = dnsProviderType
|
||||
updateObj.dnsProviderAccess = dnsProviderAccessId
|
||||
updateObj.challengeType = challengeType
|
||||
const updateObj: any = {
|
||||
id: old.id,
|
||||
dnsProviderType,
|
||||
dnsProviderAccess: dnsProviderAccessId,
|
||||
challengeType,
|
||||
}
|
||||
//更新
|
||||
await super.update(updateObj)
|
||||
@@ -256,33 +266,132 @@ export class DomainService extends BaseService<DomainEntity> {
|
||||
challengeType,
|
||||
disabled: false,
|
||||
fromType: 'auto',
|
||||
...certProps
|
||||
})
|
||||
}
|
||||
}
|
||||
const start = async ()=>{
|
||||
let count = 0
|
||||
while(true){
|
||||
const pageRes = await dnsProvider.getDomainListPage(pager)
|
||||
if(!pageRes || !pageRes.list || pageRes.list.length === 0){
|
||||
//遍历完成
|
||||
break
|
||||
}
|
||||
//处理
|
||||
for (const domainRecord of pageRes.list) {
|
||||
await importDomain(domainRecord)
|
||||
}
|
||||
|
||||
count += pageRes.list.length
|
||||
if(pageRes.total>0 && count >= pageRes.total){
|
||||
//遍历完成
|
||||
break
|
||||
}
|
||||
pager.pageNo++
|
||||
}
|
||||
const batchHandle = async (pageRes: PageRes<any>) => {
|
||||
task.setTotal(pageRes.total || 0)
|
||||
}
|
||||
const start = async () => {
|
||||
await doPageTurn({ pager, getPage: dnsProvider.getDomainListPage, itemHandle: importDomain, batchHandle })
|
||||
}
|
||||
|
||||
start()
|
||||
|
||||
}
|
||||
|
||||
async doSyncDomainsExpirationDate(req: { userId?: number }) {
|
||||
const userId = req.userId
|
||||
taskExecutor.start('syncDomainsExpirationDateTask', new BackTask({
|
||||
key: `user_${userId}`,
|
||||
title: `同步用户(${userId ?? '全部'})注册域名过期时间`,
|
||||
run: async (task: BackTask) => {
|
||||
await this._syncDomainsExpirationDate({ userId, task })
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
private async _syncDomainsExpirationDate(req: { userId?: number, task: BackTask }) {
|
||||
//同步所有域名的过期时间
|
||||
const pager = new Pager({
|
||||
pageNo: 1,
|
||||
pageSize: 100,
|
||||
})
|
||||
|
||||
const dnsJson = await http.request({
|
||||
url: "https://data.iana.org/rdap/dns.json",
|
||||
method: "GET",
|
||||
})
|
||||
const rdapMap: Record<string, string> = {}
|
||||
for (const item of dnsJson.services) {
|
||||
// [["store","work"], ["https://rdap.centralnic.com/store/"]],
|
||||
const suffixes = item[0]
|
||||
const urls = item[1]
|
||||
for (const suffix of suffixes) {
|
||||
rdapMap[suffix] = urls[0]
|
||||
}
|
||||
}
|
||||
|
||||
const getDomainExpirationDate = async (domain: string) => {
|
||||
const parsed = parseDomainByPsl(domain)
|
||||
const mainDomain = parsed.domain || ''
|
||||
if (mainDomain !== domain) {
|
||||
logger.warn(`${domain}为子域名,跳过同步`)
|
||||
return
|
||||
}
|
||||
const suffix = parsed.tld || ''
|
||||
const rdapUrl = rdapMap[suffix]
|
||||
if (!rdapUrl) {
|
||||
throw new Error(`未找到${suffix}的rdap地址`)
|
||||
}
|
||||
// https://rdap.nic.work/domain/handsfree.work
|
||||
const rdap = await http.request({
|
||||
url: `${rdapUrl}domain/${domain}`,
|
||||
method: "GET",
|
||||
})
|
||||
|
||||
let res: any = {}
|
||||
const events = rdap.events || []
|
||||
for (const item of events) {
|
||||
if (item.eventAction === 'expiration') {
|
||||
res.expirationDate = dayjs(item.eventDate).valueOf()
|
||||
} else if (item.eventAction === 'registration') {
|
||||
res.registrationDate = dayjs(item.eventDate).valueOf()
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
const query: any = {
|
||||
challengeType: "dns",
|
||||
}
|
||||
if (req.userId!=null) {
|
||||
query.userId = req.userId
|
||||
}
|
||||
const getDomainPage = async (pager: Pager) => {
|
||||
const pageRes = await this.page({
|
||||
query: query,
|
||||
buildQuery(bq) {
|
||||
bq.andWhere(" (expiration_date is null or expiration_date < :now) ", { now: dayjs().add(1, 'month').valueOf() })
|
||||
},
|
||||
page: {
|
||||
offset: pager.getOffset(),
|
||||
limit: pager.pageSize,
|
||||
}
|
||||
})
|
||||
req.task.total = pageRes.total
|
||||
return {
|
||||
list: pageRes.records,
|
||||
total: pageRes.total,
|
||||
}
|
||||
}
|
||||
|
||||
const itemHandle = async (item: any) => {
|
||||
req.task.incrementCurrent()
|
||||
try {
|
||||
const res = await getDomainExpirationDate(item.domain)
|
||||
if (!res) {
|
||||
return
|
||||
}
|
||||
const { expirationDate, registrationDate } = res
|
||||
if (!expirationDate) {
|
||||
logger.error(`获取域名${item.domain}过期时间失败`)
|
||||
return
|
||||
}
|
||||
logger.info(`更新域名${item.domain}过期时间:${dayjs(expirationDate).format('YYYY-MM-DD')}`)
|
||||
const updateObj: any = {
|
||||
id: item.id,
|
||||
expirationDate: expirationDate,
|
||||
registrationDate: registrationDate,
|
||||
}
|
||||
//更新
|
||||
await super.update(updateObj)
|
||||
} catch (error) {
|
||||
logger.error(`更新域名${item.domain}过期时间失败:${error}`)
|
||||
} finally {
|
||||
await utils.sleep(1000)
|
||||
}
|
||||
}
|
||||
|
||||
await doPageTurn({ pager, getPage: getDomainPage, itemHandle: itemHandle })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +1,35 @@
|
||||
import { logger } from "@certd/basic"
|
||||
|
||||
export class BackTaskExecutor{
|
||||
tasks :Record<string,Record<string,BackTask>> = {}
|
||||
export class BackTaskExecutor {
|
||||
tasks: Record<string, Record<string, BackTask>> = {}
|
||||
|
||||
add(type:string,task: BackTask){
|
||||
start(type: string, task: BackTask) {
|
||||
if (!this.tasks[type]) {
|
||||
this.tasks[type] = {}
|
||||
}
|
||||
const oldTask = this.tasks[type][task.key]
|
||||
if (oldTask && oldTask.status === "running") {
|
||||
throw new Error(`任务 ${task.key} 正在运行中`)
|
||||
}
|
||||
this.tasks[type][task.key] = task
|
||||
this.run(type, task);
|
||||
}
|
||||
|
||||
get(type: string,key: string){
|
||||
get(type: string, key: string) {
|
||||
if (!this.tasks[type]) {
|
||||
this.tasks[type] = {}
|
||||
}
|
||||
return this.tasks[type][key]
|
||||
}
|
||||
|
||||
removeIsEnd(type: string,key: string){
|
||||
removeIsEnd(type: string, key: string) {
|
||||
const task = this.tasks[type]?.[key]
|
||||
if (task && task.status !== "running") {
|
||||
this.clear(type,key);
|
||||
this.clear(type, key);
|
||||
}
|
||||
}
|
||||
|
||||
clear(type: string,key: string){
|
||||
clear(type: string, key: string) {
|
||||
const task = this.tasks[type]?.[key]
|
||||
if (task) {
|
||||
task.clearTimeout();
|
||||
@@ -29,33 +37,31 @@ export class BackTaskExecutor{
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
async run(type:string,key: string){
|
||||
const task = this.tasks[type]?.[key]
|
||||
if (!task) {
|
||||
throw new Error(`任务 ${key} 不存在`)
|
||||
private async run(type: string, task: any) {
|
||||
if (task.status === "running") {
|
||||
throw new Error(`任务 ${task.key} 正在运行中`)
|
||||
}
|
||||
task.startTime = Date.now();
|
||||
task.clearTimeout();
|
||||
try{
|
||||
try {
|
||||
task.status = "running";
|
||||
return await task.run();
|
||||
}catch(e){
|
||||
return await task.run(task);
|
||||
} catch (e) {
|
||||
logger.error(`任务 ${task.title}[${task.key}] 执行失败`, e.message);
|
||||
task.status = "failed";
|
||||
task.error = e.message;
|
||||
}finally{
|
||||
} finally {
|
||||
task.endTime = Date.now();
|
||||
task.status = "done";
|
||||
task.timeoutId = setTimeout(() => {
|
||||
this.clear(type,task.key);
|
||||
}, 60*60*1000);
|
||||
this.clear(type, task.key);
|
||||
}, 24 * 60 * 60 * 1000);
|
||||
delete task.run;
|
||||
}
|
||||
}
|
||||
}
|
||||
export class BackTask{
|
||||
key:string;
|
||||
export class BackTask {
|
||||
key: string;
|
||||
title: string;
|
||||
total: number = 0;
|
||||
current: number = 0;
|
||||
@@ -66,9 +72,12 @@ export class BackTask{
|
||||
timeoutId?: NodeJS.Timeout;
|
||||
|
||||
|
||||
run: () => Promise<void>;
|
||||
run: (task: BackTask) => Promise<void>;
|
||||
|
||||
constructor(key:string,title: string,run: () => Promise<void>){
|
||||
constructor(opts:{
|
||||
key: string, title: string, run: (task: BackTask) => Promise<void>
|
||||
}) {
|
||||
const {key, title, run} = opts
|
||||
this.key = key;
|
||||
this.title = title;
|
||||
Object.defineProperty(this, 'run', {
|
||||
@@ -79,10 +88,19 @@ export class BackTask{
|
||||
});
|
||||
}
|
||||
|
||||
clearTimeout(){
|
||||
clearTimeout() {
|
||||
if (this.timeoutId) {
|
||||
clearTimeout(this.timeoutId);
|
||||
this.timeoutId = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTotal(total: number) {
|
||||
this.total = total;
|
||||
}
|
||||
incrementCurrent() {
|
||||
this.current++
|
||||
}
|
||||
}
|
||||
|
||||
export const taskExecutor = new BackTaskExecutor();
|
||||
Reference in New Issue
Block a user