Skip to content

Commit

Permalink
Version 0.3.0 (#5)
Browse files Browse the repository at this point in the history
* SubjectAtlName support
  • Loading branch information
yannbouteiller authored Jan 6, 2024
1 parent 743b8f4 commit 3af67ab
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 43 deletions.
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from setuptools import setup, find_packages
import sys

from pathlib import Path


if sys.version_info < (3, 7):
sys.exit('Sorry, Python < 3.7 is not supported, upgrade your python installation to use tlspyo.')
Expand All @@ -14,8 +12,8 @@

setup(name='tlspyo',
packages=[package for package in find_packages()],
version='0.2.5',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.5.tar.gz',
version='0.3.0',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.3.0.tar.gz',
license='MIT',
description='Secure transport of python objects using TLS encryption',
long_description=long_description,
Expand Down
60 changes: 36 additions & 24 deletions tlspyo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self,

assert accepted_groups is None or isinstance(accepted_groups, dict), "Invalid format for accepted_groups."

self._stopped = False
self._header_size = header_size
self._local_com_port = local_com_port
self._local_com_srv = socket(AF_INET, SOCK_STREAM)
Expand All @@ -93,6 +92,9 @@ def __init__(self,
self._local_com_conn, self._local_com_addr = self._local_com_srv.accept()
self._send_local('TEST')

self._stop_lock = Lock()
self._stopped = False

def __del__(self):
self.stop()

Expand All @@ -105,14 +107,19 @@ def stop(self):
"""
Stop the Relay.
"""
if not self._stopped:
self._stopped = True
self._send_local('STOP')
try:
with self._stop_lock:
if not self._stopped:
self._send_local('STOP')

self._p.join()
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._p.join()
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._stopped = True
except KeyboardInterrupt as e:
self.stop()
raise e


class Endpoint:
Expand Down Expand Up @@ -172,8 +179,6 @@ def __init__(self,
elif security == "SSL":
security = "TLS"

self._stopped = False

# threading for local object receiving
self.__obj_buffer = queue.Queue()
self.__socket_closed_lock = Lock()
Expand Down Expand Up @@ -221,6 +226,9 @@ def __init__(self,
self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=True)
self._t_manage_received_objects.start()

self._stop_lock = Lock()
self._stopped = False

def __del__(self):
self.stop()

Expand All @@ -236,7 +244,6 @@ def _manage_received_objects(self):
# Check if socket is still open
with self.__socket_closed_lock:
if self.__socket_closed_flag:
self._local_com_conn.close()
return

buf += self._local_com_conn.recv(self._max_buf_len)
Expand Down Expand Up @@ -357,22 +364,27 @@ def stop(self):
"""
Stop the Endpoint.
"""
if not self._stopped:
self._stopped = True
# send STOP to the local server
self._send_local(cmd='STOP', dest=None, obj=None)
try:
with self._stop_lock:
if not self._stopped:
# send STOP to the local server
self._send_local(cmd='STOP', dest=None, obj=None)

# Join the message reading thread
with self.__socket_closed_lock:
self.__socket_closed_flag = True
self._t_manage_received_objects.join()
# Join the message reading thread
with self.__socket_closed_lock:
self.__socket_closed_flag = True
self._t_manage_received_objects.join()

# join Twisted process and stop local server
self._p.join()
# join Twisted process and stop local server
self._p.join()

self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._stopped = True
except KeyboardInterrupt as e:
self.stop()
raise e

def _process_received_list(self, received_list):
if self._deserialize_locally:
Expand Down
2 changes: 1 addition & 1 deletion tlspyo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def dataReceived(self, data):
stamp, cmd, obj = self._client.deserializer(self._buffer[i:j])
if cmd == 'ACK':
try:
logger.info(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.")
logger.debug(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.")
del self._client.pending_acks[stamp] # delete pending ACK
except KeyError:
logger.warning(f"Received ACK for stamp {stamp} not present in pending ACKs.")
Expand Down
40 changes: 31 additions & 9 deletions tlspyo/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def generate_tls_credentials(
folder_path,
email_address="emailAddress",
common_name="default",
subject_alt_name=('DNS:default',),
country_name="CA",
locality_name="localityName",
state_or_province_name="stateOrProvinceName",
Expand All @@ -42,6 +43,7 @@ def generate_tls_credentials(
folder_path (path-like object): path were the files will be created
email_address (str): your email address
common_name (str): your hostname
subject_alt_name (tuple of str): your subject alt name list
country_name (str): your country code
locality_name (str): your locality name
state_or_province_name (str): your state name
Expand All @@ -57,18 +59,26 @@ def generate_tls_credentials(
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 4096)
cert = crypto.X509()
cert.get_subject().C = country_name
cert.get_subject().ST = state_or_province_name
cert.get_subject().L = locality_name
cert.get_subject().O = organization_name
cert.get_subject().OU = organization_unit_name
cert.get_subject().CN = common_name
cert.get_subject().emailAddress = email_address
cert.set_serial_number(serial_number)

subject = cert.get_subject()
subject.commonName = common_name
subject.emailAddress = email_address
subject.organizationName = organization_name
subject.organizationalUnitName = organization_unit_name
subject.localityName = locality_name
subject.stateOrProvinceName = state_or_province_name
subject.countryName = country_name

cert.set_issuer(subject)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(validity_end_in_seconds)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.set_serial_number(serial_number)
cert.set_version(2) # for SAN
cert.add_extensions([
crypto.X509Extension(b'subjectAltName', False, ','.join(subject_alt_name).encode())
])

cert.sign(k, 'sha512')
with open(cert_file, "wt") as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
Expand All @@ -87,6 +97,7 @@ def credentials_generator_tool(custom=False):
folder_path = get_default_keys_folder()
email_address = "emailAddress"
common_name = "default"
subject_alt_name = ["DNS:" + common_name]
country_name = "CA"
locality_name = "localityName"
state_or_province_name = "stateOrProvinceName"
Expand Down Expand Up @@ -118,6 +129,16 @@ def credentials_generator_tool(custom=False):
common_name = inp
print(common_name)

subject_alt_name = ["DNS:" + common_name]
print(f"\nSubject alternative name (hostnames, leave empty to stop adding) {subject_alt_name}:")
inp = input()
if inp != "":
subject_alt_name = []
while inp != "":
subject_alt_name.append(inp)
inp = input()
print(subject_alt_name)

print(f"\nCountry code [{country_name}]:")
inp = input()
if inp != "":
Expand Down Expand Up @@ -163,6 +184,7 @@ def credentials_generator_tool(custom=False):
generate_tls_credentials(folder_path=folder_path,
email_address=email_address,
common_name=common_name,
subject_alt_name=tuple(subject_alt_name),
country_name=country_name,
locality_name=locality_name,
state_or_province_name=state_or_province_name,
Expand Down
10 changes: 5 additions & 5 deletions tlspyo/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import signal
# import signal
import queue


try:
signal.signal(signal.SIGINT, signal.SIG_DFL)
except Exception as e:
pass
# try:
# signal.signal(signal.SIGINT, signal.SIG_DFL)
# except Exception as e:
# pass


def wait_event(event):
Expand Down

0 comments on commit 3af67ab

Please sign in to comment.