diff --git a/core/src/main/java/apoc/periodic/Periodic.java b/core/src/main/java/apoc/periodic/Periodic.java index 151548f142..5717e76c4d 100644 --- a/core/src/main/java/apoc/periodic/Periodic.java +++ b/core/src/main/java/apoc/periodic/Periodic.java @@ -3,7 +3,6 @@ import apoc.Pools; import apoc.util.Util; import org.apache.commons.lang3.exception.ExceptionUtils; -import org.apache.commons.lang3.time.DateUtils; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.QueryStatistics; import org.neo4j.graphdb.Result; @@ -17,11 +16,6 @@ import org.neo4j.logging.Log; import org.neo4j.procedure.*; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.OffsetTime; -import java.time.temporal.ChronoUnit; import java.time.temporal.Temporal; import java.util.*; import java.util.concurrent.*; @@ -33,11 +27,14 @@ import java.util.regex.Pattern; import java.util.stream.Stream; +import static apoc.periodic.PeriodicUtils.getJobInfo; +import static apoc.periodic.PeriodicUtils.schedule; +import static org.neo4j.graphdb.QueryExecutionType.QueryType; +import static apoc.periodic.PeriodicUtils.submitJob; +import static apoc.periodic.PeriodicUtils.submitProc; import static apoc.util.Util.merge; public class Periodic { - - public static final String ERROR_DATE_BEFORE = "The provided date is before current date"; enum Planner {DEFAULT, COST, IDP, DP } @@ -175,45 +172,7 @@ public Stream cancel(@Name("name") String name) { @Description("apoc.periodic.submit('name',statement,params) - submit a one-off background statement; parameter 'params' is optional and can contain query parameters for Cypher statement") public Stream submit(@Name("name") String name, @Name("statement") String statement, @Name(value = "params", defaultValue = "{}") Map config) { validateQuery(statement); - Map params = (Map)config.getOrDefault("params", Collections.emptyMap()); - - final Temporal atTime = (Temporal) (config.get("atTime")); - - final Runnable task = () -> { - try { - db.executeTransactionally(statement, params); - } catch (Exception e) { - log.warn("in background task via submit", e); - throw new RuntimeException(e); - } - }; - - JobInfo info = atTime != null - ? getJobInfo(name, atTime, task, ScheduleType.DEFAULT) - : submit(name, task); - - return Stream.of(info); - } - - private JobInfo getJobInfo(String name, Temporal atTime, Runnable task, ScheduleType scheduleType) { - if (atTime instanceof LocalDate) { - atTime = ((LocalDate) atTime).atStartOfDay(); - } - final boolean isTime = atTime instanceof OffsetTime || atTime instanceof LocalTime; - Temporal now = isTime - ? LocalTime.now() - : LocalDateTime.now(); - - final long secPerDay = DateUtils.MILLIS_PER_DAY / 1000L; - long delay = now.until(atTime, ChronoUnit.SECONDS); - if (isTime && delay < 0) { - // we consider the day after - delay = delay + secPerDay; - } - if (delay < 0) { - throw new RuntimeException(ERROR_DATE_BEFORE); - } - return schedule(name, task, delay, secPerDay, scheduleType); + return submitProc(name, statement, config, db, log, pools); } @Procedure(mode = Mode.WRITE) @@ -227,9 +186,9 @@ public Stream repeat(@Name("name") String name, @Name("statement") Stri }; final JobInfo info; if (rateOrTime instanceof Long) { - info = schedule(name, runnable,0, (long) rateOrTime); + info = schedule(name, runnable,0, (long) rateOrTime, log, pools); } else if(rateOrTime instanceof Temporal) { - info = getJobInfo(name, (Temporal) rateOrTime, runnable, ScheduleType.FIXED_RATE); + info = getJobInfo(name, (Temporal) rateOrTime, runnable, log, pools, PeriodicUtils.ScheduleType.FIXED_RATE); } else { throw new RuntimeException("invalid type of rateOrTime parameter"); } @@ -238,77 +197,21 @@ public Stream repeat(@Name("name") String name, @Name("statement") Stri } private void validateQuery(String statement) { - Util.validateQuery(db, statement); + Util.validateQuery(db, statement, + Set.of(Mode.WRITE, Mode.READ, Mode.DEFAULT), + QueryType.READ_ONLY, QueryType.WRITE, QueryType.READ_WRITE); } @Procedure(mode = Mode.WRITE) @Description("apoc.periodic.countdown('name',statement,repeat-rate-in-seconds) submit a repeatedly-called background statement until it returns 0") public Stream countdown(@Name("name") String name, @Name("statement") String statement, @Name("rate") long rate) { validateQuery(statement); - JobInfo info = submit(name, new Countdown(name, statement, rate, log)); + JobInfo info = submitJob(name, new Countdown(name, statement, rate, log), log, pools); info.rate = rate; return Stream.of(info); } - /** - * Call from a procedure that gets a @Context GraphDatbaseAPI db; injected and provide that db to the runnable. - */ - public JobInfo submit(String name, Runnable task) { - JobInfo info = new JobInfo(name); - Future future = pools.getJobList().remove(info); - if (future != null && !future.isDone()) future.cancel(false); - Runnable wrappingTask = wrapTask(name, task, log); - Future newFuture = pools.getScheduledExecutorService().submit(wrappingTask); - pools.getJobList().put(info,newFuture); - return info; - } - - private enum ScheduleType { DEFAULT, FIXED_DELAY, FIXED_RATE } - - public JobInfo schedule(String name, Runnable task, long delay, long repeat) { - return schedule(name, task, delay, repeat, ScheduleType.FIXED_DELAY); - } - - /** - * Call from a procedure that gets a @Context GraphDatbaseAPI db; injected and provide that db to the runnable. - */ - public JobInfo schedule(String name, Runnable task, long delay, long repeat, ScheduleType isFixedDelay) { - JobInfo info = new JobInfo(name, delay, isFixedDelay.equals(ScheduleType.DEFAULT) ? 0 : repeat); - Future future = pools.getJobList().remove(info); - if (future != null && !future.isDone()) future.cancel(false); - - Runnable wrappingTask = wrapTask(name, task, log); - ScheduledFuture newFuture = getScheduledFuture(wrappingTask, delay, repeat, isFixedDelay); - pools.getJobList().put(info,newFuture); - return info; - } - - private ScheduledFuture getScheduledFuture(Runnable wrappingTask, long delay, long repeat, ScheduleType isFixedDelay) { - final ScheduledExecutorService service = pools.getScheduledExecutorService(); - final TimeUnit timeUnit = TimeUnit.SECONDS; - switch (isFixedDelay) { - case FIXED_DELAY: - return service.scheduleWithFixedDelay(wrappingTask, delay, repeat, timeUnit); - case FIXED_RATE: - return service.scheduleAtFixedRate(wrappingTask, delay, repeat, timeUnit); - default: - return service.schedule(wrappingTask, delay, timeUnit); - } - } - - private static Runnable wrapTask(String name, Runnable task, Log log) { - return () -> { - log.debug("Executing task " + name); - try { - task.run(); - } catch (Exception e) { - log.error("Error while executing task " + name + " because of the following exception (the task will be killed):", e); - throw e; - } - log.debug("Executed task " + name); - }; - } /** * Invoke cypherAction in batched transactions being fed from cypherIteration running in main thread @@ -520,7 +423,7 @@ public Countdown(String name, String statement, long rate, Log log) { @Override public void run() { if (Periodic.this.executeNumericResultStatement(statement, Collections.emptyMap()) > 0) { - pools.getScheduledExecutorService().schedule(() -> submit(name, this), rate, TimeUnit.SECONDS); + pools.getScheduledExecutorService().schedule(() -> submitJob(name, this, log, pools), rate, TimeUnit.SECONDS); } } } diff --git a/core/src/main/java/apoc/periodic/PeriodicUtils.java b/core/src/main/java/apoc/periodic/PeriodicUtils.java index c3abfeb0f5..ee18c544cf 100644 --- a/core/src/main/java/apoc/periodic/PeriodicUtils.java +++ b/core/src/main/java/apoc/periodic/PeriodicUtils.java @@ -2,6 +2,7 @@ import apoc.Pools; import apoc.util.Util; +import org.apache.commons.lang3.time.DateUtils; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.QueryStatistics; import org.neo4j.graphdb.Transaction; @@ -9,12 +10,22 @@ import org.neo4j.logging.Log; import org.neo4j.procedure.TerminationGuard; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetTime; +import java.time.temporal.ChronoUnit; +import java.time.temporal.Temporal; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.LockSupport; import java.util.function.BiFunction; @@ -23,11 +34,16 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static apoc.periodic.Periodic.JobInfo; + public class PeriodicUtils { private PeriodicUtils() { } + public enum ScheduleType { DEFAULT, FIXED_DELAY, FIXED_RATE } + + public static final String ERROR_DATE_BEFORE = "The provided date is before current date"; public static Pair prepareInnerStatement(String cypherAction, BatchMode batchMode, List columns, String iteratorVariableName) { String names = columns.stream().map(Util::quote).collect(Collectors.joining("|")); @@ -116,6 +132,106 @@ public static Stream iterateAndExecuteBatchedInSeparateThre } return Stream.of(collector.getResult()); } + + public static Stream submitProc(String name, String statement, Map config, GraphDatabaseService db, Log log, Pools pools) { + Map params = (Map)config.getOrDefault("params", Collections.emptyMap()); + + final Temporal atTime = (Temporal) (config.get("atTime")); + + final Runnable task = () -> { + try { + db.executeTransactionally(statement, params); + } catch (Exception e) { + log.warn("in background task via submit", e); + throw new RuntimeException(e); + } + }; + + JobInfo info = atTime != null + ? getJobInfo(name, atTime, task, log, pools, ScheduleType.DEFAULT) + : submitJob(name, task, log, pools); + + return Stream.of(info); + } + + public static JobInfo getJobInfo(String name, Temporal atTime, Runnable task, Log log, Pools pools, ScheduleType scheduleType) { + if (atTime instanceof LocalDate) { + atTime = ((LocalDate) atTime).atStartOfDay(); + } + final boolean isTime = atTime instanceof OffsetTime || atTime instanceof LocalTime; + Temporal now = isTime + ? LocalTime.now() + : LocalDateTime.now(); + + final long secPerDay = DateUtils.MILLIS_PER_DAY / 1000L; + long delay = now.until(atTime, ChronoUnit.SECONDS); + if (isTime && delay < 0) { + // we consider the day after + delay = delay + secPerDay; + } + if (delay < 0) { + throw new RuntimeException(ERROR_DATE_BEFORE); + } + return schedule(name, task, delay, secPerDay, log, pools, scheduleType); + } + + /** + * Call from a procedure that gets a @Context GraphDatbaseAPI db; injected and provide that db to the runnable. + */ + public static JobInfo submitJob(String name, Runnable task, Log log, Pools pools) { + JobInfo info = new JobInfo(name); + Future future = pools.getJobList().remove(info); + if (future != null && !future.isDone()) future.cancel(false); + + Runnable wrappingTask = wrapTask(name, task, log); + Future newFuture = pools.getScheduledExecutorService().submit(wrappingTask); + pools.getJobList().put(info,newFuture); + return info; + } + + public static JobInfo schedule(String name, Runnable task, long delay, long repeat, Log log, Pools pools) { + return schedule(name, task, delay, repeat, log, pools, ScheduleType.FIXED_DELAY); + } + + /** + * Call from a procedure that gets a @Context GraphDatbaseAPI db; injected and provide that db to the runnable. + */ + public static JobInfo schedule(String name, Runnable task, long delay, long repeat, Log log, Pools pools, ScheduleType isFixedDelay) { + JobInfo info = new JobInfo(name, delay, isFixedDelay.equals(ScheduleType.DEFAULT) ? 0 : repeat); + Future future = pools.getJobList().remove(info); + if (future != null && !future.isDone()) future.cancel(false); + + Runnable wrappingTask = wrapTask(name, task, log); + ScheduledFuture newFuture = getScheduledFuture(wrappingTask, delay, repeat, pools, isFixedDelay); + pools.getJobList().put(info,newFuture); + return info; + } + + private static ScheduledFuture getScheduledFuture(Runnable wrappingTask, long delay, long repeat, Pools pools, ScheduleType isFixedDelay) { + final ScheduledExecutorService service = pools.getScheduledExecutorService(); + final TimeUnit timeUnit = TimeUnit.SECONDS; + switch (isFixedDelay) { + case FIXED_DELAY: + return service.scheduleWithFixedDelay(wrappingTask, delay, repeat, timeUnit); + case FIXED_RATE: + return service.scheduleAtFixedRate(wrappingTask, delay, repeat, timeUnit); + default: + return service.schedule(wrappingTask, delay, timeUnit); + } + } + + public static Runnable wrapTask(String name, Runnable task, Log log) { + return () -> { + log.debug("Executing task " + name); + try { + task.run(); + } catch (Exception e) { + log.error("Error while executing task " + name + " because of the following exception (the task will be killed):", e); + throw e; + } + log.debug("Executed task " + name); + }; + } } /* diff --git a/core/src/main/java/apoc/util/Util.java b/core/src/main/java/apoc/util/Util.java index 24971b0e88..70cb2033a8 100644 --- a/core/src/main/java/apoc/util/Util.java +++ b/core/src/main/java/apoc/util/Util.java @@ -31,7 +31,10 @@ import org.neo4j.internal.kernel.api.security.SecurityContext; import org.neo4j.kernel.internal.GraphDatabaseAPI; import org.neo4j.logging.Log; +import org.neo4j.graphdb.ExecutionPlanDescription; +import org.neo4j.graphdb.Result; import org.neo4j.logging.NullLog; +import org.neo4j.procedure.Mode; import org.neo4j.procedure.TerminationGuard; import org.neo4j.values.storable.CoordinateReferenceSystem; import org.neo4j.values.storable.PointValue; @@ -78,6 +81,7 @@ import java.util.Set; import java.util.Spliterator; import java.util.Spliterators; +import java.util.TreeSet; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -1008,15 +1012,66 @@ public static Set intersection(Collection a, Collection b) { } public static void validateQuery(GraphDatabaseService db, String statement, QueryExecutionType.QueryType... supportedQueryTypes) { - if (!isQueryValid(db, statement, supportedQueryTypes)) { - throw new RuntimeException("Supported query types for the operation are " + Arrays.toString(supportedQueryTypes)); - } + validateQuery(db, statement, Collections.emptySet(), supportedQueryTypes); + } + + public static void validateQuery(GraphDatabaseService db, String statement, Set supportedModes, QueryExecutionType.QueryType... supportedQueryTypes) { + db.executeTransactionally("EXPLAIN " + statement, Collections.emptyMap(), result -> { + + if (!isQueryTypeValid(result, supportedQueryTypes)) { + throw new RuntimeException("Supported query types for the operation are " + Arrays.toString(supportedQueryTypes)); + } + + if (!procsAreValid(db, supportedModes, result)) { + throw new RuntimeException("Supported inner procedure modes for the operation are " + new TreeSet<>(supportedModes)); + } + + return null; + }); } public static boolean isQueryValid(GraphDatabaseService db, String statement, QueryExecutionType.QueryType... supportedQueryTypes) { - return db.executeTransactionally("EXPLAIN " + statement, Collections.emptyMap(), result -> - supportedQueryTypes == null || supportedQueryTypes.length == 0 || Stream.of(supportedQueryTypes) - .anyMatch(sqt -> sqt.equals(result.getQueryExecutionType().queryType()))); + return db.executeTransactionally("EXPLAIN " + statement, Collections.emptyMap(), + res -> isQueryTypeValid(res, supportedQueryTypes)); + } + + private static boolean isQueryTypeValid(Result result, QueryExecutionType.QueryType[] supportedQueryTypes) { + return supportedQueryTypes == null || supportedQueryTypes.length == 0 || Stream.of(supportedQueryTypes) + .anyMatch(sqt -> sqt.equals(result.getQueryExecutionType().queryType())); + } + + private static boolean procsAreValid(GraphDatabaseService db, Set supportedModes, Result result) { + if (supportedModes != null && !supportedModes.isEmpty()) { + final ExecutionPlanDescription executionPlanDescription = result.getExecutionPlanDescription(); + // get procedures used in the query + Set queryProcNames = new HashSet<>(); + getAllQueryProcs(executionPlanDescription, queryProcNames); + + if (!queryProcNames.isEmpty()) { + final Set modes = supportedModes.stream().map(Mode::name).collect(Collectors.toSet()); + // check if sub-procedures have valid mode + final Set procNames = db.executeTransactionally("SHOW PROCEDURES YIELD name, mode where mode in $modes return name", + Map.of("modes", modes), + r -> Iterators.asSet(r.columnAs("name"))); + + return procNames.containsAll(queryProcNames); + } + } + + return true; + } + + public static void getAllQueryProcs(ExecutionPlanDescription executionPlanDescription, Set procs) { + executionPlanDescription.getChildren().forEach(i -> { + // if executionPlanDescription is a ProcedureCall + // we return proc. name from "Details" + // by extracting up to the first `(` char, e.g. apoc.schema.assert(null, null).... + if (i.getName().equals("ProcedureCall")) { + final String procName = ((String) i.getArguments().get("Details")).split("\\(")[0]; + procs.add(procName); + } + getAllQueryProcs(i, procs); + }); } /** diff --git a/core/src/test/java/apoc/periodic/PeriodicTest.java b/core/src/test/java/apoc/periodic/PeriodicTest.java index 88aa7875c7..c68ad32dce 100644 --- a/core/src/test/java/apoc/periodic/PeriodicTest.java +++ b/core/src/test/java/apoc/periodic/PeriodicTest.java @@ -1,5 +1,7 @@ package apoc.periodic; +import apoc.cypher.Cypher; +import apoc.schema.Schemas; import apoc.util.MapUtil; import apoc.util.TestUtil; import org.junit.Before; @@ -31,8 +33,8 @@ import java.util.stream.LongStream; import java.util.stream.Stream; -import static apoc.periodic.Periodic.ERROR_DATE_BEFORE; import static apoc.periodic.Periodic.applyPlanner; +import static apoc.periodic.PeriodicUtils.ERROR_DATE_BEFORE; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map; @@ -59,7 +61,7 @@ public class PeriodicTest { @Before public void initDb() throws Exception { - TestUtil.registerProcedure(db, Periodic.class); + TestUtil.registerProcedure(db, Periodic.class, Schemas.class, Cypher.class); db.executeTransactionally("call apoc.periodic.list() yield name call apoc.periodic.cancel(name) yield name as name2 return count(*)"); } @@ -85,6 +87,42 @@ public void testSubmitStatement() throws Exception { testCall(db, callList, (r) -> assertEquals(true, r.get("done"))); } + @Test + public void testSubmitWithCreateIndexSchemaOperation() { + String errMessage = "Supported query types for the operation are [READ_ONLY, WRITE, READ_WRITE]"; + testSchemaOperationCommon("CREATE INDEX periodicIdx FOR (n:Bar) ON (n.first_name, n.last_name)", errMessage); + } + + @Test + public void testSubmitWithSchemaProcedure() { + String errMessage = "Supported inner procedure modes for the operation are [READ, WRITE, DEFAULT]"; + + // built-in neo4j procedure + final String createCons = "CALL db.createUniquePropertyConstraint('uniqueConsName', ['Alpha', 'Beta'], ['foo', 'bar'], 'lucene-1.0')"; + testSchemaOperationCommon(createCons, errMessage); + + // apoc procedures + testSchemaOperationCommon("CALL apoc.schema.assert({}, {})", errMessage); + testSchemaOperationCommon("CALL apoc.cypher.runSchema('CREATE CONSTRAINT periodicIdx FOR (n:Bar) REQUIRE n.first_name IS UNIQUE', {})", errMessage); + + // inner schema procedure + final String innerSchema = "CALL { WITH 1 AS one CALL apoc.schema.assert({}, {}) YIELD key RETURN key } " + + "RETURN 1"; + testSchemaOperationCommon(innerSchema, errMessage); + } + + private void testSchemaOperationCommon(String query, String errMessage) { + try { + testCall(db, "CALL apoc.periodic.submit('subSchema', $query)", + Map.of("query", query), + (row) -> fail("Should fail because of unsupported schema operation")); + } catch (RuntimeException e) { + final String expected = "Failed to invoke procedure `apoc.periodic.submit`: " + + "Caused by: java.lang.RuntimeException: " + errMessage; + assertEquals(expected, e.getMessage()); + } + } + @Test public void testSubmitStatementAtTime() { String callList = "CALL apoc.periodic.list()";