Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it compliant with testkit #7

Open
wants to merge 2 commits into
base: 5.4/home-db-res
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
let name
let address

if (database == null) {
database = this._homeDbCache.get({ impersonatedUser, auth })
}

const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database, auth, impersonatedUser, onDatabaseNameResolved })

const databaseSpecificErrorHandler = new ConnectionErrorHandler(
SESSION_EXPIRED,
Expand All @@ -164,7 +160,9 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
auth,
onDatabaseNameResolved: (databaseName) => {
context.database = context.database || databaseName
this._homeDbCache.set({ impersonatedUser, auth, databaseName })
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ impersonatedUser, auth, databaseName })
}
if (onDatabaseNameResolved) {
onDatabaseNameResolved(databaseName)
}
Expand Down Expand Up @@ -286,13 +284,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
return this._verifyAuthentication({
auth,
getAddress: async () => {
const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database, auth })

const routingTable = await this._freshRoutingTable({
accessMode,
database: context.database,
auth,
onDatabaseNameResolved: (databaseName) => {
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ auth, databaseName })
}
context.database = context.database || databaseName
}
})
Expand All @@ -312,13 +313,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
}

async verifyConnectivityAndGetServerInfo ({ database, accessMode }) {
const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database })

const routingTable = await this._freshRoutingTable({
accessMode,
database: context.database,
onDatabaseNameResolved: (databaseName) => {
context.database = context.database || databaseName
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ databaseName })
}
}
})

Expand Down Expand Up @@ -675,6 +679,22 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
this._log.info(`Updated routing table ${newRoutingTable}`)
}

_createContext ({ database, auth, impersonatedUser, onDatabaseNameResolved }) {
const inputDatabase = database || DEFAULT_DB_NAME
return {
database: inputDatabase || this._resolveDatabaseNameFromCache({ impersonatedUser, auth, onDatabaseNameResolved }),
homeDatabaseResolution: inputDatabase === DEFAULT_DB_NAME
}
}

_resolveDatabaseNameFromCache ({ impersonatedUser, auth, onDatabaseNameResolved }) {
const database = this._homeDbCache.get({ impersonatedUser, auth })
if (database != null && onDatabaseNameResolved != null) {
onDatabaseNameResolved(database)
}
return database
}

static _forgetRouter (routingTable, routersArray, routerIndex) {
const address = routersArray[routerIndex]
if (routingTable && address) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
* limitations under the License.
*/

const DEFAULT_KEY = -1

export default class HomeDBCache {
constructor ({ maxHomeDatabaseDelay }) {
this._disabled = maxHomeDatabaseDelay === 0
this._maxHomeDatabaseDelay = maxHomeDatabaseDelay || 5000
this._cache = new Map()
}

set ({ impersonatedUser, auth, databaseName }) {
if (this._disabled) {
return null
}
if (databaseName == null) {
return null
}
Expand All @@ -32,17 +38,20 @@ export default class HomeDBCache {
let key = impersonatedUser || auth

if (key == null) {
key = 'null' // This is for when auth is turned off basically
key = DEFAULT_KEY // This is for when auth is turned off basically
}

this._cache.set(key, { databaseName: databaseName, insertTime: Date.now() })
}
}

get ({ impersonatedUser, auth }) {
if (this._disabled) {
return null
}
let key = impersonatedUser || auth
if (key == null) {
key = 'null' // This is for when auth is turned off basically
key = DEFAULT_KEY // This is for when auth is turned off basically
}

const dbAndCreatedTime = this._cache.get(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2745,38 +2745,7 @@ describe.each([
expect(pool.has(server1)).toBeTruthy()
})

it.each(usersDataSet)('should call onDatabaseNameResolved with the resolved db acquiring home db [user=%s] then do not call it after calling the cache instead', async (user) => {
const pool = newPool()
const connectionProvider = newRoutingConnectionProvider(
[],
pool,
{
null: {
'server-non-existing-seed-router:7687': newRoutingTableWithUser({
database: null,
routers: [server1, server2, server3],
readers: [server1, server2],
writers: [server3],
user,
routingTableDatabase: 'homedb'
})
}
}
)
const onDatabaseNameResolved = jest.fn()

// Acquire connection once to set up the cache
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved })
expect(onDatabaseNameResolved).toHaveBeenCalledWith('homedb')

const onDatabaseNameResolvedUnCalled = jest.fn()

// Acquire connection again and that should hit the cache meaning onDatabaseNameResoloved will not be hit as it is retrieved from the cache
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolvedUnCalled })
expect(onDatabaseNameResolvedUnCalled).not.toHaveBeenCalled()
})

