Skip to content

Commit

Permalink
MutableMapping -> Mapping as modifications are handled by merge().
Browse files Browse the repository at this point in the history
  • Loading branch information
moodyjon committed Nov 11, 2022
1 parent e305881 commit 75f83ef
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 21 deletions.
37 changes: 17 additions & 20 deletions lbry/schema/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os.path
import hashlib
from collections.abc import MutableMapping, Iterable
from collections.abc import Mapping, Iterable
from typing import Tuple, List
from string import ascii_letters
from decimal import Decimal, ROUND_UP
Expand All @@ -24,7 +24,6 @@
Location as LocationMessage,
Language as LanguageMessage,
)
from google.protobuf.struct_pb2 import Struct as StructMessage
from lbry.schema.types.v2.extension_pb2 import Extension as ExtensionMessage

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -714,7 +713,7 @@ def merge(self, ext: 'StreamExtension', delete: bool = False) -> 'StreamExtensio
self.unpacked.merge(ext.unpacked, delete=delete)
return self

class Struct(Metadata, MutableMapping, Iterable):
class Struct(Metadata, Mapping, Iterable):
__slots__ = ()

def to_dict(self) -> dict:
Expand Down Expand Up @@ -756,23 +755,28 @@ def merge(self, other: 'Struct', delete: bool = False) -> 'Struct':
return self

def __getitem__(self, key):
def extract(val):
if not isinstance(val, ProtobufMessage):
return val
kind = val.WhichOneof('kind')
if kind == 'struct_value':
return dict(Struct(val.struct_value))
elif kind == 'list_value':
return list(map(extract, val.list_value.values))
else:
return getattr(val, kind)
if key in self.message.fields:
return self.message.fields[key]
val = self.message.fields[key]
return extract(val)
raise KeyError(key)

def __setitem__(self, key, value):
self.message.fields[key].CopyFrom(value.message)

def __delitem__(self, key):
del self.message.fields[key]

def __iter__(self):
return iter(self.message.fields)

def __len__(self):
return len(self.message.fields)

class StreamExtensionMap(Metadata, MutableMapping, Iterable):
class StreamExtensionMap(Metadata, Mapping, Iterable):
__slots__ = ()
item_class = StreamExtension

Expand All @@ -791,7 +795,7 @@ def merge(self, exts, delete: bool = False) -> 'StreamExtensionMap':
else:
obj.from_value({schema: ext})
if delete and not len(obj.unpacked):
del self[schema]
del self.message[schema]
continue
existing = StreamExtension(schema, self.message[schema])
existing.merge(obj, delete=delete)
Expand All @@ -802,15 +806,8 @@ def __getitem__(self, key):
return StreamExtension(key, self.message[key])
raise KeyError(key)

def __setitem__(self, key, value):
del self.message[key]
self.message[key].CopyFrom(value.message)

def __delitem__(self, key):
del self.message[key]

def __iter__(self):
return self.message.__iter__()
return iter(self.message)

def __len__(self):
return len(self.message)
Expand Down
35 changes: 34 additions & 1 deletion tests/unit/schema/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

from lbry.schema.claim import Claim, Stream, Collection
from lbry.schema.attrs import StreamExtension
from lbry.schema.attrs import StreamExtension, Struct
from google.protobuf.struct_pb2 import Struct as StructMessage
from lbry.schema.types.v2.extension_pb2 import Extension as ExtensionMessage
from lbry.error import InputValueIsNoneError
Expand Down Expand Up @@ -253,14 +253,17 @@ def setUp(self):

def test_extension_properties(self):
self.maxDiff = None

# Verify schema.
self.assertEqual(self.ext1.schema, 'cad')
self.assertEqual(self.ext2.schema, 'music')
self.assertEqual(self.ext3.schema, 'lit')

# Verify to_dict().
self.assertEqual(self.ext1.to_dict(), self.ext1_dict)
self.assertEqual(self.ext2.to_dict(), self.ext2_dict)
self.assertEqual(self.ext3.to_dict(), self.ext3_dict)

# Decode from dict.
parsed1 = StreamExtension(None, ExtensionMessage())
parsed1.from_value(self.ext1_dict)
Expand All @@ -271,6 +274,7 @@ def test_extension_properties(self):
parsed3 = StreamExtension(None, ExtensionMessage())
parsed3.from_value(self.ext3_dict)
self.assertEqual(parsed3.to_dict(), self.ext3_dict)

# Decode from str (JSON).
parsed1 = StreamExtension(None, ExtensionMessage())
parsed1.from_value(self.ext1_json)
Expand All @@ -282,6 +286,35 @@ def test_extension_properties(self):
parsed3.from_value(self.ext3_json)
self.assertEqual(parsed3.to_dict(), self.ext3_dict)

# Verify Mapping functionality.
self.assertEqual(self.ext1.unpacked['material'], ['PLA1', 'PLA2'])
self.assertEqual(self.ext1.unpacked['cubic_cm'], 5)
self.assertEqual(self.ext2.unpacked['venue'], 'studio')
self.assertEqual(self.ext2.unpacked['genre'], ['metal'])
self.assertEqual(self.ext2.unpacked['instrument'], ['drum', 'cymbal', 'guitar'])
self.assertEqual(self.ext3.unpacked['pages'], 185)
self.assertEqual(self.ext3.unpacked['genre'], ['fiction', 'mystery'])
self.assertEqual(self.ext3.unpacked['format'], 'epub')

# Verify Iterable functionality.
self.assertEqual(len(self.ext1.unpacked), 2)
for k, v in self.ext1.unpacked.items():
self.assertIn(k, self.ext1.unpacked)
self.assertTrue(isinstance(v, (str, list, float)), type(v))
self.assertEqual(v, self.ext1.unpacked[k])
self.assertEqual(len(self.ext2.unpacked), 3)
for k, v in self.ext2.unpacked.items():
self.assertIn(k, self.ext2.unpacked)
self.assertTrue(isinstance(v, (str, list, float)), type(v))
self.assertEqual(v, self.ext2.unpacked[k])
self.assertEqual(len(self.ext3.unpacked), 3)
for k, v in self.ext3.unpacked.items():
self.assertIn(k, self.ext3.unpacked)
self.assertTrue(isinstance(v, (str, list, float)), type(v))
self.assertEqual(v, self.ext3.unpacked[k])



def test_extension_clear_field(self):
self.maxDiff = None
ext = StreamExtension(None, ExtensionMessage())
Expand Down

0 comments on commit 75f83ef

Please sign in to comment.