说明
- 用户输入多个“信息”
- 大语言模型将“信息”进行处理,转成数组;(一维张量,向量)
- 通过余弦相似度等相关算法,计算两个向量是否相似
Ollama接口步骤
- 安装 Ollama: https://ollama.ai/
- 下载模型: ollama pull nomic-embed-text
- Ollama 默认运行在 http://localhost:11434
推荐的嵌入模型:
- nomic-embed-text: 768维,效果好,速度快
- mxbai-embed-large: 1024维,效果更好
- bge-m3: 多语言支持

springboot中调用本地模型
@Test@Disabled("需要本地运行 Ollama 服务")public void testOllamaEmbedding() {// Ollama API 地址String apiUrl = "http://localhost:11434/api/embeddings";String apiKey = ""; // Ollama 本地不需要 keyString model = "nomic-embed-text"; // 或 mxbai-embed-largeEmbeddingClient client = new EmbeddingClientImpl(apiUrl, apiKey);// 水果库List<Fruit> fruits = Arrays.asList(new Fruit("红富士苹果", "红色 甜 脆 苹果 新鲜"), new Fruit("青苹果", "绿色 酸 脆 苹果 清爽"),new Fruit("金帅苹果", "黄色 甜 软 苹果"), new Fruit("香蕉", "黄色 甜 软 香蕉 热带水果"), new Fruit("草莓", "红色 甜 小 草莓 多汁 浆果"),new Fruit("西瓜", "绿色外皮 红色果肉 甜 大 西瓜 多汁 夏天"), new Fruit("葡萄", "紫色 甜 小 葡萄 多汁 成串"));// 为每个水果生成嵌入向量for (Fruit fruit : fruits) {fruit.embedding = client.getEmbeddingVector(model, fruit.description);}// 用户搜索String query = "红色的甜水果";double[] queryVector = client.getEmbeddingVector(model, query);System.out.println("搜索: \"" + query + "\"");System.out.println("向量维度: " + queryVector.length);System.out.println();// 按相似度排序fruits.sort(Comparator.comparingDouble(f -> -cosineSimilarity(queryVector, f.embedding)));// 输出结果System.out.println("搜索结果(按相似度排序):");for (Fruit f : fruits) {double sim = cosineSimilarity(queryVector, f.embedding);System.out.printf(" %s (%.4f): %s%n", f.name, sim, f.description);}}/*** 计算两个向量的余弦相似度*/public static double cosineSimilarity(double[] vectorA, double[] vectorB) {if (vectorA.length != vectorB.length) {throw new IllegalArgumentException("向量维度必须相同");}double dotProduct = 0;double normA = 0;double normB = 0;for (int i = 0; i < vectorA.length; i++) {dotProduct += vectorA[i] * vectorB[i];normA += vectorA[i] * vectorA[i];normB += vectorB[i] * vectorB[i];}if (normA == 0 || normB == 0) {return 0;}return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));}
核心方法
@Slf4j
public class EmbeddingClientImpl implements EmbeddingClient {private final RestTemplate restTemplate;private final String address;private final String key;public EmbeddingClientImpl(String address, String key) {PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();connectionManager.setMaxTotal(100);connectionManager.setDefaultMaxPerRoute(20);// 设置请求配置RequestConfig requestConfig = RequestConfig.custom().setConnectionRequestTimeout(Timeout.ofSeconds(30)).setResponseTimeout(Timeout.ofSeconds(300)) // 5分钟响应超时.build();// 使用 HttpClientBuilder 来构建 HttpClientHttpClient httpClient = HttpClientBuilder.create().setConnectionManager(connectionManager).setDefaultRequestConfig(requestConfig).build();// 创建 HttpComponentsClientHttpRequestFactoryHttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);requestFactory.setConnectTimeout(30000); // 30秒连接超时requestFactory.setConnectionRequestTimeout(30000);// 创建 RestTemplate,只使用 StringHttpMessageConverter 避免 Jackson 依赖问题this.restTemplate = new RestTemplate(requestFactory);// 清除默认的消息转换器,只保留字符串转换器this.restTemplate.setMessageConverters(Collections.singletonList(new StringHttpMessageConverter(StandardCharsets.UTF_8)));this.address = address;this.key = key;}@Overridepublic String embedding(String model, String input) {long start = System.currentTimeMillis();String url = address;HttpHeaders headers = new HttpHeaders();headers.setContentType(MediaType.APPLICATION_JSON);headers.setAcceptCharset(Collections.singletonList(StandardCharsets.UTF_8));if (key != null && !key.isEmpty()) {headers.add("Authorization", "Bearer " + key);}// 将 request 转化为 body 字符串JSONObject jsonObject = new JSONObject();jsonObject.put("input", input);jsonObject.put("model", model);String body = jsonObject.toString();log.debug("Embedding Request Body: {}", body);// 请求HttpEntity<String> req = new HttpEntity<>(body, headers);ResponseEntity<String> result = restTemplate.postForEntity(url, req, String.class);if (!result.getStatusCode().equals(HttpStatus.OK)) {throw new RuntimeException("embeddings error, request: " + body + ", response: " + result.getBody());}log.info("embedding cost {} ms", System.currentTimeMillis() - start);return result.getBody();}/*** 获取文本嵌入向量* <p>* 解析 OpenAI 格式的响应,提取 embedding 向量** 响应格式示例: <pre>* {* "object": "list",* "data": [{* "object": "embedding",* "index": 0,* "embedding": [0.0023064255, -0.009327292, ...]* }],* "model": "text-embedding-ada-002",* "usage": {"prompt_tokens": 8, "total_tokens": 8}* }* </pre>* @param model 模型名称* @param input 输入文本* @return 嵌入向量*/@Overridepublic double[] getEmbeddingVector(String model, String input) {String response = embedding(model, input);return parseEmbeddingVector(response);}/*** 解析嵌入向量响应* @param response JSON响应字符串* @return 向量数组*/private double[] parseEmbeddingVector(String response) {try {JSONObject jsonResponse = JSONObject.parseObject(response);// OpenAI 格式if (jsonResponse.containsKey("data")) {JSONArray dataArray = jsonResponse.getJSONArray("data");if (dataArray != null && !dataArray.isEmpty()) {JSONObject firstData = dataArray.getJSONObject(0);JSONArray embeddingArray = firstData.getJSONArray("embedding");return jsonArrayToDoubleArray(embeddingArray);}}// Ollama 格式 (直接返回 embedding 数组)if (jsonResponse.containsKey("embedding")) {JSONArray embeddingArray = jsonResponse.getJSONArray("embedding");return jsonArrayToDoubleArray(embeddingArray);}// 阿里通义格式if (jsonResponse.containsKey("output")) {JSONObject output = jsonResponse.getJSONObject("output");if (output.containsKey("embeddings")) {JSONArray embeddings = output.getJSONArray("embeddings");if (!embeddings.isEmpty()) {JSONObject firstEmbedding = embeddings.getJSONObject(0);JSONArray embeddingArray = firstEmbedding.getJSONArray("embedding");return jsonArrayToDoubleArray(embeddingArray);}}}throw new RuntimeException("无法解析嵌入向量响应: " + response);}catch (Exception e) {log.error("解析嵌入向量失败: {}", response, e);throw new RuntimeException("解析嵌入向量失败", e);}}/*** 将 JSONArray 转换为 double 数组*/private double[] jsonArrayToDoubleArray(JSONArray jsonArray) {double[] result = new double[jsonArray.size()];for (int i = 0; i < jsonArray.size(); i++) {result[i] = jsonArray.getDoubleValue(i);}return result;}}
