From f9d6f26f1a674366da3d1adfe414ed66480e710f Mon Sep 17 00:00:00 2001 From: Ryan Schumacher Date: Tue, 5 Nov 2024 12:24:32 -0600 Subject: [PATCH] fix(sdk): reset reader after checking if IsNanoTDF (#1718) Closes #1717 --- sdk/sdk.go | 12 ++++++------ sdk/sdk_test.go | 12 ++++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sdk/sdk.go b/sdk/sdk.go index b57591a5a..7704dcc5b 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -291,7 +291,8 @@ func (t TdfType) String() string { return string(t) } -// String method to make the custom type printable +// GetTdfType returns the type of TDF based on the reader. +// Reader is reset after the check. func GetTdfType(reader io.ReadSeeker) TdfType { isValidNanoTdf, _ := IsValidNanoTdf(reader) @@ -348,13 +349,12 @@ func IsValidTdf(reader io.ReadSeeker) (bool, error) { return true, nil } +// IsValidNanoTdf detects whether, or not the reader is a valid Nano TDF. +// Reader is reset after the check. func IsValidNanoTdf(reader io.ReadSeeker) (bool, error) { _, _, err := NewNanoTDFHeaderFromReader(reader) - if err != nil { - return false, err - } - - return true, nil + _, _ = reader.Seek(0, io.SeekStart) // Ignore the error as we're just checking if it's a valid nano TDF + return err == nil, err } func fetchPlatformConfiguration(platformEndpoint string, dialOptions []grpc.DialOption) (PlatformConfiguration, error) { diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 7b578dc7c..35d1c7831 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -145,6 +145,12 @@ func TestNew_ShouldValidateGoodNanoTdf(t *testing.T) { require.NoError(t, err) assert.True(t, isValid) + + // Try again to see if the reader has been reset + isValid, err = sdk.IsValidNanoTdf(in) + require.NoError(t, err) + + assert.True(t, isValid) } func TestNew_ShouldNotValidateBadNanoTdf(t *testing.T) { @@ -169,6 +175,12 @@ func TestNew_ShouldValidateStandardTdf(t *testing.T) { require.NoError(t, err) assert.True(t, isValid) + + // Try again to see if the reader has been reset + isValid, err = sdk.IsValidTdf(in) + require.NoError(t, err) + + assert.True(t, isValid) } func TestNew_ShouldNotValidateBadStandardTdf(t *testing.T) {