feat: add generation source handling for task creation and updates
All checks were successful
Build And Publish Production Image / Build And Publish Production Image (push) Successful in 50s
All checks were successful
Build And Publish Production Image / Build And Publish Production Image (push) Successful in 50s
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.condado.newsletter.dto
|
||||
|
||||
import com.condado.newsletter.model.EntityTask
|
||||
import com.condado.newsletter.model.TaskGenerationSource
|
||||
import jakarta.validation.constraints.NotBlank
|
||||
import jakarta.validation.constraints.NotNull
|
||||
import java.time.LocalDateTime
|
||||
@@ -11,7 +12,8 @@ data class EntityTaskCreateDto(
|
||||
@field:NotBlank val name: String,
|
||||
val prompt: String,
|
||||
@field:NotBlank val scheduleCron: String,
|
||||
@field:NotBlank val emailLookback: String
|
||||
@field:NotBlank val emailLookback: String,
|
||||
val generationSource: TaskGenerationSource = TaskGenerationSource.LLAMA
|
||||
)
|
||||
|
||||
data class EntityTaskUpdateDto(
|
||||
@@ -19,7 +21,8 @@ data class EntityTaskUpdateDto(
|
||||
@field:NotBlank val name: String,
|
||||
@field:NotBlank val prompt: String,
|
||||
@field:NotBlank val scheduleCron: String,
|
||||
@field:NotBlank val emailLookback: String
|
||||
@field:NotBlank val emailLookback: String,
|
||||
val generationSource: TaskGenerationSource? = null
|
||||
)
|
||||
|
||||
data class EntityTaskResponseDto(
|
||||
@@ -29,6 +32,7 @@ data class EntityTaskResponseDto(
|
||||
val prompt: String,
|
||||
val scheduleCron: String,
|
||||
val emailLookback: String,
|
||||
val generationSource: TaskGenerationSource,
|
||||
val active: Boolean,
|
||||
val createdAt: LocalDateTime?
|
||||
) {
|
||||
@@ -41,6 +45,7 @@ data class EntityTaskResponseDto(
|
||||
prompt = task.prompt,
|
||||
scheduleCron = task.scheduleCron,
|
||||
emailLookback = task.emailLookback,
|
||||
generationSource = task.generationSource,
|
||||
active = task.active,
|
||||
createdAt = task.createdAt
|
||||
)
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.condado.newsletter.model
|
||||
import jakarta.persistence.CascadeType
|
||||
import jakarta.persistence.Column
|
||||
import jakarta.persistence.Entity
|
||||
import jakarta.persistence.EnumType
|
||||
import jakarta.persistence.Enumerated
|
||||
import jakarta.persistence.FetchType
|
||||
import jakarta.persistence.GeneratedValue
|
||||
import jakarta.persistence.GenerationType
|
||||
@@ -37,6 +39,10 @@ class EntityTask(
|
||||
@Column(name = "email_lookback", nullable = false)
|
||||
val emailLookback: String,
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(name = "generation_source", nullable = false)
|
||||
val generationSource: TaskGenerationSource = TaskGenerationSource.LLAMA,
|
||||
|
||||
@Column(nullable = false)
|
||||
val active: Boolean = true,
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.condado.newsletter.model
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator
|
||||
import com.fasterxml.jackson.annotation.JsonValue
|
||||
|
||||
enum class TaskGenerationSource(
|
||||
@get:JsonValue val value: String
|
||||
) {
|
||||
OPENAI("openai"),
|
||||
LLAMA("llama");
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
@JsonCreator
|
||||
fun from(value: String): TaskGenerationSource =
|
||||
entries.firstOrNull { it.value.equals(value, ignoreCase = true) }
|
||||
?: throw IllegalArgumentException("Invalid generationSource: $value")
|
||||
}
|
||||
}
|
||||
@@ -47,6 +47,7 @@ class EntityTaskService(
|
||||
prompt = dto.prompt,
|
||||
scheduleCron = dto.scheduleCron,
|
||||
emailLookback = dto.emailLookback,
|
||||
generationSource = dto.generationSource,
|
||||
active = true
|
||||
)
|
||||
|
||||
@@ -66,6 +67,7 @@ class EntityTaskService(
|
||||
prompt = dto.prompt,
|
||||
scheduleCron = dto.scheduleCron,
|
||||
emailLookback = dto.emailLookback,
|
||||
generationSource = dto.generationSource ?: existing.generationSource,
|
||||
active = existing.active,
|
||||
createdAt = existing.createdAt
|
||||
).apply { id = existing.id }
|
||||
@@ -83,6 +85,7 @@ class EntityTaskService(
|
||||
prompt = existing.prompt,
|
||||
scheduleCron = existing.scheduleCron,
|
||||
emailLookback = existing.emailLookback,
|
||||
generationSource = existing.generationSource,
|
||||
active = false,
|
||||
createdAt = existing.createdAt
|
||||
).apply { id = existing.id }
|
||||
@@ -100,6 +103,7 @@ class EntityTaskService(
|
||||
prompt = existing.prompt,
|
||||
scheduleCron = existing.scheduleCron,
|
||||
emailLookback = existing.emailLookback,
|
||||
generationSource = existing.generationSource,
|
||||
active = true,
|
||||
createdAt = existing.createdAt
|
||||
).apply { id = existing.id }
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.condado.newsletter.service
|
||||
import com.condado.newsletter.dto.GeneratedMessageHistoryResponseDto
|
||||
import com.condado.newsletter.dto.TaskPreviewGenerateRequestDto
|
||||
import com.condado.newsletter.model.GeneratedMessageHistory
|
||||
import com.condado.newsletter.model.TaskGenerationSource
|
||||
import com.condado.newsletter.repository.EntityTaskRepository
|
||||
import com.condado.newsletter.repository.GeneratedMessageHistoryRepository
|
||||
import org.springframework.stereotype.Service
|
||||
@@ -16,7 +17,8 @@ import java.util.UUID
|
||||
class TaskGeneratedMessageService(
|
||||
private val generatedMessageHistoryRepository: GeneratedMessageHistoryRepository,
|
||||
private val entityTaskRepository: EntityTaskRepository,
|
||||
private val llamaPreviewService: LlamaPreviewService
|
||||
private val llamaPreviewService: LlamaPreviewService,
|
||||
private val aiService: AiService
|
||||
) {
|
||||
|
||||
/** Lists persisted generated messages for a task. */
|
||||
@@ -25,15 +27,19 @@ class TaskGeneratedMessageService(
|
||||
.findAllByTask_IdOrderByCreatedAtDesc(taskId)
|
||||
.map { GeneratedMessageHistoryResponseDto.from(it) }
|
||||
|
||||
/**
|
||||
* Generates a new message using local Llama, persists it, and returns it.
|
||||
*/
|
||||
/** Generates a new message with the task-selected provider, persists it, and returns it. */
|
||||
@Transactional
|
||||
fun generateAndSave(taskId: UUID, request: TaskPreviewGenerateRequestDto): GeneratedMessageHistoryResponseDto {
|
||||
val task = entityTaskRepository.findById(taskId)
|
||||
.orElseThrow { IllegalArgumentException("Task not found: $taskId") }
|
||||
val prompt = buildPrompt(request)
|
||||
val generatedContent = llamaPreviewService.generate(prompt)
|
||||
val generatedContent = when (task.generationSource) {
|
||||
TaskGenerationSource.LLAMA -> llamaPreviewService.generate(prompt)
|
||||
TaskGenerationSource.OPENAI -> {
|
||||
val parsed = aiService.generate(prompt)
|
||||
"SUBJECT: ${parsed.subject}\nBODY:\n${parsed.body}"
|
||||
}
|
||||
}
|
||||
val nextLabel = "Message #${generatedMessageHistoryRepository.countByTask_Id(taskId) + 1}"
|
||||
|
||||
val saved = generatedMessageHistoryRepository.save(
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.condado.newsletter.scheduler.EntityScheduler
|
||||
import com.condado.newsletter.service.JwtService
|
||||
import com.ninjasquad.springmockk.MockkBean
|
||||
import jakarta.servlet.http.Cookie
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.beans.factory.annotation.Autowired
|
||||
@@ -15,6 +16,7 @@ import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMock
|
||||
import org.springframework.boot.test.context.SpringBootTest
|
||||
import org.springframework.http.MediaType
|
||||
import org.springframework.test.web.servlet.MockMvc
|
||||
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put
|
||||
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post
|
||||
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath
|
||||
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
|
||||
@@ -56,7 +58,8 @@ class EntityTaskControllerTest {
|
||||
"name": "Morning Blast",
|
||||
"prompt": "",
|
||||
"scheduleCron": "0 8 * * 1-5",
|
||||
"emailLookback": "last_week"
|
||||
"emailLookback": "last_week",
|
||||
"generationSource": "openai"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
@@ -70,5 +73,96 @@ class EntityTaskControllerTest {
|
||||
.andExpect(jsonPath("$.entityId").value(entity.id.toString()))
|
||||
.andExpect(jsonPath("$.name").value("Morning Blast"))
|
||||
.andExpect(jsonPath("$.prompt").value(""))
|
||||
.andExpect(jsonPath("$.generationSource").value("openai"))
|
||||
|
||||
val persisted = entityTaskRepository.findAll().first()
|
||||
assertThat(persisted.generationSource.value).isEqualTo("openai")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun should_updateTaskAndPersistGenerationSource_when_validRequestProvided() {
|
||||
val entity = virtualEntityRepository.save(
|
||||
VirtualEntity(
|
||||
name = "Entity B",
|
||||
email = "entity-b@condado.com",
|
||||
jobTitle = "Ops"
|
||||
)
|
||||
)
|
||||
|
||||
val createdPayload = """
|
||||
{
|
||||
"entityId": "${entity.id}",
|
||||
"name": "Task One",
|
||||
"prompt": "Initial prompt",
|
||||
"scheduleCron": "0 8 * * 1-5",
|
||||
"emailLookback": "last_week",
|
||||
"generationSource": "openai"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
val createdResult = mockMvc.perform(
|
||||
post("/api/v1/tasks")
|
||||
.cookie(authCookie())
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(createdPayload)
|
||||
)
|
||||
.andExpect(status().isCreated)
|
||||
.andReturn()
|
||||
|
||||
val taskId = com.jayway.jsonpath.JsonPath.read<String>(createdResult.response.contentAsString, "$.id")
|
||||
|
||||
val updatePayload = """
|
||||
{
|
||||
"entityId": "${entity.id}",
|
||||
"name": "Task One Updated",
|
||||
"prompt": "Updated prompt",
|
||||
"scheduleCron": "0 10 * * 1-5",
|
||||
"emailLookback": "last_day",
|
||||
"generationSource": "llama"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
mockMvc.perform(
|
||||
put("/api/v1/tasks/$taskId")
|
||||
.cookie(authCookie())
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(updatePayload)
|
||||
)
|
||||
.andExpect(status().isOk)
|
||||
.andExpect(jsonPath("$.name").value("Task One Updated"))
|
||||
.andExpect(jsonPath("$.generationSource").value("llama"))
|
||||
|
||||
val persisted = entityTaskRepository.findById(java.util.UUID.fromString(taskId)).orElseThrow()
|
||||
assertThat(persisted.generationSource.value).isEqualTo("llama")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun should_returnBadRequest_when_generationSourceIsInvalid() {
|
||||
val entity = virtualEntityRepository.save(
|
||||
VirtualEntity(
|
||||
name = "Entity C",
|
||||
email = "entity-c@condado.com",
|
||||
jobTitle = "Ops"
|
||||
)
|
||||
)
|
||||
|
||||
val payload = """
|
||||
{
|
||||
"entityId": "${entity.id}",
|
||||
"name": "Morning Blast",
|
||||
"prompt": "Prompt",
|
||||
"scheduleCron": "0 8 * * 1-5",
|
||||
"emailLookback": "last_week",
|
||||
"generationSource": "invalid-provider"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
mockMvc.perform(
|
||||
post("/api/v1/tasks")
|
||||
.cookie(authCookie())
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(payload)
|
||||
)
|
||||
.andExpect(status().isBadRequest)
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import com.condado.newsletter.dto.TaskPreviewGenerateRequestDto
|
||||
import com.condado.newsletter.dto.TaskPreviewTaskDto
|
||||
import com.condado.newsletter.model.EntityTask
|
||||
import com.condado.newsletter.model.GeneratedMessageHistory
|
||||
import com.condado.newsletter.model.ParsedAiResponse
|
||||
import com.condado.newsletter.model.TaskGenerationSource
|
||||
import com.condado.newsletter.model.VirtualEntity
|
||||
import com.condado.newsletter.repository.EntityTaskRepository
|
||||
import com.condado.newsletter.repository.GeneratedMessageHistoryRepository
|
||||
@@ -21,15 +23,17 @@ class TaskGeneratedMessageServiceTest {
|
||||
private val generatedMessageHistoryRepository: GeneratedMessageHistoryRepository = mockk()
|
||||
private val entityTaskRepository: EntityTaskRepository = mockk()
|
||||
private val llamaPreviewService: LlamaPreviewService = mockk()
|
||||
private val aiService: AiService = mockk()
|
||||
|
||||
private val service = TaskGeneratedMessageService(
|
||||
generatedMessageHistoryRepository = generatedMessageHistoryRepository,
|
||||
entityTaskRepository = entityTaskRepository,
|
||||
llamaPreviewService = llamaPreviewService
|
||||
llamaPreviewService = llamaPreviewService,
|
||||
aiService = aiService
|
||||
)
|
||||
|
||||
@Test
|
||||
fun should_generateAndPersistMessage_when_generateAndSaveCalled() {
|
||||
fun should_useLlamaProvider_when_taskGenerationSourceIsLlama() {
|
||||
val taskId = UUID.randomUUID()
|
||||
val entity = VirtualEntity(name = "Entity", email = "e@x.com", jobTitle = "Ops").apply { id = UUID.randomUUID() }
|
||||
val task = EntityTask(
|
||||
@@ -37,7 +41,8 @@ class TaskGeneratedMessageServiceTest {
|
||||
name = "Task",
|
||||
prompt = "Prompt",
|
||||
scheduleCron = "0 9 * * 1",
|
||||
emailLookback = "last_week"
|
||||
emailLookback = "last_week",
|
||||
generationSource = TaskGenerationSource.LLAMA
|
||||
).apply { id = taskId }
|
||||
val captured = slot<GeneratedMessageHistory>()
|
||||
|
||||
@@ -59,9 +64,40 @@ class TaskGeneratedMessageServiceTest {
|
||||
assertThat(captured.captured.task.id).isEqualTo(taskId)
|
||||
|
||||
verify(exactly = 1) { llamaPreviewService.generate(any()) }
|
||||
verify(exactly = 0) { aiService.generate(any()) }
|
||||
verify(exactly = 1) { generatedMessageHistoryRepository.save(any()) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun should_useOpenAiProvider_when_taskGenerationSourceIsOpenai() {
|
||||
val taskId = UUID.randomUUID()
|
||||
val entity = VirtualEntity(name = "Entity", email = "e@x.com", jobTitle = "Ops").apply { id = UUID.randomUUID() }
|
||||
val task = EntityTask(
|
||||
virtualEntity = entity,
|
||||
name = "Task",
|
||||
prompt = "Prompt",
|
||||
scheduleCron = "0 9 * * 1",
|
||||
emailLookback = "last_week",
|
||||
generationSource = TaskGenerationSource.OPENAI
|
||||
).apply { id = taskId }
|
||||
val captured = slot<GeneratedMessageHistory>()
|
||||
|
||||
every { aiService.generate(any()) } returns ParsedAiResponse(subject = "Open Subject", body = "Open Body")
|
||||
every { entityTaskRepository.findById(taskId) } returns java.util.Optional.of(task)
|
||||
every { generatedMessageHistoryRepository.countByTask_Id(taskId) } returns 0
|
||||
every { generatedMessageHistoryRepository.save(capture(captured)) } answers {
|
||||
captured.captured.apply {
|
||||
id = UUID.fromString("00000000-0000-0000-0000-000000000001")
|
||||
}
|
||||
}
|
||||
|
||||
val response = service.generateAndSave(taskId, sampleRequest())
|
||||
|
||||
assertThat(response.content).isEqualTo("SUBJECT: Open Subject\nBODY:\nOpen Body")
|
||||
verify(exactly = 1) { aiService.generate(any()) }
|
||||
verify(exactly = 0) { llamaPreviewService.generate(any()) }
|
||||
}
|
||||
|
||||
private fun sampleRequest() = TaskPreviewGenerateRequestDto(
|
||||
entity = TaskPreviewEntityDto(
|
||||
id = UUID.randomUUID().toString(),
|
||||
|
||||
Reference in New Issue
Block a user