Skip to content

Commit

Permalink
Merge pull request #1234 from France-ioi/ask_hint_fixes
Browse files Browse the repository at this point in the history
Handle DB errors during token unmarshalling correctly in itemGetHintToken & saveGrade
  • Loading branch information
zenovich authored Jan 7, 2025
2 parents 05343cc + 7af568a commit cc4ea9a
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 4 deletions.
2 changes: 1 addition & 1 deletion app/api/items/ask_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func (requestData *AskHintRequest) unmarshalHintToken(wrapper *askHintRequestWra
wrapper.HintRequestedToken.Bytes(),
"hint_requested",
)
if err != nil {
if err != nil && !token.IsUnexpectedError(err) {
return err
}
service.MustNotBeError(err)
Expand Down
31 changes: 29 additions & 2 deletions app/api/items/ask_hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/France-ioi/AlgoreaBackend/v2/app/payloadstest"
"github.com/France-ioi/AlgoreaBackend/v2/app/token"
"github.com/France-ioi/AlgoreaBackend/v2/app/tokentest"
"github.com/France-ioi/AlgoreaBackend/v2/testhelpers/testoutput"
)

func TestAskHintRequest_UnmarshalJSON(t *testing.T) {
Expand Down Expand Up @@ -127,12 +128,14 @@ func TestAskHintRequest_UnmarshalJSON(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
testoutput.SuppressIfPasses(t)

db, mock := database.NewDBMock()
defer func() { _ = db.Close() }()

if tt.mockDB {
mockQuery := mock.ExpectQuery(regexp.QuoteMeta("SELECT public_key " +
"FROM `platforms` JOIN items ON items.platform_id = platforms.id WHERE (items.id = ?) LIMIT 1")).
mockQuery := mock.ExpectQuery("^" + regexp.QuoteMeta("SELECT public_key "+
"FROM `platforms` JOIN items ON items.platform_id = platforms.id WHERE (items.id = ?) LIMIT 1") + "$").
WithArgs(tt.itemID)

if tt.platform != nil {
Expand Down Expand Up @@ -176,3 +179,27 @@ func TestAskHintRequest_UnmarshalJSON(t *testing.T) {
})
}
}

func TestAskHintRequest_UnmarshalJSON_DBError(t *testing.T) {
testoutput.SuppressIfPasses(t)

db, mock := database.NewDBMock()
defer func() { _ = db.Close() }()

expectedError := errors.New("error")
mock.ExpectQuery("^" + regexp.QuoteMeta("SELECT public_key "+
"FROM `platforms` JOIN items ON items.platform_id = platforms.id WHERE (items.id = ?) LIMIT 1") + "$").
WithArgs(901756573345831409).WillReturnError(expectedError)

r := &AskHintRequest{
store: database.NewDataStore(db),
publicKey: tokentest.AlgoreaPlatformPublicKeyParsed,
}
assert.PanicsWithError(t, expectedError.Error(), func() {
_ = r.UnmarshalJSON([]byte(fmt.Sprintf(`{"task_token": %q, "hint_requested": %q}`,
token.Generate(payloadstest.TaskPayloadFromAlgoreaPlatform, tokentest.AlgoreaPlatformPrivateKeyParsed),
token.Generate(payloadstest.HintPayloadFromTaskPlatform, tokentest.TaskPlatformPrivateKeyParsed),
)))
})
assert.NoError(t, mock.ExpectationsWereMet())
}
3 changes: 2 additions & 1 deletion app/api/items/save_grade.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,10 @@ func (requestData *saveGradeRequestParsed) unmarshalScoreToken(wrapper *saveGrad
wrapper.ScoreToken.Bytes(),
"score_token",
)
if err != nil {
if err != nil && !token.IsUnexpectedError(err) {
return err
}
service.MustNotBeError(err)
}

if !hasScoreToken || !hasPlatformKey {
Expand Down
40 changes: 40 additions & 0 deletions app/api/items/save_grade_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package items

import (
"errors"
"fmt"
"regexp"
"testing"

"github.com/stretchr/testify/assert"

"github.com/France-ioi/AlgoreaBackend/v2/app/database"
"github.com/France-ioi/AlgoreaBackend/v2/app/payloadstest"
"github.com/France-ioi/AlgoreaBackend/v2/app/token"
"github.com/France-ioi/AlgoreaBackend/v2/app/tokentest"
"github.com/France-ioi/AlgoreaBackend/v2/testhelpers/testoutput"
)

func Test_saveGradeRequestParsed_UnmarshalJSON(t *testing.T) {
testoutput.SuppressIfPasses(t)

db, mock := database.NewDBMock()
defer func() { _ = db.Close() }()

expectedError := errors.New("error")
mock.ExpectQuery("^" + regexp.QuoteMeta("SELECT public_key "+
"FROM `platforms` JOIN items ON items.platform_id = platforms.id WHERE (items.id = ?) LIMIT 1") + "$").
WithArgs(901756573345831409).WillReturnError(expectedError)

r := saveGradeRequestParsed{
store: database.NewDataStore(db),
publicKey: tokentest.AlgoreaPlatformPublicKeyParsed,
}
assert.PanicsWithError(t, expectedError.Error(), func() {
_ = r.UnmarshalJSON([]byte(fmt.Sprintf(`{"score_token": %q, "answer_token": %q}`,
token.Generate(payloadstest.ScorePayloadFromGrader, tokentest.AlgoreaPlatformPrivateKeyParsed),
token.Generate(payloadstest.AnswerPayloadFromAlgoreaPlatform, tokentest.AlgoreaPlatformPrivateKeyParsed),
)))
})
assert.NoError(t, mock.ExpectationsWereMet())
}
33 changes: 33 additions & 0 deletions app/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,24 @@ func Generate(payload map[string]interface{}, privateKey *rsa.PrivateKey) []byte
return token
}

// UnexpectedError represents an unexpected error so that we could differentiate it from expected errors.
type UnexpectedError struct {
err error
}

// Error returns a string representation for an unexpected error.
func (ue *UnexpectedError) Error() string {
return ue.err.Error()
}

// IsUnexpectedError returns true if its argument is an unexpected error.
func IsUnexpectedError(err error) bool {
if _, unexpected := err.(*UnexpectedError); unexpected {
return true
}
return false
}

// UnmarshalDependingOnItemPlatform unmarshals a token from JSON representation
// using a platform's public key for given itemID.
// The function returns nil (success) if the platform doesn't use tokens.
Expand All @@ -144,11 +162,14 @@ func UnmarshalDependingOnItemPlatform(
tokenFieldName string,
) (platformHasKey bool, err error) {
targetRefl := reflect.ValueOf(target)
defer recoverPanics(&err)

publicKey, err := store.Platforms().GetPublicKeyByItemID(itemID)
if gorm.IsRecordNotFoundError(err) {
return false, fmt.Errorf("cannot find the platform for item %d", itemID)
}
mustNotBeError(err)

if publicKey == "" {
return false, nil
}
Expand All @@ -175,3 +196,15 @@ func UnmarshalDependingOnItemPlatform(

return true, nil
}

func mustNotBeError(err error) {
if err != nil {
panic(err)
}
}

func recoverPanics(err *error) { // nolint:gocritic
if r := recover(); r != nil {
*err = &UnexpectedError{err: r.(error)}
}
}
17 changes: 17 additions & 0 deletions app/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,20 @@ func createTmpPrivateKeyFile(key []byte) (*os.File, error) {
_, err = tmpFilePrivate.Write(key)
return tmpFilePrivate, err
}

func Test_recoverPanics_and_mustNotBeError(t *testing.T) {
expectedError := errors.New("some error")
err := func() (err error) {
defer recoverPanics(&err)
mustNotBeError(expectedError)
return nil
}()
assert.Equal(t, &UnexpectedError{expectedError}, err)
assert.Equal(t, expectedError.Error(), err.Error())
}

func Test_UnexpectedError(t *testing.T) {
assert.True(t, IsUnexpectedError(&UnexpectedError{err: errors.New("some error")}))
assert.False(t, IsUnexpectedError(errors.New("some error")))
assert.False(t, IsUnexpectedError(nil))
}

0 comments on commit cc4ea9a

Please sign in to comment.