it.each(usersDataSet)('should call onDatabaseNameResolved twice after clearing the home db cache [user=%s]', async (user) => {
it.each(usersDataSet)('should call onDatabaseNameResolved always after resolve the database (independent if it comes from cached or server) [user=%s]', async (user) => {
const pool = newPool()
const connectionProvider = newRoutingConnectionProvider(
[],
Expand All @@ -2800,18 +2769,16 @@ describe.each([
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolved })
expect(onDatabaseNameResolved).toHaveBeenCalledWith('homedb')

const onDatabaseNameResolvedUnCalled = jest.fn()

// Acquire connection again and that should hit the cache meaning onDatabaseNameResoloved will not be hit as it is retrieved from the cache
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolvedUnCalled })
expect(onDatabaseNameResolvedUnCalled).not.toHaveBeenCalled()
// Acquire connection again and that should hit the cache meaning onDatabaseNameResolved should be hit with cache information
// since the session needs to have this information for next transactions
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolved })
expect(onDatabaseNameResolved).toBeCalledTimes(2)

connectionProvider.forceHomeDbResolution()

const onDatabaseNameResolved2 = jest.fn()
// Acquire connection again and that should not hit the cache meaning onDatabaseNameResolved will be hit again as the cache was removed
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolved2 })
expect(onDatabaseNameResolved2).toHaveBeenCalledWith('homedb')
await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved: onDatabaseNameResolved })
expect(onDatabaseNameResolved).toBeCalledTimes(3)
})

it.each(usersDataSet)('should call onDatabaseNameResolved with the resolved db acquiring named db [user=%s]', async (user) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
let name
let address

if (database == null) {
database = this._homeDbCache.get({ impersonatedUser, auth })
}

const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database, auth, impersonatedUser, onDatabaseNameResolved })

const databaseSpecificErrorHandler = new ConnectionErrorHandler(
SESSION_EXPIRED,
Expand All @@ -164,7 +160,9 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
auth,
onDatabaseNameResolved: (databaseName) => {
context.database = context.database || databaseName
this._homeDbCache.set({ impersonatedUser, auth, databaseName })
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ impersonatedUser, auth, databaseName })
}
if (onDatabaseNameResolved) {
onDatabaseNameResolved(databaseName)
}
Expand Down Expand Up @@ -286,13 +284,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
return this._verifyAuthentication({
auth,
getAddress: async () => {
const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database, auth })

const routingTable = await this._freshRoutingTable({
accessMode,
database: context.database,
auth,
onDatabaseNameResolved: (databaseName) => {
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ auth, databaseName })
}
context.database = context.database || databaseName
}
})
Expand All @@ -312,13 +313,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
}

async verifyConnectivityAndGetServerInfo ({ database, accessMode }) {
const context = { database: database || DEFAULT_DB_NAME }
const context = this._createContext({ database })

const routingTable = await this._freshRoutingTable({
accessMode,
database: context.database,
onDatabaseNameResolved: (databaseName) => {
context.database = context.database || databaseName
if (context.homeDatabaseResolution) {
this._homeDbCache.set({ databaseName })
}
}
})

Expand Down Expand Up @@ -675,6 +679,22 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
this._log.info(`Updated routing table ${newRoutingTable}`)
}

