This commit is contained in:
2025-08-27 19:57:50 +08:00
parent c55985d316
commit 8bd7c89d28

View File

@@ -0,0 +1,191 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.store.qdrant;
import com.agentsflex.core.document.Document;
import com.agentsflex.core.store.DocumentStore;
import com.agentsflex.core.store.SearchWrapper;
import com.agentsflex.core.store.StoreOptions;
import com.agentsflex.core.store.StoreResult;
import com.agentsflex.core.util.CollectionUtil;
import com.agentsflex.core.util.StringUtil;
import com.agentsflex.core.util.VectorUtil;
import io.grpc.Grpc;
import io.grpc.TlsChannelCredentials;
import io.qdrant.client.QdrantClient;
import io.grpc.ManagedChannel;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import io.qdrant.client.grpc.Points.Filter;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointStruct;
import io.qdrant.client.grpc.Points.QueryPoints;
import io.qdrant.client.grpc.Points.ScoredPoint;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import static io.qdrant.client.ConditionFactory.matchKeyword;
import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.QueryFactory.nearest;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorsFactory.vectors;
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
public class QdrantVectorStore extends DocumentStore {
private final QdrantVectorStoreConfig config;
private final QdrantClient client;
private final String defaultCollectionName;
private boolean isCreateCollection = false;
public QdrantVectorStore(QdrantVectorStoreConfig config) throws IOException {
this.config = config;
this.defaultCollectionName = config.getDefaultCollectionName();
String uri = config.getUri();
int port = 6334;
QdrantGrpcClient.Builder builder;
if (StringUtil.hasText(config.getCaPath())) {
ManagedChannel channel = Grpc.newChannelBuilder(
uri,
TlsChannelCredentials.newBuilder().trustManager(new File(config.getCaPath())).build()
).build();
builder = QdrantGrpcClient.newBuilder(channel, true);
} else {
if (uri.contains(":")) {
uri = uri.split(":")[0];
port = Integer.parseInt(uri.split(":")[1]);
}
builder = QdrantGrpcClient.newBuilder(uri, port, false);
}
if (StringUtil.hasText(config.getApiKey())) {
builder.withApiKey(config.getApiKey());
}
this.client = new QdrantClient(builder.build());
}
@Override
public StoreResult storeInternal(List<Document> documents, StoreOptions options) {
List<PointStruct> points = new ArrayList<>();
int size = 1024;
for (Document doc : documents) {
size = doc.getVector().length;
Map<String, JsonWithInt.Value> payload = new HashMap<>();
payload.put("content", value(doc.getContent()));
points.add(PointStruct.newBuilder()
.setId(id(Long.parseLong(doc.getId().toString())))
.setVectors(vectors(VectorUtil.toFloatArray(doc.getVector())))
.putAllPayload(payload)
.build());
}
try {
String collectionName = options.getCollectionNameOrDefault(defaultCollectionName);
if (config.isAutoCreateCollection() && !isCreateCollection) {
Boolean exists = client.collectionExistsAsync(collectionName).get();
if (!exists) {
client.createCollectionAsync(collectionName, Collections.VectorParams.newBuilder()
.setDistance(Collections.Distance.Cosine)
.setSize(size)
.build())
.get();
}
} else {
isCreateCollection = true;
}
if (CollectionUtil.hasItems(points)) {
client.upsertAsync(collectionName, points).get();
}
return StoreResult.successWithIds(documents);
} catch (Exception e) {
return StoreResult.fail();
}
}
@Override
public StoreResult deleteInternal(Collection<?> ids, StoreOptions options) {
try {
String collectionName = options.getCollectionNameOrDefault(defaultCollectionName);
List<PointId> pointIds = ids.stream()
.map(id -> id((Long) id))
.collect(Collectors.toList());
client.deleteAsync(collectionName, pointIds).get();
return StoreResult.success();
} catch (Exception e) {
return StoreResult.fail();
}
}
@Override
public StoreResult updateInternal(List<Document> documents, StoreOptions options) {
try {
List<PointStruct> points = new ArrayList<>();
for (Document doc : documents) {
Map<String, JsonWithInt.Value> payload = new HashMap<>();
payload.put("content", value(doc.getContent()));
points.add(PointStruct.newBuilder()
.setId(id(Long.parseLong(doc.getId().toString())))
.setVectors(vectors(VectorUtil.toFloatArray(doc.getVector())))
.putAllPayload(payload)
.build());
}
String collectionName = options.getCollectionNameOrDefault(defaultCollectionName);
if (CollectionUtil.hasItems(points)) {
client.upsertAsync(collectionName, points).get();
}
return StoreResult.successWithIds(documents);
} catch (Exception e) {
return StoreResult.fail();
}
}
@Override
public List<Document> searchInternal(SearchWrapper wrapper, StoreOptions options) {
List<Document> documents = new ArrayList<>();
try {
String collectionName = options.getCollectionNameOrDefault(defaultCollectionName);
QueryPoints.Builder query = QueryPoints.newBuilder()
.setCollectionName(collectionName)
.setLimit(wrapper.getMaxResults())
.setWithVectors(Points.WithVectorsSelector.newBuilder().setEnable(true).build())
.setWithPayload(enable(true));
if (wrapper.getVector() != null) {
query.setQuery(nearest(VectorUtil.toFloatArray(wrapper.getVector())));
}
if (StringUtil.hasText(wrapper.getText())) {
query.setFilter(Filter.newBuilder().addMust(matchKeyword("content", wrapper.getText())));
}
List<ScoredPoint> data = client.queryAsync(query.build()).get();
for (ScoredPoint point : data) {
Document doc = new Document();
doc.setId(point.getId().getNum());
doc.setVector(VectorUtil.convertToVector(point.getVectors().getVector().getDataList()));
doc.setContent(point.getPayloadMap().get("content").getStringValue());
documents.add(doc);
}
return documents;
} catch (Exception e) {
return documents;
}
}
public QdrantClient getClient() {
return client;
}
}