소스 검색

Generate stock info by LLM inference (#6)

Daniel Bohry 8 달 전
부모
커밋
d1e76a07ee

+ 18 - 5
src/main/java/com/danielbohry/stocks/api/stock/StockController.java

@@ -1,6 +1,8 @@
 package com.danielbohry.stocks.api.stock;
 
+import com.danielbohry.stocks.context.UserContextHolder;
 import com.danielbohry.stocks.domain.Quote;
+import com.danielbohry.stocks.service.StockInfoService;
 import com.danielbohry.stocks.service.StockService;
 import io.swagger.v3.oas.annotations.Hidden;
 import lombok.AllArgsConstructor;
@@ -26,6 +28,7 @@ import static java.time.LocalDateTime.now;
 public class StockController {
 
     private final StockService service;
+    private final StockInfoService stockInfoService;
 
     @GetMapping
     public ResponseEntity<List<Quote>> find(@RequestParam(value = "q", required = false) String query) {
@@ -51,9 +54,9 @@ public class StockController {
 
         try (BufferedReader reader = new BufferedReader(new InputStreamReader(file.getInputStream(), UTF_8))) {
             List<Quote> quotes = reader.lines()
-                    .map(r -> convert(r, currency))
-                    .filter(Objects::nonNull)
-                    .toList();
+                .map(r -> convert(r, currency))
+                .filter(Objects::nonNull)
+                .toList();
 
             List<Quote> response = service.update(quotes);
 
@@ -64,6 +67,16 @@ public class StockController {
         }
     }
 
+    @Hidden
+    @PostMapping("/generate")
+    public ResponseEntity<Void> generateStockInfo() {
+        if (UserContextHolder.isAdmin()) {
+            stockInfoService.generate();
+        }
+
+        return ResponseEntity.ok().build();
+    }
+
     private Quote convert(String input, String currency) {
         String[] value = input.split(",");
 
@@ -72,8 +85,8 @@ public class StockController {
         if (value[3] != null && !value[3].equals("#N/A")) {
             BigDecimal price = new BigDecimal(value[3]);
             return price.compareTo(BigDecimal.ZERO) > 0
-                    ? new Quote(value[0], value[1], currency, price, now())
-                    : null;
+                ? new Quote(value[0], value[1], currency, price, now())
+                : null;
         } else if (value[0] != null && Objects.equals(value[3], "#N/A")) {
             return null;
         } else {

+ 2 - 0
src/main/java/com/danielbohry/stocks/client/AuthClient.java

@@ -13,6 +13,7 @@ import org.springframework.web.client.RestTemplate;
 
 import java.time.Instant;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 @Component
@@ -108,6 +109,7 @@ public class AuthClient {
     @NoArgsConstructor
     public static class CurrentUser {
         private String username;
+        private List<String> roles;
     }
 
 }

+ 129 - 0
src/main/java/com/danielbohry/stocks/client/InferenceClient.java

@@ -0,0 +1,129 @@
+package com.danielbohry.stocks.client;
+
+import com.danielbohry.stocks.domain.Quote;
+import com.danielbohry.stocks.domain.StockInfo;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.http.HttpEntity;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
+import org.springframework.stereotype.Component;
+import org.springframework.web.client.RestTemplate;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static java.time.Instant.now;
+
+@Slf4j
+@Component
+@RequiredArgsConstructor
+public class InferenceClient {
+
+    private final RestTemplate rest;
+
+    @Value("${clients.inference.url}")
+    private String baseUrl;
+
+    @Value("${clients.inference.model}")
+    private String model;
+
+    private static final String PROMPT = """
+        You are an API. Respond only with a **valid JSON object** using the exact format and keys below. 
+        Do not include any explanation, markdown, or extra text.
+        
+        Respond with JSON in the following format:
+        {
+          "founded": "<4-digit year or 'unknown'>",
+          "ipo": "<4-digit year or 'unknown'>",
+          "exchange": "<stock exchange name or 'unknown'>",
+          "headquarters": "<city and state/country or 'unknown'>",
+          "industry": "<industry sector or 'unknown'>",
+          "description": "<description or 'unknown'>"
+        }
+        
+        Make sure the output is a valid JSON string with no extra text or markdown.
+        
+        Company stock code: 
+        """;
+
+    private static final ObjectMapper mapper = new ObjectMapper();
+
+    public StockInfo infer(Quote quote) {
+        String inference = infer(PROMPT + quote.getCode());
+
+        return buildStockInfo(quote, inference);
+    }
+
+    private String infer(String prompt) {
+        try {
+            Map<String, Object> requestBody = new HashMap<>();
+            requestBody.put("model", model);
+            requestBody.put("prompt", prompt);
+            requestBody.put("stream", false);
+
+            HttpHeaders headers = new HttpHeaders();
+            headers.setContentType(MediaType.APPLICATION_JSON);
+
+            HttpEntity<Map<String, Object>> request = new HttpEntity<>(requestBody, headers);
+            ResponseEntity<InferenceResponse> response = rest.postForEntity(baseUrl + "/api/generate", request, InferenceResponse.class);
+
+            if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) {
+                return response.getBody().getResponse();
+            } else {
+                return null;
+            }
+
+        } catch (Exception e) {
+            log.error("Error during inference", e);
+            return null;
+        }
+    }
+
+    private StockInfo buildStockInfo(Quote stock, String inference) {
+        try {
+            String cleanInference = inference.replaceAll("(?i)^[^{]*\\{", "{").replaceAll("[^}]*$", "}").trim();
+            Inference parsed = mapper.readValue(cleanInference, Inference.class);
+
+            return StockInfo.builder()
+                .code(stock.getCode())
+                .name(stock.getName())
+                .description(parsed.getDescription())
+                .foundation(parsed.getFounded())
+                .ipo(parsed.getIpo())
+                .exchange(parsed.getExchange())
+                .headquarters(parsed.getHeadquarters())
+                .industry(parsed.getIndustry())
+                .updatedAt(now())
+                .build();
+        } catch (Exception e) {
+            log.error("Failed to parse inference response [{}]", inference, e);
+            return null;
+        }
+    }
+
+    @NoArgsConstructor
+    @AllArgsConstructor
+    @Data
+    public static class Inference {
+        private String founded;
+        private String ipo;
+        private String exchange;
+        private String headquarters;
+        private String industry;
+        private String description;
+    }
+
+    @NoArgsConstructor
+    @AllArgsConstructor
+    @Data
+    public static class InferenceResponse {
+        private String response;
+    }
+}

+ 24 - 0
src/main/java/com/danielbohry/stocks/config/AsyncConfig.java

@@ -0,0 +1,24 @@
+package com.danielbohry.stocks.config;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.scheduling.annotation.EnableAsync;
+import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
+
+import java.util.concurrent.Executor;
+
+@EnableAsync
+@Configuration
+public class AsyncConfig {
+
+    @Bean(name = "stockInfoExecutor")
+    public Executor taskExecutor() {
+        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
+        executor.setCorePoolSize(5);
+        executor.setMaxPoolSize(10);
+        executor.setQueueCapacity(10);
+        executor.setThreadNamePrefix("StockInfoExecutor-");
+        executor.initialize();
+        return executor;
+    }
+}

+ 8 - 5
src/main/java/com/danielbohry/stocks/config/CacheConfig.java

@@ -22,14 +22,17 @@ public class CacheConfig {
         CaffeineCache exchangeRates = new CaffeineCache("exchangeRates",
             Caffeine.newBuilder().expireAfterWrite(6, HOURS).build());
 
-        CaffeineCache stockQuotes = new CaffeineCache("stockQuotes",
-            Caffeine.newBuilder().expireAfterWrite(10, MINUTES).build());
+        CaffeineCache allStockQuotes = new CaffeineCache("allStockQuotes",
+            Caffeine.newBuilder().expireAfterWrite(5, MINUTES).build());
+
+        CaffeineCache stockQuotesQuery = new CaffeineCache("stockQuotesQuery",
+            Caffeine.newBuilder().expireAfterWrite(5, MINUTES).build());
 
-        CaffeineCache portfolio = new CaffeineCache("portfolio",
-            Caffeine.newBuilder().expireAfterWrite(10, MINUTES).build());
+        CaffeineCache stockQuotes = new CaffeineCache("stockQuotes",
+            Caffeine.newBuilder().expireAfterWrite(5, MINUTES).build());
 
         SimpleCacheManager manager = new SimpleCacheManager();
-        manager.setCaches(List.of(exchangeRates, stockQuotes, portfolio));
+        manager.setCaches(List.of(exchangeRates, allStockQuotes, stockQuotesQuery, stockQuotes));
         return manager;
     }
 

+ 3 - 2
src/main/java/com/danielbohry/stocks/context/ServiceContextFilter.java

@@ -11,6 +11,7 @@ import org.springframework.stereotype.Component;
 import org.springframework.web.filter.OncePerRequestFilter;
 
 import java.io.IOException;
+import java.util.List;
 
 @Component
 @AllArgsConstructor
@@ -39,7 +40,7 @@ public class ServiceContextFilter extends OncePerRequestFilter {
             CurrentUser user = extractCurrentUser(request);
 
             if (user != null) {
-                UserContextHolder.set(new UserContext(user.getUsername()));
+                UserContextHolder.set(new UserContext(user.getUsername(), user.getRoles()));
 
                 try {
                     filterChain.doFilter(request, response);
@@ -59,7 +60,7 @@ public class ServiceContextFilter extends OncePerRequestFilter {
 
         return token != null ?
             authClient.getCurrent(token)
-            : new CurrentUser("anonymous");
+            : new CurrentUser("anonymous", List.of("USER"));
     }
 
 }

+ 3 - 0
src/main/java/com/danielbohry/stocks/context/UserContext.java

@@ -3,10 +3,13 @@ package com.danielbohry.stocks.context;
 import lombok.AllArgsConstructor;
 import lombok.Getter;
 
+import java.util.List;
+
 @Getter
 @AllArgsConstructor
 public class UserContext {
 
     private String username;
+    private List<String> roles;
 
 }

+ 4 - 0
src/main/java/com/danielbohry/stocks/context/UserContextHolder.java

@@ -12,6 +12,10 @@ public class UserContextHolder {
         return CONTEXT.get();
     }
 
+    public static Boolean isAdmin() {
+        return CONTEXT.get().getRoles().contains("ADMIN");
+    }
+
     public static void clear() {
         CONTEXT.remove();
     }

+ 26 - 0
src/main/java/com/danielbohry/stocks/domain/StockInfo.java

@@ -0,0 +1,26 @@
+package com.danielbohry.stocks.domain;
+
+import lombok.Builder;
+import lombok.Data;
+import org.springframework.data.annotation.Id;
+import org.springframework.data.mongodb.core.mapping.Document;
+
+import java.time.Instant;
+
+@Builder
+@Data
+@Document("stock-infos")
+public class StockInfo {
+
+    @Id
+    private String code;
+    private String name;
+    private String description;
+    private String foundation;
+    private String ipo;
+    private String exchange;
+    private String headquarters;
+    private String industry;
+    private Instant updatedAt;
+
+}

+ 9 - 0
src/main/java/com/danielbohry/stocks/repository/StockInfoRepository.java

@@ -0,0 +1,9 @@
+package com.danielbohry.stocks.repository;
+
+import com.danielbohry.stocks.domain.StockInfo;
+import org.springframework.data.mongodb.repository.MongoRepository;
+import org.springframework.stereotype.Repository;
+
+@Repository
+public interface StockInfoRepository extends MongoRepository<StockInfo, String> {
+}

+ 53 - 0
src/main/java/com/danielbohry/stocks/service/StockInfoService.java

@@ -0,0 +1,53 @@
+package com.danielbohry.stocks.service;
+
+import com.danielbohry.stocks.client.InferenceClient;
+import com.danielbohry.stocks.domain.Quote;
+import com.danielbohry.stocks.domain.StockInfo;
+import com.danielbohry.stocks.repository.StockInfoRepository;
+import com.danielbohry.stocks.repository.StockRepository;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.scheduling.annotation.Async;
+import org.springframework.stereotype.Service;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+
+import static java.util.stream.Collectors.toSet;
+
+@Slf4j
+@Service
+@AllArgsConstructor
+public class StockInfoService {
+
+    private InferenceClient client;
+    private StockRepository stockRepository;
+    private StockInfoRepository infoRepository;
+
+    @Async("stockInfoExecutor")
+    public void generate() {
+        List<Quote> stocks = stockRepository.findAll();
+
+        Set<String> existingIds = infoRepository.findAllById(
+                stocks.stream()
+                    .map(Quote::getCode)
+                    .collect(toSet())
+            ).stream()
+            .map(StockInfo::getCode)
+            .collect(toSet());
+
+        stocks.stream()
+            .filter(stock -> !existingIds.contains(stock.getCode()))
+            .forEach(stock -> {
+                try {
+                    log.info("Generating stock info for {}", stock.getCode());
+                    Optional.ofNullable(client.infer(stock))
+                        .ifPresent(info -> infoRepository.save(info));
+                } catch (Exception e) {
+                    log.error("Failed to infer stock info for code: {}", stock.getCode(), e);
+                }
+            });
+    }
+
+}

+ 3 - 3
src/main/java/com/danielbohry/stocks/service/StockService.java

@@ -9,10 +9,8 @@ import org.springframework.stereotype.Service;
 
 import java.util.List;
 import java.util.Objects;
-import java.util.Set;
 
 import static java.util.Collections.emptyList;
-import static java.util.Collections.emptySet;
 
 @Slf4j
 @Service
@@ -21,11 +19,12 @@ public class StockService {
 
     private StockRepository repository;
 
-    @Cacheable("stockQuotes")
+    @Cacheable("allStockQuotes")
     public List<Quote> getAll() {
         return repository.findAll();
     }
 
+    @Cacheable(value = "stockQuotesQuery", key = "#query")
     public List<Quote> get(String query) {
         if (Objects.equals(query, "")) {
             return emptyList();
@@ -35,6 +34,7 @@ public class StockService {
         return repository.findLike(query);
     }
 
+    @Cacheable(value = "stockQuotes", key = "#code")
     public Quote getByCode(String code) {
         log.debug("Getting stock by code [{}]", code);
         return repository.findByCode(code);

+ 3 - 0
src/main/resources/application.yml

@@ -8,6 +8,9 @@ clients:
   exchange:
     url: ${exchange_provider:https://v6.exchangerate-api.com/v6}
     key: ${exchange_key:}
+  inference:
+    url: ${inference_client:}
+    model: ${inference_model:}
 
 auth:
   api: ${auth_api:}

+ 3 - 2
src/test/java/service/StockServiceTest.java

@@ -1,10 +1,12 @@
 package service;
 
 import com.danielbohry.stocks.App;
+import com.danielbohry.stocks.client.InferenceClient;
 import com.danielbohry.stocks.domain.Quote;
+import com.danielbohry.stocks.repository.StockInfoRepository;
 import com.danielbohry.stocks.repository.StockRepository;
 import com.danielbohry.stocks.service.StockService;
-import org.junit.jupiter.api.AfterEach;
+import org.checkerframework.checker.units.qual.A;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
@@ -13,7 +15,6 @@ import org.springframework.boot.test.context.SpringBootTest;
 import org.springframework.test.context.ContextConfiguration;
 
 import java.util.List;
-import java.util.Set;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;