diff --git a/agents-flex-store/agents-flex-store-milvus/src/main/java/com/agentsflex/store/milvus/MilvusVectorStore.java b/agents-flex-store/agents-flex-store-milvus/src/main/java/com/agentsflex/store/milvus/MilvusVectorStore.java new file mode 100644 index 0000000..4fc7a47 --- /dev/null +++ b/agents-flex-store/agents-flex-store-milvus/src/main/java/com/agentsflex/store/milvus/MilvusVectorStore.java @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com). + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.agentsflex.store.milvus; + +import com.agentsflex.core.document.Document; +import com.agentsflex.core.store.DocumentStore; +import com.agentsflex.core.store.SearchWrapper; +import com.agentsflex.core.store.StoreOptions; +import com.agentsflex.core.store.StoreResult; +import com.agentsflex.core.util.Maps; +import com.agentsflex.core.util.StringUtil; +import com.agentsflex.core.util.VectorUtil; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.exception.MilvusClientException; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.GetLoadStateReq; +import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.response.InsertResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +/** + * MilvusVectorStore class provides an interface to interact with Milvus Vector Database. + */ +public class MilvusVectorStore extends DocumentStore { + + private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); + private final MilvusClientV2 client; + private final String defaultCollectionName; + private final MilvusVectorStoreConfig config; + + public MilvusVectorStore(MilvusVectorStoreConfig config) { + ConnectConfig connectConfig = ConnectConfig.builder() + .uri(config.getUri()) + .dbName(config.getDatabaseName()) + .token(config.getToken()) + .username(config.getUsername()) + .password(config.getPassword()) + .build(); + + this.client = new MilvusClientV2(connectConfig); + this.defaultCollectionName = config.getDefaultCollectionName(); + this.config = config; + } + + @Override + public StoreResult storeInternal(List documents, StoreOptions options) { + List data = new ArrayList<>(); + for (Document doc : documents) { + JSONObject dict = new JSONObject(); + dict.put("id", String.valueOf(doc.getId())); + dict.put("content", doc.getContent()); + dict.put("vector", VectorUtil.toFloatList(doc.getVector())); + + Map metadatas = doc.getMetadataMap(); + JSONObject jsonObject = JSON.parseObject(JSON.toJSONBytes(metadatas == null ? Collections.EMPTY_MAP : metadatas)); + dict.put("metadata", jsonObject); + data.add(dict); + } + + String collectionName = options.getCollectionNameOrDefault(defaultCollectionName); + InsertReq.InsertReqBuilder builder = InsertReq.builder(); + if (StringUtil.hasText(options.getPartitionName())) { + builder.partitionName(options.getPartitionName()); + } + InsertReq insertReq = builder + .collectionName(collectionName) + .data(data) + .build(); + try { + InsertResp insertResp = client.insert(insertReq); + } catch (MilvusClientException e) { + if (e.getMessage() != null && e.getMessage().contains("collection not found") + && config.isAutoCreateCollection() + && options.getMetadata("forInternal") == null) { + + Boolean success = createCollection(collectionName); + if (success != null && success) { + //store + options.addMetadata("forInternal", true); + storeInternal(documents, options); + } + } else { + return StoreResult.fail(); + } + } + + return StoreResult.successWithIds(documents); + } + + + private Boolean createCollection(String collectionName) { + List fieldSchemaList = new ArrayList<>(); + + //id + CreateCollectionReq.FieldSchema id = CreateCollectionReq.FieldSchema.builder() + .name("id") + .dataType(DataType.VarChar) + .maxLength(36) + .isPrimaryKey(true) + .autoID(false) + .build(); + fieldSchemaList.add(id); + + //content + CreateCollectionReq.FieldSchema content = CreateCollectionReq.FieldSchema.builder() + .name("content") + .dataType(DataType.VarChar) + .maxLength(65535) + .build(); + fieldSchemaList.add(content); + + //metadata + CreateCollectionReq.FieldSchema metadata = CreateCollectionReq.FieldSchema.builder() + .name("metadata") + .dataType(DataType.JSON) + .build(); + fieldSchemaList.add(metadata); + + //vector + CreateCollectionReq.FieldSchema vector = CreateCollectionReq.FieldSchema.builder() + .name("vector") + .dataType(DataType.FloatVector) + .dimension(this.getEmbeddingModel().dimensions()) + .build(); + fieldSchemaList.add(vector); + + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema + .builder() + .fieldSchemaList(fieldSchemaList) + .build(); + + + List indexParams = new ArrayList<>(); + IndexParam vectorIndex = IndexParam.builder().fieldName("vector") + .indexType(IndexParam.IndexType.IVF_FLAT) + .metricType(IndexParam.MetricType.COSINE) + .indexName("vector") + .extraParams(Maps.of("nlist", 1024)) + .build(); + indexParams.add(vectorIndex); + + + CreateCollectionReq createCollectionReq = CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(collectionSchema) + .primaryFieldName("id") + .vectorFieldName("vector") + .description("Agents Flex Vector Store") + .indexParams(indexParams) + .build(); + + client.createCollection(createCollectionReq); + + GetLoadStateReq quickSetupLoadStateReq = GetLoadStateReq.builder() + .collectionName(collectionName) + .build(); + + return client.getLoadState(quickSetupLoadStateReq); + } + + @Override + public StoreResult deleteInternal(Collection ids, StoreOptions options) { + + DeleteReq.DeleteReqBuilder builder = DeleteReq.builder(); + if (StringUtil.hasText(options.getPartitionName())) { + builder.partitionName(options.getPartitionName()); + } + + DeleteReq deleteReq = builder + .collectionName(options.getCollectionNameOrDefault(defaultCollectionName)) + .ids(new ArrayList<>(ids)) + .build(); + + try { + client.delete(deleteReq); + } catch (Exception e) { + logger.error("delete document error: " + e, e); + return StoreResult.fail(); + } + + return StoreResult.success(); + + } + + @Override + public List searchInternal(SearchWrapper searchWrapper, StoreOptions options) { + List outputFields = searchWrapper.isOutputVector() + ? Arrays.asList("id", "vector", "content", "metadata") + : Arrays.asList("id", "content", "metadata"); + + // 判断是否为向量查询 + if (searchWrapper.getVector() != null && searchWrapper.getVector().length > 0) { + // 向量查询 - 使用SearchReq + SearchReq.SearchReqBuilder builder = SearchReq.builder(); + if (StringUtil.hasText(options.getPartitionName())) { + builder.partitionNames(options.getPartitionNamesOrEmpty()); + } + + SearchReq searchReq = builder + .collectionName(options.getCollectionNameOrDefault(defaultCollectionName)) + .consistencyLevel(ConsistencyLevel.STRONG) + .outputFields(outputFields) + .topK(searchWrapper.getMaxResults()) + .annsField("vector") + .data(Collections.singletonList(VectorUtil.toFloatList(searchWrapper.getVector()))) + .filter(searchWrapper.toFilterExpression(MilvusExpressionAdaptor.DEFAULT)) + .build(); + + try { + SearchResp resp = client.search(searchReq); + // Parse and convert search results to Document list + List> results = resp.getSearchResults(); + List documents = new ArrayList<>(); + for (List resultList : results) { + for (SearchResp.SearchResult result : resultList) { + Map entity = result.getEntity(); + if (entity == null || entity.isEmpty()) { + continue; + } + + Document doc = new Document(); + doc.setId(result.getId()); + + Object vectorObj = entity.get("vector"); + if (vectorObj instanceof List) { + //noinspection unchecked + doc.setVector(VectorUtil.convertToVector((List) vectorObj)); + } + + doc.setContent((String) entity.get("content")); + + // 根据 metric 类型计算相似度 + Float distance = result.getDistance(); + if (distance != null) { + // 根据 https://milvus.io/docs/zh/single-vector-search.md#Single-Vector-Search + // 适用的度量类型和相应的距离范围表 + // 当 metricType 类 COSINE 时,数值越大,表示相似度越高。相似度即为 distance 值; + + // distance 的范围是 [-1, 1], 需要统一转换为 [0, 1] + double score = (distance + 1) / 2; + doc.setScore(score); + } + + JSONObject object = (JSONObject) entity.get("metadata"); + doc.addMetadata(object); + documents.add(doc); + } + } + return documents; + } catch (Exception e) { + logger.error("Error searching in Milvus", e); + return Collections.emptyList(); + } + } else { + // 非向量查询 - 使用QueryReq + QueryReq.QueryReqBuilder builder = QueryReq.builder(); + if (StringUtil.hasText(options.getPartitionName())) { + builder.partitionNames(options.getPartitionNamesOrEmpty()); + } + + QueryReq queryReq = builder + .collectionName(options.getCollectionNameOrDefault(defaultCollectionName)) + .consistencyLevel(ConsistencyLevel.STRONG) + .outputFields(outputFields) + .filter(searchWrapper.toFilterExpression(MilvusExpressionAdaptor.DEFAULT)) + .build(); + + try { + QueryResp resp = client.query(queryReq); + List results = resp.getQueryResults(); + List documents = new ArrayList<>(); + + for (QueryResp.QueryResult result : results) { + Map entity = result.getEntity(); + if (entity == null || entity.isEmpty()) { + continue; + } + + Document doc = new Document(); + doc.setId(result.getEntity().get("id")); + + Object vectorObj = entity.get("vector"); + if (vectorObj instanceof List) { + //noinspection unchecked + doc.setVector(VectorUtil.convertToVector((List) vectorObj)); + } + + doc.setContent((String) entity.get("content")); + + JSONObject object = (JSONObject) entity.get("metadata"); + doc.addMetadata(object); + documents.add(doc); + } + return documents; + } catch (Exception e) { + logger.error("Error querying in Milvus", e); + return Collections.emptyList(); + } + } + } + + @Override + public StoreResult updateInternal(List documents, StoreOptions options) { + if (documents == null || documents.isEmpty()) { + return StoreResult.success(); + } + List data = new ArrayList<>(); + for (Document doc : documents) { + JSONObject dict = new JSONObject(); + + dict.put("id", String.valueOf(doc.getId())); + dict.put("content", doc.getContent()); + dict.put("vector", VectorUtil.toFloatList(doc.getVector())); + + Map metadatas = doc.getMetadataMap(); + JSONObject jsonObject = JSON.parseObject(JSON.toJSONBytes(metadatas == null ? Collections.EMPTY_MAP : metadatas)); + dict.put("metadata", jsonObject); + data.add(dict); + } + + UpsertReq upsertReq = UpsertReq.builder() + .collectionName(options.getCollectionNameOrDefault(defaultCollectionName)) + .partitionName(options.getPartitionName()) + .data(data) + .build(); + client.upsert(upsertReq); + return StoreResult.successWithIds(documents); + } + + + public MilvusClientV2 getClient() { + return client; + } +}