Compare commits

...

4 Commits

Author SHA1 Message Date
1808837298@qq.com 9b2cc6add7 feat: Enhance ConvertClaudeRequest method to set request model and handle vertex-specific request conversion 2025-03-17 17:13:33 +08:00
1808837298@qq.com 4f6167243f feat: Update RerankerInfo structure and modify GenRelayInfoRerank function to accept RerankRequest 2025-03-17 16:44:53 +08:00
Calcium-Ion eafbfac6a0 Merge pull request #872 from neotf/main
feat: support AWS Model CrossRegion
2025-03-17 16:18:11 +08:00
neotf ac9bd53098 feat: support AWS Model CrossRegion 2025-03-15 01:42:24 +08:00
8 changed files with 100 additions and 16 deletions
+8 -1
View File
@@ -5,11 +5,18 @@ type RerankRequest struct {
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil {
return false
}
return *r.ReturnDocuments
}
type RerankResponseResult struct {
Document any `json:"document,omitempty"`
Index int `json:"index"`
+2
View File
@@ -21,6 +21,8 @@ type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
+37
View File
@@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-3-sonnet-20240229-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-opus-20240229-v1:0": {
"us": true,
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"us": true,
"ap": true,
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"us": true,
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"us": true,
},
}
var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us",
"eu": "eu",
"ap": "apac",
}
var ChannelName = "aws"
+28
View File
@@ -43,6 +43,28 @@ func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
}
}
func awsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
regionPrefix = parts[0]
}
return regionPrefix
}
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
return exists && regionSet[awsRegionPrefix]
}
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
if !find {
return awsModelId
}
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
@@ -62,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
+8 -1
View File
@@ -39,8 +39,15 @@ type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
} else {
c.Set("request_model", request.Model)
}
vertexClaudeReq := copyRequest(request, anthropicVersion)
return vertexClaudeReq, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
+5 -3
View File
@@ -34,7 +34,8 @@ const (
)
type RerankerInfo struct {
Documents []any
Documents []any
ReturnDocuments bool
}
type RelayInfo struct {
@@ -116,11 +117,12 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info
}
func GenRelayInfoRerank(c *gin.Context, documents []any) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RerankerInfo = &RerankerInfo{
Documents: documents,
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
}
return info
}
+11 -10
View File
@@ -32,19 +32,20 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
jinaRespResults[i] = dto.RerankResponseResult{
respResult := dto.RerankResponseResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
Document: dto.RerankDocument{
Text: document,
},
}
if info.ReturnDocuments {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
respResult.Document = document
}
jinaRespResults[i] = respResult
}
jinaResp = dto.RerankResponse{
Results: jinaRespResults,
+1 -1
View File
@@ -33,7 +33,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest.Documents)
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)