_createContext ({ database, auth, impersonatedUser, onDatabaseNameResolved }) {
const inputDatabase = database || DEFAULT_DB_NAME
return {
database: inputDatabase || this._resolveDatabaseNameFromCache({ impersonatedUser, auth, onDatabaseNameResolved }),
homeDatabaseResolution: inputDatabase === DEFAULT_DB_NAME
}
}

_resolveDatabaseNameFromCache ({ impersonatedUser, auth, onDatabaseNameResolved }) {
const database = this._homeDbCache.get({ impersonatedUser, auth })
if (database != null && onDatabaseNameResolved != null) {
onDatabaseNameResolved(database)
}
return database
}

static _forgetRouter (routingTable, routersArray, routerIndex) {
const address = routersArray[routerIndex]
if (routingTable && address) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

const DEFAULT_KEY = -1

export default class HomeDBCache {
constructor ({ maxHomeDatabaseDelay }) {
this._disabled = maxHomeDatabaseDelay === 0
this._maxHomeDatabaseDelay = maxHomeDatabaseDelay || 5000
this._cache = new Map()
}

set ({ impersonatedUser, auth, databaseName }) {
if (this._disabled) {
return null
}
if (databaseName == null) {
return null
}

if (this._maxHomeDatabaseDelay > 0) {
let key = impersonatedUser || auth

if (key == null) {
key = DEFAULT_KEY // This is for when auth is turned off basically
}

this._cache.set(key, { databaseName: databaseName, insertTime: Date.now() })
}
}

get ({ impersonatedUser, auth }) {
if (this._disabled) {
return null
}
let key = impersonatedUser || auth
if (key == null) {
key = DEFAULT_KEY // This is for when auth is turned off basically
}

const dbAndCreatedTime = this._cache.get(key)

if (dbAndCreatedTime == null) {
return null
}

if (Date.now() > dbAndCreatedTime.insertTime + this._maxHomeDatabaseDelay) {
this._cache.delete(key)
return null
} else {
return this._cache.get(key).databaseName
}
}

clearCache () {
this._cache = new Map()
}
}
1 change: 1 addition & 0 deletions packages/testkit-backend/src/feature/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const features = [
'Feature:API:Driver.VerifyAuthentication',
'Feature:API:Driver.VerifyConnectivity',
'Feature:API:Session:NotificationsConfig',
'Feature:HomeDbCache',
'Optimization:AuthPipelining',
'Optimization:EagerTransactionBegin',
'Optimization:ImplicitDefaultArguments',
Expand Down
3 changes: 2 additions & 1 deletion packages/testkit-backend/src/request-handlers-rx.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ export {
ExpirationBasedAuthTokenProviderCompleted,
FakeTimeInstall,
FakeTimeTick,
FakeTimeUninstall
FakeTimeUninstall,
ForceHomeDatabaseResolution
} from './request-handlers.js'

export function NewSession ({ neo4j }, context, data, wire) {
Expand Down
14 changes: 14 additions & 0 deletions packages/testkit-backend/src/request-handlers.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ export function NewDriver ({ neo4j }, context, data, wire) {
disabledCategories: data.notificationsDisabledCategories
}
}
if ('maxHomeDatabaseDelayMs' in data) {
config.maxHomeDatabaseDelay = data.maxHomeDatabaseDelayMs
}
let driver
try {
driver = neo4j.driver(uri, parsedAuthToken, config)
Expand Down Expand Up @@ -600,6 +603,17 @@ export function GetRoutingTable (_, context, { driverId, database }, wire) {
}
}

export function ForceHomeDatabaseResolution (_, context, { driverId }, wire) {
const driver = context.getDriver(driverId)

if (driver) {
driver.forceHomeDbResolution()
wire.writeResponse(responses.Driver({ id: driverId }))
} else {
wire.writeError('Driver not found!')
}
}

export function ForcedRoutingTableUpdate (_, context, { driverId, database, bookmarks }, wire) {
const driver = context.getDriver(driverId)
const provider = driver._getOrCreateConnectionProvider()
Expand Down