Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.a2a.extras.pushnotificationconfigstore.database.jpa;

import java.util.ArrayList;
import java.util.Collections;
import jakarta.persistence.TypedQuery;
import java.time.Instant;
import java.util.List;

import jakarta.annotation.Priority;
Expand All @@ -17,6 +17,7 @@
import io.a2a.spec.ListTaskPushNotificationConfigResult;
import io.a2a.spec.PushNotificationConfig;
import io.a2a.spec.TaskPushNotificationConfig;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -26,7 +27,9 @@
public class JpaDatabasePushNotificationConfigStore implements PushNotificationConfigStore {

private static final Logger LOGGER = LoggerFactory.getLogger(JpaDatabasePushNotificationConfigStore.class);


private static final Instant NULL_TIMESTAMP_SENTINEL = Instant.EPOCH;

@PersistenceContext(unitName = "a2a-java")
EntityManager em;

Expand All @@ -36,6 +39,8 @@ public PushNotificationConfig setInfo(String taskId, PushNotificationConfig noti
// Ensure config has an ID - default to taskId if not provided (mirroring InMemoryPushNotificationConfigStore behavior)
PushNotificationConfig.Builder builder = PushNotificationConfig.builder(notificationConfig);
if (notificationConfig.id() == null || notificationConfig.id().isEmpty()) {
// This means the taskId and configId are same. This will not allow having multiple configs for a single Task.
// The configId is a required field in the spec and should not be empty
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on lines 42-43 states that configId is a required field in the spec and should not be empty. However, the code defaults notificationConfig.id() to taskId if it's null or empty. This implies that if a client doesn't provide a unique id for a PushNotificationConfig, it will be assigned the taskId as its configId. If the intention is to allow multiple push notification configurations for a single task, each requiring a unique configId, this defaulting behavior could prevent that by creating duplicate configIds (all equal to taskId). Please clarify if this behavior is intentional for a default configuration or if configId should always be explicitly provided and unique.

builder.id(taskId);
}
notificationConfig = builder.build();
Expand Down Expand Up @@ -72,15 +77,61 @@ public PushNotificationConfig setInfo(String taskId, PushNotificationConfig noti
@Override
public ListTaskPushNotificationConfigResult getInfo(ListTaskPushNotificationConfigParams params) {
String taskId = params.id();
LOGGER.debug("Retrieving PushNotificationConfigs for Task '{}'", taskId);
LOGGER.debug("Retrieving PushNotificationConfigs for Task '{}' with params: pageSize={}, pageToken={}",
taskId, params.pageSize(), params.pageToken());
try {
List<JpaPushNotificationConfig> jpaConfigs = em.createQuery(
"SELECT c FROM JpaPushNotificationConfig c WHERE c.id.taskId = :taskId",
JpaPushNotificationConfig.class)
.setParameter("taskId", taskId)
.getResultList();
StringBuilder queryBuilder = new StringBuilder("SELECT c FROM JpaPushNotificationConfig c WHERE c.id.taskId = :taskId");

if (params.pageToken() != null && !params.pageToken().isEmpty()) {
String[] tokenParts = params.pageToken().split(":", 2);
if (tokenParts.length == 2) {
// Keyset pagination: get tasks where timestamp < tokenTimestamp OR (timestamp = tokenTimestamp AND id > tokenId)
// All tasks have timestamps (TaskStatus canonical constructor ensures this)
queryBuilder.append(" AND (COALESCE(c.createdAt, :nullSentinel) < :tokenTimestamp OR (COALESCE(c.createdAt, :nullSentinel) = :tokenTimestamp AND c.id.configId > :tokenId))");
} else {
// Based on the comments in the test case, if the pageToken is invalid start from the beginning.
}
Comment on lines +92 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 92 suggests that if the pageToken is invalid, the system should "start from the beginning." However, the current implementation for tokenParts.length != 2 (i.e., the pageToken is not in the expected timestamp:id format) silently ignores the pageToken and proceeds to fetch the first page. For better clarity and consistency with the NumberFormatException handling (lines 112-116), it would be more robust to explicitly throw an InvalidParamsError when the pageToken is malformed or does not adhere to the expected structure. This provides clearer feedback to the client about incorrect usage.

                // Keyset pagination: get tasks where timestamp < tokenTimestamp OR (timestamp = tokenTimestamp AND id > tokenId)
                // All tasks have timestamps (TaskStatus canonical constructor ensures this)
                queryBuilder.append(" AND (COALESCE(c.createdAt, :nullSentinel) < :tokenTimestamp OR (COALESCE(c.createdAt, :nullSentinel) = :tokenTimestamp AND c.id.configId > :tokenId))");
              } else {
                throw new io.a2a.spec.InvalidParamsError(null,
                    "Invalid pageToken format: pageToken must be in 'timestamp_millis:configId' format", null);
              }
            }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ehsavoie any suggestion on this about the default behavior ? I see many common public APIs ignore pagination token if they are invalid ?

}

queryBuilder.append(" ORDER BY COALESCE(c.createdAt, :nullSentinel) DESC, c.id.configId ASC");

TypedQuery<JpaPushNotificationConfig> query = em.createQuery(queryBuilder.toString(), JpaPushNotificationConfig.class);
query.setParameter("taskId", taskId);
query.setParameter("nullSentinel", NULL_TIMESTAMP_SENTINEL);

if (params.pageToken() != null && !params.pageToken().isEmpty()) {
String[] tokenParts = params.pageToken().split(":", 2);
if (tokenParts.length == 2) {
try {
long timestampMillis = Long.parseLong(tokenParts[0]);
String tokenId = tokenParts[1];

Instant tokenTimestamp = Instant.ofEpochMilli(timestampMillis);
query.setParameter("tokenTimestamp", tokenTimestamp);
query.setParameter("tokenId", tokenId);
} catch (NumberFormatException e) {
// Malformed timestamp in pageToken
throw new io.a2a.spec.InvalidParamsError(null,
"Invalid pageToken format: timestamp must be numeric milliseconds", null);
}
}
}

int pageSize = params.getEffectivePageSize();
query.setMaxResults(pageSize + 1);
List<JpaPushNotificationConfig> jpaConfigsPage = query.getResultList();

String nextPageToken = null;
if (jpaConfigsPage.size() > pageSize) {
// There are more results than the page size, and in this case, a nextToken should be created with the last item.
// Format: "timestamp_millis:taskId" for keyset pagination
jpaConfigsPage = jpaConfigsPage.subList(0, pageSize);
JpaPushNotificationConfig lastConfig = jpaConfigsPage.get(jpaConfigsPage.size() - 1);
Instant timestamp = lastConfig.getCreatedAt() != null ? lastConfig.getCreatedAt() : NULL_TIMESTAMP_SENTINEL;
nextPageToken = timestamp.toEpochMilli() + ":" + lastConfig.getId().getConfigId();
}

List<PushNotificationConfig> configs = jpaConfigs.stream()
List<PushNotificationConfig> configs = jpaConfigsPage.stream()
.map(jpaConfig -> {
try {
return jpaConfig.getConfig();
Expand All @@ -95,57 +146,17 @@ public ListTaskPushNotificationConfigResult getInfo(ListTaskPushNotificationConf

LOGGER.debug("Successfully retrieved {} PushNotificationConfigs for Task '{}'", configs.size(), taskId);

// Handle pagination
if (configs.isEmpty()) {
return new ListTaskPushNotificationConfigResult(Collections.emptyList());
}

if (params.pageSize() <= 0) {
return new ListTaskPushNotificationConfigResult(convertPushNotificationConfig(configs, params), null);
}

// Apply pageToken filtering if provided
List<PushNotificationConfig> paginatedConfigs = configs;
if (params.pageToken() != null && !params.pageToken().isBlank()) {
int index = findFirstIndex(configs, params.pageToken());
if (index < configs.size()) {
paginatedConfigs = configs.subList(index, configs.size());
}
}

// Apply page size limit
if (paginatedConfigs.size() <= params.pageSize()) {
return new ListTaskPushNotificationConfigResult(convertPushNotificationConfig(paginatedConfigs, params), null);
}
List<TaskPushNotificationConfig> taskPushNotificationConfigs = configs.stream()
.map(config -> new TaskPushNotificationConfig(params.id(), config, params.tenant()))
.collect(Collectors.toList());

String nextToken = paginatedConfigs.get(params.pageSize()).token();
return new ListTaskPushNotificationConfigResult(
convertPushNotificationConfig(paginatedConfigs.subList(0, params.pageSize()), params),
nextToken);
return new ListTaskPushNotificationConfigResult(taskPushNotificationConfigs, nextPageToken);
} catch (Exception e) {
LOGGER.error("Failed to retrieve PushNotificationConfigs for Task '{}'", taskId, e);
throw e;
}
}

private int findFirstIndex(List<PushNotificationConfig> configs, String token) {
for (int i = 0; i < configs.size(); i++) {
if (token.equals(configs.get(i).token())) {
return i;
}
}
return configs.size();
}

private List<TaskPushNotificationConfig> convertPushNotificationConfig(List<PushNotificationConfig> pushNotificationConfigList, ListTaskPushNotificationConfigParams params) {
List<TaskPushNotificationConfig> taskPushNotificationConfigList = new ArrayList<>(pushNotificationConfigList.size());
for (PushNotificationConfig pushNotificationConfig : pushNotificationConfigList) {
TaskPushNotificationConfig taskPushNotificationConfig = new TaskPushNotificationConfig(params.id(), pushNotificationConfig, params.tenant());
taskPushNotificationConfigList.add(taskPushNotificationConfig);
}
return taskPushNotificationConfigList;
}

@Transactional
@Override
public void deleteInfo(String taskId, String configId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import jakarta.persistence.Column;
import jakarta.persistence.EmbeddedId;
import jakarta.persistence.Entity;
import jakarta.persistence.PrePersist;
import jakarta.persistence.Table;
import jakarta.persistence.Transient;

import io.a2a.jsonrpc.common.json.JsonProcessingException;
import io.a2a.jsonrpc.common.json.JsonUtil;
import io.a2a.spec.PushNotificationConfig;
import java.time.Instant;

@Entity
@Table(name = "a2a_push_notification_configs")
Expand All @@ -19,6 +21,9 @@ public class JpaPushNotificationConfig {
@Column(name = "task_data", columnDefinition = "TEXT", nullable = false)
private String configJson;

@Column(name = "created_at")
private Instant createdAt;

@Transient
private PushNotificationConfig config;

Expand All @@ -31,6 +36,12 @@ public JpaPushNotificationConfig(TaskConfigId id, String configJson) {
this.configJson = configJson;
}

@PrePersist
protected void onCreate() {
if (createdAt == null) {
createdAt = Instant.now();
}
}

public TaskConfigId getId() {
return id;
Expand Down Expand Up @@ -60,6 +71,14 @@ public void setConfig(PushNotificationConfig config) throws JsonProcessingExcept
this.config = config;
}

public Instant getCreatedAt() {
return createdAt;
}

public void setCreatedAt(Instant createdAt) {
this.createdAt = createdAt;
}

static JpaPushNotificationConfig createFromConfig(String taskId, PushNotificationConfig config) throws JsonProcessingException {
String json = JsonUtil.toJson(config);
JpaPushNotificationConfig jpaPushNotificationConfig =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,7 @@ private PushNotificationConfig createSamplePushConfig(String url, String configI
public void testPaginationWithPageSize() {
String taskId = "task_pagination_" + System.currentTimeMillis();
// Create 5 configs
for (int i = 0; i < 5; i++) {
PushNotificationConfig config = createSamplePushConfig(
"http://url" + i + ".com/callback", "cfg" + i, "token" + i);
pushNotificationConfigStore.setInfo(taskId, config);
}

createSamples(taskId, 5);
// Request first page with pageSize=2
ListTaskPushNotificationConfigParams params = new ListTaskPushNotificationConfigParams(taskId, 2, "", "");
ListTaskPushNotificationConfigResult result = pushNotificationConfigStore.getInfo(params);
Expand Down Expand Up @@ -251,11 +246,13 @@ public void testPaginationWithPageToken() {
}

// Also verify the pages are sequential (first page ends before second page starts)
// Since configs are created in order, we can verify the IDs
assertEquals("cfg0", firstPageIds.get(0));
assertEquals("cfg1", firstPageIds.get(1));
// Since configs are created in order, we can verify the IDs.
// There is no spec about pagination for PushNotifications, hence following the Task List
// behavior by which recent notifications are returned first
assertEquals("cfg4", firstPageIds.get(0));
assertEquals("cfg3", firstPageIds.get(1));
assertEquals("cfg2", secondPageIds.get(0));
assertEquals("cfg3", secondPageIds.get(1));
assertEquals("cfg1", secondPageIds.get(1));
}

@Test
Expand Down Expand Up @@ -410,11 +407,20 @@ public void testPaginationFullIteration() {
}

private void createSamples(String taskId, int size) {
// Create 7 configs
// Create configs with slight delays to ensure unique timestamps for deterministic ordering
for (int i = 0; i < size; i++) {
PushNotificationConfig config = createSamplePushConfig(
"http://url" + i + ".com/callback", "cfg" + i, "token" + i);
pushNotificationConfigStore.setInfo(taskId, config);

// Sleep briefly to ensure each config gets a unique timestamp
// This prevents non-deterministic ordering in pagination tests
try {
Thread.sleep(2);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted while creating test samples", e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ quarkus.datasource.password=
quarkus.hibernate-orm.database.generation=drop-and-create
quarkus.hibernate-orm.log.sql=true
quarkus.hibernate-orm.log.format-sql=true

# Transaction timeout (set to 30 minutes for debugging - 1800 seconds)
# quarkus.transaction-manager.default-transaction-timeout=1800s
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import static io.a2a.client.http.A2AHttpClient.CONTENT_TYPE;
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;

import io.a2a.spec.TaskPushNotificationConfig;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.http.JdkA2AHttpClient;
Expand All @@ -26,6 +29,7 @@
public class BasePushNotificationSender implements PushNotificationSender {

private static final Logger LOGGER = LoggerFactory.getLogger(BasePushNotificationSender.class);
public static final int DEFAULT_PAGE_SIZE = 100;

private final A2AHttpClient httpClient;
private final PushNotificationConfigStore configStore;
Expand All @@ -43,12 +47,18 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt

@Override
public void sendNotification(Task task) {
ListTaskPushNotificationConfigResult pushConfigs = configStore.getInfo(new ListTaskPushNotificationConfigParams(task.id()));
if (pushConfigs == null || pushConfigs.isEmpty()) {
return;
}

List<CompletableFuture<Boolean>> dispatchResults = pushConfigs.configs()
List<TaskPushNotificationConfig> configs = new ArrayList<>();
String nextPageToken = null;
do {
ListTaskPushNotificationConfigResult pageResult = configStore.getInfo(new ListTaskPushNotificationConfigParams(task.id(),
DEFAULT_PAGE_SIZE, nextPageToken, ""));
if (!pageResult.configs().isEmpty()) {
configs.addAll(pageResult.configs());
}
nextPageToken = pageResult.nextPageToken();
} while (nextPageToken != null);

List<CompletableFuture<Boolean>> dispatchResults = configs
.stream()
.map(pushConfig -> dispatch(task, pushConfig.pushNotificationConfig()))
.toList();
Expand All @@ -57,7 +67,7 @@ public void sendNotification(Task task) {
.allMatch(CompletableFuture::join));
try {
boolean allSent = dispatchResult.get();
if (! allSent) {
if (!allSent) {
LOGGER.warn("Some push notifications failed to send for taskId: " + task.id());
}
} catch (InterruptedException | ExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,16 @@ public record ListTaskPushNotificationConfigParams(String id, int pageSize, Stri
public ListTaskPushNotificationConfigParams(String id) {
this(id, 0, "", "");
}

/**
* Validates and returns the effective page size (between 1 and 100, defaults to 100).
*
* @return the effective page size
*/
public int getEffectivePageSize() {
if (pageSize <= 0 || pageSize > 100) {
return 100;
}
return pageSize;
}
}