Compare commits
72 Commits
nightly
...
v1.0.0-rc.4
| Author | SHA1 | Date | |
|---|---|---|---|
| e8cfb546fa | |||
| d98f0e8ac3 | |||
| 38a3314b9b | |||
| 5c793d7992 | |||
| ee190b6049 | |||
| dede1e2968 | |||
| 446a8420f5 | |||
| dc8deb0c24 | |||
| f8cf9c57c4 | |||
| 0f9f094a48 | |||
| 9acf5fecae | |||
| 8b2b03d276 | |||
| dac55f0fde | |||
| 938dc9522b | |||
| 5114ad0677 | |||
| d46df94f05 | |||
| aa730395f1 | |||
| d2b30dfc95 | |||
| 987b7ecd22 | |||
| 5f86839c7e | |||
| 8f3c41ae77 | |||
| 8bff691089 | |||
| 22fd1741ab | |||
| 9b0ec8ed48 | |||
| 95648353e4 | |||
| 2f8637048e | |||
| b2232f4355 | |||
| b44faec66b | |||
| 3b592895c6 | |||
| e0b6eb3a59 | |||
| 6f57dcd2f5 | |||
| 8ca103342d | |||
| 22ae14f0d7 | |||
| f982544825 | |||
| 438410708f | |||
| 75af3db11f | |||
| db48108d21 | |||
| 22ef5b2f80 | |||
| 28f7e9eb2e | |||
| fc377dae3e | |||
| df14a0bf18 | |||
| c609cb13b2 | |||
| a42b397607 | |||
| 9f8a4ec050 | |||
| bee339d279 | |||
| 4e93148d9e | |||
| e36d191c2e | |||
| 34afe9b426 | |||
| d604f48c06 | |||
| 86cfb3920e | |||
| 097a50ebdc | |||
| f424f906d8 | |||
| cc4ad6c39e | |||
| 4c21c4c43b | |||
| db89b57e1c | |||
| 62d4b63fc3 | |||
| 355307223a | |||
| f2f3410dcf | |||
| 02aacb38a2 | |||
| a7c38ec851 | |||
| 095e1920f1 | |||
| 8993386743 | |||
| 435d7ae0dd | |||
| 3a2138ba61 | |||
| e3d64cb76d | |||
| 2e610e5fb3 | |||
| 05b0041de2 | |||
| ec8f3dceaa | |||
| 63ce2db988 | |||
| df6d862895 | |||
| 69ba18d392 | |||
| 65b1654732 |
@@ -0,0 +1,83 @@
|
||||
---
|
||||
name: classic-to-default-sync
|
||||
description: Inspect a given commit's web/classic changes and sync all features/fixes to web/default. Use when the user provides a commit ID and wants to audit whether web/default already has the same features as web/classic, port missing features, improve suboptimal implementations, fix bugs, and remove redundant code. Trigger phrases include: "/classic-to-default-sync <hash>", "classic-to-default-sync <hash>", "sync classic to default", "port from classic", "compare classic commit", "classic 和 default 对比", "把这次 classic 的修改同步到 default", "查看这次提交 classic 中的修改并同步", or any request supplying a commit hash together with classic/default comparison intent.
|
||||
---
|
||||
|
||||
# Classic-to-Default Sync
|
||||
|
||||
Given a **commit ID**, audit all `web/classic` changes and ensure `web/default` reaches feature parity with the best possible implementation.
|
||||
|
||||
## Input
|
||||
|
||||
The user must supply a `<commit-id>`.
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1 — Extract classic diff
|
||||
|
||||
```bash
|
||||
git show <commit-id> -- web/classic
|
||||
```
|
||||
|
||||
Read every changed file in `web/classic`. Identify the **logical changes** (new features, UI/UX improvements, bug fixes, config tweaks, removed dead code, etc.) — not just line diffs.
|
||||
|
||||
### Step 2 — Map to default counterparts
|
||||
|
||||
For each logical change found in Step 1, locate the equivalent file(s) in `web/default/src/`. Use Glob/Grep/SemanticSearch as needed. Consider that:
|
||||
|
||||
- `web/classic` uses **React 18 + Vite + Semi Design**
|
||||
- `web/default` uses **React 19 + Rsbuild + Base UI + Tailwind CSS**
|
||||
- Component names, file paths, and API shapes may differ; match by **functionality**, not filename.
|
||||
|
||||
### Step 3 — Triage each change
|
||||
|
||||
Classify every logical change as one of:
|
||||
|
||||
| Status | Meaning |
|
||||
|--------|---------|
|
||||
| ✅ Already present & optimal | No action needed |
|
||||
| ⚠️ Present but suboptimal | Improve: logic, layout, style, or code quality |
|
||||
| ❌ Missing | Implement from scratch in default's stack |
|
||||
|
||||
### Step 4 — Implement
|
||||
|
||||
For each **⚠️** or **❌** item:
|
||||
|
||||
1. **Read the target file(s) in `web/default`** before editing (required by project conventions).
|
||||
2. Implement using `web/default` conventions:
|
||||
- React 19 patterns (hooks, Suspense, etc.)
|
||||
- Base UI primitives where applicable
|
||||
- Tailwind CSS for styling (no inline styles or Semi Design imports)
|
||||
- `useTranslation()` + `t('English key')` for all user-visible strings
|
||||
- TypeScript — explicit types, no `any`
|
||||
- No dead code, no redundant comments
|
||||
3. Follow **Rule 6** (pointer types for optional relay DTOs) if touching relay-related TS types.
|
||||
4. After editing, run `ReadLints` on changed files and fix any introduced lint errors.
|
||||
|
||||
### Step 5 — i18n
|
||||
|
||||
If any new user-visible strings were added, run the i18n sync:
|
||||
|
||||
```bash
|
||||
cd web/default && bun run i18n:sync
|
||||
```
|
||||
|
||||
Then add missing translations for all supported locales (en, zh, fr, ja, ru, vi) following the **i18n-translate** skill.
|
||||
|
||||
### Step 6 — Report
|
||||
|
||||
Summarise the work in a concise table:
|
||||
|
||||
| # | Change (from classic commit) | Status | Action taken |
|
||||
|---|------------------------------|--------|--------------|
|
||||
| 1 | … | ✅ / ⚠️ / ❌ | None / Improved / Implemented |
|
||||
|
||||
If every item is ✅ with no action needed, simply reply: **"已完成 — web/default 已具备此次提交的所有功能,且实现质量良好,无需修改。"**
|
||||
|
||||
## Quality bar
|
||||
|
||||
- No unused imports, variables, or components
|
||||
- No commented-out code left behind
|
||||
- Consistent naming with surrounding `web/default` code
|
||||
- All interactive elements accessible (keyboard nav, ARIA labels where Radix doesn't provide them automatically)
|
||||
- No regressions: existing behaviour in `web/default` must not break
|
||||
Executable
+254
@@ -0,0 +1,254 @@
|
||||
---
|
||||
name: i18n-translate
|
||||
description: >-
|
||||
Complete and maintain frontend i18n translations for this project. Covers
|
||||
finding missing translation keys, detecting untranslated entries, and adding
|
||||
translations for all supported locales (en, zh, fr, ja, ru, vi). Use when the
|
||||
user asks to add translations, fix i18n, complete missing translations, or
|
||||
when new UI text needs to be internationalized.
|
||||
---
|
||||
|
||||
# Frontend i18n Translation Workflow
|
||||
|
||||
## Overview
|
||||
|
||||
- Locale files: `web/default/src/i18n/locales/{en,zh,fr,ja,ru,vi}.json`
|
||||
- Format: flat JSON under `"translation"` key, keys are English source strings
|
||||
- Base locale: `en.json` (most keys), fallback: `zh` (Chinese)
|
||||
- Sync script: `bun run i18n:sync` (from `web/default/`)
|
||||
- All `t()` calls must have corresponding keys in every locale file
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Run sync and read report
|
||||
|
||||
```bash
|
||||
cd web/default && bun run i18n:sync
|
||||
```
|
||||
|
||||
Read `web/default/src/i18n/locales/_reports/_sync-report.json` to see per-locale status (missingCount, extrasCount, untranslatedCount).
|
||||
|
||||
### Step 2: Find missing keys (used in code but not in locale files)
|
||||
|
||||
Create and run `web/default/scripts/find-missing-keys.mjs`:
|
||||
|
||||
```javascript
|
||||
import fs from 'node:fs/promises'
|
||||
import path from 'node:path'
|
||||
|
||||
const LOCALES_DIR = path.resolve('src/i18n/locales')
|
||||
const SRC_DIR = path.resolve('src')
|
||||
|
||||
const en = JSON.parse(await fs.readFile(path.join(LOCALES_DIR, 'en.json'), 'utf8'))
|
||||
const enKeys = new Set(Object.keys(en.translation))
|
||||
|
||||
const tCallRegex = /\bt\(\s*['"`]([^'"`\n]+?)['"`]\s*[,)]/g
|
||||
const tCallMultilineRegex = /\bt\(\s*['"`]([^'"`]+?)['"`]\s*\)/g
|
||||
|
||||
async function walkDir(dir) {
|
||||
const files = []
|
||||
const entries = await fs.readdir(dir, { withFileTypes: true })
|
||||
for (const entry of entries) {
|
||||
const fullPath = path.join(dir, entry.name)
|
||||
if (entry.isDirectory()) {
|
||||
if (['node_modules', '.git', 'locales', '_reports', '_extras'].includes(entry.name)) continue
|
||||
files.push(...(await walkDir(fullPath)))
|
||||
} else if (/\.(tsx?|jsx?)$/.test(entry.name)) {
|
||||
files.push(fullPath)
|
||||
}
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
const files = await walkDir(SRC_DIR)
|
||||
const missingKeys = new Map()
|
||||
|
||||
for (const file of files) {
|
||||
const content = await fs.readFile(file, 'utf8')
|
||||
const relPath = path.relative(SRC_DIR, file)
|
||||
for (const regex of [tCallRegex, tCallMultilineRegex]) {
|
||||
regex.lastIndex = 0
|
||||
let match
|
||||
while ((match = regex.exec(content)) !== null) {
|
||||
const key = match[1]
|
||||
if (key.startsWith('{{') || key.includes('${')) continue
|
||||
if (!enKeys.has(key)) {
|
||||
if (!missingKeys.has(key)) missingKeys.set(key, [])
|
||||
missingKeys.get(key).push(relPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (missingKeys.size === 0) {
|
||||
console.log('All t() keys found in en.json!')
|
||||
} else {
|
||||
console.log(`Found ${missingKeys.size} missing keys:\n`)
|
||||
for (const [key, files] of [...missingKeys.entries()].sort(([a], [b]) => a.localeCompare(b))) {
|
||||
console.log(` "${key}"`)
|
||||
for (const f of [...new Set(files)]) console.log(` -> ${f}`)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Find untranslated entries (value equals English)
|
||||
|
||||
Create and run `web/default/scripts/find-untranslated.mjs`:
|
||||
|
||||
```javascript
|
||||
import fs from 'node:fs/promises'
|
||||
import path from 'node:path'
|
||||
|
||||
const LOCALES_DIR = path.resolve('src/i18n/locales')
|
||||
const en = JSON.parse(await fs.readFile(path.join(LOCALES_DIR, 'en.json'), 'utf8'))
|
||||
const enTrans = en.translation
|
||||
|
||||
// Brand names, URLs, technical terms — skip these
|
||||
const skipPatterns = [
|
||||
/^https?:\/\//, /^smtp\./, /^socks5:/, /^name@/, /^noreply@/,
|
||||
/^org-/, /^price_/, /^whsec_/, /^edit_this$/, /^my-status$/,
|
||||
/^_copy$/, /^gpt-/, /^checkout\./, /^footer\./, /^\[?\{/,
|
||||
/^"default/, /^\/status\//, /^\/your\//, /^example\.com/,
|
||||
/^AZURE_/, /^AccessKey/, /^OAuth/, /^Client /, /^Webhook URL/,
|
||||
/^API URL$/, /^Well-Known/, /^Worker URL$/, /^Uptime Kuma/,
|
||||
/^New API/, /^Baidu V2$/, /^Zhipu V4$/, /^Quota:$/,
|
||||
]
|
||||
|
||||
const brandNames = new Set([
|
||||
'AIGC2D','Anthropic','API2GPT','Claude','Cloudflare','Cohere','DeepSeek',
|
||||
'Discord','DoubaoVideo','FastGPT','Gemini','GitHub','Jimeng','JustSong',
|
||||
'LingYiWanWu','LinuxDO','Midjourney','MidjourneyPlus','MiniMax','Mistral',
|
||||
'MokaAI','Moonshot','NewAPI','OhMyGPT','Ollama','OpenAI','OpenAIMax',
|
||||
'OpenRouter','Passkey','Perplexity','QuantumNous','Replicate','SiliconFlow',
|
||||
'Stripe','Submodel','SunoAPI','Telegram','Tencent','Vertex AI','VolcEngine',
|
||||
'WeChat','Xinference','Xunfei','AI Proxy','One API',
|
||||
])
|
||||
|
||||
const locales = ['fr', 'ja', 'ru', 'zh', 'vi']
|
||||
|
||||
for (const locale of locales) {
|
||||
const locFile = JSON.parse(await fs.readFile(path.join(LOCALES_DIR, `${locale}.json`), 'utf8'))
|
||||
const locTrans = locFile.translation
|
||||
const untranslated = {}
|
||||
|
||||
for (const [key, enVal] of Object.entries(enTrans)) {
|
||||
const locVal = locTrans[key]
|
||||
if (locVal === undefined || locVal !== enVal) continue
|
||||
if (brandNames.has(key)) continue
|
||||
if (skipPatterns.some(p => p.test(key))) continue
|
||||
if (typeof enVal === 'string' && enVal.length < 4) continue
|
||||
if (/[a-zA-Z]{3,}/.test(String(enVal))) untranslated[key] = enVal
|
||||
}
|
||||
|
||||
const count = Object.keys(untranslated).length
|
||||
if (count > 0) {
|
||||
console.log(`\n=== ${locale} (${count} untranslated) ===`)
|
||||
for (const [k, v] of Object.entries(untranslated))
|
||||
console.log(` ${JSON.stringify(k)}: ${JSON.stringify(v)}`)
|
||||
} else {
|
||||
console.log(`\n=== ${locale}: all translated ===`)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Add translations
|
||||
|
||||
Create `web/default/scripts/add-missing-keys.mjs` with this structure:
|
||||
|
||||
```javascript
|
||||
import fs from 'node:fs/promises'
|
||||
import path from 'node:path'
|
||||
|
||||
const LOCALES_DIR = path.resolve('src/i18n/locales')
|
||||
|
||||
function stableStringify(obj) {
|
||||
return JSON.stringify(obj, null, 2) + '\n'
|
||||
}
|
||||
|
||||
const newKeys = {
|
||||
en: { /* "key": "English value" */ },
|
||||
zh: { /* "key": "中文翻译" */ },
|
||||
fr: { /* "key": "Traduction française" */ },
|
||||
ja: { /* "key": "日本語翻訳" */ },
|
||||
ru: { /* "key": "Русский перевод" */ },
|
||||
vi: { /* "key": "Bản dịch tiếng Việt" */ },
|
||||
}
|
||||
|
||||
async function main() {
|
||||
let totalAdded = 0
|
||||
|
||||
for (const [locale, trans] of Object.entries(newKeys)) {
|
||||
const filePath = path.join(LOCALES_DIR, `${locale}.json`)
|
||||
const json = JSON.parse(await fs.readFile(filePath, 'utf8'))
|
||||
|
||||
let count = 0
|
||||
for (const [key, value] of Object.entries(trans)) {
|
||||
if (!Object.prototype.hasOwnProperty.call(json.translation, key)) {
|
||||
json.translation[key] = value
|
||||
count++
|
||||
} else if (json.translation[key] !== value) {
|
||||
json.translation[key] = value
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
json.translation = Object.fromEntries(
|
||||
Object.entries(json.translation).sort(([a], [b]) => a.localeCompare(b))
|
||||
)
|
||||
await fs.writeFile(filePath, stableStringify(json), 'utf8')
|
||||
}
|
||||
|
||||
console.log(`${locale}: ${count} translations applied`)
|
||||
totalAdded += count
|
||||
}
|
||||
|
||||
console.log(`\nTotal: ${totalAdded} translations applied`)
|
||||
}
|
||||
|
||||
main().catch((err) => { console.error(err); process.exitCode = 1 })
|
||||
```
|
||||
|
||||
Populate the `newKeys` object with actual translations for each locale.
|
||||
|
||||
### Step 5: Verify and clean up
|
||||
|
||||
```bash
|
||||
cd web/default
|
||||
node scripts/add-missing-keys.mjs # apply translations
|
||||
node scripts/find-missing-keys.mjs # verify: should say "All t() keys found"
|
||||
bun run i18n:sync # normalize file order
|
||||
```
|
||||
|
||||
Delete temporary scripts after completion.
|
||||
|
||||
## Translation Guidelines
|
||||
|
||||
| Language | Code | Notes |
|
||||
|----------|------|-------|
|
||||
| English | en | Base locale, key = value |
|
||||
| Chinese | zh | Fallback locale, must be complete |
|
||||
| French | fr | Many English cognates are valid (e.g., "Configuration") |
|
||||
| Japanese | ja | Use katakana for technical loanwords |
|
||||
| Russian | ru | Use formal register |
|
||||
| Vietnamese | vi | Use standard Vietnamese |
|
||||
|
||||
**Keep as English (do not translate):**
|
||||
- Brand/product names (OpenAI, Claude, Gemini, etc.)
|
||||
- URLs and email placeholders
|
||||
- Technical identifiers (JSON keys, API paths, model names)
|
||||
- Code-like strings (gpt-3.5-turbo, price_xxx, etc.)
|
||||
|
||||
**Always translate:**
|
||||
- UI labels, button text, error messages, descriptions
|
||||
- Time units (hours, minutes, months, years)
|
||||
- Action words (Move, Show, Delete, etc.)
|
||||
|
||||
## Key Rules
|
||||
|
||||
1. All scripts run from `web/default/` directory
|
||||
2. Use `node scripts/xxx.mjs` (ESM format with top-level await)
|
||||
3. Sort keys alphabetically when writing locale files
|
||||
4. Always run `bun run i18n:sync` as the final step
|
||||
5. Delete temporary scripts after completion
|
||||
6. The `{{variable}}` placeholders in keys must be preserved in all translations
|
||||
@@ -0,0 +1,105 @@
|
||||
---
|
||||
name: shadcn-ui
|
||||
description: >-
|
||||
Give the assistant project-aware shadcn/ui context: components.json,
|
||||
composition patterns, CLI, registries, theming, and MCP. Use when working on
|
||||
web/default UI, shadcn components, or presets. Overview aligns with
|
||||
https://ui.shadcn.com/docs/skills.md; full upstream skill text is vendored
|
||||
under vendor/shadcn/.
|
||||
---
|
||||
|
||||
<!-- Canonical overview: https://ui.shadcn.com/docs/skills.md -->
|
||||
|
||||
# Skills (shadcn/ui)
|
||||
|
||||
Skills give AI assistants project-aware context about shadcn/ui. When used, the assistant knows how to find, install, compose, and customize components using the correct APIs and patterns for your project.
|
||||
|
||||
For example, you can ask:
|
||||
|
||||
- _"Add a login form with email and password fields."_
|
||||
- _"Create a settings page with a form for updating profile information."_
|
||||
- _"Build a dashboard with a sidebar, stats cards, and a data table."_
|
||||
- _"Switch to --preset [CODE]"_
|
||||
- _"Can you add a hero from @tailark?"_
|
||||
|
||||
The skill reads your project's `components.json` and provides your framework, aliases, installed components, icon library, and base library so it can generate correct code on the first try.
|
||||
|
||||
---
|
||||
|
||||
## Install (ecosystem vs this repo)
|
||||
|
||||
Official install from [Skills — shadcn/ui](https://ui.shadcn.com/docs/skills.md):
|
||||
|
||||
```bash
|
||||
npx skills add shadcn/ui
|
||||
```
|
||||
|
||||
That installs the skill where the `skills` CLI is available. **This repository** keeps the same intent under `.agents/skills/shadcn-ui/` (overview here + **vendored** upstream docs in [`vendor/shadcn/`](./vendor/shadcn/)) and runs the shadcn CLI from the frontend app root:
|
||||
|
||||
```bash
|
||||
cd web/default && bunx shadcn@latest info --json
|
||||
```
|
||||
|
||||
Learn more about skills at [skills.sh](https://skills.sh).
|
||||
|
||||
---
|
||||
|
||||
## What's included (and where)
|
||||
|
||||
### Project context
|
||||
|
||||
Run **`shadcn info --json`** (here: `cd web/default && bunx shadcn@latest info --json`) for framework, Tailwind version, aliases, base (`radix` | `base`), icon library, installed components, and resolved paths.
|
||||
|
||||
### CLI commands
|
||||
|
||||
Full command reference (vendored): [`vendor/shadcn/cli.md`](./vendor/shadcn/cli.md).
|
||||
|
||||
### Theming and customization
|
||||
|
||||
Vendored: [`vendor/shadcn/customization.md`](./vendor/shadcn/customization.md). Live docs: [Theming](https://ui.shadcn.com/docs/theming).
|
||||
|
||||
### Registry authoring
|
||||
|
||||
Not duplicated as a single file in the vendor tree; see [Registry](https://ui.shadcn.com/docs/registry) and `build` in [`vendor/shadcn/cli.md`](./vendor/shadcn/cli.md).
|
||||
|
||||
### MCP server
|
||||
|
||||
Vendored: [`vendor/shadcn/mcp.md`](./vendor/shadcn/mcp.md). Live docs: [MCP Server](https://ui.shadcn.com/docs/mcp).
|
||||
|
||||
---
|
||||
|
||||
## How it works
|
||||
|
||||
1. **Project detection** — Applies when `components.json` exists (here: `web/default/components.json`).
|
||||
2. **Context injection** — Use `shadcn info --json` as ground truth for imports and APIs.
|
||||
3. **Pattern enforcement** — Follow rules in [`vendor/shadcn/SKILL.md`](./vendor/shadcn/SKILL.md) and [`vendor/shadcn/rules/`](./vendor/shadcn/rules/).
|
||||
4. **Component discovery** — `shadcn docs`, `shadcn search`, MCP, or registries — see vendored SKILL + MCP doc.
|
||||
|
||||
---
|
||||
|
||||
## Learn more (web)
|
||||
|
||||
- [CLI](https://ui.shadcn.com/docs/cli) — complements [`vendor/shadcn/cli.md`](./vendor/shadcn/cli.md)
|
||||
- [Theming](https://ui.shadcn.com/docs/theming)
|
||||
- [Registry](https://ui.shadcn.com/docs/registry)
|
||||
- [skills.sh](https://skills.sh)
|
||||
|
||||
---
|
||||
|
||||
## Vendored upstream bundle (deep rules)
|
||||
|
||||
Snapshot from [shadcn-ui/ui `skills/shadcn`](https://github.com/shadcn-ui/ui/tree/main/skills/shadcn); revision note in [`vendor/shadcn/UPSTREAM.txt`](./vendor/shadcn/UPSTREAM.txt).
|
||||
|
||||
| Doc | Path |
|
||||
| --- | --- |
|
||||
| Full official skill body | [`vendor/shadcn/SKILL.md`](./vendor/shadcn/SKILL.md) |
|
||||
| CLI reference | [`vendor/shadcn/cli.md`](./vendor/shadcn/cli.md) |
|
||||
| Theming / customization | [`vendor/shadcn/customization.md`](./vendor/shadcn/customization.md) |
|
||||
| MCP | [`vendor/shadcn/mcp.md`](./vendor/shadcn/mcp.md) |
|
||||
| Forms | [`vendor/shadcn/rules/forms.md`](./vendor/shadcn/rules/forms.md) |
|
||||
| Composition | [`vendor/shadcn/rules/composition.md`](./vendor/shadcn/rules/composition.md) |
|
||||
| Icons | [`vendor/shadcn/rules/icons.md`](./vendor/shadcn/rules/icons.md) |
|
||||
| Styling | [`vendor/shadcn/rules/styling.md`](./vendor/shadcn/rules/styling.md) |
|
||||
| Base vs Radix | [`vendor/shadcn/rules/base-vs-radix.md`](./vendor/shadcn/rules/base-vs-radix.md) |
|
||||
|
||||
**Workflow:** Prefer this **root** `SKILL.md` for repo paths (`web/default`, Bun). Read **`vendor/shadcn/SKILL.md`** for the complete upstream workflow, patterns, and CLI quick reference. Use **`vendor/shadcn/rules/*.md`** when validating concrete markup.
|
||||
+260
@@ -0,0 +1,260 @@
|
||||
---
|
||||
name: shadcn
|
||||
description: Manages shadcn components and projects — adding, searching, fixing, debugging, styling, and composing UI. Provides project context, component docs, and usage examples. Applies when working with shadcn/ui, component registries, presets, --preset codes, or any project with a components.json file. Also triggers for "shadcn init", "create an app with --preset", or "switch to --preset".
|
||||
user-invocable: false
|
||||
allowed-tools: Bash(npx shadcn@latest *), Bash(pnpm dlx shadcn@latest *), Bash(bunx --bun shadcn@latest *)
|
||||
---
|
||||
|
||||
# shadcn/ui
|
||||
|
||||
A framework for building ui, components and design systems. Components are added as source code to the user's project via the CLI.
|
||||
|
||||
> **IMPORTANT:** Run all CLI commands using the project's package runner: `npx shadcn@latest`, `pnpm dlx shadcn@latest`, or `bunx --bun shadcn@latest` — based on the project's `packageManager`. Examples below use `npx shadcn@latest` but substitute the correct runner for the project.
|
||||
|
||||
## Current Project Context
|
||||
|
||||
```json
|
||||
!`npx shadcn@latest info --json`
|
||||
```
|
||||
|
||||
The JSON above contains the project config and installed components. Use `npx shadcn@latest docs <component>` to get documentation and example URLs for any component.
|
||||
|
||||
## Principles
|
||||
|
||||
1. **Use existing components first.** Use `npx shadcn@latest search` to check registries before writing custom UI. Check community registries too.
|
||||
2. **Compose, don't reinvent.** Settings page = Tabs + Card + form controls. Dashboard = Sidebar + Card + Chart + Table.
|
||||
3. **Use built-in variants before custom styles.** `variant="outline"`, `size="sm"`, etc.
|
||||
4. **Use semantic colors.** `bg-primary`, `text-muted-foreground` — never raw values like `bg-blue-500`.
|
||||
|
||||
## Critical Rules
|
||||
|
||||
These rules are **always enforced**. Each links to a file with Incorrect/Correct code pairs.
|
||||
|
||||
### Styling & Tailwind → [styling.md](./rules/styling.md)
|
||||
|
||||
- **`className` for layout, not styling.** Never override component colors or typography.
|
||||
- **No `space-x-*` or `space-y-*`.** Use `flex` with `gap-*`. For vertical stacks, `flex flex-col gap-*`.
|
||||
- **Use `size-*` when width and height are equal.** `size-10` not `w-10 h-10`.
|
||||
- **Use `truncate` shorthand.** Not `overflow-hidden text-ellipsis whitespace-nowrap`.
|
||||
- **No manual `dark:` color overrides.** Use semantic tokens (`bg-background`, `text-muted-foreground`).
|
||||
- **Use `cn()` for conditional classes.** Don't write manual template literal ternaries.
|
||||
- **No manual `z-index` on overlay components.** Dialog, Sheet, Popover, etc. handle their own stacking.
|
||||
|
||||
### Forms & Inputs → [forms.md](./rules/forms.md)
|
||||
|
||||
- **Forms use `FieldGroup` + `Field`.** Never use raw `div` with `space-y-*` or `grid gap-*` for form layout.
|
||||
- **`InputGroup` uses `InputGroupInput`/`InputGroupTextarea`.** Never raw `Input`/`Textarea` inside `InputGroup`.
|
||||
- **Buttons inside inputs use `InputGroup` + `InputGroupAddon`.**
|
||||
- **Option sets (2–7 choices) use `ToggleGroup`.** Don't loop `Button` with manual active state.
|
||||
- **`FieldSet` + `FieldLegend` for grouping related checkboxes/radios.** Don't use a `div` with a heading.
|
||||
- **Field validation uses `data-invalid` + `aria-invalid`.** `data-invalid` on `Field`, `aria-invalid` on the control. For disabled: `data-disabled` on `Field`, `disabled` on the control.
|
||||
|
||||
### Component Structure → [composition.md](./rules/composition.md)
|
||||
|
||||
- **Items always inside their Group.** `SelectItem` → `SelectGroup`. `DropdownMenuItem` → `DropdownMenuGroup`. `CommandItem` → `CommandGroup`.
|
||||
- **Use `asChild` (radix) or `render` (base) for custom triggers.** Check `base` field from `npx shadcn@latest info`. → [base-vs-radix.md](./rules/base-vs-radix.md)
|
||||
- **Dialog, Sheet, and Drawer always need a Title.** `DialogTitle`, `SheetTitle`, `DrawerTitle` required for accessibility. Use `className="sr-only"` if visually hidden.
|
||||
- **Use full Card composition.** `CardHeader`/`CardTitle`/`CardDescription`/`CardContent`/`CardFooter`. Don't dump everything in `CardContent`.
|
||||
- **Button has no `isPending`/`isLoading`.** Compose with `Spinner` + `data-icon` + `disabled`.
|
||||
- **`TabsTrigger` must be inside `TabsList`.** Never render triggers directly in `Tabs`.
|
||||
- **`Avatar` always needs `AvatarFallback`.** For when the image fails to load.
|
||||
|
||||
### Use Components, Not Custom Markup → [composition.md](./rules/composition.md)
|
||||
|
||||
- **Use existing components before custom markup.** Check if a component exists before writing a styled `div`.
|
||||
- **Callouts use `Alert`.** Don't build custom styled divs.
|
||||
- **Empty states use `Empty`.** Don't build custom empty state markup.
|
||||
- **Toast via `sonner`.** Use `toast()` from `sonner`.
|
||||
- **Use `Separator`** instead of `<hr>` or `<div className="border-t">`.
|
||||
- **Use `Skeleton`** for loading placeholders. No custom `animate-pulse` divs.
|
||||
- **Use `Badge`** instead of custom styled spans.
|
||||
|
||||
### Icons → [icons.md](./rules/icons.md)
|
||||
|
||||
- **Icons in `Button` use `data-icon`.** `data-icon="inline-start"` or `data-icon="inline-end"` on the icon.
|
||||
- **No sizing classes on icons inside components.** Components handle icon sizing via CSS. No `size-4` or `w-4 h-4`.
|
||||
- **Pass icons as objects, not string keys.** `icon={CheckIcon}`, not a string lookup.
|
||||
|
||||
### CLI
|
||||
|
||||
- **Never decode preset codes or build preset URLs manually.** Use `npx shadcn@latest preset decode <code>`, `preset url <code>`, or `preset open <code>`. For project-aware preset detection, use `npx shadcn@latest preset resolve`.
|
||||
- **Apply preset codes directly with the CLI.** Use `npx shadcn@latest apply <code>` for existing projects, or `npx shadcn@latest init --preset <code>` when initializing.
|
||||
|
||||
## Key Patterns
|
||||
|
||||
These are the most common patterns that differentiate correct shadcn/ui code. For edge cases, see the linked rule files above.
|
||||
|
||||
```tsx
|
||||
// Form layout: FieldGroup + Field, not div + Label.
|
||||
<FieldGroup>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="email">Email</FieldLabel>
|
||||
<Input id="email" />
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
|
||||
// Validation: data-invalid on Field, aria-invalid on the control.
|
||||
<Field data-invalid>
|
||||
<FieldLabel>Email</FieldLabel>
|
||||
<Input aria-invalid />
|
||||
<FieldDescription>Invalid email.</FieldDescription>
|
||||
</Field>
|
||||
|
||||
// Icons in buttons: data-icon, no sizing classes.
|
||||
<Button>
|
||||
<SearchIcon data-icon="inline-start" />
|
||||
Search
|
||||
</Button>
|
||||
|
||||
// Spacing: gap-*, not space-y-*.
|
||||
<div className="flex flex-col gap-4"> // correct
|
||||
<div className="space-y-4"> // wrong
|
||||
|
||||
// Equal dimensions: size-*, not w-* h-*.
|
||||
<Avatar className="size-10"> // correct
|
||||
<Avatar className="w-10 h-10"> // wrong
|
||||
|
||||
// Status colors: Badge variants or semantic tokens, not raw colors.
|
||||
<Badge variant="secondary">+20.1%</Badge> // correct
|
||||
<span className="text-emerald-600">+20.1%</span> // wrong
|
||||
```
|
||||
|
||||
## Component Selection
|
||||
|
||||
| Need | Use |
|
||||
| -------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
| Button/action | `Button` with appropriate variant |
|
||||
| Form inputs | `Input`, `Select`, `Combobox`, `Switch`, `Checkbox`, `RadioGroup`, `Textarea`, `InputOTP`, `Slider` |
|
||||
| Toggle between 2–5 options | `ToggleGroup` + `ToggleGroupItem` |
|
||||
| Data display | `Table`, `Card`, `Badge`, `Avatar` |
|
||||
| Navigation | `Sidebar`, `NavigationMenu`, `Breadcrumb`, `Tabs`, `Pagination` |
|
||||
| Overlays | `Dialog` (modal), `Sheet` (side panel), `Drawer` (bottom sheet), `AlertDialog` (confirmation) |
|
||||
| Feedback | `sonner` (toast), `Alert`, `Progress`, `Skeleton`, `Spinner` |
|
||||
| Command palette | `Command` inside `Dialog` |
|
||||
| Charts | `Chart` (wraps Recharts) |
|
||||
| Layout | `Card`, `Separator`, `Resizable`, `ScrollArea`, `Accordion`, `Collapsible` |
|
||||
| Empty states | `Empty` |
|
||||
| Menus | `DropdownMenu`, `ContextMenu`, `Menubar` |
|
||||
| Tooltips/info | `Tooltip`, `HoverCard`, `Popover` |
|
||||
|
||||
## Key Fields
|
||||
|
||||
The injected project context contains these key fields:
|
||||
|
||||
- **`aliases`** → use the actual alias prefix for imports (e.g. `@/`, `~/`), never hardcode.
|
||||
- **`isRSC`** → when `true`, components using `useState`, `useEffect`, event handlers, or browser APIs need `"use client"` at the top of the file. Always reference this field when advising on the directive.
|
||||
- **`tailwindVersion`** → `"v4"` uses `@theme inline` blocks; `"v3"` uses `tailwind.config.js`.
|
||||
- **`tailwindCssFile`** → the global CSS file where custom CSS variables are defined. Always edit this file, never create a new one.
|
||||
- **`style`** → component visual treatment (e.g. `nova`, `vega`).
|
||||
- **`base`** → primitive library (`radix` or `base`). Affects component APIs and available props.
|
||||
- **`iconLibrary`** → determines icon imports. Use `lucide-react` for `lucide`, `@tabler/icons-react` for `tabler`, etc. Never assume `lucide-react`.
|
||||
- **`resolvedPaths`** → exact file-system destinations for components, utils, hooks, etc.
|
||||
- **`framework`** → routing and file conventions (e.g. Next.js App Router vs Vite SPA).
|
||||
- **`packageManager`** → use this for any non-shadcn dependency installs (e.g. `pnpm add date-fns` vs `npm install date-fns`).
|
||||
- **`preset`** → resolved preset code and values for the current project. Use `npx shadcn@latest preset resolve --json` when you only need preset information.
|
||||
|
||||
See [cli.md — `info` command](./cli.md) for the full field reference.
|
||||
|
||||
## Component Docs, Examples, and Usage
|
||||
|
||||
Run `npx shadcn@latest docs <component>` to get the URLs for a component's documentation, examples, and API reference. Fetch these URLs to get the actual content.
|
||||
|
||||
```bash
|
||||
npx shadcn@latest docs button dialog select
|
||||
```
|
||||
|
||||
**When creating, fixing, debugging, or using a component, always run `npx shadcn@latest docs` and fetch the URLs first.** This ensures you're working with the correct API and usage patterns rather than guessing.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Get project context** — already injected above. Run `npx shadcn@latest info` again if you need to refresh.
|
||||
2. **Check installed components first** — before running `add`, always check the `components` list from project context or list the `resolvedPaths.ui` directory. Don't import components that haven't been added, and don't re-add ones already installed.
|
||||
3. **Find components** — `npx shadcn@latest search`.
|
||||
4. **Get docs and examples** — run `npx shadcn@latest docs <component>` to get URLs, then fetch them. Use `npx shadcn@latest view` to browse registry items you haven't installed. To preview changes to installed components, use `npx shadcn@latest add --diff`.
|
||||
5. **Install or update** — `npx shadcn@latest add`. When updating existing components, use `--dry-run` and `--diff` to preview changes first (see [Updating Components](#updating-components) below).
|
||||
6. **Fix imports in third-party components** — After adding components from community registries (e.g. `@bundui`, `@magicui`), check the added non-UI files for hardcoded import paths like `@/components/ui/...`. These won't match the project's actual aliases. Use `npx shadcn@latest info` to get the correct `ui` alias (e.g. `@workspace/ui/components`) and rewrite the imports accordingly. The CLI rewrites imports for its own UI files, but third-party registry components may use default paths that don't match the project.
|
||||
7. **Review added components** — After adding a component or block from any registry, **always read the added files and verify they are correct**. Check for missing sub-components (e.g. `SelectItem` without `SelectGroup`), missing imports, incorrect composition, or violations of the [Critical Rules](#critical-rules). Also replace any icon imports with the project's `iconLibrary` from the project context (e.g. if the registry item uses `lucide-react` but the project uses `hugeicons`, swap the imports and icon names accordingly). Fix all issues before moving on.
|
||||
8. **Registry must be explicit** — When the user asks to add a block or component, **do not guess the registry**. If no registry is specified (e.g. user says "add a login block" without specifying `@shadcn`, `@tailark`, etc.), ask which registry to use. Never default to a registry on behalf of the user.
|
||||
9. **Switching presets** — Ask the user first: **overwrite**, **partial**, **merge**, or **skip**?
|
||||
- **Inspect current preset**: `npx shadcn@latest preset resolve`. Use `--json` when you need structured values.
|
||||
- **Inspect incoming preset**: `npx shadcn@latest preset decode <code>`. Use `preset url <code>` or `preset open <code>` to share or open the preset builder.
|
||||
- **Overwrite**: `npx shadcn@latest apply <code>`. Overwrites detected components, fonts, and CSS variables.
|
||||
- **Partial**: `npx shadcn@latest apply <code> --only theme,font`. Updates only the selected preset parts without reinstalling UI components. Supported values are `theme` and `font`; comma-separated combinations are allowed. `icon` is intentionally not supported, because icon changes may require full component reinstall and transforms.
|
||||
- **Merge**: `npx shadcn@latest init --preset <code> --force --no-reinstall`, then run `npx shadcn@latest info` to list installed components, then for each installed component use `--dry-run` and `--diff` to [smart merge](#updating-components) it individually.
|
||||
- **Skip**: `npx shadcn@latest init --preset <code> --force --no-reinstall`. Only updates config and CSS, leaves components as-is.
|
||||
- **Important**: Always run preset commands inside the user's project directory. `apply` only works in an existing project with a `components.json` file. The CLI automatically preserves the current base (`base` vs `radix`) from `components.json`. If you must use a scratch/temp directory (e.g. for `--dry-run` comparisons), pass `--base <current-base>` explicitly — preset codes do not encode the base.
|
||||
|
||||
## Updating Components
|
||||
|
||||
When the user asks to update a component from upstream while keeping their local changes, use `--dry-run` and `--diff` to intelligently merge. **NEVER fetch raw files from GitHub manually — always use the CLI.**
|
||||
|
||||
1. Run `npx shadcn@latest add <component> --dry-run` to see all files that would be affected.
|
||||
2. For each file, run `npx shadcn@latest add <component> --diff <file>` to see what changed upstream vs local.
|
||||
3. Decide per file based on the diff:
|
||||
- No local changes → safe to overwrite.
|
||||
- Has local changes → read the local file, analyze the diff, and apply upstream updates while preserving local modifications.
|
||||
- User says "just update everything" → use `--overwrite`, but confirm first.
|
||||
4. **Never use `--overwrite` without the user's explicit approval.**
|
||||
|
||||
## Quick Reference
|
||||
|
||||
```bash
|
||||
# Create a new project.
|
||||
npx shadcn@latest init --name my-app --preset base-nova
|
||||
npx shadcn@latest init --name my-app --preset a2r6bw --template vite
|
||||
|
||||
# Create a monorepo project.
|
||||
npx shadcn@latest init --name my-app --preset base-nova --monorepo
|
||||
npx shadcn@latest init --name my-app --preset base-nova --template next --monorepo
|
||||
|
||||
# Initialize existing project.
|
||||
npx shadcn@latest init --preset base-nova
|
||||
npx shadcn@latest init --defaults # shortcut: --template=next --preset=nova (base style implied)
|
||||
|
||||
# Apply a preset to an existing project.
|
||||
npx shadcn@latest apply a2r6bw
|
||||
npx shadcn@latest apply a2r6bw --only theme
|
||||
npx shadcn@latest apply a2r6bw --only font
|
||||
npx shadcn@latest apply a2r6bw --only theme,font
|
||||
|
||||
# Inspect preset codes and project preset state.
|
||||
npx shadcn@latest preset decode a2r6bw
|
||||
npx shadcn@latest preset url a2r6bw
|
||||
npx shadcn@latest preset open a2r6bw
|
||||
npx shadcn@latest preset resolve
|
||||
npx shadcn@latest preset resolve --json
|
||||
|
||||
# Add components.
|
||||
npx shadcn@latest add button card dialog
|
||||
npx shadcn@latest add @magicui/shimmer-button
|
||||
npx shadcn@latest add --all
|
||||
|
||||
# Preview changes before adding/updating.
|
||||
npx shadcn@latest add button --dry-run
|
||||
npx shadcn@latest add button --diff button.tsx
|
||||
npx shadcn@latest add @acme/form --view button.tsx
|
||||
|
||||
# Search registries.
|
||||
npx shadcn@latest search @shadcn -q "sidebar"
|
||||
npx shadcn@latest search @tailark -q "stats"
|
||||
|
||||
# Get component docs and example URLs.
|
||||
npx shadcn@latest docs button dialog select
|
||||
|
||||
# View registry item details (for items not yet installed).
|
||||
npx shadcn@latest view @shadcn/button
|
||||
```
|
||||
|
||||
**Named presets:** `nova`, `vega`, `maia`, `lyra`, `mira`, `luma`
|
||||
**Templates:** `next`, `vite`, `start`, `react-router`, `astro` (all support `--monorepo`) and `laravel` (not supported for monorepo)
|
||||
**Preset codes:** Version-prefixed base62 strings (e.g. `a2r6bw` or `b0`), from [ui.shadcn.com](https://ui.shadcn.com).
|
||||
|
||||
## Detailed References
|
||||
|
||||
- [rules/forms.md](./rules/forms.md) — FieldGroup, Field, InputGroup, ToggleGroup, FieldSet, validation states
|
||||
- [rules/composition.md](./rules/composition.md) — Groups, overlays, Card, Tabs, Avatar, Alert, Empty, Toast, Separator, Skeleton, Badge, Button loading
|
||||
- [rules/icons.md](./rules/icons.md) — data-icon, icon sizing, passing icons as objects
|
||||
- [rules/styling.md](./rules/styling.md) — Semantic colors, variants, className, spacing, size, truncate, dark mode, cn(), z-index
|
||||
- [rules/base-vs-radix.md](./rules/base-vs-radix.md) — asChild vs render, Select, ToggleGroup, Slider, Accordion
|
||||
- [cli.md](./cli.md) — Commands, flags, presets, templates
|
||||
- [customization.md](./customization.md) — Theming, CSS variables, extending components
|
||||
@@ -0,0 +1,3 @@
|
||||
Source: https://github.com/shadcn-ui/ui/tree/56161142f1b83f612462772d18883807b5f0d601/skills/shadcn
|
||||
Branch: main
|
||||
Fetched: 2026-04-29
|
||||
+276
@@ -0,0 +1,276 @@
|
||||
# shadcn CLI Reference
|
||||
|
||||
Configuration is read from `components.json`.
|
||||
|
||||
> **IMPORTANT:** Always run commands using the project's package runner: `npx shadcn@latest`, `pnpm dlx shadcn@latest`, or `bunx --bun shadcn@latest`. Check `packageManager` from project context to choose the right one. Examples below use `npx shadcn@latest` but substitute the correct runner for the project.
|
||||
|
||||
> **IMPORTANT:** Only use the flags documented below. Do not invent or guess flags — if a flag isn't listed here, it doesn't exist. The CLI auto-detects the package manager from the project's lockfile; there is no `--package-manager` flag.
|
||||
|
||||
## Contents
|
||||
|
||||
- Commands: init, apply, add (dry-run, smart merge), search, view, docs, info, build
|
||||
- Templates: next, vite, start, react-router, astro
|
||||
- Presets: named, code, URL formats and fields
|
||||
- Switching presets
|
||||
|
||||
---
|
||||
|
||||
## Commands
|
||||
|
||||
### `init` — Initialize or create a project
|
||||
|
||||
```bash
|
||||
npx shadcn@latest init [components...] [options]
|
||||
```
|
||||
|
||||
Initializes shadcn/ui in an existing project or creates a new project (when `--name` is provided). Optionally installs components in the same step.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| ----------------------- | ----- | --------------------------------------------------------- | ------- |
|
||||
| `--template <template>` | `-t` | Template (next, start, vite, next-monorepo, react-router) | — |
|
||||
| `--preset [name]` | `-p` | Preset configuration (named, code, or URL) | — |
|
||||
| `--yes` | `-y` | Skip confirmation prompt | `true` |
|
||||
| `--defaults` | `-d` | Use defaults (`--template=next --preset=base-nova`) | `false` |
|
||||
| `--force` | `-f` | Force overwrite existing configuration | `false` |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
| `--name <name>` | `-n` | Name for new project | — |
|
||||
| `--silent` | `-s` | Mute output | `false` |
|
||||
| `--rtl` | | Enable RTL support | — |
|
||||
| `--reinstall` | | Re-install existing UI components | `false` |
|
||||
| `--monorepo` | | Scaffold a monorepo project | — |
|
||||
| `--no-monorepo` | | Skip the monorepo prompt | — |
|
||||
|
||||
`npx shadcn@latest create` is an alias for `npx shadcn@latest init`.
|
||||
|
||||
### `apply` — Apply a preset to an existing project
|
||||
|
||||
```bash
|
||||
npx shadcn@latest apply [preset] [options]
|
||||
```
|
||||
|
||||
Applies a preset to an existing project, overwriting preset-driven config, fonts, CSS variables, and detected UI components.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| ------------------- | ----- | ------------------------------------------ | ------- |
|
||||
| `--preset <preset>` | — | Preset configuration (named, code, or URL) | — |
|
||||
| `--yes` | `-y` | Skip confirmation prompt | `false` |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
| `--silent` | `-s` | Mute output | `false` |
|
||||
|
||||
`[preset]` is a shorthand for `--preset <preset>`. If both are provided, they must match.
|
||||
If no preset is provided, the CLI offers to open the custom preset builder on `ui.shadcn.com/create`.
|
||||
|
||||
### `add` — Add components
|
||||
|
||||
> **IMPORTANT:** To compare local components against upstream or to preview changes, ALWAYS use `npx shadcn@latest add <component> --dry-run`, `--diff`, or `--view`. NEVER fetch raw files from GitHub or other sources manually. The CLI handles registry resolution, file paths, and CSS diffing automatically.
|
||||
|
||||
```bash
|
||||
npx shadcn@latest add [components...] [options]
|
||||
```
|
||||
|
||||
Accepts component names, registry-prefixed names (`@magicui/shimmer-button`), URLs, or local paths.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| --------------- | ----- | -------------------------------------------------------------------------------------------------------------------- | ------- |
|
||||
| `--yes` | `-y` | Skip confirmation prompt | `false` |
|
||||
| `--overwrite` | `-o` | Overwrite existing files | `false` |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
| `--all` | `-a` | Add all available components | `false` |
|
||||
| `--path <path>` | `-p` | Target path for the component | — |
|
||||
| `--silent` | `-s` | Mute output | `false` |
|
||||
| `--dry-run` | | Preview all changes without writing files | `false` |
|
||||
| `--diff [path]` | | Show diffs. Without a path, shows the first 5 files. With a path, shows that file only (implies `--dry-run`) | — |
|
||||
| `--view [path]` | | Show file contents. Without a path, shows the first 5 files. With a path, shows that file only (implies `--dry-run`) | — |
|
||||
|
||||
#### Dry-Run Mode
|
||||
|
||||
Use `--dry-run` to preview what `add` would do without writing any files. `--diff` and `--view` both imply `--dry-run`.
|
||||
|
||||
```bash
|
||||
# Preview all changes.
|
||||
npx shadcn@latest add button --dry-run
|
||||
|
||||
# Show diffs for all files (top 5).
|
||||
npx shadcn@latest add button --diff
|
||||
|
||||
# Show the diff for a specific file.
|
||||
npx shadcn@latest add button --diff button.tsx
|
||||
|
||||
# Show contents for all files (top 5).
|
||||
npx shadcn@latest add button --view
|
||||
|
||||
# Show the full content of a specific file.
|
||||
npx shadcn@latest add button --view button.tsx
|
||||
|
||||
# Works with URLs too.
|
||||
npx shadcn@latest add https://api.npoint.io/abc123 --dry-run
|
||||
|
||||
# CSS diffs.
|
||||
npx shadcn@latest add button --diff globals.css
|
||||
```
|
||||
|
||||
**When to use dry-run:**
|
||||
|
||||
- When the user asks "what files will this add?" or "what will this change?" — use `--dry-run`.
|
||||
- Before overwriting existing components — use `--diff` to preview the changes first.
|
||||
- When the user wants to inspect component source code without installing — use `--view`.
|
||||
- When checking what CSS changes would be made to `globals.css` — use `--diff globals.css`.
|
||||
- When the user asks to review or audit third-party registry code before installing — use `--view` to inspect the source.
|
||||
|
||||
> **`npx shadcn@latest add --dry-run` vs `npx shadcn@latest view`:** Prefer `npx shadcn@latest add --dry-run/--diff/--view` over `npx shadcn@latest view` when the user wants to preview changes to their project. `npx shadcn@latest view` only shows raw registry metadata. `npx shadcn@latest add --dry-run` shows exactly what would happen in the user's project: resolved file paths, diffs against existing files, and CSS updates. Use `npx shadcn@latest view` only when the user wants to browse registry info without a project context.
|
||||
|
||||
#### Smart Merge from Upstream
|
||||
|
||||
See [Updating Components in SKILL.md](./SKILL.md#updating-components) for the full workflow.
|
||||
|
||||
### `search` — Search registries
|
||||
|
||||
```bash
|
||||
npx shadcn@latest search <registries...> [options]
|
||||
```
|
||||
|
||||
Fuzzy search across registries. Also aliased as `npx shadcn@latest list`. Without `-q`, lists all items.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| ------------------- | ----- | ---------------------- | ------- |
|
||||
| `--query <query>` | `-q` | Search query | — |
|
||||
| `--limit <number>` | `-l` | Max items per registry | `100` |
|
||||
| `--offset <number>` | `-o` | Items to skip | `0` |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
|
||||
### `view` — View item details
|
||||
|
||||
```bash
|
||||
npx shadcn@latest view <items...> [options]
|
||||
```
|
||||
|
||||
Displays item info including file contents. Example: `npx shadcn@latest view @shadcn/button`.
|
||||
|
||||
### `docs` — Get component documentation URLs
|
||||
|
||||
```bash
|
||||
npx shadcn@latest docs <components...> [options]
|
||||
```
|
||||
|
||||
Outputs resolved URLs for component documentation, examples, and API references. Accepts one or more component names. Fetch the URLs to get the actual content.
|
||||
|
||||
Example output for `npx shadcn@latest docs input button`:
|
||||
|
||||
```
|
||||
base radix
|
||||
|
||||
input
|
||||
docs https://ui.shadcn.com/docs/components/radix/input
|
||||
examples https://raw.githubusercontent.com/.../examples/input-example.tsx
|
||||
|
||||
button
|
||||
docs https://ui.shadcn.com/docs/components/radix/button
|
||||
examples https://raw.githubusercontent.com/.../examples/button-example.tsx
|
||||
```
|
||||
|
||||
Some components include an `api` link to the underlying library (e.g. `cmdk` for the command component).
|
||||
|
||||
### `diff` — Check for updates
|
||||
|
||||
Do not use this command. Use `npx shadcn@latest add --diff` instead.
|
||||
|
||||
### `info` — Project information
|
||||
|
||||
```bash
|
||||
npx shadcn@latest info [options]
|
||||
```
|
||||
|
||||
Displays project info and `components.json` configuration. Run this first to discover the project's framework, aliases, Tailwind version, and resolved paths.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| ------------- | ----- | ----------------- | ------- |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
|
||||
**Project Info fields:**
|
||||
|
||||
| Field | Type | Meaning |
|
||||
| -------------------- | --------- | ------------------------------------------------------------------ |
|
||||
| `framework` | `string` | Detected framework (`next`, `vite`, `react-router`, `start`, etc.) |
|
||||
| `frameworkVersion` | `string` | Framework version (e.g. `15.2.4`) |
|
||||
| `isSrcDir` | `boolean` | Whether the project uses a `src/` directory |
|
||||
| `isRSC` | `boolean` | Whether React Server Components are enabled |
|
||||
| `isTsx` | `boolean` | Whether the project uses TypeScript |
|
||||
| `tailwindVersion` | `string` | `"v3"` or `"v4"` |
|
||||
| `tailwindConfigFile` | `string` | Path to the Tailwind config file |
|
||||
| `tailwindCssFile` | `string` | Path to the global CSS file |
|
||||
| `aliasPrefix` | `string` | Import alias prefix (e.g. `@`, `~`, `@/`) |
|
||||
| `packageManager` | `string` | Detected package manager (`npm`, `pnpm`, `yarn`, `bun`) |
|
||||
|
||||
**Components.json fields:**
|
||||
|
||||
| Field | Type | Meaning |
|
||||
| -------------------- | --------- | ------------------------------------------------------------------------------------------ |
|
||||
| `base` | `string` | Primitive library (`radix` or `base`) — determines component APIs and available props |
|
||||
| `style` | `string` | Visual style (e.g. `nova`, `vega`) |
|
||||
| `rsc` | `boolean` | RSC flag from config |
|
||||
| `tsx` | `boolean` | TypeScript flag |
|
||||
| `tailwind.config` | `string` | Tailwind config path |
|
||||
| `tailwind.css` | `string` | Global CSS path — this is where custom CSS variables go |
|
||||
| `iconLibrary` | `string` | Icon library — determines icon import package (e.g. `lucide-react`, `@tabler/icons-react`) |
|
||||
| `aliases.components` | `string` | Component import alias (e.g. `@/components`) |
|
||||
| `aliases.utils` | `string` | Utils import alias (e.g. `@/lib/utils`) |
|
||||
| `aliases.ui` | `string` | UI component alias (e.g. `@/components/ui`) |
|
||||
| `aliases.lib` | `string` | Lib alias (e.g. `@/lib`) |
|
||||
| `aliases.hooks` | `string` | Hooks alias (e.g. `@/hooks`) |
|
||||
| `resolvedPaths` | `object` | Absolute file-system paths for each alias |
|
||||
| `registries` | `object` | Configured custom registries |
|
||||
|
||||
**Links fields:**
|
||||
|
||||
The `info` output includes a **Links** section with templated URLs for component docs, source, and examples. For resolved URLs, use `npx shadcn@latest docs <component>` instead.
|
||||
|
||||
### `build` — Build a custom registry
|
||||
|
||||
```bash
|
||||
npx shadcn@latest build [registry] [options]
|
||||
```
|
||||
|
||||
Builds `registry.json` into individual JSON files for distribution. Default input: `./registry.json`, default output: `./public/r`.
|
||||
|
||||
| Flag | Short | Description | Default |
|
||||
| ----------------- | ----- | ----------------- | ------------ |
|
||||
| `--output <path>` | `-o` | Output directory | `./public/r` |
|
||||
| `--cwd <cwd>` | `-c` | Working directory | current |
|
||||
|
||||
---
|
||||
|
||||
## Templates
|
||||
|
||||
| Value | Framework | Monorepo support |
|
||||
| -------------- | -------------- | ---------------- |
|
||||
| `next` | Next.js | Yes |
|
||||
| `vite` | Vite | Yes |
|
||||
| `start` | TanStack Start | Yes |
|
||||
| `react-router` | React Router | Yes |
|
||||
| `astro` | Astro | Yes |
|
||||
| `laravel` | Laravel | No |
|
||||
|
||||
All templates support monorepo scaffolding via the `--monorepo` flag. When passed, the CLI uses a monorepo-specific template directory (e.g. `next-monorepo`, `vite-monorepo`). When neither `--monorepo` nor `--no-monorepo` is passed, the CLI prompts interactively. Laravel does not support monorepo scaffolding.
|
||||
|
||||
---
|
||||
|
||||
## Presets
|
||||
|
||||
Three ways to specify a preset via `--preset`:
|
||||
|
||||
1. **Named:** `--preset nova` or `--preset lyra`
|
||||
2. **Code:** `--preset a2r6bw` (version-prefixed base62 string, e.g. `a2r6bw` or `b0`)
|
||||
3. **URL:** `--preset "https://ui.shadcn.com/init?base=radix&style=nova&..."`
|
||||
|
||||
> **IMPORTANT:** Never try to decode, fetch, or resolve preset codes manually. Preset codes are opaque — pass them directly to `npx shadcn@latest init --preset <code>` and let the CLI handle resolution.
|
||||
> Use `npx shadcn@latest apply --preset <code>` when overwriting an existing project's preset.
|
||||
|
||||
## Switching Presets
|
||||
|
||||
Ask the user first: **overwrite**, **merge**, or **skip** existing components?
|
||||
|
||||
- **Overwrite / Re-install** → `npx shadcn@latest apply --preset <code>`. Overwrites all detected component files with the new preset styles. Use when the user hasn't customized components.
|
||||
- **Merge** → `npx shadcn@latest init --preset <code> --force --no-reinstall`, then run `npx shadcn@latest info` to get the list of installed components and use the [smart merge workflow](./SKILL.md#updating-components) to update them one by one, preserving local changes. Use when the user has customized components.
|
||||
- **Skip** → `npx shadcn@latest init --preset <code> --force --no-reinstall`. Only updates config and CSS variables, leaves existing components as-is.
|
||||
|
||||
Always run preset commands inside the user's project directory. `apply` only works in an existing project with a `components.json` file. The CLI automatically preserves the current base (`base` vs `radix`) from `components.json`. If you must use a scratch/temp directory (e.g. for `--dry-run` comparisons), pass `--base <current-base>` explicitly — preset codes do not encode the base.
|
||||
@@ -0,0 +1,209 @@
|
||||
# Customization & Theming
|
||||
|
||||
Components reference semantic CSS variable tokens. Change the variables to change every component.
|
||||
|
||||
## Contents
|
||||
|
||||
- How it works (CSS variables → Tailwind utilities → components)
|
||||
- Color variables and OKLCH format
|
||||
- Dark mode setup
|
||||
- Changing the theme (presets, CSS variables)
|
||||
- Adding custom colors (Tailwind v3 and v4)
|
||||
- Border radius
|
||||
- Customizing components (variants, className, wrappers)
|
||||
- Checking for updates
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
1. CSS variables defined in `:root` (light) and `.dark` (dark mode).
|
||||
2. Tailwind maps them to utilities: `bg-primary`, `text-muted-foreground`, etc.
|
||||
3. Components use these utilities — changing a variable changes all components that reference it.
|
||||
|
||||
---
|
||||
|
||||
## Color Variables
|
||||
|
||||
Every color follows the `name` / `name-foreground` convention. The base variable is for backgrounds, `-foreground` is for text/icons on that background.
|
||||
|
||||
| Variable | Purpose |
|
||||
| -------------------------------------------- | -------------------------------- |
|
||||
| `--background` / `--foreground` | Page background and default text |
|
||||
| `--card` / `--card-foreground` | Card surfaces |
|
||||
| `--primary` / `--primary-foreground` | Primary buttons and actions |
|
||||
| `--secondary` / `--secondary-foreground` | Secondary actions |
|
||||
| `--muted` / `--muted-foreground` | Muted/disabled states |
|
||||
| `--accent` / `--accent-foreground` | Hover and accent states |
|
||||
| `--destructive` / `--destructive-foreground` | Error and destructive actions |
|
||||
| `--border` | Default border color |
|
||||
| `--input` | Form input borders |
|
||||
| `--ring` | Focus ring color |
|
||||
| `--chart-1` through `--chart-5` | Chart/data visualization |
|
||||
| `--sidebar-*` | Sidebar-specific colors |
|
||||
| `--surface` / `--surface-foreground` | Secondary surface |
|
||||
|
||||
Colors use OKLCH: `--primary: oklch(0.205 0 0)` where values are lightness (0–1), chroma (0 = gray), and hue (0–360).
|
||||
|
||||
---
|
||||
|
||||
## Dark Mode
|
||||
|
||||
Class-based toggle via `.dark` on the root element. In Next.js, use `next-themes`:
|
||||
|
||||
```tsx
|
||||
import { ThemeProvider } from "next-themes"
|
||||
|
||||
<ThemeProvider attribute="class" defaultTheme="system" enableSystem>
|
||||
{children}
|
||||
</ThemeProvider>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Changing the Theme
|
||||
|
||||
```bash
|
||||
# Apply a preset code from ui.shadcn.com.
|
||||
npx shadcn@latest apply --preset a2r6bw
|
||||
|
||||
# Positional shorthand also works.
|
||||
npx shadcn@latest apply a2r6bw
|
||||
|
||||
# Switch to a named preset and overwrite existing components.
|
||||
npx shadcn@latest apply --preset nova
|
||||
|
||||
# Preserve existing components instead.
|
||||
npx shadcn@latest init --preset nova --force --no-reinstall
|
||||
|
||||
# Use a custom theme URL.
|
||||
npx shadcn@latest apply --preset "https://ui.shadcn.com/init?base=radix&style=nova&theme=blue&..."
|
||||
```
|
||||
|
||||
Or edit CSS variables directly in `globals.css`.
|
||||
|
||||
---
|
||||
|
||||
## Adding Custom Colors
|
||||
|
||||
Add variables to the file at `tailwindCssFile` from `npx shadcn@latest info` (typically `globals.css`). Never create a new CSS file for this.
|
||||
|
||||
```css
|
||||
/* 1. Define in the global CSS file. */
|
||||
:root {
|
||||
--warning: oklch(0.84 0.16 84);
|
||||
--warning-foreground: oklch(0.28 0.07 46);
|
||||
}
|
||||
.dark {
|
||||
--warning: oklch(0.41 0.11 46);
|
||||
--warning-foreground: oklch(0.99 0.02 95);
|
||||
}
|
||||
```
|
||||
|
||||
```css
|
||||
/* 2a. Register with Tailwind v4 (@theme inline). */
|
||||
@theme inline {
|
||||
--color-warning: var(--warning);
|
||||
--color-warning-foreground: var(--warning-foreground);
|
||||
}
|
||||
```
|
||||
|
||||
When `tailwindVersion` is `"v3"` (check via `npx shadcn@latest info`), register in `tailwind.config.js` instead:
|
||||
|
||||
```js
|
||||
// 2b. Register with Tailwind v3 (tailwind.config.js).
|
||||
module.exports = {
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
warning: "oklch(var(--warning) / <alpha-value>)",
|
||||
"warning-foreground":
|
||||
"oklch(var(--warning-foreground) / <alpha-value>)",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
```tsx
|
||||
// 3. Use in components.
|
||||
<div className="bg-warning text-warning-foreground">Warning</div>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Border Radius
|
||||
|
||||
`--radius` controls border radius globally. Components derive values from it (`rounded-lg` = `var(--radius)`, `rounded-md` = `calc(var(--radius) - 2px)`).
|
||||
|
||||
---
|
||||
|
||||
## Customizing Components
|
||||
|
||||
See also: [rules/styling.md](./rules/styling.md) for Incorrect/Correct examples.
|
||||
|
||||
Prefer these approaches in order:
|
||||
|
||||
### 1. Built-in variants
|
||||
|
||||
```tsx
|
||||
<Button variant="outline" size="sm">
|
||||
Click
|
||||
</Button>
|
||||
```
|
||||
|
||||
### 2. Tailwind classes via `className`
|
||||
|
||||
```tsx
|
||||
<Card className="mx-auto max-w-md">...</Card>
|
||||
```
|
||||
|
||||
### 3. Add a new variant
|
||||
|
||||
Edit the component source to add a variant via `cva`:
|
||||
|
||||
```tsx
|
||||
// components/ui/button.tsx
|
||||
warning: "bg-warning text-warning-foreground hover:bg-warning/90",
|
||||
```
|
||||
|
||||
### 4. Wrapper components
|
||||
|
||||
Compose shadcn/ui primitives into higher-level components:
|
||||
|
||||
```tsx
|
||||
export function ConfirmDialog({ title, description, onConfirm, children }) {
|
||||
return (
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>{children}</AlertDialogTrigger>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>{title}</AlertDialogTitle>
|
||||
<AlertDialogDescription>{description}</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction onClick={onConfirm}>Confirm</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Checking for Updates
|
||||
|
||||
```bash
|
||||
npx shadcn@latest add button --diff
|
||||
```
|
||||
|
||||
To preview exactly what would change before updating, use `--dry-run` and `--diff`:
|
||||
|
||||
```bash
|
||||
npx shadcn@latest add button --dry-run # see all affected files
|
||||
npx shadcn@latest add button --diff button.tsx # see the diff for a specific file
|
||||
```
|
||||
|
||||
See [Updating Components in SKILL.md](./SKILL.md#updating-components) for the full smart merge workflow.
|
||||
+94
@@ -0,0 +1,94 @@
|
||||
# shadcn MCP Server
|
||||
|
||||
The CLI includes an MCP server that lets AI assistants search, browse, view, and install components from registries.
|
||||
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
shadcn mcp # start the MCP server (stdio)
|
||||
shadcn mcp init # write config for your editor
|
||||
```
|
||||
|
||||
Editor config files:
|
||||
|
||||
| Editor | Config file |
|
||||
|--------|------------|
|
||||
| Claude Code | `.mcp.json` |
|
||||
| Cursor | `.cursor/mcp.json` |
|
||||
| VS Code | `.vscode/mcp.json` |
|
||||
| OpenCode | `opencode.json` |
|
||||
| Codex | `~/.codex/config.toml` (manual) |
|
||||
|
||||
---
|
||||
|
||||
## Tools
|
||||
|
||||
> **Tip:** MCP tools handle registry operations (search, view, install). For project configuration (aliases, framework, Tailwind version), use `npx shadcn@latest info` — there is no MCP equivalent.
|
||||
|
||||
### `shadcn:get_project_registries`
|
||||
|
||||
Returns registry names from `components.json`. Errors if no `components.json` exists.
|
||||
|
||||
**Input:** none
|
||||
|
||||
### `shadcn:list_items_in_registries`
|
||||
|
||||
Lists all items from one or more registries.
|
||||
|
||||
**Input:** `registries` (string[]), `limit` (number, optional), `offset` (number, optional)
|
||||
|
||||
### `shadcn:search_items_in_registries`
|
||||
|
||||
Fuzzy search across registries.
|
||||
|
||||
**Input:** `registries` (string[]), `query` (string), `limit` (number, optional), `offset` (number, optional)
|
||||
|
||||
### `shadcn:view_items_in_registries`
|
||||
|
||||
View item details including full file contents.
|
||||
|
||||
**Input:** `items` (string[]) — e.g. `["@shadcn/button", "@shadcn/card"]`
|
||||
|
||||
### `shadcn:get_item_examples_from_registries`
|
||||
|
||||
Find usage examples and demos with source code.
|
||||
|
||||
**Input:** `registries` (string[]), `query` (string) — e.g. `"accordion-demo"`, `"button example"`
|
||||
|
||||
### `shadcn:get_add_command_for_items`
|
||||
|
||||
Returns the CLI install command.
|
||||
|
||||
**Input:** `items` (string[]) — e.g. `["@shadcn/button"]`
|
||||
|
||||
### `shadcn:get_audit_checklist`
|
||||
|
||||
Returns a checklist for verifying components (imports, deps, lint, TypeScript).
|
||||
|
||||
**Input:** none
|
||||
|
||||
---
|
||||
|
||||
## Configuring Registries
|
||||
|
||||
Registries are set in `components.json`. The `@shadcn` registry is always built-in.
|
||||
|
||||
```json
|
||||
{
|
||||
"registries": {
|
||||
"@acme": "https://acme.com/r/{name}.json",
|
||||
"@private": {
|
||||
"url": "https://private.com/r/{name}.json",
|
||||
"headers": { "Authorization": "Bearer ${MY_TOKEN}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- Names must start with `@`.
|
||||
- URLs must contain `{name}`.
|
||||
- `${VAR}` references are resolved from environment variables.
|
||||
|
||||
Community registry index: `https://ui.shadcn.com/r/registries.json`
|
||||
@@ -0,0 +1,306 @@
|
||||
# Base vs Radix
|
||||
|
||||
API differences between `base` and `radix`. Check the `base` field from `npx shadcn@latest info`.
|
||||
|
||||
## Contents
|
||||
|
||||
- Composition: asChild vs render
|
||||
- Button / trigger as non-button element
|
||||
- Select (items prop, placeholder, positioning, multiple, object values)
|
||||
- ToggleGroup (type vs multiple)
|
||||
- Slider (scalar vs array)
|
||||
- Accordion (type and defaultValue)
|
||||
|
||||
---
|
||||
|
||||
## Composition: asChild (radix) vs render (base)
|
||||
|
||||
Radix uses `asChild` to replace the default element. Base uses `render`. Don't wrap triggers in extra elements.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<DialogTrigger>
|
||||
<div>
|
||||
<Button>Open</Button>
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
<DialogTrigger asChild>
|
||||
<Button>Open</Button>
|
||||
</DialogTrigger>
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
<DialogTrigger render={<Button />}>Open</DialogTrigger>
|
||||
```
|
||||
|
||||
This applies to all trigger and close components: `DialogTrigger`, `SheetTrigger`, `AlertDialogTrigger`, `DropdownMenuTrigger`, `PopoverTrigger`, `TooltipTrigger`, `CollapsibleTrigger`, `DialogClose`, `SheetClose`, `NavigationMenuLink`, `BreadcrumbLink`, `SidebarMenuButton`, `Badge`, `Item`.
|
||||
|
||||
---
|
||||
|
||||
## Button / trigger as non-button element (base only)
|
||||
|
||||
When `render` changes an element to a non-button (`<a>`, `<span>`), add `nativeButton={false}`.
|
||||
|
||||
**Incorrect (base):** missing `nativeButton={false}`.
|
||||
|
||||
```tsx
|
||||
<Button render={<a href="/docs" />}>Read the docs</Button>
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
<Button render={<a href="/docs" />} nativeButton={false}>
|
||||
Read the docs
|
||||
</Button>
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
<Button asChild>
|
||||
<a href="/docs">Read the docs</a>
|
||||
</Button>
|
||||
```
|
||||
|
||||
Same for triggers whose `render` is not a `Button`:
|
||||
|
||||
```tsx
|
||||
// base.
|
||||
<PopoverTrigger render={<InputGroupAddon />} nativeButton={false}>
|
||||
Pick date
|
||||
</PopoverTrigger>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Select
|
||||
|
||||
**items prop (base only).** Base requires an `items` prop on the root. Radix uses inline JSX only.
|
||||
|
||||
**Incorrect (base):**
|
||||
|
||||
```tsx
|
||||
<Select>
|
||||
<SelectTrigger><SelectValue placeholder="Select a fruit" /></SelectTrigger>
|
||||
</Select>
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
const items = [
|
||||
{ label: "Select a fruit", value: null },
|
||||
{ label: "Apple", value: "apple" },
|
||||
{ label: "Banana", value: "banana" },
|
||||
]
|
||||
|
||||
<Select items={items}>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
{items.map((item) => (
|
||||
<SelectItem key={item.value} value={item.value}>{item.label}</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
<Select>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a fruit" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="apple">Apple</SelectItem>
|
||||
<SelectItem value="banana">Banana</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
```
|
||||
|
||||
**Placeholder.** Base uses a `{ value: null }` item in the items array. Radix uses `<SelectValue placeholder="...">`.
|
||||
|
||||
**Content positioning.** Base uses `alignItemWithTrigger`. Radix uses `position`.
|
||||
|
||||
```tsx
|
||||
// base.
|
||||
<SelectContent alignItemWithTrigger={false} side="bottom">
|
||||
|
||||
// radix.
|
||||
<SelectContent position="popper">
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Select — multiple selection and object values (base only)
|
||||
|
||||
Base supports `multiple`, render-function children on `SelectValue`, and object values with `itemToStringValue`. Radix is single-select with string values only.
|
||||
|
||||
**Correct (base — multiple selection):**
|
||||
|
||||
```tsx
|
||||
<Select items={items} multiple defaultValue={[]}>
|
||||
<SelectTrigger>
|
||||
<SelectValue>
|
||||
{(value: string[]) => value.length === 0 ? "Select fruits" : `${value.length} selected`}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
...
|
||||
</Select>
|
||||
```
|
||||
|
||||
**Correct (base — object values):**
|
||||
|
||||
```tsx
|
||||
<Select defaultValue={plans[0]} itemToStringValue={(plan) => plan.name}>
|
||||
<SelectTrigger>
|
||||
<SelectValue>{(value) => value.name}</SelectValue>
|
||||
</SelectTrigger>
|
||||
...
|
||||
</Select>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ToggleGroup
|
||||
|
||||
Base uses a `multiple` boolean prop. Radix uses `type="single"` or `type="multiple"`.
|
||||
|
||||
**Incorrect (base):**
|
||||
|
||||
```tsx
|
||||
<ToggleGroup type="single" defaultValue="daily">
|
||||
<ToggleGroupItem value="daily">Daily</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
// Single (no prop needed), defaultValue is always an array.
|
||||
<ToggleGroup defaultValue={["daily"]} spacing={2}>
|
||||
<ToggleGroupItem value="daily">Daily</ToggleGroupItem>
|
||||
<ToggleGroupItem value="weekly">Weekly</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
|
||||
// Multi-selection.
|
||||
<ToggleGroup multiple>
|
||||
<ToggleGroupItem value="bold">Bold</ToggleGroupItem>
|
||||
<ToggleGroupItem value="italic">Italic</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
// Single, defaultValue is a string.
|
||||
<ToggleGroup type="single" defaultValue="daily" spacing={2}>
|
||||
<ToggleGroupItem value="daily">Daily</ToggleGroupItem>
|
||||
<ToggleGroupItem value="weekly">Weekly</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
|
||||
// Multi-selection.
|
||||
<ToggleGroup type="multiple">
|
||||
<ToggleGroupItem value="bold">Bold</ToggleGroupItem>
|
||||
<ToggleGroupItem value="italic">Italic</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
```
|
||||
|
||||
**Controlled single value:**
|
||||
|
||||
```tsx
|
||||
// base — wrap/unwrap arrays.
|
||||
const [value, setValue] = React.useState("normal")
|
||||
<ToggleGroup value={[value]} onValueChange={(v) => setValue(v[0])}>
|
||||
|
||||
// radix — plain string.
|
||||
const [value, setValue] = React.useState("normal")
|
||||
<ToggleGroup type="single" value={value} onValueChange={setValue}>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Slider
|
||||
|
||||
Base accepts a plain number for a single thumb. Radix always requires an array.
|
||||
|
||||
**Incorrect (base):**
|
||||
|
||||
```tsx
|
||||
<Slider defaultValue={[50]} max={100} step={1} />
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
<Slider defaultValue={50} max={100} step={1} />
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
<Slider defaultValue={[50]} max={100} step={1} />
|
||||
```
|
||||
|
||||
Both use arrays for range sliders. Controlled `onValueChange` in base may need a cast:
|
||||
|
||||
```tsx
|
||||
// base.
|
||||
const [value, setValue] = React.useState([0.3, 0.7])
|
||||
<Slider value={value} onValueChange={(v) => setValue(v as number[])} />
|
||||
|
||||
// radix.
|
||||
const [value, setValue] = React.useState([0.3, 0.7])
|
||||
<Slider value={value} onValueChange={setValue} />
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Accordion
|
||||
|
||||
Radix requires `type="single"` or `type="multiple"` and supports `collapsible`. `defaultValue` is a string. Base uses no `type` prop, uses `multiple` boolean, and `defaultValue` is always an array.
|
||||
|
||||
**Incorrect (base):**
|
||||
|
||||
```tsx
|
||||
<Accordion type="single" collapsible defaultValue="item-1">
|
||||
<AccordionItem value="item-1">...</AccordionItem>
|
||||
</Accordion>
|
||||
```
|
||||
|
||||
**Correct (base):**
|
||||
|
||||
```tsx
|
||||
<Accordion defaultValue={["item-1"]}>
|
||||
<AccordionItem value="item-1">...</AccordionItem>
|
||||
</Accordion>
|
||||
|
||||
// Multi-select.
|
||||
<Accordion multiple defaultValue={["item-1", "item-2"]}>
|
||||
<AccordionItem value="item-1">...</AccordionItem>
|
||||
<AccordionItem value="item-2">...</AccordionItem>
|
||||
</Accordion>
|
||||
```
|
||||
|
||||
**Correct (radix):**
|
||||
|
||||
```tsx
|
||||
<Accordion type="single" collapsible defaultValue="item-1">
|
||||
<AccordionItem value="item-1">...</AccordionItem>
|
||||
</Accordion>
|
||||
```
|
||||
@@ -0,0 +1,195 @@
|
||||
# Component Composition
|
||||
|
||||
## Contents
|
||||
|
||||
- Items always inside their Group component
|
||||
- Callouts use Alert
|
||||
- Empty states use Empty component
|
||||
- Toast notifications use sonner
|
||||
- Choosing between overlay components
|
||||
- Dialog, Sheet, and Drawer always need a Title
|
||||
- Card structure
|
||||
- Button has no isPending or isLoading prop
|
||||
- TabsTrigger must be inside TabsList
|
||||
- Avatar always needs AvatarFallback
|
||||
- Use Separator instead of raw hr or border divs
|
||||
- Use Skeleton for loading placeholders
|
||||
- Use Badge instead of custom styled spans
|
||||
|
||||
---
|
||||
|
||||
## Items always inside their Group component
|
||||
|
||||
Never render items directly inside the content container.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<SelectContent>
|
||||
<SelectItem value="apple">Apple</SelectItem>
|
||||
<SelectItem value="banana">Banana</SelectItem>
|
||||
</SelectContent>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="apple">Apple</SelectItem>
|
||||
<SelectItem value="banana">Banana</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
```
|
||||
|
||||
This applies to all group-based components:
|
||||
|
||||
| Item | Group |
|
||||
|------|-------|
|
||||
| `SelectItem`, `SelectLabel` | `SelectGroup` |
|
||||
| `DropdownMenuItem`, `DropdownMenuLabel`, `DropdownMenuSub` | `DropdownMenuGroup` |
|
||||
| `MenubarItem` | `MenubarGroup` |
|
||||
| `ContextMenuItem` | `ContextMenuGroup` |
|
||||
| `CommandItem` | `CommandGroup` |
|
||||
|
||||
---
|
||||
|
||||
## Callouts use Alert
|
||||
|
||||
```tsx
|
||||
<Alert>
|
||||
<AlertTitle>Warning</AlertTitle>
|
||||
<AlertDescription>Something needs attention.</AlertDescription>
|
||||
</Alert>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Empty states use Empty component
|
||||
|
||||
```tsx
|
||||
<Empty>
|
||||
<EmptyHeader>
|
||||
<EmptyMedia variant="icon"><FolderIcon /></EmptyMedia>
|
||||
<EmptyTitle>No projects yet</EmptyTitle>
|
||||
<EmptyDescription>Get started by creating a new project.</EmptyDescription>
|
||||
</EmptyHeader>
|
||||
<EmptyContent>
|
||||
<Button>Create Project</Button>
|
||||
</EmptyContent>
|
||||
</Empty>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Toast notifications use sonner
|
||||
|
||||
```tsx
|
||||
import { toast } from "sonner"
|
||||
|
||||
toast.success("Changes saved.")
|
||||
toast.error("Something went wrong.")
|
||||
toast("File deleted.", {
|
||||
action: { label: "Undo", onClick: () => undoDelete() },
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Choosing between overlay components
|
||||
|
||||
| Use case | Component |
|
||||
|----------|-----------|
|
||||
| Focused task that requires input | `Dialog` |
|
||||
| Destructive action confirmation | `AlertDialog` |
|
||||
| Side panel with details or filters | `Sheet` |
|
||||
| Mobile-first bottom panel | `Drawer` |
|
||||
| Quick info on hover | `HoverCard` |
|
||||
| Small contextual content on click | `Popover` |
|
||||
|
||||
---
|
||||
|
||||
## Dialog, Sheet, and Drawer always need a Title
|
||||
|
||||
`DialogTitle`, `SheetTitle`, `DrawerTitle` are required for accessibility. Use `className="sr-only"` if visually hidden.
|
||||
|
||||
```tsx
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Edit Profile</DialogTitle>
|
||||
<DialogDescription>Update your profile.</DialogDescription>
|
||||
</DialogHeader>
|
||||
...
|
||||
</DialogContent>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Card structure
|
||||
|
||||
Use full composition — don't dump everything into `CardContent`:
|
||||
|
||||
```tsx
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Team Members</CardTitle>
|
||||
<CardDescription>Manage your team.</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>...</CardContent>
|
||||
<CardFooter>
|
||||
<Button>Invite</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Button has no isPending or isLoading prop
|
||||
|
||||
Compose with `Spinner` + `data-icon` + `disabled`:
|
||||
|
||||
```tsx
|
||||
<Button disabled>
|
||||
<Spinner data-icon="inline-start" />
|
||||
Saving...
|
||||
</Button>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TabsTrigger must be inside TabsList
|
||||
|
||||
Never render `TabsTrigger` directly inside `Tabs` — always wrap in `TabsList`:
|
||||
|
||||
```tsx
|
||||
<Tabs defaultValue="account">
|
||||
<TabsList>
|
||||
<TabsTrigger value="account">Account</TabsTrigger>
|
||||
<TabsTrigger value="password">Password</TabsTrigger>
|
||||
</TabsList>
|
||||
<TabsContent value="account">...</TabsContent>
|
||||
</Tabs>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Avatar always needs AvatarFallback
|
||||
|
||||
Always include `AvatarFallback` for when the image fails to load:
|
||||
|
||||
```tsx
|
||||
<Avatar>
|
||||
<AvatarImage src="/avatar.png" alt="User" />
|
||||
<AvatarFallback>JD</AvatarFallback>
|
||||
</Avatar>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Use existing components instead of custom markup
|
||||
|
||||
| Instead of | Use |
|
||||
|---|---|
|
||||
| `<hr>` or `<div className="border-t">` | `<Separator />` |
|
||||
| `<div className="animate-pulse">` with styled divs | `<Skeleton className="h-4 w-3/4" />` |
|
||||
| `<span className="rounded-full bg-green-100 ...">` | `<Badge variant="secondary">` |
|
||||
@@ -0,0 +1,192 @@
|
||||
# Forms & Inputs
|
||||
|
||||
## Contents
|
||||
|
||||
- Forms use FieldGroup + Field
|
||||
- InputGroup requires InputGroupInput/InputGroupTextarea
|
||||
- Buttons inside inputs use InputGroup + InputGroupAddon
|
||||
- Option sets (2–7 choices) use ToggleGroup
|
||||
- FieldSet + FieldLegend for grouping related fields
|
||||
- Field validation and disabled states
|
||||
|
||||
---
|
||||
|
||||
## Forms use FieldGroup + Field
|
||||
|
||||
Always use `FieldGroup` + `Field` — never raw `div` with `space-y-*`:
|
||||
|
||||
```tsx
|
||||
<FieldGroup>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="email">Email</FieldLabel>
|
||||
<Input id="email" type="email" />
|
||||
</Field>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="password">Password</FieldLabel>
|
||||
<Input id="password" type="password" />
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
```
|
||||
|
||||
Use `Field orientation="horizontal"` for settings pages. Use `FieldLabel className="sr-only"` for visually hidden labels.
|
||||
|
||||
**Choosing form controls:**
|
||||
|
||||
- Simple text input → `Input`
|
||||
- Dropdown with predefined options → `Select`
|
||||
- Searchable dropdown → `Combobox`
|
||||
- Native HTML select (no JS) → `native-select`
|
||||
- Boolean toggle → `Switch` (for settings) or `Checkbox` (for forms)
|
||||
- Single choice from few options → `RadioGroup`
|
||||
- Toggle between 2–5 options → `ToggleGroup` + `ToggleGroupItem`
|
||||
- OTP/verification code → `InputOTP`
|
||||
- Multi-line text → `Textarea`
|
||||
|
||||
---
|
||||
|
||||
## InputGroup requires InputGroupInput/InputGroupTextarea
|
||||
|
||||
Never use raw `Input` or `Textarea` inside an `InputGroup`.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<InputGroup>
|
||||
<Input placeholder="Search..." />
|
||||
</InputGroup>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
import { InputGroup, InputGroupInput } from "@/components/ui/input-group"
|
||||
|
||||
<InputGroup>
|
||||
<InputGroupInput placeholder="Search..." />
|
||||
</InputGroup>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Buttons inside inputs use InputGroup + InputGroupAddon
|
||||
|
||||
Never place a `Button` directly inside or adjacent to an `Input` with custom positioning.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<div className="relative">
|
||||
<Input placeholder="Search..." className="pr-10" />
|
||||
<Button className="absolute right-0 top-0" size="icon">
|
||||
<SearchIcon />
|
||||
</Button>
|
||||
</div>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
import { InputGroup, InputGroupInput, InputGroupAddon } from "@/components/ui/input-group"
|
||||
|
||||
<InputGroup>
|
||||
<InputGroupInput placeholder="Search..." />
|
||||
<InputGroupAddon>
|
||||
<Button size="icon">
|
||||
<SearchIcon data-icon="inline-start" />
|
||||
</Button>
|
||||
</InputGroupAddon>
|
||||
</InputGroup>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Option sets (2–7 choices) use ToggleGroup
|
||||
|
||||
Don't manually loop `Button` components with active state.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
const [selected, setSelected] = useState("daily")
|
||||
|
||||
<div className="flex gap-2">
|
||||
{["daily", "weekly", "monthly"].map((option) => (
|
||||
<Button
|
||||
key={option}
|
||||
variant={selected === option ? "default" : "outline"}
|
||||
onClick={() => setSelected(option)}
|
||||
>
|
||||
{option}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"
|
||||
|
||||
<ToggleGroup spacing={2}>
|
||||
<ToggleGroupItem value="daily">Daily</ToggleGroupItem>
|
||||
<ToggleGroupItem value="weekly">Weekly</ToggleGroupItem>
|
||||
<ToggleGroupItem value="monthly">Monthly</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
```
|
||||
|
||||
Combine with `Field` for labelled toggle groups:
|
||||
|
||||
```tsx
|
||||
<Field orientation="horizontal">
|
||||
<FieldTitle id="theme-label">Theme</FieldTitle>
|
||||
<ToggleGroup aria-labelledby="theme-label" spacing={2}>
|
||||
<ToggleGroupItem value="light">Light</ToggleGroupItem>
|
||||
<ToggleGroupItem value="dark">Dark</ToggleGroupItem>
|
||||
<ToggleGroupItem value="system">System</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
</Field>
|
||||
```
|
||||
|
||||
> **Note:** `defaultValue` and `type`/`multiple` props differ between base and radix. See [base-vs-radix.md](./base-vs-radix.md#togglegroup).
|
||||
|
||||
---
|
||||
|
||||
## FieldSet + FieldLegend for grouping related fields
|
||||
|
||||
Use `FieldSet` + `FieldLegend` for related checkboxes, radios, or switches — not `div` with a heading:
|
||||
|
||||
```tsx
|
||||
<FieldSet>
|
||||
<FieldLegend variant="label">Preferences</FieldLegend>
|
||||
<FieldDescription>Select all that apply.</FieldDescription>
|
||||
<FieldGroup className="gap-3">
|
||||
<Field orientation="horizontal">
|
||||
<Checkbox id="dark" />
|
||||
<FieldLabel htmlFor="dark" className="font-normal">Dark mode</FieldLabel>
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
</FieldSet>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field validation and disabled states
|
||||
|
||||
Both attributes are needed — `data-invalid`/`data-disabled` styles the field (label, description), while `aria-invalid`/`disabled` styles the control.
|
||||
|
||||
```tsx
|
||||
// Invalid.
|
||||
<Field data-invalid>
|
||||
<FieldLabel htmlFor="email">Email</FieldLabel>
|
||||
<Input id="email" aria-invalid />
|
||||
<FieldDescription>Invalid email address.</FieldDescription>
|
||||
</Field>
|
||||
|
||||
// Disabled.
|
||||
<Field data-disabled>
|
||||
<FieldLabel htmlFor="email">Email</FieldLabel>
|
||||
<Input id="email" disabled />
|
||||
</Field>
|
||||
```
|
||||
|
||||
Works for all controls: `Input`, `Textarea`, `Select`, `Checkbox`, `RadioGroupItem`, `Switch`, `Slider`, `NativeSelect`, `InputOTP`.
|
||||
@@ -0,0 +1,101 @@
|
||||
# Icons
|
||||
|
||||
**Always use the project's configured `iconLibrary` for imports.** Check the `iconLibrary` field from project context: `lucide` → `lucide-react`, `tabler` → `@tabler/icons-react`, etc. Never assume `lucide-react`.
|
||||
|
||||
---
|
||||
|
||||
## Icons in Button use data-icon attribute
|
||||
|
||||
Add `data-icon="inline-start"` (prefix) or `data-icon="inline-end"` (suffix) to the icon. No sizing classes on the icon.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<Button>
|
||||
<SearchIcon className="mr-2 size-4" />
|
||||
Search
|
||||
</Button>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<Button>
|
||||
<SearchIcon data-icon="inline-start"/>
|
||||
Search
|
||||
</Button>
|
||||
|
||||
<Button>
|
||||
Next
|
||||
<ArrowRightIcon data-icon="inline-end"/>
|
||||
</Button>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## No sizing classes on icons inside components
|
||||
|
||||
Components handle icon sizing via CSS. Don't add `size-4`, `w-4 h-4`, or other sizing classes to icons inside `Button`, `DropdownMenuItem`, `Alert`, `Sidebar*`, or other shadcn components. Unless the user explicitly asks for custom icon sizes.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<Button>
|
||||
<SearchIcon className="size-4" data-icon="inline-start" />
|
||||
Search
|
||||
</Button>
|
||||
|
||||
<DropdownMenuItem>
|
||||
<SettingsIcon className="mr-2 size-4" />
|
||||
Settings
|
||||
</DropdownMenuItem>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<Button>
|
||||
<SearchIcon data-icon="inline-start" />
|
||||
Search
|
||||
</Button>
|
||||
|
||||
<DropdownMenuItem>
|
||||
<SettingsIcon />
|
||||
Settings
|
||||
</DropdownMenuItem>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pass icons as component objects, not string keys
|
||||
|
||||
Use `icon={CheckIcon}`, not a string key to a lookup map.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
const iconMap = {
|
||||
check: CheckIcon,
|
||||
alert: AlertIcon,
|
||||
}
|
||||
|
||||
function StatusBadge({ icon }: { icon: string }) {
|
||||
const Icon = iconMap[icon]
|
||||
return <Icon />
|
||||
}
|
||||
|
||||
<StatusBadge icon="check" />
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
// Import from the project's configured iconLibrary (e.g. lucide-react, @tabler/icons-react).
|
||||
import { CheckIcon } from "lucide-react"
|
||||
|
||||
function StatusBadge({ icon: Icon }: { icon: React.ComponentType }) {
|
||||
return <Icon />
|
||||
}
|
||||
|
||||
<StatusBadge icon={CheckIcon} />
|
||||
```
|
||||
@@ -0,0 +1,162 @@
|
||||
# Styling & Customization
|
||||
|
||||
See [customization.md](../customization.md) for theming, CSS variables, and adding custom colors.
|
||||
|
||||
## Contents
|
||||
|
||||
- Semantic colors
|
||||
- Built-in variants first
|
||||
- className for layout only
|
||||
- No space-x-* / space-y-*
|
||||
- Prefer size-* over w-* h-* when equal
|
||||
- Prefer truncate shorthand
|
||||
- No manual dark: color overrides
|
||||
- Use cn() for conditional classes
|
||||
- No manual z-index on overlay components
|
||||
|
||||
---
|
||||
|
||||
## Semantic colors
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<div className="bg-blue-500 text-white">
|
||||
<p className="text-gray-600">Secondary text</p>
|
||||
</div>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<div className="bg-primary text-primary-foreground">
|
||||
<p className="text-muted-foreground">Secondary text</p>
|
||||
</div>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## No raw color values for status/state indicators
|
||||
|
||||
For positive, negative, or status indicators, use Badge variants, semantic tokens like `text-destructive`, or define custom CSS variables — don't reach for raw Tailwind colors.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<span className="text-emerald-600">+20.1%</span>
|
||||
<span className="text-green-500">Active</span>
|
||||
<span className="text-red-600">-3.2%</span>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<Badge variant="secondary">+20.1%</Badge>
|
||||
<Badge>Active</Badge>
|
||||
<span className="text-destructive">-3.2%</span>
|
||||
```
|
||||
|
||||
If you need a success/positive color that doesn't exist as a semantic token, use a Badge variant or ask the user about adding a custom CSS variable to the theme (see [customization.md](../customization.md)).
|
||||
|
||||
---
|
||||
|
||||
## Built-in variants first
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<Button className="border border-input bg-transparent hover:bg-accent">
|
||||
Click me
|
||||
</Button>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<Button variant="outline">Click me</Button>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## className for layout only
|
||||
|
||||
Use `className` for layout (e.g. `max-w-md`, `mx-auto`, `mt-4`), **not** for overriding component colors or typography. To change colors, use semantic tokens, built-in variants, or CSS variables.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<Card className="bg-blue-100 text-blue-900 font-bold">
|
||||
<CardContent>Dashboard</CardContent>
|
||||
</Card>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
<Card className="max-w-md mx-auto">
|
||||
<CardContent>Dashboard</CardContent>
|
||||
</Card>
|
||||
```
|
||||
|
||||
To customize a component's appearance, prefer these approaches in order:
|
||||
1. **Built-in variants** — `variant="outline"`, `variant="destructive"`, etc.
|
||||
2. **Semantic color tokens** — `bg-primary`, `text-muted-foreground`.
|
||||
3. **CSS variables** — define custom colors in the global CSS file (see [customization.md](../customization.md)).
|
||||
|
||||
---
|
||||
|
||||
## No space-x-* / space-y-*
|
||||
|
||||
Use `gap-*` instead. `space-y-4` → `flex flex-col gap-4`. `space-x-2` → `flex gap-2`.
|
||||
|
||||
```tsx
|
||||
<div className="flex flex-col gap-4">
|
||||
<Input />
|
||||
<Input />
|
||||
<Button>Submit</Button>
|
||||
</div>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Prefer size-* over w-* h-* when equal
|
||||
|
||||
`size-10` not `w-10 h-10`. Applies to icons, avatars, skeletons, etc.
|
||||
|
||||
---
|
||||
|
||||
## Prefer truncate shorthand
|
||||
|
||||
`truncate` not `overflow-hidden text-ellipsis whitespace-nowrap`.
|
||||
|
||||
---
|
||||
|
||||
## No manual dark: color overrides
|
||||
|
||||
Use semantic tokens — they handle light/dark via CSS variables. `bg-background text-foreground` not `bg-white dark:bg-gray-950`.
|
||||
|
||||
---
|
||||
|
||||
## Use cn() for conditional classes
|
||||
|
||||
Use the `cn()` utility from the project for conditional or merged class names. Don't write manual ternaries in className strings.
|
||||
|
||||
**Incorrect:**
|
||||
|
||||
```tsx
|
||||
<div className={`flex items-center ${isActive ? "bg-primary text-primary-foreground" : "bg-muted"}`}>
|
||||
```
|
||||
|
||||
**Correct:**
|
||||
|
||||
```tsx
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
<div className={cn("flex items-center", isActive ? "bg-primary text-primary-foreground" : "bg-muted")}>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## No manual z-index on overlay components
|
||||
|
||||
`Dialog`, `Sheet`, `Drawer`, `AlertDialog`, `DropdownMenu`, `Popover`, `Tooltip`, `HoverCard` handle their own stacking. Never add `z-50` or `z-[999]`.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
+12
-2
@@ -1,14 +1,24 @@
|
||||
# Security Policy
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Bulk Reporting Policy:** If you need to submit multiple vulnerability reports in bulk, **you must contact us first** ([support@quantumnous.com](mailto:support@quantumnous.com)) to coordinate the submission process. Uncoordinated bulk submissions have caused significant disruption to our team, and we will take the following actions:
|
||||
>
|
||||
> 1. **All uncoordinated bulk reports will be closed without review.**
|
||||
> 2. **Repeated offenders may be blocked** from further submissions.
|
||||
>
|
||||
> We welcome thorough security research, but please reach out before submitting multiple reports.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
We provide security updates for the following versions:
|
||||
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| Latest | :white_check_mark: |
|
||||
| Older | :x: |
|
||||
|
||||
|
||||
We strongly recommend that users always use the latest version for the best security and features.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
@@ -23,7 +33,7 @@ To report a security issue, please use the GitHub Security Advisories tab to "[O
|
||||
|
||||
Alternatively, you can report via email:
|
||||
|
||||
- **Email:** support@quantumnous.com
|
||||
- **Email:** [support@quantumnous.com](mailto:support@quantumnous.com)
|
||||
- **Subject:** `[SECURITY] Security Vulnerability Report`
|
||||
|
||||
### What to Include
|
||||
@@ -83,4 +93,4 @@ For detailed configuration instructions, please refer to the project documentati
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This project is provided "as is" without any express or implied warranty. Users should assess the security risks of using this software in their environment.
|
||||
This project is provided "as is" without any express or implied warranty. Users should assess the security risks of using this software in their environment.
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Publish Docker image (Multi Registries, native amd64+arm64)
|
||||
name: Publish Docker image (Multi-arch)
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -14,7 +14,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build_single_arch:
|
||||
name: Build & push (${{ matrix.arch }}) [native]
|
||||
name: Build & push (${{ matrix.arch }})
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -26,6 +26,8 @@ jobs:
|
||||
platform: linux/arm64
|
||||
runner: ubuntu-24.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
outputs:
|
||||
tag: ${{ steps.version.outputs.tag }}
|
||||
|
||||
permissions:
|
||||
packages: write
|
||||
@@ -34,58 +36,46 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: ${{ github.event_name == 'workflow_dispatch' && 0 || 1 }}
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
|
||||
- name: Resolve tag & write VERSION
|
||||
id: version
|
||||
run: |
|
||||
if [ -n "${{ github.event.inputs.tag }}" ]; then
|
||||
TAG="${{ github.event.inputs.tag }}"
|
||||
# Verify tag exists
|
||||
if ! git rev-parse "refs/tags/$TAG" >/dev/null 2>&1; then
|
||||
echo "Error: Tag '$TAG' does not exist in the repository"
|
||||
echo "::error::Tag '$TAG' does not exist"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
TAG=${GITHUB_REF#refs/tags/}
|
||||
fi
|
||||
echo "TAG=$TAG" >> $GITHUB_ENV
|
||||
echo "$TAG" > VERSION
|
||||
echo "Building tag: $TAG for ${{ matrix.arch }}"
|
||||
|
||||
|
||||
# - name: Normalize GHCR repository
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
echo "TAG=${TAG}" >> $GITHUB_ENV
|
||||
echo "tag=${TAG}" >> $GITHUB_OUTPUT
|
||||
echo "${TAG}" > VERSION
|
||||
echo "Building tag: ${TAG} for ${{ matrix.arch }}"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
# password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (labels)
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}
|
||||
images: calciumion/new-api
|
||||
|
||||
- name: Build & push single-arch (to both registries)
|
||||
- name: Build & push
|
||||
id: build
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -93,8 +83,6 @@ jobs:
|
||||
tags: |
|
||||
calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}
|
||||
calciumion/new-api:latest-${{ matrix.arch }}
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}:${{ env.TAG }}-${{ matrix.arch }}
|
||||
# ghcr.io/${{ env.GHCR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -102,81 +90,52 @@ jobs:
|
||||
sbom: true
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@398d4b0eeef1380460a10c8013a76f728fb906ac # v3
|
||||
uses: sigstore/cosign-installer@v3
|
||||
|
||||
- name: Sign image with cosign
|
||||
run: cosign sign --yes calciumion/new-api@${{ steps.build.outputs.digest }}
|
||||
|
||||
- name: Output digest
|
||||
- name: Image summary
|
||||
run: |
|
||||
echo "### Docker Image Digest (${{ matrix.arch }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "calciumion/new-api:${TAG}-${{ matrix.arch }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.build.outputs.digest }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
create_manifests:
|
||||
name: Create multi-arch manifests (Docker Hub)
|
||||
name: Create multi-arch manifests
|
||||
needs: [build_single_arch]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch'
|
||||
|
||||
steps:
|
||||
- name: Extract tag
|
||||
run: |
|
||||
if [ -n "${{ github.event.inputs.tag }}" ]; then
|
||||
echo "TAG=${{ github.event.inputs.tag }}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
|
||||
fi
|
||||
#
|
||||
# - name: Normalize GHCR repository
|
||||
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
|
||||
- name: Set version
|
||||
run: echo "TAG=${{ needs.build_single_arch.outputs.tag }}" >> $GITHUB_ENV
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Create & push manifest (Docker Hub - version)
|
||||
- name: Create & push manifest (version)
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
-t calciumion/new-api:${TAG} \
|
||||
calciumion/new-api:${TAG}-amd64 \
|
||||
calciumion/new-api:${TAG}-arm64
|
||||
|
||||
- name: Create & push manifest (Docker Hub - latest)
|
||||
- name: Create & push manifest (latest)
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
-t calciumion/new-api:latest \
|
||||
calciumion/new-api:latest-amd64 \
|
||||
calciumion/new-api:latest-arm64
|
||||
|
||||
- name: Output manifest digest
|
||||
- name: Manifest summary
|
||||
run: |
|
||||
echo "### Multi-arch Manifest" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
docker buildx imagetools inspect calciumion/new-api:${TAG} >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# ---- GHCR ----
|
||||
# - name: Log in to GHCR
|
||||
# uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
# with:
|
||||
# registry: ghcr.io
|
||||
# username: ${{ github.actor }}
|
||||
# password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# - name: Create & push manifest (GHCR - version)
|
||||
# run: |
|
||||
# docker buildx imagetools create \
|
||||
# -t ghcr.io/${GHCR_REPOSITORY}:${TAG} \
|
||||
# ghcr.io/${GHCR_REPOSITORY}:${TAG}-amd64 \
|
||||
# ghcr.io/${GHCR_REPOSITORY}:${TAG}-arm64
|
||||
#
|
||||
# - name: Create & push manifest (GHCR - latest)
|
||||
# run: |
|
||||
# docker buildx imagetools create \
|
||||
# -t ghcr.io/${GHCR_REPOSITORY}:latest \
|
||||
# ghcr.io/${GHCR_REPOSITORY}:latest-amd64 \
|
||||
# ghcr.io/${GHCR_REPOSITORY}:latest-arm64
|
||||
@@ -29,14 +29,22 @@ jobs:
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
cd web/default
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
@@ -78,15 +86,23 @@ jobs:
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (default)
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
cd web/default
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
@@ -126,14 +142,22 @@ jobs:
|
||||
- uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
cd web/default
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
cd ../..
|
||||
- name: Build Frontend (classic)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web/classic
|
||||
bun install
|
||||
VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
|
||||
+4
-1
@@ -8,6 +8,9 @@ upload
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
||||
web/default/dist
|
||||
web/classic/dist
|
||||
web/node_modules
|
||||
web/dist
|
||||
.env
|
||||
one-api
|
||||
@@ -19,9 +22,9 @@ tiktoken_cache
|
||||
.gocache
|
||||
.gomodcache/
|
||||
.cache
|
||||
web/bun.lock
|
||||
plans
|
||||
.claude
|
||||
.cursor
|
||||
|
||||
electron/node_modules
|
||||
electron/dist
|
||||
|
||||
@@ -7,7 +7,7 @@ This is an AI API gateway/proxy built with Go. It aggregates 40+ upstream AI pro
|
||||
## Tech Stack
|
||||
|
||||
- **Backend**: Go 1.22+, Gin web framework, GORM v2 ORM
|
||||
- **Frontend**: React 18, Vite, Semi Design UI (@douyinfe/semi-ui)
|
||||
- **Frontend**: React 19, TypeScript, Rsbuild, Base UI, Tailwind CSS
|
||||
- **Databases**: SQLite, MySQL, PostgreSQL (all three must be supported)
|
||||
- **Cache**: Redis (go-redis) + in-memory cache
|
||||
- **Auth**: JWT, WebAuthn/Passkeys, OAuth (GitHub, Discord, OIDC, etc.)
|
||||
@@ -33,8 +33,10 @@ types/ — Type definitions (relay formats, file sources, errors)
|
||||
i18n/ — Backend internationalization (go-i18n, en/zh)
|
||||
oauth/ — OAuth provider implementations
|
||||
pkg/ — Internal packages (cachex, ionet)
|
||||
web/ — React frontend
|
||||
web/src/i18n/ — Frontend internationalization (i18next, zh/en/fr/ru/ja/vi)
|
||||
web/ — Frontend themes container
|
||||
web/default/ — Default frontend (React 19, Rsbuild, Base UI, Tailwind)
|
||||
web/classic/ — Classic frontend (React 18, Vite, Semi Design)
|
||||
web/default/src/i18n/ — Frontend internationalization (i18next, zh/en/fr/ru/ja/vi)
|
||||
```
|
||||
|
||||
## Internationalization (i18n)
|
||||
@@ -43,13 +45,12 @@ web/ — React frontend
|
||||
- Library: `nicksnyder/go-i18n/v2`
|
||||
- Languages: en, zh
|
||||
|
||||
### Frontend (`web/src/i18n/`)
|
||||
### Frontend (`web/default/src/i18n/`)
|
||||
- Library: `i18next` + `react-i18next` + `i18next-browser-languagedetector`
|
||||
- Languages: zh (fallback), en, fr, ru, ja, vi
|
||||
- Translation files: `web/src/i18n/locales/{lang}.json` — flat JSON, keys are Chinese source strings
|
||||
- Usage: `useTranslation()` hook, call `t('中文key')` in components
|
||||
- Semi UI locale synced via `SemiLocaleWrapper`
|
||||
- CLI tools: `bun run i18n:extract`, `bun run i18n:sync`, `bun run i18n:lint`
|
||||
- Languages: en (base), zh (fallback), fr, ru, ja, vi
|
||||
- Translation files: `web/default/src/i18n/locales/{lang}.json` — flat JSON, keys are English source strings
|
||||
- Usage: `useTranslation()` hook, call `t('English key')` in components
|
||||
- CLI tools: `bun run i18n:sync` (from `web/default/`)
|
||||
|
||||
## Rules
|
||||
|
||||
@@ -93,7 +94,7 @@ All database code MUST be fully compatible with all three databases simultaneous
|
||||
|
||||
### Rule 3: Frontend — Prefer Bun
|
||||
|
||||
Use `bun` as the preferred package manager and script runner for the frontend (`web/` directory):
|
||||
Use `bun` as the preferred package manager and script runner for the frontend (`web/default/` directory):
|
||||
- `bun install` for dependency installation
|
||||
- `bun run dev` for development server
|
||||
- `bun run build` for production build
|
||||
|
||||
@@ -7,7 +7,7 @@ This is an AI API gateway/proxy built with Go. It aggregates 40+ upstream AI pro
|
||||
## Tech Stack
|
||||
|
||||
- **Backend**: Go 1.22+, Gin web framework, GORM v2 ORM
|
||||
- **Frontend**: React 18, Vite, Semi Design UI (@douyinfe/semi-ui)
|
||||
- **Frontend**: React 19, TypeScript, Rsbuild, Base UI, Tailwind CSS
|
||||
- **Databases**: SQLite, MySQL, PostgreSQL (all three must be supported)
|
||||
- **Cache**: Redis (go-redis) + in-memory cache
|
||||
- **Auth**: JWT, WebAuthn/Passkeys, OAuth (GitHub, Discord, OIDC, etc.)
|
||||
@@ -33,8 +33,10 @@ types/ — Type definitions (relay formats, file sources, errors)
|
||||
i18n/ — Backend internationalization (go-i18n, en/zh)
|
||||
oauth/ — OAuth provider implementations
|
||||
pkg/ — Internal packages (cachex, ionet)
|
||||
web/ — React frontend
|
||||
web/src/i18n/ — Frontend internationalization (i18next, zh/en/fr/ru/ja/vi)
|
||||
web/ — Frontend themes container
|
||||
web/default/ — Default frontend (React 19, Rsbuild, Base UI, Tailwind)
|
||||
web/classic/ — Classic frontend (React 18, Vite, Semi Design)
|
||||
web/default/src/i18n/ — Frontend internationalization (i18next, zh/en/fr/ru/ja/vi)
|
||||
```
|
||||
|
||||
## Internationalization (i18n)
|
||||
@@ -43,13 +45,12 @@ web/ — React frontend
|
||||
- Library: `nicksnyder/go-i18n/v2`
|
||||
- Languages: en, zh
|
||||
|
||||
### Frontend (`web/src/i18n/`)
|
||||
### Frontend (`web/default/src/i18n/`)
|
||||
- Library: `i18next` + `react-i18next` + `i18next-browser-languagedetector`
|
||||
- Languages: zh (fallback), en, fr, ru, ja, vi
|
||||
- Translation files: `web/src/i18n/locales/{lang}.json` — flat JSON, keys are Chinese source strings
|
||||
- Usage: `useTranslation()` hook, call `t('中文key')` in components
|
||||
- Semi UI locale synced via `SemiLocaleWrapper`
|
||||
- CLI tools: `bun run i18n:extract`, `bun run i18n:sync`, `bun run i18n:lint`
|
||||
- Languages: en (base), zh (fallback), fr, ru, ja, vi
|
||||
- Translation files: `web/default/src/i18n/locales/{lang}.json` — flat JSON, keys are English source strings
|
||||
- Usage: `useTranslation()` hook, call `t('English key')` in components
|
||||
- CLI tools: `bun run i18n:sync` (from `web/default/`)
|
||||
|
||||
## Rules
|
||||
|
||||
@@ -93,7 +94,7 @@ All database code MUST be fully compatible with all three databases simultaneous
|
||||
|
||||
### Rule 3: Frontend — Prefer Bun
|
||||
|
||||
Use `bun` as the preferred package manager and script runner for the frontend (`web/` directory):
|
||||
Use `bun` as the preferred package manager and script runner for the frontend (`web/default/` directory):
|
||||
- `bun install` for dependency installation
|
||||
- `bun run dev` for development server
|
||||
- `bun run build` for production build
|
||||
|
||||
+15
-4
@@ -1,13 +1,23 @@
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
COPY web/bun.lock .
|
||||
COPY web/default/package.json .
|
||||
COPY web/default/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web .
|
||||
COPY ./web/default .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM oven/bun:1@sha256:0733e50325078969732ebe3b15ce4c4be5082f18c4ac1a0f0ca4839c2e4e42a7 AS builder-classic
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/classic/package.json .
|
||||
COPY web/classic/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web/classic .
|
||||
COPY ./VERSION .
|
||||
RUN VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder2
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
|
||||
@@ -22,7 +32,8 @@ ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
COPY --from=builder /build/dist ./web/default/dist
|
||||
COPY --from=builder-classic /build/dist ./web/classic/dist
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim@sha256:f06537653ac770703bc45b4b113475bd402f451e85223f0f2837acbf89ab020a
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# Backend-only build for frontend development
|
||||
# Skips frontend build, uses a placeholder for //go:embed web/dist
|
||||
|
||||
FROM golang:1.26.1-alpine AS builder
|
||||
|
||||
ENV GO111MODULE=on CGO_ENABLED=0
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ENV GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH:-amd64}
|
||||
ENV GOEXPERIMENT=greenteagc
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p web/default/dist web/classic/dist && \
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/default/dist/index.html && \
|
||||
echo '<!doctype html><html><head><title>dev</title></head><body>use frontend dev server</body></html>' > web/classic/dist/index.html
|
||||
|
||||
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata wget \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& update-ca-certificates
|
||||
|
||||
COPY --from=builder /build/new-api /
|
||||
EXPOSE 3000
|
||||
WORKDIR /data
|
||||
ENTRYPOINT ["/new-api"]
|
||||
+459
@@ -0,0 +1,459 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
# New API
|
||||
|
||||
🍥 **Next-Generation Large Model Gateway and AI Asset Management System**
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md">中文</a> |
|
||||
<strong>English</strong> |
|
||||
<a href="./README.fr.md">Français</a> |
|
||||
<a href="./README.ja.md">日本語</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://raw.githubusercontent.com/Calcium-Ion/new-api/main/LICENSE">
|
||||
<img src="https://img.shields.io/github/license/Calcium-Ion/new-api?color=brightgreen" alt="license">
|
||||
</a>
|
||||
<a href="https://github.com/Calcium-Ion/new-api/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/Calcium-Ion/new-api?color=brightgreen&include_prereleases" alt="release">
|
||||
</a>
|
||||
<a href="https://github.com/users/Calcium-Ion/packages/container/package/new-api">
|
||||
<img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/CalciumIon/new-api">
|
||||
<img src="https://img.shields.io/badge/docker-dockerHub-blue" alt="docker">
|
||||
</a>
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-quick-start">Quick Start</a> •
|
||||
<a href="#-key-features">Key Features</a> •
|
||||
<a href="#-deployment">Deployment</a> •
|
||||
<a href="#-documentation">Documentation</a> •
|
||||
<a href="#-help-support">Help</a>
|
||||
</p>
|
||||
|
||||
</div>
|
||||
|
||||
## 📝 Project Description
|
||||
|
||||
> [!NOTE]
|
||||
> This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> - This project is for personal learning purposes only, with no guarantee of stability or technical support
|
||||
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes
|
||||
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Trusted Partners
|
||||
|
||||
<p align="center">
|
||||
<em>No particular order</em>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank">
|
||||
<img src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="80" />
|
||||
</a>
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank">
|
||||
<img src="./docs/images/pku.png" alt="Peking University" height="80" />
|
||||
</a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank">
|
||||
<img src="./docs/images/ucloud.png" alt="UCloud" height="80" />
|
||||
</a>
|
||||
<a href="https://www.aliyun.com/" target="_blank">
|
||||
<img src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="80" />
|
||||
</a>
|
||||
<a href="https://io.net/" target="_blank">
|
||||
<img src="./docs/images/io-net.png" alt="IO.NET" height="80" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🙏 Special Thanks
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.jetbrains.com/?from=new-api" target="_blank">
|
||||
<img src="https://resources.jetbrains.com/storage/products/company/brand/logos/jb_beam.png" alt="JetBrains Logo" width="120" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<strong>Thanks to <a href="https://www.jetbrains.com/?from=new-api">JetBrains</a> for providing free open-source development license for this project</strong>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Using Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone the project
|
||||
git clone https://github.com/QuantumNous/new-api.git
|
||||
cd new-api
|
||||
|
||||
# Edit docker-compose.yml configuration
|
||||
nano docker-compose.yml
|
||||
|
||||
# Start the service
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Using Docker Commands</strong></summary>
|
||||
|
||||
```bash
|
||||
# Pull the latest image
|
||||
docker pull calciumion/new-api:latest
|
||||
|
||||
# Using SQLite (default)
|
||||
docker run --name new-api -d --restart always \
|
||||
-p 3000:3000 \
|
||||
-e TZ=Asia/Shanghai \
|
||||
-v ./data:/data \
|
||||
calciumion/new-api:latest
|
||||
|
||||
# Using MySQL
|
||||
docker run --name new-api -d --restart always \
|
||||
-p 3000:3000 \
|
||||
-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" \
|
||||
-e TZ=Asia/Shanghai \
|
||||
-v ./data:/data \
|
||||
calciumion/new-api:latest
|
||||
```
|
||||
|
||||
> **💡 Tip:** `-v ./data:/data` will save data in the `data` folder of the current directory, you can also change it to an absolute path like `-v /your/custom/path:/data`
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
🎉 After deployment is complete, visit `http://localhost:3000` to start using!
|
||||
|
||||
📖 For more deployment methods, please refer to [Deployment Guide](https://docs.newapi.pro/en/docs/installation)
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
<div align="center">
|
||||
|
||||
### 📖 [Official Documentation](https://docs.newapi.pro/en/docs) | [](https://deepwiki.com/QuantumNous/new-api)
|
||||
|
||||
</div>
|
||||
|
||||
**Quick Navigation:**
|
||||
|
||||
| Category | Link |
|
||||
|------|------|
|
||||
| 🚀 Deployment Guide | [Installation Documentation](https://docs.newapi.pro/en/docs/installation) |
|
||||
| ⚙️ Environment Configuration | [Environment Variables](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables) |
|
||||
| 📡 API Documentation | [API Documentation](https://docs.newapi.pro/en/docs/api) |
|
||||
| ❓ FAQ | [FAQ](https://docs.newapi.pro/en/docs/support/faq) |
|
||||
| 💬 Community Interaction | [Communication Channels](https://docs.newapi.pro/en/docs/support/community-interaction) |
|
||||
|
||||
---
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
> For detailed features, please refer to [Features Introduction](https://docs.newapi.pro/en/docs/guide/wiki/basic-concepts/features-introduction)
|
||||
|
||||
### 🎨 Core Functions
|
||||
|
||||
| Feature | Description |
|
||||
|------|------|
|
||||
| 🎨 New UI | Modern user interface design |
|
||||
| 🌍 Multi-language | Supports Chinese, English, French, Japanese |
|
||||
| 🔄 Data Compatibility | Fully compatible with the original One API database |
|
||||
| 📈 Data Dashboard | Visual console and statistical analysis |
|
||||
| 🔒 Permission Management | Token grouping, model restrictions, user management |
|
||||
|
||||
### 💰 Payment and Billing
|
||||
|
||||
- ✅ Online recharge (EPay, Stripe)
|
||||
- ✅ Pay-per-use model pricing
|
||||
- ✅ Cache billing support (OpenAI, Azure, DeepSeek, Claude, Qwen and all supported models)
|
||||
- ✅ Flexible billing policy configuration
|
||||
|
||||
### 🔐 Authorization and Security
|
||||
|
||||
- 😈 Discord authorization login
|
||||
- 🤖 LinuxDO authorization login
|
||||
- 📱 Telegram authorization login
|
||||
- 🔑 OIDC unified authentication
|
||||
|
||||
### 🚀 Advanced Features
|
||||
|
||||
**API Format Support:**
|
||||
- ⚡ [OpenAI Responses](https://docs.newapi.pro/en/docs/api/ai-model/chat/openai/create-response)
|
||||
- ⚡ [OpenAI Realtime API](https://docs.newapi.pro/en/docs/api/ai-model/realtime/create-realtime-session) (including Azure)
|
||||
- ⚡ [Claude Messages](https://docs.newapi.pro/en/docs/api/ai-model/chat/create-message)
|
||||
- ⚡ [Google Gemini](https://doc.newapi.pro/en/api/google-gemini-chat)
|
||||
- 🔄 [Rerank Models](https://docs.newapi.pro/en/docs/api/ai-model/rerank/create-rerank) (Cohere, Jina)
|
||||
|
||||
**Intelligent Routing:**
|
||||
- ⚖️ Channel weighted random
|
||||
- 🔄 Automatic retry on failure
|
||||
- 🚦 User-level model rate limiting
|
||||
|
||||
**Format Conversion:**
|
||||
- 🔄 **OpenAI Compatible ⇄ Claude Messages**
|
||||
- 🔄 **OpenAI Compatible → Google Gemini**
|
||||
- 🔄 **Google Gemini → OpenAI Compatible** - Text only, function calling not supported yet
|
||||
- 🚧 **OpenAI Compatible ⇄ OpenAI Responses** - In development
|
||||
- 🔄 **Thinking-to-content functionality**
|
||||
|
||||
**Reasoning Effort Support:**
|
||||
|
||||
<details>
|
||||
<summary>View detailed configuration</summary>
|
||||
|
||||
**OpenAI series models:**
|
||||
- `o3-mini-high` - High reasoning effort
|
||||
- `o3-mini-medium` - Medium reasoning effort
|
||||
- `o3-mini-low` - Low reasoning effort
|
||||
- `gpt-5-high` - High reasoning effort
|
||||
- `gpt-5-medium` - Medium reasoning effort
|
||||
- `gpt-5-low` - Low reasoning effort
|
||||
|
||||
**Claude thinking models:**
|
||||
- `claude-3-7-sonnet-20250219-thinking` - Enable thinking mode
|
||||
|
||||
**Google Gemini series models:**
|
||||
- `gemini-2.5-flash-thinking` - Enable thinking mode
|
||||
- `gemini-2.5-flash-nothinking` - Disable thinking mode
|
||||
- `gemini-2.5-pro-thinking` - Enable thinking mode
|
||||
- `gemini-2.5-pro-thinking-128` - Enable thinking mode with thinking budget of 128 tokens
|
||||
- You can also append `-low`, `-medium`, or `-high` to any Gemini model name to request the corresponding reasoning effort (no extra thinking-budget suffix needed).
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🤖 Model Support
|
||||
|
||||
> For details, please refer to [API Documentation - Relay Interface](https://docs.newapi.pro/en/docs/api)
|
||||
|
||||
| Model Type | Description | Documentation |
|
||||
|---------|------|------|
|
||||
| 🤖 OpenAI GPTs | gpt-4-gizmo-* series | - |
|
||||
| 🎨 Midjourney-Proxy | [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) | [Documentation](https://doc.newapi.pro/en/api/midjourney-proxy-image) |
|
||||
| 🎵 Suno-API | [Suno API](https://github.com/Suno-API/Suno-API) | [Documentation](https://doc.newapi.pro/en/api/suno-music) |
|
||||
| 🔄 Rerank | Cohere, Jina | [Documentation](https://docs.newapi.pro/en/docs/api/ai-model/rerank/create-rerank) |
|
||||
| 💬 Claude | Messages format | [Documentation](https://docs.newapi.pro/en/docs/api/ai-model/chat/create-message) |
|
||||
| 🌐 Gemini | Google Gemini format | [Documentation](https://doc.newapi.pro/en/api/google-gemini-chat) |
|
||||
| 🔧 Dify | ChatFlow mode | - |
|
||||
| 🎯 Custom | Supports complete call address | - |
|
||||
|
||||
### 📡 Supported Interfaces
|
||||
|
||||
<details>
|
||||
<summary>View complete interface list</summary>
|
||||
|
||||
- [Chat Interface (Chat Completions)](https://docs.newapi.pro/en/docs/api/ai-model/chat/openai/create-chat-completion)
|
||||
- [Response Interface (Responses)](https://docs.newapi.pro/en/docs/api/ai-model/chat/openai/create-response)
|
||||
- [Image Interface (Image)](https://docs.newapi.pro/en/docs/api/ai-model/images/openai/v1-images-generations--post)
|
||||
- [Audio Interface (Audio)](https://docs.newapi.pro/en/docs/api/ai-model/audio/openai/create-transcription)
|
||||
- [Video Interface (Video)](https://docs.newapi.pro/en/docs/api/ai-model/videos/create-video-generation)
|
||||
- [Embedding Interface (Embeddings)](https://docs.newapi.pro/en/docs/api/ai-model/embeddings/create-embedding)
|
||||
- [Rerank Interface (Rerank)](https://docs.newapi.pro/en/docs/api/ai-model/rerank/create-rerank)
|
||||
- [Realtime Conversation (Realtime)](https://docs.newapi.pro/en/docs/api/ai-model/realtime/create-realtime-session)
|
||||
- [Claude Chat](https://docs.newapi.pro/en/docs/api/ai-model/chat/create-message)
|
||||
- [Google Gemini Chat](https://doc.newapi.pro/en/api/google-gemini-chat)
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🚢 Deployment
|
||||
|
||||
> [!TIP]
|
||||
> **Latest Docker image:** `calciumion/new-api:latest`
|
||||
|
||||
### 📋 Deployment Requirements
|
||||
|
||||
| Component | Requirement |
|
||||
|------|------|
|
||||
| **Local database** | SQLite (Docker must mount `/data` directory)|
|
||||
| **Remote database** | MySQL ≥ 5.7.8 or PostgreSQL ≥ 9.6 |
|
||||
| **Container engine** | Docker / Docker Compose |
|
||||
|
||||
### ⚙️ Environment Variable Configuration
|
||||
|
||||
<details>
|
||||
<summary>Common environment variable configuration</summary>
|
||||
|
||||
| Variable Name | Description | Default Value |
|
||||
|--------|------|--------|
|
||||
| `SESSION_SECRET` | Session secret (required for multi-machine deployment) | - |
|
||||
| `CRYPTO_SECRET` | Encryption secret (required for Redis) | - |
|
||||
| `SQL_DSN` | Database connection string | - |
|
||||
| `REDIS_CONN_STRING` | Redis connection string | - |
|
||||
| `STREAMING_TIMEOUT` | Streaming timeout (seconds) | `300` |
|
||||
| `STREAM_SCANNER_MAX_BUFFER_MB` | Max per-line buffer (MB) for the stream scanner; increase when upstream sends huge image/base64 payloads | `64` |
|
||||
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | Error log switch | `false` |
|
||||
| `PYROSCOPE_URL` | Pyroscope server address | - |
|
||||
| `PYROSCOPE_APP_NAME` | Pyroscope application name | `new-api` |
|
||||
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope basic auth user | - |
|
||||
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope basic auth password | - |
|
||||
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutex sampling rate | `5` |
|
||||
| `PYROSCOPE_BLOCK_RATE` | Pyroscope block sampling rate | `5` |
|
||||
| `HOSTNAME` | Hostname tag for Pyroscope | `new-api` |
|
||||
|
||||
📖 **Complete configuration:** [Environment Variables Documentation](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables)
|
||||
|
||||
</details>
|
||||
|
||||
### 🔧 Deployment Methods
|
||||
|
||||
<details>
|
||||
<summary><strong>Method 1: Docker Compose (Recommended)</strong></summary>
|
||||
|
||||
```bash
|
||||
# Clone the project
|
||||
git clone https://github.com/QuantumNous/new-api.git
|
||||
cd new-api
|
||||
|
||||
# Edit configuration
|
||||
nano docker-compose.yml
|
||||
|
||||
# Start service
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>Method 2: Docker Commands</strong></summary>
|
||||
|
||||
**Using SQLite:**
|
||||
```bash
|
||||
docker run --name new-api -d --restart always \
|
||||
-p 3000:3000 \
|
||||
-e TZ=Asia/Shanghai \
|
||||
-v ./data:/data \
|
||||
calciumion/new-api:latest
|
||||
```
|
||||
|
||||
**Using MySQL:**
|
||||
```bash
|
||||
docker run --name new-api -d --restart always \
|
||||
-p 3000:3000 \
|
||||
-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" \
|
||||
-e TZ=Asia/Shanghai \
|
||||
-v ./data:/data \
|
||||
calciumion/new-api:latest
|
||||
```
|
||||
|
||||
> **💡 Path explanation:**
|
||||
> - `./data:/data` - Relative path, data saved in the data folder of the current directory
|
||||
> - You can also use absolute path, e.g.: `/your/custom/path:/data`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>Method 3: BaoTa Panel</strong></summary>
|
||||
|
||||
1. Install BaoTa Panel (≥ 9.2.0 version)
|
||||
2. Search for **New-API** in the application store
|
||||
3. One-click installation
|
||||
|
||||
📖 [Tutorial with images](./docs/BT.md)
|
||||
|
||||
</details>
|
||||
|
||||
### ⚠️ Multi-machine Deployment Considerations
|
||||
|
||||
> [!WARNING]
|
||||
> - **Must set** `SESSION_SECRET` - Otherwise login status inconsistent
|
||||
> - **Shared Redis must set** `CRYPTO_SECRET` - Otherwise data cannot be decrypted
|
||||
|
||||
### 🔄 Channel Retry and Cache
|
||||
|
||||
**Retry configuration:** `Settings → Operation Settings → General Settings → Failure Retry Count`
|
||||
|
||||
**Cache configuration:**
|
||||
- `REDIS_CONN_STRING`: Redis cache (recommended)
|
||||
- `MEMORY_CACHE_ENABLED`: Memory cache
|
||||
|
||||
---
|
||||
|
||||
## 🔗 Related Projects
|
||||
|
||||
### Upstream Projects
|
||||
|
||||
| Project | Description |
|
||||
|------|------|
|
||||
| [One API](https://github.com/songquanpeng/one-api) | Original project base |
|
||||
| [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy) | Midjourney interface support |
|
||||
|
||||
### Supporting Tools
|
||||
|
||||
| Project | Description |
|
||||
|------|------|
|
||||
| [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) | Key quota query tool |
|
||||
| [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) | New API high-performance optimized version |
|
||||
|
||||
---
|
||||
|
||||
## 💬 Help Support
|
||||
|
||||
### 📖 Documentation Resources
|
||||
|
||||
| Resource | Link |
|
||||
|------|------|
|
||||
| 📘 FAQ | [FAQ](https://docs.newapi.pro/en/docs/support/faq) |
|
||||
| 💬 Community Interaction | [Communication Channels](https://docs.newapi.pro/en/docs/support/community-interaction) |
|
||||
| 🐛 Issue Feedback | [Issue Feedback](https://docs.newapi.pro/en/docs/support/feedback-issues) |
|
||||
| 📚 Complete Documentation | [Official Documentation](https://docs.newapi.pro/en/docs) |
|
||||
|
||||
### 🤝 Contribution Guide
|
||||
|
||||
Welcome all forms of contribution!
|
||||
|
||||
- 🐛 Report Bugs
|
||||
- 💡 Propose New Features
|
||||
- 📝 Improve Documentation
|
||||
- 🔧 Submit Code
|
||||
|
||||
---
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
### 💖 Thank you for using New API
|
||||
|
||||
If this project is helpful to you, welcome to give us a ⭐️ Star!
|
||||
|
||||
**[Official Documentation](https://docs.newapi.pro/en/docs)** • **[Issue Feedback](https://github.com/Calcium-Ion/new-api/issues)** • **[Latest Release](https://github.com/Calcium-Ion/new-api/releases)**
|
||||
|
||||
<sub>Built with ❤️ by QuantumNous</sub>
|
||||
|
||||
</div>
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# New API
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# New API
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# New API
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# New API
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
# New API
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
//"os"
|
||||
//"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,6 +18,24 @@ var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
|
||||
var themeValue atomic.Value // stores string; safe for concurrent read/write
|
||||
|
||||
func init() {
|
||||
themeValue.Store("classic")
|
||||
}
|
||||
|
||||
func GetTheme() string {
|
||||
return themeValue.Load().(string)
|
||||
}
|
||||
|
||||
// SetTheme updates the frontend theme atomically.
|
||||
// Only "default" and "classic" are accepted; other values are silently ignored.
|
||||
func SetTheme(t string) {
|
||||
if t == "default" || t == "classic" {
|
||||
themeValue.Store(t)
|
||||
}
|
||||
}
|
||||
|
||||
// var ChatLink = ""
|
||||
// var ChatLink2 = ""
|
||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
|
||||
@@ -41,3 +41,29 @@ func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
|
||||
FileSystem: http.FS(efs),
|
||||
}
|
||||
}
|
||||
|
||||
// themeAwareFileSystem delegates to the appropriate embedded FS based on
|
||||
// the current theme (via GetTheme). This enables runtime theme switching
|
||||
// without restarting the server.
|
||||
type themeAwareFileSystem struct {
|
||||
defaultFS static.ServeFileSystem
|
||||
classicFS static.ServeFileSystem
|
||||
}
|
||||
|
||||
func (t *themeAwareFileSystem) Exists(prefix string, path string) bool {
|
||||
if GetTheme() == "classic" {
|
||||
return t.classicFS.Exists(prefix, path)
|
||||
}
|
||||
return t.defaultFS.Exists(prefix, path)
|
||||
}
|
||||
|
||||
func (t *themeAwareFileSystem) Open(name string) (http.File, error) {
|
||||
if GetTheme() == "classic" {
|
||||
return t.classicFS.Open(name)
|
||||
}
|
||||
return t.defaultFS.Open(name)
|
||||
}
|
||||
|
||||
func NewThemeAwareFS(defaultFS, classicFS static.ServeFileSystem) static.ServeFileSystem {
|
||||
return &themeAwareFileSystem{defaultFS: defaultFS, classicFS: classicFS}
|
||||
}
|
||||
|
||||
@@ -43,3 +43,19 @@ func GetJsonType(data json.RawMessage) string {
|
||||
return "number"
|
||||
}
|
||||
}
|
||||
|
||||
// JsonRawMessageToString returns JSON strings as their decoded value and other JSON values as raw text.
|
||||
func JsonRawMessageToString(data json.RawMessage) string {
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
|
||||
return ""
|
||||
}
|
||||
if trimmed[0] != '"' {
|
||||
return string(trimmed)
|
||||
}
|
||||
var value string
|
||||
if err := Unmarshal(trimmed, &value); err != nil {
|
||||
return string(trimmed)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJsonRawMessageToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data json.RawMessage
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "object",
|
||||
data: json.RawMessage(`{"city":"Paris","days":0,"strict":false}`),
|
||||
want: `{"city":"Paris","days":0,"strict":false}`,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
data: json.RawMessage(`"{\"city\":\"Paris\",\"days\":0,\"strict\":false}"`),
|
||||
want: `{"city":"Paris","days":0,"strict":false}`,
|
||||
},
|
||||
{
|
||||
name: "null",
|
||||
data: json.RawMessage(`null`),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
data: nil,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, JsonRawMessageToString(tt.data))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
@@ -98,7 +99,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -131,12 +132,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to get channels: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
||||
@@ -252,6 +248,7 @@ func SearchChannels(c *gin.Context) {
|
||||
statusParam := c.Query("status")
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
channelData := make([]*model.Channel, 0)
|
||||
if enableTagMode {
|
||||
@@ -265,14 +262,14 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
|
||||
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort, sortOptions)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -32,6 +32,26 @@ const (
|
||||
channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10
|
||||
)
|
||||
|
||||
var channelUpstreamModelUpdateSelectFields = []string{
|
||||
"id",
|
||||
"name",
|
||||
"type",
|
||||
"key",
|
||||
"status",
|
||||
"base_url",
|
||||
"models",
|
||||
"model_mapping",
|
||||
"settings",
|
||||
"setting",
|
||||
"other",
|
||||
"group",
|
||||
"priority",
|
||||
"weight",
|
||||
"tag",
|
||||
"channel_info",
|
||||
"header_override",
|
||||
}
|
||||
|
||||
var (
|
||||
channelUpstreamModelUpdateTaskOnce sync.Once
|
||||
channelUpstreamModelUpdateTaskRunning atomic.Bool
|
||||
@@ -521,7 +541,7 @@ func runChannelUpstreamModelUpdateTaskOnce() {
|
||||
for {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(channelUpstreamModelUpdateTaskBatchSize)
|
||||
@@ -814,7 +834,7 @@ func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings)
|
||||
func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) {
|
||||
var channels []*model.Channel
|
||||
query := model.DB.
|
||||
Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override").
|
||||
Select(channelUpstreamModelUpdateSelectFields).
|
||||
Where("status = ?", common.ChannelStatusEnabled).
|
||||
Order("id asc").
|
||||
Limit(batchSize)
|
||||
|
||||
@@ -81,6 +81,10 @@ func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) {
|
||||
require.Equal(t, []string{"old-model"}, pendingRemoveModels)
|
||||
}
|
||||
|
||||
func TestChannelUpstreamModelUpdateSelectFieldsIncludeModelMapping(t *testing.T) {
|
||||
require.Contains(t, channelUpstreamModelUpdateSelectFields, "model_mapping")
|
||||
}
|
||||
|
||||
func TestNormalizeChannelModelMapping(t *testing.T) {
|
||||
modelMapping := `{
|
||||
" alias-model ": " upstream-model ",
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LinuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Linux DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to Linux DO server")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user info from Linux DO")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
var linuxdoUser LinuxdoUser
|
||||
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
return nil, errors.New("invalid user info returned")
|
||||
}
|
||||
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxdoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
err := user.FillUserByLinuxDOId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
+1
-1
@@ -61,6 +61,7 @@ func GetStatus(c *gin.Context) {
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"theme": system_setting.GetThemeSettings().Frontend,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
@@ -69,7 +70,6 @@ func GetStatus(c *gin.Context) {
|
||||
"server_address": system_setting.ServerAddress,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
// 兼容旧前端:保留 display_in_currency,同时提供新的 quota_display_type
|
||||
|
||||
+3
-5
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
@@ -134,8 +134,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(allowModel) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -182,8 +181,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
_, _, exist := ratio_setting.GetModelRatioOrPrice(modelName)
|
||||
if !exist {
|
||||
if !helper.HasModelBillingConfig(modelName) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type listModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data []dto.OpenAIModels `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
initModelListColumnNames(t)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
common.UsingSQLite = true
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
common.RedisEnabled = false
|
||||
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
sqlDB, err := db.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func initModelListColumnNames(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalIsMasterNode := common.IsMasterNode
|
||||
originalSQLitePath := common.SQLitePath
|
||||
originalUsingSQLite := common.UsingSQLite
|
||||
originalUsingMySQL := common.UsingMySQL
|
||||
originalUsingPostgreSQL := common.UsingPostgreSQL
|
||||
originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN")
|
||||
defer func() {
|
||||
common.IsMasterNode = originalIsMasterNode
|
||||
common.SQLitePath = originalSQLitePath
|
||||
common.UsingSQLite = originalUsingSQLite
|
||||
common.UsingMySQL = originalUsingMySQL
|
||||
common.UsingPostgreSQL = originalUsingPostgreSQL
|
||||
if hadSQLDSN {
|
||||
require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN))
|
||||
} else {
|
||||
require.NoError(t, os.Unsetenv("SQL_DSN"))
|
||||
}
|
||||
}()
|
||||
|
||||
common.IsMasterNode = false
|
||||
common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
common.UsingSQLite = false
|
||||
common.UsingMySQL = false
|
||||
common.UsingPostgreSQL = false
|
||||
require.NoError(t, os.Setenv("SQL_DSN", "local"))
|
||||
|
||||
require.NoError(t, model.InitDB())
|
||||
if model.DB != nil {
|
||||
sqlDB, err := model.DB.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
saved := map[string]string{}
|
||||
require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error {
|
||||
if strings.HasPrefix(key, "billing_setting.") {
|
||||
saved[key] = value
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(saved))
|
||||
model.InvalidatePricingCache()
|
||||
})
|
||||
|
||||
modeBytes, err := common.Marshal(modes)
|
||||
require.NoError(t, err)
|
||||
exprBytes, err := common.Marshal(exprs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{
|
||||
"billing_setting.billing_mode": string(modeBytes),
|
||||
"billing_setting.billing_expr": string(exprBytes),
|
||||
}))
|
||||
model.InvalidatePricingCache()
|
||||
}
|
||||
|
||||
func withSelfUseModeDisabled(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
original := operation_setting.SelfUseModeEnabled
|
||||
operation_setting.SelfUseModeEnabled = false
|
||||
t.Cleanup(func() {
|
||||
operation_setting.SelfUseModeEnabled = original
|
||||
})
|
||||
}
|
||||
|
||||
func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} {
|
||||
t.Helper()
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
var payload listModelsResponse
|
||||
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.True(t, payload.Success)
|
||||
require.Equal(t, "list", payload.Object)
|
||||
|
||||
ids := make(map[string]struct{}, len(payload.Data))
|
||||
for _, item := range payload.Data {
|
||||
ids[item.Id] = struct{}{}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
|
||||
byName := make(map[string]model.Pricing, len(pricings))
|
||||
for _, pricing := range pricings {
|
||||
byName[pricing.ModelName] = pricing
|
||||
}
|
||||
return byName
|
||||
}
|
||||
|
||||
func TestListModelsIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-tiered-visible-model": "tiered_expr",
|
||||
"zz-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-tiered-empty-expr-model": " ",
|
||||
})
|
||||
|
||||
db := setupModelListControllerTestDB(t)
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Id: 1001,
|
||||
Username: "model-list-user",
|
||||
Password: "password",
|
||||
Group: "default",
|
||||
Status: common.UserStatusEnabled,
|
||||
}).Error)
|
||||
require.NoError(t, db.Create(&[]model.Ability{
|
||||
{Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true},
|
||||
{Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true},
|
||||
}).Error)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
ctx.Set("id", 1001)
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-unpriced-model")
|
||||
|
||||
pricingByName := pricingByModelName(model.GetPricing())
|
||||
visiblePricing, ok := pricingByName["zz-tiered-visible-model"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "tiered_expr", visiblePricing.BillingMode)
|
||||
require.NotEmpty(t, visiblePricing.BillingExpr)
|
||||
|
||||
emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, emptyExprPricing.BillingMode)
|
||||
require.Empty(t, emptyExprPricing.BillingExpr)
|
||||
|
||||
missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"]
|
||||
require.True(t, ok)
|
||||
require.Empty(t, missingExprPricing.BillingMode)
|
||||
require.Empty(t, missingExprPricing.BillingExpr)
|
||||
}
|
||||
|
||||
func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
|
||||
withSelfUseModeDisabled(t)
|
||||
withTieredBillingConfig(t, map[string]string{
|
||||
"zz-token-tiered-visible-model": "tiered_expr",
|
||||
"zz-token-tiered-empty-expr-model": "tiered_expr",
|
||||
"zz-token-tiered-missing-expr-model": "tiered_expr",
|
||||
}, map[string]string{
|
||||
"zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`,
|
||||
"zz-token-tiered-empty-expr-model": "",
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true)
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{
|
||||
"zz-token-tiered-visible-model": true,
|
||||
"zz-token-tiered-empty-expr-model": true,
|
||||
"zz-token-tiered-missing-expr-model": true,
|
||||
"zz-token-unpriced-model": true,
|
||||
})
|
||||
|
||||
ListModels(ctx, constant.ChannelTypeOpenAI)
|
||||
|
||||
ids := decodeListModelsResponse(t, recorder)
|
||||
require.Contains(t, ids, "zz-token-tiered-visible-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-empty-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
|
||||
require.NotContains(t, ids, "zz-token-unpriced-model")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -198,6 +198,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "theme.frontend":
|
||||
if option.Value != "default" && option.Value != "classic" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的主题值,可选值:default(新版前端)、classic(经典前端)",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
perfmetrics "github.com/QuantumNous/new-api/pkg/perf_metrics"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetPerfMetricsSummary(c *gin.Context) {
|
||||
hours := 24
|
||||
if rawHours := c.Query("hours"); rawHours != "" {
|
||||
if parsed, err := strconv.Atoi(rawHours); err == nil {
|
||||
hours = parsed
|
||||
}
|
||||
}
|
||||
|
||||
result, err := perfmetrics.QuerySummaryAll(hours)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
func GetPerfMetrics(c *gin.Context) {
|
||||
modelName := c.Query("model")
|
||||
if modelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "model is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
hours := 24
|
||||
if rawHours := c.Query("hours"); rawHours != "" {
|
||||
if parsed, err := strconv.Atoi(rawHours); err == nil {
|
||||
hours = parsed
|
||||
}
|
||||
}
|
||||
|
||||
result, err := perfmetrics.Query(perfmetrics.QueryParams{
|
||||
Model: modelName,
|
||||
Group: c.Query("group"),
|
||||
Hours: hours,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRankings(c *gin.Context) {
|
||||
result, err := service.GetRankingsSnapshot(c.DefaultQuery("period", "week"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
+161
-46
@@ -21,14 +21,16 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/billing_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
defaultEndpoint = "/api/pricing"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
@@ -59,7 +61,29 @@ func valuesEqual(a, b interface{}) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
var pricingSyncFields = []string{
|
||||
"model_ratio",
|
||||
"completion_ratio",
|
||||
"cache_ratio",
|
||||
"create_cache_ratio",
|
||||
"image_ratio",
|
||||
"audio_ratio",
|
||||
"audio_completion_ratio",
|
||||
"model_price",
|
||||
billing_setting.BillingModeField,
|
||||
billing_setting.BillingExprField,
|
||||
}
|
||||
|
||||
var numericPricingSyncFields = map[string]bool{
|
||||
"model_ratio": true,
|
||||
"completion_ratio": true,
|
||||
"cache_ratio": true,
|
||||
"create_cache_ratio": true,
|
||||
"image_ratio": true,
|
||||
"audio_ratio": true,
|
||||
"audio_completion_ratio": true,
|
||||
"model_price": true,
|
||||
}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
@@ -67,6 +91,54 @@ type upstreamResult struct {
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func valueMap(value any) map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return typed
|
||||
case map[string]float64:
|
||||
return lo.MapValues(typed, func(value float64, _ string) any { return value })
|
||||
case map[string]string:
|
||||
return lo.MapValues(typed, func(value string, _ string) any { return value })
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat64(value any) (float64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed, true
|
||||
case float32:
|
||||
return float64(typed), true
|
||||
case int:
|
||||
return float64(typed), true
|
||||
case int64:
|
||||
return float64(typed), true
|
||||
case json.Number:
|
||||
parsed, err := typed.Float64()
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSyncValue(field string, value any) any {
|
||||
if numericPricingSyncFields[field] {
|
||||
if parsed, ok := asFloat64(value); ok {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func getLocalPricingSyncData() map[string]any {
|
||||
data := billing_setting.GetPricingSyncData(map[string]any(ratio_setting.GetExposedData()))
|
||||
data["image_ratio"] = ratio_setting.GetImageRatioCopy()
|
||||
data["audio_ratio"] = ratio_setting.GetAudioRatioCopy()
|
||||
data["audio_completion_ratio"] = ratio_setting.GetAudioCompletionRatioCopy()
|
||||
return data
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -293,7 +365,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
if err := common.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
for _, rt := range pricingSyncFields {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
@@ -307,11 +379,18 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
CacheRatio *float64 `json:"cache_ratio"`
|
||||
CreateCacheRatio *float64 `json:"create_cache_ratio"`
|
||||
ImageRatio *float64 `json:"image_ratio"`
|
||||
AudioRatio *float64 `json:"audio_ratio"`
|
||||
AudioCompletionRatio *float64 `json:"audio_completion_ratio"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
BillingExpr string `json:"billing_expr"`
|
||||
}
|
||||
if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
@@ -321,9 +400,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
cacheRatioMap := make(map[string]float64)
|
||||
createCacheRatioMap := make(map[string]float64)
|
||||
imageRatioMap := make(map[string]float64)
|
||||
audioRatioMap := make(map[string]float64)
|
||||
audioCompletionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
billingModeMap := make(map[string]string)
|
||||
billingExprMap := make(map[string]string)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.ModelName == "" {
|
||||
continue
|
||||
}
|
||||
if item.BillingMode == billing_setting.BillingModeTieredExpr && strings.TrimSpace(item.BillingExpr) != "" {
|
||||
billingModeMap[item.ModelName] = billing_setting.BillingModeTieredExpr
|
||||
billingExprMap[item.ModelName] = item.BillingExpr
|
||||
}
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
@@ -331,6 +424,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
if item.CacheRatio != nil {
|
||||
cacheRatioMap[item.ModelName] = *item.CacheRatio
|
||||
}
|
||||
if item.CreateCacheRatio != nil {
|
||||
createCacheRatioMap[item.ModelName] = *item.CreateCacheRatio
|
||||
}
|
||||
if item.ImageRatio != nil {
|
||||
imageRatioMap[item.ModelName] = *item.ImageRatio
|
||||
}
|
||||
if item.AudioRatio != nil {
|
||||
audioRatioMap[item.ModelName] = *item.AudioRatio
|
||||
}
|
||||
if item.AudioCompletionRatio != nil {
|
||||
audioCompletionRatioMap[item.ModelName] = *item.AudioCompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
@@ -350,6 +458,21 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
if len(cacheRatioMap) > 0 {
|
||||
converted["cache_ratio"] = valueMap(cacheRatioMap)
|
||||
}
|
||||
if len(createCacheRatioMap) > 0 {
|
||||
converted["create_cache_ratio"] = valueMap(createCacheRatioMap)
|
||||
}
|
||||
if len(imageRatioMap) > 0 {
|
||||
converted["image_ratio"] = valueMap(imageRatioMap)
|
||||
}
|
||||
if len(audioRatioMap) > 0 {
|
||||
converted["audio_ratio"] = valueMap(audioRatioMap)
|
||||
}
|
||||
if len(audioCompletionRatioMap) > 0 {
|
||||
converted["audio_completion_ratio"] = valueMap(audioCompletionRatioMap)
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
@@ -358,6 +481,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
if len(billingModeMap) > 0 {
|
||||
converted[billing_setting.BillingModeField] = valueMap(billingModeMap)
|
||||
}
|
||||
if len(billingExprMap) > 0 {
|
||||
converted[billing_setting.BillingExprField] = valueMap(billingExprMap)
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
@@ -366,7 +495,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
localData := getLocalPricingSyncData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
@@ -412,22 +541,16 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(localData[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
for _, field := range pricingSyncFields {
|
||||
for modelName := range valueMap(channel.data[field]) {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -438,10 +561,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
modelRatios := valueMap(channel.data["model_ratio"])
|
||||
completionRatios := valueMap(channel.data["completion_ratio"])
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
if len(modelRatios) > 0 && len(completionRatios) > 0 {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
@@ -451,12 +574,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
modelRatioFloat, modelRatioOK := asFloat64(modelRatioVal)
|
||||
completionRatioFloat, completionRatioOK := asFloat64(completionRatioVal)
|
||||
if modelRatioOK && completionRatioOK && nearlyEqual(modelRatioFloat, 37.5) && nearlyEqual(completionRatioFloat, 1.0) {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -470,14 +591,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
for _, ratioType := range pricingSyncFields {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
if val, exists := valueMap(localData[ratioType])[modelName]; exists {
|
||||
localValue = normalizeSyncValue(ratioType, val)
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
@@ -488,16 +605,14 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
if val, exists := valueMap(channel.data[ratioType])[modelName]; exists {
|
||||
upstreamValue = normalizeSyncValue(ratioType, val)
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
if localValue != nil && !valuesEqual(localValue, upstreamValue) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, upstreamValue) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
perfmetrics "github.com/QuantumNous/new-api/pkg/perf_metrics"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
@@ -239,6 +240,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
logger.LogInfo(c, retryLogStr)
|
||||
}
|
||||
if newAPIError != nil {
|
||||
gopool.Go(func() {
|
||||
perfmetrics.RecordRelaySample(relayInfo, false, 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
|
||||
@@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) {
|
||||
|
||||
// create pending order first
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
@@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
common.ApiErrorMsg(c, "创建订单失败")
|
||||
@@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) {
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod)
|
||||
_ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay)
|
||||
common.ApiErrorMsg(c, "拉起支付失败")
|
||||
return
|
||||
}
|
||||
@@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
_, _ = c.Writer.Write([]byte("fail"))
|
||||
return
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) {
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil {
|
||||
if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil {
|
||||
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) {
|
||||
}
|
||||
|
||||
order := &model.SubscriptionOrder{
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: userId,
|
||||
PlanId: plan.Id,
|
||||
Money: plan.PriceAmount,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := order.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
info := &relaycommon.RelayInfo{}
|
||||
info.ChannelMeta = &relaycommon.ChannelMeta{
|
||||
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
||||
}
|
||||
info.ApiKey = cacheGetChannel.Key
|
||||
adaptor.Init(info)
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
key := channel.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
task.Data = t.Data
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
|
||||
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
||||
if taskResult.TotalTokens > 0 {
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
||||
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if hasRatioSetting && modelRatio > 0 {
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group != "" {
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
// 计算差额
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
// 需要补扣费
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.DecreaseUserQuota(task.UserId, quotaDelta, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录消费日志
|
||||
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else if quotaDelta < 0 {
|
||||
// 需要退还多扣的费用
|
||||
refundQuota := -quotaDelta
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(refundQuota),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录退款日志
|
||||
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else {
|
||||
// quotaDelta == 0, 预扣费刚好准确
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
+15
-24
@@ -110,6 +110,7 @@ func GetTopUpInfo(c *gin.Context) {
|
||||
"waffo_pancake_min_topup": setting.WaffoPancakeMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
"topup_link": common.TopUpLink,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
@@ -123,17 +124,6 @@ type AmountRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
}
|
||||
|
||||
var nonEpayPaymentMethodsForCallback = []string{
|
||||
model.PaymentMethodStripe,
|
||||
model.PaymentMethodCreem,
|
||||
model.PaymentMethodWaffo,
|
||||
model.PaymentMethodWaffoPancake,
|
||||
}
|
||||
|
||||
func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool {
|
||||
return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod)
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
@@ -248,13 +238,14 @@ func RequestEpay(c *gin.Context) {
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: req.PaymentMethod,
|
||||
PaymentProvider: model.PaymentProviderEpay,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -379,15 +370,15 @@ func EpayNotify(c *gin.Context) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo)))
|
||||
return
|
||||
}
|
||||
if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
if topUp.PaymentProvider != model.PaymentProviderEpay {
|
||||
logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
if topUp.Status == common.TopUpStatusPending {
|
||||
if topUp.PaymentMethod != verifyInfo.Type {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP()))
|
||||
topUp.PaymentMethod = verifyInfo.Type
|
||||
}
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
|
||||
|
||||
// 先创建订单记录,使用产品配置的金额和充值额度
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: selectedProduct.Quota, // 充值额度
|
||||
Money: selectedProduct.Price, // 支付金额
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodCreem,
|
||||
PaymentProvider: model.PaymentProviderCreem,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil {
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
paymentMethod string
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true},
|
||||
{name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true},
|
||||
{name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true},
|
||||
{name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true},
|
||||
{name: "alipay", paymentMethod: "alipay", expectedBlocked: false},
|
||||
{name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false},
|
||||
{name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked {
|
||||
t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+13
-12
@@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
PaymentMethod: model.PaymentMethodStripe,
|
||||
PaymentProvider: model.PaymentProviderStripe,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
@@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != model.PaymentMethodStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp))
|
||||
if topUp.PaymentProvider != model.PaymentProviderStripe {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
// Subscription order expiration
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil {
|
||||
if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
@@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired)
|
||||
err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired)
|
||||
if errors.Is(err, model.ErrTopUpNotFound) {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId))
|
||||
return
|
||||
|
||||
@@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) {
|
||||
|
||||
// 创建本地订单
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: merchantOrderId,
|
||||
PaymentMethod: model.PaymentMethodWaffo,
|
||||
PaymentProvider: model.PaymentProviderWaffo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error()))
|
||||
@@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa
|
||||
logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP()))
|
||||
// 终态失败订单标记为 failed,避免永远停在 pending
|
||||
if result.MerchantOrderID != "" {
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil &&
|
||||
!errors.Is(err, model.ErrTopUpNotFound) &&
|
||||
!errors.Is(err, model.ErrTopUpStatusInvalid) {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error()))
|
||||
|
||||
@@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) {
|
||||
|
||||
tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6))
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
UserId: id,
|
||||
Amount: normalizeWaffoPancakeTopUpAmount(req.Amount),
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: model.PaymentMethodWaffoPancake,
|
||||
PaymentProvider: model.PaymentProviderWaffoPancake,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
if err := topUp.Insert(); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error()))
|
||||
|
||||
@@ -91,6 +91,7 @@ func Login(c *gin.Context) {
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func setupLogin(user *model.User, c *gin.Context) {
|
||||
model.UpdateUserLastLoginAt(user.Id)
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
# Frontend Development - Backend built from local source
|
||||
#
|
||||
# Usage:
|
||||
# 1. docker compose -f docker-compose.dev.yml up -d
|
||||
# 2. cd web && bun install && bun run dev
|
||||
# 3. Open http://localhost:3001 (Rsbuild dev server, API auto-proxied to :3000)
|
||||
#
|
||||
# Rebuild backend after Go code changes:
|
||||
# docker compose -f docker-compose.dev.yml up -d --build new-api
|
||||
#
|
||||
# Stop:
|
||||
# docker compose -f docker-compose.dev.yml down
|
||||
#
|
||||
# Reset data:
|
||||
# docker compose -f docker-compose.dev.yml down -v
|
||||
|
||||
services:
|
||||
new-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.dev
|
||||
image: new-api-dev:local
|
||||
container_name: new-api-dev
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- dev_data:/data
|
||||
environment:
|
||||
- SQL_DSN=postgresql://root:123456@postgres:5432/new-api
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- BATCH_UPDATE_ENABLED=true
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- dev-network
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: new-api-dev-redis
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dev-network
|
||||
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
container_name: new-api-dev-pg
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: root
|
||||
POSTGRES_PASSWORD: 123456
|
||||
POSTGRES_DB: new-api
|
||||
volumes:
|
||||
- dev_pg_data:/var/lib/postgresql/data
|
||||
networks:
|
||||
- dev-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U root -d new-api"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
dev_data:
|
||||
dev_pg_data:
|
||||
|
||||
networks:
|
||||
dev-network:
|
||||
driver: bridge
|
||||
+4
-1
@@ -27,7 +27,10 @@ type ImageRequest struct {
|
||||
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Images json.RawMessage `json:"images,omitempty"`
|
||||
Mask json.RawMessage `json:"mask,omitempty"`
|
||||
InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
// zhipu 4v
|
||||
WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
|
||||
UserId json.RawMessage `json:"user_id,omitempty"`
|
||||
|
||||
+12
-2
@@ -279,8 +279,8 @@ type Message struct {
|
||||
Content any `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
Reasoning *string `json:"reasoning,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
parsedContent []MediaContent
|
||||
@@ -431,6 +431,16 @@ const (
|
||||
//ContentTypeAudioUrl = "audio_url"
|
||||
)
|
||||
|
||||
func (m *Message) GetReasoningContent() string {
|
||||
if m.ReasoningContent == nil && m.Reasoning == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ReasoningContent != nil {
|
||||
return *m.ReasoningContent
|
||||
}
|
||||
return *m.Reasoning
|
||||
}
|
||||
|
||||
func (m *Message) GetPrefix() bool {
|
||||
if m.Prefix == nil {
|
||||
return false
|
||||
|
||||
+15
-1
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
@@ -346,7 +347,20 @@ type ResponsesOutput struct {
|
||||
Size string `json:"size"`
|
||||
CallId string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Arguments json.RawMessage `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// ArgumentsString returns function call arguments in the string form expected by Chat Completions.
|
||||
func (r *ResponsesOutput) ArgumentsString() string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
return ResponsesArgumentsString(r.Arguments)
|
||||
}
|
||||
|
||||
// ResponsesArgumentsString returns function call arguments in the string form expected by Chat Completions.
|
||||
func ResponsesArgumentsString(arguments json.RawMessage) string {
|
||||
return common.JsonRawMessageToString(arguments)
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
|
||||
+13
-12
@@ -304,18 +304,19 @@ const (
|
||||
|
||||
// Distributor related messages
|
||||
const (
|
||||
MsgDistributorInvalidRequest = "distributor.invalid_request"
|
||||
MsgDistributorInvalidChannelId = "distributor.invalid_channel_id"
|
||||
MsgDistributorChannelDisabled = "distributor.channel_disabled"
|
||||
MsgDistributorTokenNoModelAccess = "distributor.token_no_model_access"
|
||||
MsgDistributorTokenModelForbidden = "distributor.token_model_forbidden"
|
||||
MsgDistributorModelNameRequired = "distributor.model_name_required"
|
||||
MsgDistributorInvalidPlayground = "distributor.invalid_playground_request"
|
||||
MsgDistributorGroupAccessDenied = "distributor.group_access_denied"
|
||||
MsgDistributorGetChannelFailed = "distributor.get_channel_failed"
|
||||
MsgDistributorNoAvailableChannel = "distributor.no_available_channel"
|
||||
MsgDistributorInvalidMidjourney = "distributor.invalid_midjourney_request"
|
||||
MsgDistributorInvalidParseModel = "distributor.invalid_request_parse_model"
|
||||
MsgDistributorInvalidRequest = "distributor.invalid_request"
|
||||
MsgDistributorInvalidChannelId = "distributor.invalid_channel_id"
|
||||
MsgDistributorChannelDisabled = "distributor.channel_disabled"
|
||||
MsgDistributorAffinityChannelDisabled = "distributor.affinity_channel_disabled"
|
||||
MsgDistributorTokenNoModelAccess = "distributor.token_no_model_access"
|
||||
MsgDistributorTokenModelForbidden = "distributor.token_model_forbidden"
|
||||
MsgDistributorModelNameRequired = "distributor.model_name_required"
|
||||
MsgDistributorInvalidPlayground = "distributor.invalid_playground_request"
|
||||
MsgDistributorGroupAccessDenied = "distributor.group_access_denied"
|
||||
MsgDistributorGetChannelFailed = "distributor.get_channel_failed"
|
||||
MsgDistributorNoAvailableChannel = "distributor.no_available_channel"
|
||||
MsgDistributorInvalidMidjourney = "distributor.invalid_midjourney_request"
|
||||
MsgDistributorInvalidParseModel = "distributor.invalid_request_parse_model"
|
||||
)
|
||||
|
||||
// Custom OAuth provider related messages
|
||||
|
||||
@@ -257,6 +257,7 @@ common.invalid_input: "Invalid input"
|
||||
distributor.invalid_request: "Invalid request: {{.Error}}"
|
||||
distributor.invalid_channel_id: "Invalid channel ID"
|
||||
distributor.channel_disabled: "This channel has been disabled"
|
||||
distributor.affinity_channel_disabled: "The channel selected by channel affinity has been disabled, and retry was stopped by rule. Please contact the administrator"
|
||||
distributor.token_no_model_access: "This token has no access to any models"
|
||||
distributor.token_model_forbidden: "This token has no access to model {{.Model}}"
|
||||
distributor.model_name_required: "Model name not specified, model name cannot be empty"
|
||||
|
||||
@@ -258,6 +258,7 @@ common.invalid_input: "输入不合法"
|
||||
distributor.invalid_request: "无效的请求,{{.Error}}"
|
||||
distributor.invalid_channel_id: "无效的渠道 Id"
|
||||
distributor.channel_disabled: "该渠道已被禁用"
|
||||
distributor.affinity_channel_disabled: "渠道亲和性命中的渠道已被禁用,已按规则停止重试,请联系管理员处理"
|
||||
distributor.token_no_model_access: "该令牌无权访问任何模型"
|
||||
distributor.token_model_forbidden: "该令牌无权访问模型 {{.Model}}"
|
||||
distributor.model_name_required: "未指定模型名称,模型名称不能为空"
|
||||
|
||||
@@ -258,6 +258,7 @@ common.invalid_input: "輸入不合法"
|
||||
distributor.invalid_request: "無效的請求,{{.Error}}"
|
||||
distributor.invalid_channel_id: "無效的管道 Id"
|
||||
distributor.channel_disabled: "該管道已被禁用"
|
||||
distributor.affinity_channel_disabled: "管道親和性命中的管道已被禁用,已按規則停止重試,請聯絡管理員處理"
|
||||
distributor.token_no_model_access: "該令牌無權存取任何模型"
|
||||
distributor.token_model_forbidden: "該令牌無權存取模型 {{.Model}}"
|
||||
distributor.model_name_required: "未指定模型名稱,模型名稱不能為空"
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
perfmetrics "github.com/QuantumNous/new-api/pkg/perf_metrics"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/router"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -34,12 +35,18 @@ import (
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
//go:embed web/dist
|
||||
//go:embed web/default/dist
|
||||
var buildFS embed.FS
|
||||
|
||||
//go:embed web/dist/index.html
|
||||
//go:embed web/default/dist/index.html
|
||||
var indexPage []byte
|
||||
|
||||
//go:embed web/classic/dist
|
||||
var classicBuildFS embed.FS
|
||||
|
||||
//go:embed web/classic/dist/index.html
|
||||
var classicIndexPage []byte
|
||||
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -183,7 +190,12 @@ func main() {
|
||||
InjectGoogleAnalytics()
|
||||
|
||||
// 设置路由
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
router.SetRouter(server, router.ThemeAssets{
|
||||
DefaultBuildFS: buildFS,
|
||||
DefaultIndexPage: indexPage,
|
||||
ClassicBuildFS: classicBuildFS,
|
||||
ClassicIndexPage: classicIndexPage,
|
||||
})
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
@@ -213,8 +225,10 @@ func InjectUmamiAnalytics() {
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||
analyticsInject := []byte(analyticsInjectBuilder.String())
|
||||
placeholder := []byte("<!--umami-->\n")
|
||||
indexPage = bytes.ReplaceAll(indexPage, placeholder, analyticsInject)
|
||||
classicIndexPage = bytes.ReplaceAll(classicIndexPage, placeholder, analyticsInject)
|
||||
}
|
||||
|
||||
func InjectGoogleAnalytics() {
|
||||
@@ -235,8 +249,10 @@ func InjectGoogleAnalytics() {
|
||||
analyticsInjectBuilder.WriteString("</script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||
analyticsInject := []byte(analyticsInjectBuilder.String())
|
||||
placeholder := []byte("<!--Google Analytics-->\n")
|
||||
indexPage = bytes.ReplaceAll(indexPage, placeholder, analyticsInject)
|
||||
classicIndexPage = bytes.ReplaceAll(classicIndexPage, placeholder, analyticsInject)
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
@@ -291,6 +307,8 @@ func InitResources() error {
|
||||
return err
|
||||
}
|
||||
|
||||
perfmetrics.Init()
|
||||
|
||||
// 启动系统监控
|
||||
common.StartSystemMonitor()
|
||||
|
||||
|
||||
@@ -1,14 +1,35 @@
|
||||
FRONTEND_DIR = ./web
|
||||
FRONTEND_DIR = ./web/default
|
||||
FRONTEND_CLASSIC_DIR = ./web/classic
|
||||
BACKEND_DIR = .
|
||||
|
||||
.PHONY: all build-frontend start-backend
|
||||
.PHONY: all build-frontend build-frontend-classic build-all-frontends start-backend dev dev-api dev-web dev-web-classic
|
||||
|
||||
all: build-frontend start-backend
|
||||
all: build-all-frontends start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
@echo "Building default frontend..."
|
||||
@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-frontend-classic:
|
||||
@echo "Building classic frontend..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && VITE_REACT_APP_VERSION=$(cat ../../VERSION) bun run build
|
||||
|
||||
build-all-frontends: build-frontend build-frontend-classic
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
@cd $(BACKEND_DIR) && go run main.go &
|
||||
|
||||
dev-api:
|
||||
@echo "Starting backend services (docker)..."
|
||||
@docker compose -f docker-compose.dev.yml up -d
|
||||
|
||||
dev-web:
|
||||
@echo "Starting frontend dev server..."
|
||||
@cd $(FRONTEND_DIR) && bun install && bun run dev
|
||||
|
||||
dev-web-classic:
|
||||
@echo "Starting classic frontend dev server..."
|
||||
@cd $(FRONTEND_CLASSIC_DIR) && bun install && bun run dev
|
||||
|
||||
dev: dev-api dev-web
|
||||
|
||||
@@ -104,7 +104,7 @@ func Distribute() func(c *gin.Context) {
|
||||
if err == nil && preferred != nil {
|
||||
if preferred.Status != common.ChannelStatusEnabled {
|
||||
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorAffinityChannelDisabled))
|
||||
return
|
||||
}
|
||||
} else if usingGroup == "auto" {
|
||||
|
||||
+71
-19
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -67,6 +68,66 @@ type ChannelInfo struct {
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
type ChannelSortOptions struct {
|
||||
SortBy string
|
||||
SortOrder string
|
||||
IDSort bool
|
||||
}
|
||||
|
||||
var channelSortColumns = map[string]string{
|
||||
"id": "id",
|
||||
"name": "name",
|
||||
"priority": "priority",
|
||||
"balance": "balance",
|
||||
"response_time": "response_time",
|
||||
"test_time": "test_time",
|
||||
}
|
||||
|
||||
func NewChannelSortOptions(sortBy string, sortOrder string, idSort bool) ChannelSortOptions {
|
||||
normalizedSortBy := strings.ToLower(strings.TrimSpace(sortBy))
|
||||
normalizedSortOrder := strings.ToLower(strings.TrimSpace(sortOrder))
|
||||
if _, ok := channelSortColumns[normalizedSortBy]; !ok {
|
||||
normalizedSortBy = ""
|
||||
normalizedSortOrder = ""
|
||||
} else if normalizedSortOrder != "asc" {
|
||||
normalizedSortOrder = "desc"
|
||||
}
|
||||
|
||||
return ChannelSortOptions{
|
||||
SortBy: normalizedSortBy,
|
||||
SortOrder: normalizedSortOrder,
|
||||
IDSort: idSort,
|
||||
}
|
||||
}
|
||||
|
||||
func (options ChannelSortOptions) Apply(query *gorm.DB) *gorm.DB {
|
||||
if columnName, ok := channelSortColumns[options.SortBy]; ok {
|
||||
return query.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Name: columnName},
|
||||
Desc: options.SortOrder != "asc",
|
||||
})
|
||||
}
|
||||
if options.IDSort {
|
||||
return query.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Name: "id"},
|
||||
Desc: true,
|
||||
})
|
||||
}
|
||||
return query.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Name: "priority"},
|
||||
Desc: true,
|
||||
})
|
||||
}
|
||||
|
||||
func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) ChannelSortOptions {
|
||||
if len(sortOptions) == 0 {
|
||||
return NewChannelSortOptions("", "", idSort)
|
||||
}
|
||||
options := sortOptions[0]
|
||||
options.IDSort = options.IDSort || idSort
|
||||
return options
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.Marshal(&c)
|
||||
@@ -260,28 +321,22 @@ func (channel *Channel) SaveWithoutKey() error {
|
||||
return DB.Omit("key").Save(channel).Error
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
var err error
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
order := resolveChannelSortOptions(idSort, sortOptions)
|
||||
if selectAll {
|
||||
err = DB.Order(order).Find(&channels).Error
|
||||
err = order.Apply(DB).Find(&channels).Error
|
||||
} else {
|
||||
err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
err = order.Apply(DB).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
}
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, error) {
|
||||
func GetChannelsByTag(tag string, idSort bool, selectAll bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
query := DB.Where("tag = ?", tag).Order(order)
|
||||
order := resolveChannelSortOptions(idSort, sortOptions)
|
||||
query := order.Apply(DB.Where("tag = ?", tag))
|
||||
if !selectAll {
|
||||
query = query.Omit("key")
|
||||
}
|
||||
@@ -289,7 +344,7 @@ func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, erro
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
|
||||
func SearchChannels(keyword string, group string, model string, idSort bool, sortOptions ...ChannelSortOptions) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
modelsCol := "`models`"
|
||||
|
||||
@@ -304,10 +359,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
baseURLCol = `"base_url"`
|
||||
}
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
order := resolveChannelSortOptions(idSort, sortOptions)
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
@@ -331,7 +383,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
|
||||
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -280,6 +280,7 @@ func migrateDB() error {
|
||||
&SubscriptionPreConsumeRecord{},
|
||||
&CustomOAuthProvider{},
|
||||
&UserOAuthBinding{},
|
||||
&PerfMetric{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -328,6 +329,7 @@ func migrateDBFast() error {
|
||||
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
|
||||
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
|
||||
{&UserOAuthBinding{}, "UserOAuthBinding"},
|
||||
{&PerfMetric{}, "PerfMetric"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
@@ -578,6 +578,11 @@ func handleConfigUpdate(key, value string) bool {
|
||||
performance_setting.UpdateAndSync()
|
||||
} else if configName == "tool_price_setting" {
|
||||
operation_setting.RebuildToolPriceIndex()
|
||||
} else if configName == "billing_setting" {
|
||||
InvalidatePricingCache()
|
||||
ratio_setting.InvalidateExposedDataCache()
|
||||
} else if configName == "theme" {
|
||||
system_setting.UpdateAndSyncTheme()
|
||||
}
|
||||
|
||||
return true // 已处理
|
||||
|
||||
@@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti
|
||||
return plan
|
||||
}
|
||||
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) {
|
||||
func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
order := &SubscriptionOrder{
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
PlanId: planID,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, order.Insert())
|
||||
}
|
||||
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) {
|
||||
func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) {
|
||||
t.Helper()
|
||||
topUp := &TopUp{
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentMethod,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
UserId: userID,
|
||||
Amount: 2,
|
||||
Money: 9.99,
|
||||
TradeNo: tradeNo,
|
||||
PaymentMethod: paymentProvider,
|
||||
PaymentProvider: paymentProvider,
|
||||
Status: common.TopUpStatusPending,
|
||||
CreateTime: time.Now().Unix(),
|
||||
}
|
||||
require.NoError(t, topUp.Insert())
|
||||
}
|
||||
@@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 101, 0)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe)
|
||||
insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe)
|
||||
|
||||
err := RechargeWaffoPancake("waffo-pancake-guard")
|
||||
require.Error(t, err)
|
||||
@@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101))
|
||||
}
|
||||
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentMethod string
|
||||
expectedPaymentMethod string
|
||||
targetStatus string
|
||||
name string
|
||||
tradeNo string
|
||||
storedPaymentProvider string
|
||||
expectedPaymentProvider string
|
||||
targetStatus string
|
||||
}{
|
||||
{
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentMethod: PaymentMethodCreem,
|
||||
expectedPaymentMethod: PaymentMethodStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
name: "stripe expire",
|
||||
tradeNo: "stripe-expire-guard",
|
||||
storedPaymentProvider: PaymentProviderCreem,
|
||||
expectedPaymentProvider: PaymentProviderStripe,
|
||||
targetStatus: common.TopUpStatusExpired,
|
||||
},
|
||||
{
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentMethod: PaymentMethodStripe,
|
||||
expectedPaymentMethod: PaymentMethodWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
name: "waffo failed",
|
||||
tradeNo: "waffo-failed-guard",
|
||||
storedPaymentProvider: PaymentProviderStripe,
|
||||
expectedPaymentProvider: PaymentProviderWaffo,
|
||||
targetStatus: common.TopUpStatusFailed,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
truncateTables(t)
|
||||
insertUserForPaymentGuardTest(t, 150, 0)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod)
|
||||
insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider)
|
||||
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus)
|
||||
err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 202, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 301)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay")
|
||||
err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay")
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-guard-order")
|
||||
@@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T)
|
||||
assert.Nil(t, topUp)
|
||||
}
|
||||
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) {
|
||||
func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
insertUserForPaymentGuardTest(t, 303, 0)
|
||||
plan := insertSubscriptionPlanForPaymentGuardTest(t, 401)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe)
|
||||
insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe)
|
||||
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem)
|
||||
err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem)
|
||||
require.ErrorIs(t, err, ErrPaymentMethodMismatch)
|
||||
|
||||
order := GetSubscriptionOrderByTradeNo("sub-expire-guard")
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// PerfMetric stores aggregated relay performance metrics for the model square.
|
||||
type PerfMetric struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;uniqueIndex:idx_perf_model_group_bucket,priority:1"`
|
||||
Group string `json:"group" gorm:"column:group;size:64;uniqueIndex:idx_perf_model_group_bucket,priority:2"`
|
||||
BucketTs int64 `json:"bucket_ts" gorm:"uniqueIndex:idx_perf_model_group_bucket,priority:3;index:idx_perf_bucket_ts"`
|
||||
RequestCount int64 `json:"-" gorm:"default:0"`
|
||||
SuccessCount int64 `json:"-" gorm:"default:0"`
|
||||
TotalLatencyMs int64 `json:"-" gorm:"default:0"`
|
||||
TtftSumMs int64 `json:"-" gorm:"default:0"`
|
||||
TtftCount int64 `json:"-" gorm:"default:0"`
|
||||
OutputTokens int64 `json:"-" gorm:"default:0"`
|
||||
GenerationMs int64 `json:"-" gorm:"default:0"`
|
||||
}
|
||||
|
||||
func (PerfMetric) TableName() string {
|
||||
return "perf_metrics"
|
||||
}
|
||||
|
||||
func UpsertPerfMetric(metric *PerfMetric) error {
|
||||
if metric == nil || metric.RequestCount == 0 {
|
||||
return nil
|
||||
}
|
||||
return DB.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "model_name"},
|
||||
{Name: "group"},
|
||||
{Name: "bucket_ts"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + ?", metric.RequestCount),
|
||||
"success_count": gorm.Expr("success_count + ?", metric.SuccessCount),
|
||||
"total_latency_ms": gorm.Expr("total_latency_ms + ?", metric.TotalLatencyMs),
|
||||
"ttft_sum_ms": gorm.Expr("ttft_sum_ms + ?", metric.TtftSumMs),
|
||||
"ttft_count": gorm.Expr("ttft_count + ?", metric.TtftCount),
|
||||
"output_tokens": gorm.Expr("output_tokens + ?", metric.OutputTokens),
|
||||
"generation_ms": gorm.Expr("generation_ms + ?", metric.GenerationMs),
|
||||
}),
|
||||
}).Create(metric).Error
|
||||
}
|
||||
|
||||
func GetPerfMetrics(modelName string, group string, startTs int64, endTs int64) ([]PerfMetric, error) {
|
||||
var metrics []PerfMetric
|
||||
query := DB.Model(&PerfMetric{}).
|
||||
Where("model_name = ? AND bucket_ts >= ? AND bucket_ts <= ?", modelName, startTs, endTs)
|
||||
if group != "" {
|
||||
query = query.Where(commonGroupCol+" = ?", group)
|
||||
}
|
||||
err := query.Order("bucket_ts ASC").Find(&metrics).Error
|
||||
return metrics, err
|
||||
}
|
||||
|
||||
type PerfMetricSummary struct {
|
||||
ModelName string `json:"model_name"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
TotalLatencyMs int64 `json:"total_latency_ms"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
GenerationMs int64 `json:"generation_ms"`
|
||||
}
|
||||
|
||||
func GetPerfMetricsSummaryAll(startTs int64, endTs int64) ([]PerfMetricSummary, error) {
|
||||
var summaries []PerfMetricSummary
|
||||
err := DB.Model(&PerfMetric{}).
|
||||
Select("model_name, SUM(request_count) as request_count, SUM(success_count) as success_count, SUM(total_latency_ms) as total_latency_ms, SUM(output_tokens) as output_tokens, SUM(generation_ms) as generation_ms").
|
||||
Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs).
|
||||
Group("model_name").
|
||||
Having("SUM(request_count) > 0").
|
||||
Find(&summaries).Error
|
||||
return summaries, err
|
||||
}
|
||||
|
||||
func DeletePerfMetricsBefore(cutoffTs int64) error {
|
||||
if cutoffTs <= 0 {
|
||||
return nil
|
||||
}
|
||||
return DB.Where("bucket_ts < ?", cutoffTs).Delete(&PerfMetric{}).Error
|
||||
}
|
||||
|
||||
func PerfMetricStartTime(hours int) int64 {
|
||||
if hours <= 0 {
|
||||
hours = 24
|
||||
}
|
||||
return time.Now().Add(-time.Duration(hours) * time.Hour).Unix()
|
||||
}
|
||||
+10
-1
@@ -77,6 +77,15 @@ func GetPricing() []Pricing {
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func InvalidatePricingCache() {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
pricingMap = nil
|
||||
vendorsList = nil
|
||||
lastGetPricingTime = time.Time{}
|
||||
}
|
||||
|
||||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||||
func GetVendors() []PricingVendor {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
@@ -323,7 +332,7 @@ func updatePricing() {
|
||||
pricing.AudioCompletionRatio = &audioCompletionRatio
|
||||
}
|
||||
if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" {
|
||||
if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" {
|
||||
if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" {
|
||||
pricing.BillingMode = billingMode
|
||||
pricing.BillingExpr = expr
|
||||
}
|
||||
|
||||
+15
-9
@@ -198,11 +198,12 @@ type SubscriptionOrder struct {
|
||||
PlanId int `json:"plan_id" gorm:"index"`
|
||||
Money float64 `json:"money"`
|
||||
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
Status string `json:"status"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
|
||||
ProviderPayload string `json:"provider_payload" gorm:"type:text"`
|
||||
}
|
||||
@@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
|
||||
}
|
||||
|
||||
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error {
|
||||
// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check).
|
||||
// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update).
|
||||
func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status == common.TopUpStatusSuccess {
|
||||
@@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP
|
||||
if providerPayload != "" {
|
||||
order.ProviderPayload = providerPayload
|
||||
}
|
||||
if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod {
|
||||
order.PaymentMethod = actualPaymentMethod
|
||||
}
|
||||
if err := tx.Save(&order).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
|
||||
return tx.Save(&topup).Error
|
||||
}
|
||||
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error {
|
||||
func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("tradeNo is empty")
|
||||
}
|
||||
@@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
|
||||
@@ -416,6 +416,17 @@ func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// TaskBulkUpdate performs an unconditional bulk UPDATE by upstream task_id strings.
|
||||
// Same caveats as TaskBulkUpdateByID — no CAS guard.
|
||||
func TaskBulkUpdate(taskIds []string, params map[string]any) error {
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
return DB.Model(&Task{}).
|
||||
Where("task_id in (?)", taskIds).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
|
||||
// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
|
||||
// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
|
||||
|
||||
+25
-16
@@ -12,15 +12,16 @@ import (
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
|
||||
PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -30,6 +31,14 @@ const (
|
||||
PaymentMethodWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentProviderEpay = "epay"
|
||||
PaymentProviderStripe = "stripe"
|
||||
PaymentProviderCreem = "creem"
|
||||
PaymentProviderWaffo = "waffo"
|
||||
PaymentProviderWaffoPancake = "waffo_pancake"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPaymentMethodMismatch = errors.New("payment method mismatch")
|
||||
ErrTopUpNotFound = errors.New("topup not found")
|
||||
@@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
return topUp
|
||||
}
|
||||
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error {
|
||||
func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error {
|
||||
if tradeNo == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
@@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil {
|
||||
return ErrTopUpNotFound
|
||||
}
|
||||
if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod {
|
||||
if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
@@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodStripe {
|
||||
if topUp.PaymentProvider != PaymentProviderStripe {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error {
|
||||
// 计算应充值额度:
|
||||
// - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit
|
||||
// - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit
|
||||
if topUp.PaymentMethod == PaymentMethodStripe {
|
||||
if topUp.PaymentProvider == PaymentProviderStripe {
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart())
|
||||
} else {
|
||||
@@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodCreem {
|
||||
if topUp.PaymentProvider != PaymentProviderCreem {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffo {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffo {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
@@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.PaymentMethod != PaymentMethodWaffoPancake {
|
||||
if topUp.PaymentProvider != PaymentProviderWaffoPancake {
|
||||
return ErrPaymentMethodMismatch
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RankingQuotaTotal struct {
|
||||
ModelName string `json:"model_name"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type RankingQuotaBucket struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Bucket int64 `json:"bucket"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
func GetRankingQuotaTotals(startTime int64, endTime int64) ([]RankingQuotaTotal, error) {
|
||||
var rows []RankingQuotaTotal
|
||||
query := DB.Table("quota_data").
|
||||
Select("model_name, sum(token_used) as total_tokens").
|
||||
Where("model_name <> ''").
|
||||
Group("model_name").
|
||||
Having("sum(token_used) > 0").
|
||||
Order("total_tokens DESC")
|
||||
query = applyRankingQuotaTimeRange(query, startTime, endTime)
|
||||
err := query.Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func GetRankingQuotaBuckets(startTime int64, endTime int64, bucketSize int64) ([]RankingQuotaBucket, error) {
|
||||
if bucketSize <= 0 {
|
||||
bucketSize = 3600
|
||||
}
|
||||
bucketExpr := rankingBucketExpr(bucketSize)
|
||||
var rows []RankingQuotaBucket
|
||||
query := DB.Table("quota_data").
|
||||
Select(fmt.Sprintf("model_name, %s as bucket, sum(token_used) as tokens", bucketExpr)).
|
||||
Where("model_name <> ''").
|
||||
Group(fmt.Sprintf("model_name, %s", bucketExpr)).
|
||||
Having("sum(token_used) > 0").
|
||||
Order("bucket ASC")
|
||||
query = applyRankingQuotaTimeRange(query, startTime, endTime)
|
||||
err := query.Find(&rows).Error
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func rankingBucketExpr(bucketSize int64) string {
|
||||
if common.UsingMySQL {
|
||||
return fmt.Sprintf("FLOOR(created_at / %d) * %d", bucketSize, bucketSize)
|
||||
}
|
||||
return fmt.Sprintf("(created_at / %d) * %d", bucketSize, bucketSize)
|
||||
}
|
||||
|
||||
func applyRankingQuotaTimeRange(query *gorm.DB, startTime int64, endTime int64) *gorm.DB {
|
||||
if startTime > 0 {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime > 0 {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
return query
|
||||
}
|
||||
@@ -50,6 +50,8 @@ type User struct {
|
||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime;column:created_at"`
|
||||
LastLoginAt int64 `json:"last_login_at" gorm:"default:0;column:last_login_at"`
|
||||
}
|
||||
|
||||
func (user *User) ToBaseUser() *UserBase {
|
||||
@@ -951,6 +953,12 @@ func GetRootUser() (user *User) {
|
||||
return user
|
||||
}
|
||||
|
||||
func UpdateUserLastLoginAt(id int) {
|
||||
if err := DB.Model(&User{}).Where("id = ?", id).Update("last_login_at", common.GetTimestamp()).Error; err != nil {
|
||||
common.SysLog("failed to update user last_login_at: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||
|
||||
@@ -1000,11 +1000,82 @@ func TestImageAudioZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// len variable tests — tier conditions based on context length
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const lenTieredExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6)`
|
||||
|
||||
func TestLen_StandardTier(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 80000, C: 5000, Len: 100000, CR: 20000}
|
||||
cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := 80000*3 + 5000*15 + 20000*0.3
|
||||
if math.Abs(cost-want) > 1e-6 {
|
||||
t.Errorf("cost = %f, want %f", cost, want)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_LongContextTier(t *testing.T) {
|
||||
// p is low (cache subtracted), but len is high (full context)
|
||||
params := billingexpr.TokenParams{P: 50000, C: 5000, Len: 300000, CR: 250000}
|
||||
cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := 50000*6 + 5000*22.5 + 250000*0.6
|
||||
if math.Abs(cost-want) > 1e-6 {
|
||||
t.Errorf("cost = %f, want %f", cost, want)
|
||||
}
|
||||
if trace.MatchedTier != "long_context" {
|
||||
t.Errorf("tier = %q, want long_context (len=300000 > 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_BoundaryExact(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200000, CR: 100000}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard (len=200000 <= 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_BoundaryPlusOne(t *testing.T) {
|
||||
params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200001, CR: 100001}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "long_context" {
|
||||
t.Errorf("tier = %q, want long_context (len=200001 > 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLen_ZeroDefaultsToZero(t *testing.T) {
|
||||
// len defaults to 0 when not set
|
||||
params := billingexpr.TokenParams{P: 1000, C: 500}
|
||||
_, trace, err := billingexpr.RunExpr(lenTieredExpr, params)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if trace.MatchedTier != "standard" {
|
||||
t.Errorf("tier = %q, want standard (len=0 <= 200000)", trace.MatchedTier)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Benchmarks: compile vs cached execution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const benchComplexExpr = `p <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)`
|
||||
const benchComplexExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)`
|
||||
|
||||
func BenchmarkExprCompile(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -1015,7 +1086,7 @@ func BenchmarkExprCompile(b *testing.B) {
|
||||
|
||||
func BenchmarkExprRunCached(b *testing.B) {
|
||||
billingexpr.CompileFromCache(benchComplexExpr)
|
||||
params := billingexpr.TokenParams{P: 150000, C: 10000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500}
|
||||
params := billingexpr.TokenParams{P: 150000, C: 10000, Len: 188000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
billingexpr.RunExpr(benchComplexExpr, params)
|
||||
|
||||
@@ -41,6 +41,7 @@ var (
|
||||
var compileEnvPrototypeV1 = map[string]interface{}{
|
||||
"p": float64(0),
|
||||
"c": float64(0),
|
||||
"len": float64(0),
|
||||
"cr": float64(0),
|
||||
"cc": float64(0),
|
||||
"cc1h": float64(0),
|
||||
|
||||
+16
-3
@@ -30,7 +30,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
|
||||
| 变量 | 含义 |
|
||||
|------|------|
|
||||
| `p` | 输入 token 数。**自动排除**表达式中单独计价的子类别(见下方说明) |
|
||||
| `p` | 输入 token 数(**计价用**)。**自动排除**表达式中单独计价的子类别(见下方说明) |
|
||||
| `len` | 输入上下文总长度(**条件判断用**)。不受自动排除影响,始终反映完整输入长度。非 Claude:等于原始 `prompt_tokens`;Claude:等于文本输入 + 缓存读取 + 缓存创建 |
|
||||
| `cr` | 缓存命中(读取)token 数 |
|
||||
| `cc` | 缓存创建 token 数(Claude 5分钟 TTL / 通用) |
|
||||
| `cc1h` | 缓存创建 token 数 — 1小时 TTL(Claude 专用) |
|
||||
@@ -51,6 +52,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
|
||||
**规则:如果表达式使用了某个子类别变量,对应的 token 就从 `p` 或 `c` 中扣除;如果没使用,那些 token 就留在 `p` 或 `c` 里按基础价格计费。**
|
||||
|
||||
> **重要:`len` 不受自动排除影响。** `len` 始终代表完整的输入上下文长度,不管表达式是否单独对缓存/图片/音频定价。因此**阶梯条件应使用 `len` 而非 `p`**,以避免缓存命中导致 `p` 降低而误判档位。
|
||||
|
||||
举例说明(假设上游返回的原始数据:prompt_tokens=1000,其中包含 200 cache read、100 image):
|
||||
|
||||
| 表达式 | `p` 的值 | 说明 |
|
||||
@@ -93,8 +96,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are
|
||||
# Simple flat pricing
|
||||
tier("base", p * 2.5 + c * 15 + cr * 0.25)
|
||||
|
||||
# Multi-tier (Claude Sonnet style)
|
||||
p <= 200000
|
||||
# Multi-tier (Claude Sonnet style) — use len for tier conditions
|
||||
len <= 200000
|
||||
? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6)
|
||||
: tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12)
|
||||
|
||||
@@ -199,6 +202,16 @@ Example: `p * 2.5 + c * 15 + cr * 0.25`
|
||||
- Expression uses `cr` → cache read tokens subtracted from `p`
|
||||
- Expression doesn't use `img` → image tokens stay in `p`, priced at $2.50
|
||||
|
||||
### `len` — Context Length Variable
|
||||
|
||||
`len` represents the total input context length, designed for **tier condition evaluation** (e.g. `len <= 200000 ? ...`). Unlike `p`, `len` is never reduced by sub-category exclusion.
|
||||
|
||||
**Computation rules:**
|
||||
- **Non-Claude (GPT/OpenAI format)**: `len = prompt_tokens` (the raw total from the upstream response)
|
||||
- **Claude format**: `len = input_tokens + cache_read_tokens + cache_creation_tokens` (since Claude's `input_tokens` is text-only, cache must be added back to reflect full context length)
|
||||
|
||||
This ensures that heavy cache usage doesn't cause the tier condition to incorrectly evaluate to a lower tier. For example, if a request has 300K total context but 250K is cached, `p` with cache subtracted would be only 50K (standard tier), while `len` correctly reports 300K (long-context tier).
|
||||
|
||||
### Quota Conversion
|
||||
|
||||
Expression coefficients are $/1M tokens. Conversion to internal quota:
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
|
||||
// RunExpr compiles (with cache) and executes an expression string.
|
||||
// The environment exposes:
|
||||
// - p, c — prompt / completion tokens
|
||||
// - p, c — prompt / completion tokens (auto-excluding separately-priced sub-categories)
|
||||
// - len — total input context length for tier conditions (never reduced by sub-category exclusion)
|
||||
// - cr, cc, cc1h — cache read / creation / creation-1h tokens
|
||||
// - tier(name, value) — trace callback that records which tier matched
|
||||
// - max, min, abs, ceil, floor — standard math helpers
|
||||
@@ -54,6 +55,7 @@ func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (flo
|
||||
env := map[string]interface{}{
|
||||
"p": params.P,
|
||||
"c": params.C,
|
||||
"len": params.Len,
|
||||
"cr": params.CR,
|
||||
"cc": params.CC,
|
||||
"cc1h": params.CC1h,
|
||||
|
||||
@@ -14,8 +14,9 @@ type RequestInput struct {
|
||||
// Fields beyond P and C are optional — when absent they default to 0,
|
||||
// which means cache-unaware expressions keep working unchanged.
|
||||
type TokenParams struct {
|
||||
P float64 // prompt tokens (text)
|
||||
C float64 // completion tokens (text)
|
||||
P float64 // prompt tokens (text) — auto-excludes sub-categories priced separately
|
||||
C float64 // completion tokens (text) — auto-excludes sub-categories priced separately
|
||||
Len float64 // total input context length for tier conditions (non-Claude: raw prompt_tokens; Claude: text + cache read + cache creation)
|
||||
CR float64 // cache read (hit) tokens
|
||||
CC float64 // cache creation tokens (5-min TTL for Claude, generic for others)
|
||||
CC1h float64 // cache creation tokens — 1-hour TTL (Claude only)
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package perfmetrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/perf_metrics_setting"
|
||||
)
|
||||
|
||||
func flushLoop() {
|
||||
for {
|
||||
interval := perf_metrics_setting.GetFlushIntervalMinutes()
|
||||
time.Sleep(time.Duration(interval) * time.Minute)
|
||||
setting := perf_metrics_setting.GetSetting()
|
||||
if !setting.Enabled {
|
||||
continue
|
||||
}
|
||||
flushCompletedBuckets()
|
||||
cleanupExpiredMetrics(setting.RetentionDays)
|
||||
}
|
||||
}
|
||||
|
||||
func flushCompletedBuckets() {
|
||||
currentBucket := bucketStart(time.Now().Unix())
|
||||
hotBuckets.Range(func(key, value any) bool {
|
||||
k := key.(bucketKey)
|
||||
if k.bucketTs >= currentBucket {
|
||||
return true
|
||||
}
|
||||
|
||||
bucket := value.(*atomicBucket)
|
||||
drained := bucket.drain()
|
||||
if drained.requestCount == 0 {
|
||||
deleteOldEmptyBucket(k, key)
|
||||
return true
|
||||
}
|
||||
|
||||
err := model.UpsertPerfMetric(&model.PerfMetric{
|
||||
ModelName: k.model,
|
||||
Group: k.group,
|
||||
BucketTs: k.bucketTs,
|
||||
RequestCount: drained.requestCount,
|
||||
SuccessCount: drained.successCount,
|
||||
TotalLatencyMs: drained.totalLatencyMs,
|
||||
TtftSumMs: drained.ttftSumMs,
|
||||
TtftCount: drained.ttftCount,
|
||||
OutputTokens: drained.outputTokens,
|
||||
GenerationMs: drained.generationMs,
|
||||
})
|
||||
if err != nil {
|
||||
bucket.addCounters(drained)
|
||||
common.SysError(fmt.Sprintf("failed to flush perf metric bucket model=%s group=%s bucket=%d: %s", k.model, k.group, k.bucketTs, err.Error()))
|
||||
return true
|
||||
}
|
||||
|
||||
deleteOldEmptyBucket(k, key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func deleteOldEmptyBucket(k bucketKey, rawKey any) {
|
||||
if k.bucketTs < bucketStart(time.Now().Add(-24*time.Hour).Unix()) {
|
||||
hotBuckets.Delete(rawKey)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredMetrics(retentionDays int) {
|
||||
if retentionDays <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().Add(-time.Duration(retentionDays) * 24 * time.Hour).Unix()
|
||||
if err := model.DeletePerfMetricsBefore(cutoff); err != nil {
|
||||
common.SysError("failed to cleanup expired perf metrics: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func redisCounters(values map[string]string) counters {
|
||||
return counters{
|
||||
requestCount: parseRedisInt(values["req"]),
|
||||
successCount: parseRedisInt(values["ok"]),
|
||||
totalLatencyMs: parseRedisInt(values["lat"]),
|
||||
ttftSumMs: parseRedisInt(values["ttft"]),
|
||||
ttftCount: parseRedisInt(values["ttft_n"]),
|
||||
outputTokens: parseRedisInt(values["out"]),
|
||||
generationMs: parseRedisInt(values["gen_ms"]),
|
||||
}
|
||||
}
|
||||
|
||||
func parseRedisInt(value string) int64 {
|
||||
if value == "" {
|
||||
return 0
|
||||
}
|
||||
parsed, _ := strconv.ParseInt(value, 10, 64)
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
package perfmetrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/perf_metrics_setting"
|
||||
)
|
||||
|
||||
var hotBuckets sync.Map
|
||||
|
||||
// seriesSchema is a stable client cache/schema marker. Do not change it when
|
||||
// hiding fields or making response-only privacy hardening changes.
|
||||
const seriesSchema = "dbcd0a3c01b55203"
|
||||
|
||||
func Init() {
|
||||
go flushLoop()
|
||||
}
|
||||
|
||||
func RecordRelaySample(info *relaycommon.RelayInfo, success bool, outputTokens int64) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
hasTtft := info.IsStream && info.HasSendResponse()
|
||||
ttftMs := int64(0)
|
||||
if hasTtft {
|
||||
ttftMs = info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
|
||||
}
|
||||
latencyMs := now.Sub(info.StartTime).Milliseconds()
|
||||
generationMs := latencyMs
|
||||
if hasTtft {
|
||||
generationMs = now.Sub(info.FirstResponseTime).Milliseconds()
|
||||
}
|
||||
if generationMs <= 0 {
|
||||
generationMs = latencyMs
|
||||
}
|
||||
Record(Sample{
|
||||
Model: info.OriginModelName,
|
||||
Group: info.UsingGroup,
|
||||
LatencyMs: latencyMs,
|
||||
TtftMs: ttftMs,
|
||||
HasTtft: hasTtft,
|
||||
Success: success,
|
||||
OutputTokens: outputTokens,
|
||||
GenerationMs: generationMs,
|
||||
})
|
||||
}
|
||||
|
||||
func Record(sample Sample) {
|
||||
setting := perf_metrics_setting.GetSetting()
|
||||
if !setting.Enabled || sample.Model == "" {
|
||||
return
|
||||
}
|
||||
if sample.Group == "" {
|
||||
sample.Group = "default"
|
||||
}
|
||||
if sample.LatencyMs < 0 {
|
||||
sample.LatencyMs = 0
|
||||
}
|
||||
|
||||
key := bucketKey{
|
||||
model: sample.Model,
|
||||
group: sample.Group,
|
||||
bucketTs: bucketStart(time.Now().Unix()),
|
||||
}
|
||||
actual, _ := hotBuckets.LoadOrStore(key, &atomicBucket{})
|
||||
actual.(*atomicBucket).add(sample)
|
||||
recordRedis(key, sample)
|
||||
}
|
||||
|
||||
func Query(params QueryParams) (QueryResult, error) {
|
||||
if params.Hours <= 0 {
|
||||
params.Hours = 24
|
||||
}
|
||||
if params.Hours > 24*30 {
|
||||
params.Hours = 24 * 30
|
||||
}
|
||||
endTs := time.Now().Unix()
|
||||
startTs := endTs - int64(params.Hours)*3600
|
||||
|
||||
merged := map[bucketKey]counters{}
|
||||
rows, err := model.GetPerfMetrics(params.Model, params.Group, startTs, endTs)
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
for _, row := range rows {
|
||||
mergeCounters(merged, bucketKey{
|
||||
model: row.ModelName,
|
||||
group: row.Group,
|
||||
bucketTs: row.BucketTs,
|
||||
}, counters{
|
||||
requestCount: row.RequestCount,
|
||||
successCount: row.SuccessCount,
|
||||
totalLatencyMs: row.TotalLatencyMs,
|
||||
ttftSumMs: row.TtftSumMs,
|
||||
ttftCount: row.TtftCount,
|
||||
outputTokens: row.OutputTokens,
|
||||
generationMs: row.GenerationMs,
|
||||
})
|
||||
}
|
||||
|
||||
hotBuckets.Range(func(key, value any) bool {
|
||||
k := key.(bucketKey)
|
||||
if k.model != params.Model || k.bucketTs < startTs || k.bucketTs > endTs {
|
||||
return true
|
||||
}
|
||||
if params.Group != "" && k.group != params.Group {
|
||||
return true
|
||||
}
|
||||
mergeCounters(merged, k, value.(*atomicBucket).snapshot())
|
||||
return true
|
||||
})
|
||||
|
||||
return buildQueryResult(params.Model, merged), nil
|
||||
}
|
||||
|
||||
func QuerySummaryAll(hours int) (SummaryAllResult, error) {
|
||||
if hours <= 0 {
|
||||
hours = 24
|
||||
}
|
||||
if hours > 24*30 {
|
||||
hours = 24 * 30
|
||||
}
|
||||
endTs := time.Now().Unix()
|
||||
startTs := endTs - int64(hours)*3600
|
||||
|
||||
rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs)
|
||||
if err != nil {
|
||||
return SummaryAllResult{}, err
|
||||
}
|
||||
|
||||
totals := map[string]counters{}
|
||||
for _, row := range rows {
|
||||
totals[row.ModelName] = counters{
|
||||
requestCount: row.RequestCount,
|
||||
successCount: row.SuccessCount,
|
||||
totalLatencyMs: row.TotalLatencyMs,
|
||||
outputTokens: row.OutputTokens,
|
||||
generationMs: row.GenerationMs,
|
||||
}
|
||||
}
|
||||
|
||||
hotBuckets.Range(func(key, value any) bool {
|
||||
k := key.(bucketKey)
|
||||
if k.bucketTs < startTs || k.bucketTs > endTs {
|
||||
return true
|
||||
}
|
||||
snap := value.(*atomicBucket).snapshot()
|
||||
if snap.requestCount == 0 {
|
||||
return true
|
||||
}
|
||||
cur := totals[k.model]
|
||||
cur.requestCount += snap.requestCount
|
||||
cur.successCount += snap.successCount
|
||||
cur.totalLatencyMs += snap.totalLatencyMs
|
||||
cur.outputTokens += snap.outputTokens
|
||||
cur.generationMs += snap.generationMs
|
||||
totals[k.model] = cur
|
||||
return true
|
||||
})
|
||||
|
||||
models := make([]ModelSummary, 0, len(totals))
|
||||
for name, total := range totals {
|
||||
if total.requestCount == 0 {
|
||||
continue
|
||||
}
|
||||
avgLatency := total.totalLatencyMs / total.requestCount
|
||||
successRate := float64(total.successCount) / float64(total.requestCount) * 100
|
||||
avgTps := 0.0
|
||||
if total.generationMs > 0 {
|
||||
avgTps = float64(total.outputTokens) / (float64(total.generationMs) / 1000.0)
|
||||
}
|
||||
models = append(models, ModelSummary{
|
||||
ModelName: name,
|
||||
AvgLatencyMs: avgLatency,
|
||||
SuccessRate: math.Round(successRate*100) / 100,
|
||||
AvgTps: math.Round(avgTps*100) / 100,
|
||||
RequestCount: total.requestCount,
|
||||
})
|
||||
}
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
return models[i].ModelName < models[j].ModelName
|
||||
})
|
||||
|
||||
return SummaryAllResult{Models: models}, nil
|
||||
}
|
||||
|
||||
func bucketStart(ts int64) int64 {
|
||||
bucketSeconds := perf_metrics_setting.GetBucketSeconds()
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 3600
|
||||
}
|
||||
return ts - (ts % bucketSeconds)
|
||||
}
|
||||
|
||||
func mergeCounters(merged map[bucketKey]counters, key bucketKey, value counters) {
|
||||
if value.requestCount == 0 {
|
||||
return
|
||||
}
|
||||
current := merged[key]
|
||||
current.requestCount += value.requestCount
|
||||
current.successCount += value.successCount
|
||||
current.totalLatencyMs += value.totalLatencyMs
|
||||
current.ttftSumMs += value.ttftSumMs
|
||||
current.ttftCount += value.ttftCount
|
||||
current.outputTokens += value.outputTokens
|
||||
current.generationMs += value.generationMs
|
||||
merged[key] = current
|
||||
}
|
||||
|
||||
func buildQueryResult(modelName string, merged map[bucketKey]counters) QueryResult {
|
||||
groupBuckets := map[string]map[int64]counters{}
|
||||
for key, value := range merged {
|
||||
if value.requestCount == 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := groupBuckets[key.group]; !ok {
|
||||
groupBuckets[key.group] = map[int64]counters{}
|
||||
}
|
||||
groupBuckets[key.group][key.bucketTs] = value
|
||||
}
|
||||
|
||||
groups := make([]string, 0, len(groupBuckets))
|
||||
for group := range groupBuckets {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
sort.Strings(groups)
|
||||
|
||||
results := make([]GroupResult, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
buckets := groupBuckets[group]
|
||||
timestamps := make([]int64, 0, len(buckets))
|
||||
for ts := range buckets {
|
||||
timestamps = append(timestamps, ts)
|
||||
}
|
||||
sort.Slice(timestamps, func(i, j int) bool {
|
||||
return timestamps[i] < timestamps[j]
|
||||
})
|
||||
|
||||
total := counters{}
|
||||
series := make([]BucketPoint, 0, len(timestamps))
|
||||
for _, ts := range timestamps {
|
||||
value := buckets[ts]
|
||||
total.requestCount += value.requestCount
|
||||
total.successCount += value.successCount
|
||||
total.totalLatencyMs += value.totalLatencyMs
|
||||
total.ttftSumMs += value.ttftSumMs
|
||||
total.ttftCount += value.ttftCount
|
||||
total.outputTokens += value.outputTokens
|
||||
total.generationMs += value.generationMs
|
||||
series = append(series, bucketPoint(ts, value))
|
||||
}
|
||||
|
||||
results = append(results, GroupResult{
|
||||
Group: group,
|
||||
AvgTtftMs: avg(total.ttftSumMs, total.ttftCount),
|
||||
AvgLatencyMs: avg(total.totalLatencyMs, total.requestCount),
|
||||
SuccessRate: successRate(total),
|
||||
AvgTps: avgTps(total),
|
||||
Series: series,
|
||||
})
|
||||
}
|
||||
|
||||
return QueryResult{
|
||||
ModelName: modelName,
|
||||
SeriesSchema: seriesSchema,
|
||||
Groups: results,
|
||||
}
|
||||
}
|
||||
|
||||
func bucketPoint(ts int64, value counters) BucketPoint {
|
||||
return BucketPoint{
|
||||
Ts: ts,
|
||||
AvgTtftMs: avg(value.ttftSumMs, value.ttftCount),
|
||||
AvgLatencyMs: avg(value.totalLatencyMs, value.requestCount),
|
||||
SuccessRate: successRate(value),
|
||||
AvgTps: avgTps(value),
|
||||
}
|
||||
}
|
||||
|
||||
func avg(sum int64, count int64) int64 {
|
||||
if count <= 0 {
|
||||
return 0
|
||||
}
|
||||
return sum / count
|
||||
}
|
||||
|
||||
func successRate(value counters) float64 {
|
||||
if value.requestCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(value.successCount) / float64(value.requestCount) * 100
|
||||
}
|
||||
|
||||
func avgTps(value counters) float64 {
|
||||
if value.outputTokens <= 0 || value.generationMs <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(value.outputTokens) / (float64(value.generationMs) / 1000)
|
||||
}
|
||||
|
||||
func recordRedis(key bucketKey, sample Sample) {
|
||||
if !common.RedisEnabled || common.RDB == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
redisKey := redisBucketKey(key)
|
||||
pipe := common.RDB.TxPipeline()
|
||||
pipe.HIncrBy(ctx, redisKey, "req", 1)
|
||||
if sample.Success {
|
||||
pipe.HIncrBy(ctx, redisKey, "ok", 1)
|
||||
}
|
||||
if sample.LatencyMs > 0 {
|
||||
pipe.HIncrBy(ctx, redisKey, "lat", sample.LatencyMs)
|
||||
}
|
||||
if sample.HasTtft && sample.TtftMs >= 0 {
|
||||
pipe.HIncrBy(ctx, redisKey, "ttft", sample.TtftMs)
|
||||
pipe.HIncrBy(ctx, redisKey, "ttft_n", 1)
|
||||
}
|
||||
if sample.OutputTokens > 0 && sample.GenerationMs > 0 {
|
||||
pipe.HIncrBy(ctx, redisKey, "out", sample.OutputTokens)
|
||||
pipe.HIncrBy(ctx, redisKey, "gen_ms", sample.GenerationMs)
|
||||
}
|
||||
pipe.Expire(ctx, redisKey, time.Hour)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
}
|
||||
|
||||
func mergeRedisActiveBuckets(merged map[bucketKey]counters, params QueryParams, startTs int64, endTs int64) {
|
||||
if !common.RedisEnabled || common.RDB == nil || params.Model == "" || params.Group == "" {
|
||||
return
|
||||
}
|
||||
active := bucketStart(time.Now().Unix())
|
||||
if active < startTs || active > endTs {
|
||||
return
|
||||
}
|
||||
key := bucketKey{model: params.Model, group: params.Group, bucketTs: active}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
values, err := common.RDB.HGetAll(ctx, redisBucketKey(key)).Result()
|
||||
if err != nil || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
mergeCounters(merged, key, redisCounters(values))
|
||||
}
|
||||
|
||||
func redisBucketKey(key bucketKey) string {
|
||||
return fmt.Sprintf("perf:%s:%s:%d", key.model, key.group, key.bucketTs)
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package perfmetrics
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type Store interface {
|
||||
Record(sample Sample)
|
||||
Query(params QueryParams) (QueryResult, error)
|
||||
}
|
||||
|
||||
type Sample struct {
|
||||
Model string
|
||||
Group string
|
||||
LatencyMs int64
|
||||
TtftMs int64
|
||||
HasTtft bool
|
||||
Success bool
|
||||
OutputTokens int64
|
||||
GenerationMs int64
|
||||
}
|
||||
|
||||
type QueryParams struct {
|
||||
Model string
|
||||
Group string
|
||||
Hours int
|
||||
}
|
||||
|
||||
type BucketPoint struct {
|
||||
Ts int64 `json:"ts"`
|
||||
AvgTtftMs int64 `json:"avg_ttft_ms"`
|
||||
AvgLatencyMs int64 `json:"avg_latency_ms"`
|
||||
SuccessRate float64 `json:"success_rate"`
|
||||
AvgTps float64 `json:"avg_tps"`
|
||||
}
|
||||
|
||||
type GroupResult struct {
|
||||
Group string `json:"group"`
|
||||
AvgTtftMs int64 `json:"avg_ttft_ms"`
|
||||
AvgLatencyMs int64 `json:"avg_latency_ms"`
|
||||
SuccessRate float64 `json:"success_rate"`
|
||||
AvgTps float64 `json:"avg_tps"`
|
||||
Series []BucketPoint `json:"series"`
|
||||
}
|
||||
|
||||
type QueryResult struct {
|
||||
ModelName string `json:"model_name"`
|
||||
SeriesSchema string `json:"series_schema"`
|
||||
Groups []GroupResult `json:"groups"`
|
||||
}
|
||||
|
||||
type ModelSummary struct {
|
||||
ModelName string `json:"model_name"`
|
||||
AvgLatencyMs int64 `json:"avg_latency_ms"`
|
||||
SuccessRate float64 `json:"success_rate"`
|
||||
AvgTps float64 `json:"avg_tps"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
}
|
||||
|
||||
type SummaryAllResult struct {
|
||||
Models []ModelSummary `json:"models"`
|
||||
}
|
||||
|
||||
type bucketKey struct {
|
||||
model string
|
||||
group string
|
||||
bucketTs int64
|
||||
}
|
||||
|
||||
type counters struct {
|
||||
requestCount int64
|
||||
successCount int64
|
||||
totalLatencyMs int64
|
||||
ttftSumMs int64
|
||||
ttftCount int64
|
||||
outputTokens int64
|
||||
generationMs int64
|
||||
}
|
||||
|
||||
type atomicBucket struct {
|
||||
requestCount atomic.Int64
|
||||
successCount atomic.Int64
|
||||
totalLatencyMs atomic.Int64
|
||||
ttftSumMs atomic.Int64
|
||||
ttftCount atomic.Int64
|
||||
outputTokens atomic.Int64
|
||||
generationMs atomic.Int64
|
||||
}
|
||||
|
||||
func (b *atomicBucket) add(sample Sample) {
|
||||
b.requestCount.Add(1)
|
||||
if sample.Success {
|
||||
b.successCount.Add(1)
|
||||
}
|
||||
if sample.LatencyMs > 0 {
|
||||
b.totalLatencyMs.Add(sample.LatencyMs)
|
||||
}
|
||||
if sample.HasTtft && sample.TtftMs >= 0 {
|
||||
b.ttftSumMs.Add(sample.TtftMs)
|
||||
b.ttftCount.Add(1)
|
||||
}
|
||||
if sample.OutputTokens > 0 && sample.GenerationMs > 0 {
|
||||
b.outputTokens.Add(sample.OutputTokens)
|
||||
b.generationMs.Add(sample.GenerationMs)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *atomicBucket) snapshot() counters {
|
||||
return counters{
|
||||
requestCount: b.requestCount.Load(),
|
||||
successCount: b.successCount.Load(),
|
||||
totalLatencyMs: b.totalLatencyMs.Load(),
|
||||
ttftSumMs: b.ttftSumMs.Load(),
|
||||
ttftCount: b.ttftCount.Load(),
|
||||
outputTokens: b.outputTokens.Load(),
|
||||
generationMs: b.generationMs.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *atomicBucket) drain() counters {
|
||||
return counters{
|
||||
requestCount: b.requestCount.Swap(0),
|
||||
successCount: b.successCount.Swap(0),
|
||||
totalLatencyMs: b.totalLatencyMs.Swap(0),
|
||||
ttftSumMs: b.ttftSumMs.Swap(0),
|
||||
ttftCount: b.ttftCount.Swap(0),
|
||||
outputTokens: b.outputTokens.Swap(0),
|
||||
generationMs: b.generationMs.Swap(0),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *atomicBucket) addCounters(c counters) {
|
||||
if c.requestCount != 0 {
|
||||
b.requestCount.Add(c.requestCount)
|
||||
}
|
||||
if c.successCount != 0 {
|
||||
b.successCount.Add(c.successCount)
|
||||
}
|
||||
if c.totalLatencyMs != 0 {
|
||||
b.totalLatencyMs.Add(c.totalLatencyMs)
|
||||
}
|
||||
if c.ttftSumMs != 0 {
|
||||
b.ttftSumMs.Add(c.ttftSumMs)
|
||||
}
|
||||
if c.ttftCount != 0 {
|
||||
b.ttftCount.Add(c.ttftCount)
|
||||
}
|
||||
if c.outputTokens != 0 {
|
||||
b.outputTokens.Add(c.outputTokens)
|
||||
}
|
||||
if c.generationMs != 0 {
|
||||
b.generationMs.Add(c.generationMs)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
@@ -18,12 +19,16 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
IsSyncImageModel bool
|
||||
}
|
||||
|
||||
const aliAnthropicMessagesModelsEnv = "ALI_ANTHROPIC_MESSAGES_MODELS"
|
||||
const defaultAliAnthropicMessagesModels = "qwen,deepseek-v4,kimi,glm,minimax-m"
|
||||
|
||||
/*
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
@@ -32,8 +37,22 @@ type Adaptor struct {
|
||||
}
|
||||
*/
|
||||
func supportsAliAnthropicMessages(modelName string) bool {
|
||||
// Only models with the "qwen" designation can use the Claude-compatible interface; others require conversion.
|
||||
return strings.Contains(strings.ToLower(modelName), "qwen")
|
||||
normalizedModelName := strings.ToLower(strings.TrimSpace(modelName))
|
||||
if normalizedModelName == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return lo.SomeBy(aliAnthropicMessagesModelPatterns(), func(pattern string) bool {
|
||||
return strings.Contains(normalizedModelName, pattern)
|
||||
})
|
||||
}
|
||||
|
||||
func aliAnthropicMessagesModelPatterns() []string {
|
||||
configuredModels := common.GetEnvOrDefaultString(aliAnthropicMessagesModelsEnv, defaultAliAnthropicMessagesModels)
|
||||
return lo.FilterMap(strings.Split(configuredModels, ","), func(item string, _ int) (string, bool) {
|
||||
pattern := strings.ToLower(strings.TrimSpace(item))
|
||||
return pattern, pattern != ""
|
||||
})
|
||||
}
|
||||
|
||||
var syncModels = []string{
|
||||
|
||||
@@ -567,12 +567,14 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe
|
||||
}
|
||||
choice.SetStringContent(responseText)
|
||||
if len(responseThinking) > 0 {
|
||||
choice.ReasoningContent = responseThinking
|
||||
choice.ReasoningContent = &responseThinking
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
choice.Message.SetToolCalls(tools)
|
||||
}
|
||||
choice.Message.ReasoningContent = thinkingContent
|
||||
if thinkingContent != "" {
|
||||
choice.Message.ReasoningContent = &thinkingContent
|
||||
}
|
||||
fullTextResponse.Model = claudeResponse.Model
|
||||
choices = append(choices, choice)
|
||||
fullTextResponse.Choices = choices
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -27,7 +29,18 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claudeRequest, ok := convertedRequest.(*dto.ClaudeRequest)
|
||||
if !ok {
|
||||
return convertedRequest, nil
|
||||
}
|
||||
if err := applyDeepSeekV4ClaudeThinkingSuffix(info, claudeRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claudeRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -71,9 +84,71 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if err := applyDeepSeekV4OpenAIThinkingSuffix(info, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func applyDeepSeekV4OpenAIThinkingSuffix(info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) error {
|
||||
modelName := request.Model
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
modelName = info.UpstreamModelName
|
||||
}
|
||||
baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
thinking, err := common.Marshal(map[string]string{
|
||||
"type": thinkingType,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling thinking: %w", err)
|
||||
}
|
||||
request.Model = baseModel
|
||||
request.THINKING = thinking
|
||||
request.ReasoningEffort = effort
|
||||
if info != nil {
|
||||
if info.ChannelMeta != nil {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
info.ReasoningEffort = effort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyDeepSeekV4ClaudeThinkingSuffix(info *relaycommon.RelayInfo, request *dto.ClaudeRequest) error {
|
||||
modelName := request.Model
|
||||
if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" {
|
||||
modelName = info.UpstreamModelName
|
||||
}
|
||||
baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
request.Model = baseModel
|
||||
request.Thinking = &dto.Thinking{Type: thinkingType}
|
||||
if effort == "" {
|
||||
request.OutputConfig = nil
|
||||
} else {
|
||||
outputConfig, err := common.Marshal(map[string]string{
|
||||
"effort": effort,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling output_config: %w", err)
|
||||
}
|
||||
request.OutputConfig = outputConfig
|
||||
}
|
||||
if info != nil {
|
||||
if info.ChannelMeta != nil {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
info.ReasoningEffort = effort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package deepseek
|
||||
|
||||
var ModelList = []string{
|
||||
"deepseek-chat", "deepseek-reasoner",
|
||||
"deepseek-v4-flash", "deepseek-v4-flash-none", "deepseek-v4-flash-max",
|
||||
"deepseek-v4-pro", "deepseek-v4-pro-none", "deepseek-v4-pro-max",
|
||||
}
|
||||
|
||||
var ChannelName = "deepseek"
|
||||
|
||||
@@ -1097,7 +1097,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
|
||||
toolCalls = append(toolCalls, *call)
|
||||
}
|
||||
} else if part.Thought {
|
||||
choice.Message.ReasoningContent = part.Text
|
||||
choice.Message.ReasoningContent = &part.Text
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
|
||||
|
||||
@@ -273,7 +273,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
|
||||
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
|
||||
if rc := reasoningBuilder.String(); rc != "" {
|
||||
msg.ReasoningContent = rc
|
||||
msg.ReasoningContent = &rc
|
||||
}
|
||||
full := dto.OpenAITextResponse{
|
||||
Id: common.GetUUID(),
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
@@ -39,21 +40,6 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
|
||||
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
|
||||
// minimal effort only available in gpt-5
|
||||
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"}
|
||||
for _, suffix := range effortSuffixes {
|
||||
if strings.HasSuffix(model, suffix) {
|
||||
effort := strings.TrimPrefix(suffix, "-")
|
||||
originModel := strings.TrimSuffix(model, suffix)
|
||||
return effort, originModel
|
||||
}
|
||||
}
|
||||
return "", model
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
@@ -342,7 +328,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
if effort != "" {
|
||||
request.ReasoningEffort = effort
|
||||
info.UpstreamModelName = originModel
|
||||
@@ -440,6 +426,9 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
if isJSONRequest(c) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
@@ -565,6 +554,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
}
|
||||
|
||||
func isJSONRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json")
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type based on the file extension
|
||||
func detectImageMimeType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
@@ -587,7 +583,7 @@ func detectImageMimeType(filename string) string {
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
if request.Reasoning == nil {
|
||||
request.Reasoning = &dto.Reasoning{
|
||||
@@ -607,7 +603,7 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
(info.RelayMode == relayconstant.RelayModeImagesEdits && !isJSONRequest(c)) {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
|
||||
@@ -408,7 +408,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
toolCallNameByID[callID] = name
|
||||
}
|
||||
|
||||
newArgs := streamResp.Item.Arguments
|
||||
newArgs := streamResp.Item.ArgumentsString()
|
||||
prevArgs := toolCallArgsByID[callID]
|
||||
argsDelta := ""
|
||||
if newArgs != "" {
|
||||
|
||||
@@ -245,7 +245,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
completionTokens := simpleResponse.Usage.CompletionTokens
|
||||
if completionTokens == 0 {
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.GetReasoningContent(), info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,20 +95,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
return vertexcore.BuildGoogleModelURL(a.baseURL, vertexcore.DefaultAPIVersion, adc.ProjectID, region, modelName, "predictLongRunning"), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
@@ -238,6 +225,22 @@ func (a *TaskAdaptor) GetModelList() []string {
|
||||
}
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
func buildFetchOperationURL(baseURL, upstreamName string) (string, error) {
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
return "", fmt.Errorf("cannot extract model from operation name")
|
||||
}
|
||||
if strings.TrimSpace(project) == "" {
|
||||
return "", fmt.Errorf("cannot extract project from operation name")
|
||||
}
|
||||
return vertexcore.BuildGoogleModelURL(baseURL, vertexcore.DefaultAPIVersion, project, region, modelName, "fetchPredictOperation"), nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
@@ -248,20 +251,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
url, err := buildFetchOperationURL(baseUrl, upstreamName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := fetchOperationPayload{OperationName: upstreamName}
|
||||
data, err := common.Marshal(payload)
|
||||
|
||||
@@ -134,47 +134,11 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
|
||||
} else if a.RequestMode == RequestModeOpenSource {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return BuildOpenSourceChatCompletionsURL(info.ChannelBaseUrl, adc.ProjectID, region), nil
|
||||
}
|
||||
} else {
|
||||
var keyPrefix string
|
||||
@@ -183,20 +147,17 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
|
||||
} else {
|
||||
keyPrefix = "?"
|
||||
}
|
||||
if region == "global" {
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
"%s%skey=%s",
|
||||
BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
|
||||
keyPrefix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
"%s%skey=%s",
|
||||
BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
|
||||
keyPrefix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user