Skip to content

Commit

Permalink
Generate RestJson event stream implementation
Browse files Browse the repository at this point in the history
This updates generic event stream generation with recently introduced
changes and also introduces the concrete implementation for RestJson.

Testing for all of this will be done via protocol tests, and in the
early days manual testing.

Since a lot of this is effectively throwaway code, I was more liberal
with type ignoring and using Any types than I otherwise would be. The
request pipeline is going to be moving to pure python soon^tm, and
the typing issues will be resolved at that time.
  • Loading branch information
JordonPhillips committed Oct 24, 2024
1 parent 7cec4ee commit a2c2e48
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.traits.DocumentationTrait;
import software.amazon.smithy.model.traits.StringTrait;
import software.amazon.smithy.python.codegen.integration.PythonIntegration;
Expand Down Expand Up @@ -123,6 +124,16 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
}

private void generateOperationExecutor(PythonWriter writer) {
writer.pushState();

var hasStreaming = hasEventStream();
writer.putContext("hasEventStream", hasStreaming);
if (hasStreaming) {
writer.addImports("smithy_core.deserializers", Set.of(
"ShapeDeserializer", "DeserializeableShape"));
writer.addStdlibImport("typing", "Any");
}

var transportRequest = context.applicationProtocol().requestType();
var transportResponse = context.applicationProtocol().responseType();
var errorSymbol = CodegenUtils.getServiceError(context.settings());
Expand Down Expand Up @@ -191,10 +202,18 @@ async def _execute_operation(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
) -> Output:
try:
return await self._handle_execution(
input, plugins, serialize, deserialize, config, operation_name
input, plugins, serialize, deserialize, config, operation_name,
${?hasEventStream}
has_input_stream, event_deserializer, event_response_deserializer,
${/hasEventStream}
)
except Exception as e:
# Make sure every exception that we throw is an instance of $4T so
Expand All @@ -211,6 +230,11 @@ async def _handle_execution(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
) -> Output:
logger.debug(f"Making request for operation {operation_name} with parameters: {input}")
context: InterceptorContext[Input, None, None, None] = InterceptorContext(
Expand Down Expand Up @@ -326,7 +350,16 @@ await sleep(retry_token.retry_delay)
execution_context = cast(
InterceptorContext[Input, Output, $2T | None, $3T | None], context
)
${^hasEventStream}
return await self._finalize_execution(interceptors, execution_context)
${/hasEventStream}
${?hasEventStream}
operation_output = await self._finalize_execution(interceptors, execution_context)
if has_input_stream or event_deserializer is not None:
${6C|}
else:
return operation_output
${/hasEventStream}
async def _handle_attempt(
self,
Expand All @@ -342,7 +375,8 @@ async def _handle_attempt(
for interceptor in interceptors:
interceptor.read_before_attempt(context)
""", pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol);
""", pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol,
writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w)));

boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
Expand Down Expand Up @@ -604,6 +638,18 @@ async def _finalize_execution(
return context.response
""", transportRequest, transportResponse);
writer.dedent();
writer.popState();
}

private boolean hasEventStream() {
var streamIndex = EventStreamIndex.of(context.model());
var topDownIndex = TopDownIndex.of(context.model());
for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) {
if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) {
return true;
}
}
return false;
}

private void initializeHttpAuthParameters(PythonWriter writer) {
Expand Down Expand Up @@ -649,40 +695,7 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {

writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "",
operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> {
writer.writeDocs(() -> {
var docs = operation.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));

var inputDocs = input.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse("The operation's input.");

writer.write("""
$L
:param input: $L
:param plugins: A list of callables that modify the configuration dynamically.
Changes made by these plugins only apply for the duration of the operation
execution and will not affect any other operation invocations.""", docs, inputDocs);
});

var defaultPlugins = new LinkedHashSet<SymbolReference>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
}
writer.write("""
operation_plugins: list[Plugin] = [
$C
]
if plugins:
operation_plugins.extend(plugins)
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));
writeSharedOperationInit(writer, operation, input);

if (context.protocolGenerator() == null) {
writer.write("raise NotImplementedError()");
Expand All @@ -704,16 +717,55 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
});
}

private void writeSharedOperationInit(PythonWriter writer, OperationShape operation, Shape input) {
writer.writeDocs(() -> {
var docs = operation.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));

var inputDocs = input.getTrait(DocumentationTrait.class)
.map(StringTrait::getValue)
.orElse("The operation's input.");

writer.write("""
$L
:param input: $L
:param plugins: A list of callables that modify the configuration dynamically.
Changes made by these plugins only apply for the duration of the operation
execution and will not affect any other operation invocations.""", docs, inputDocs);
});

var defaultPlugins = new LinkedHashSet<SymbolReference>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
}
writer.write("""
operation_plugins: list[Plugin] = [
$C
]
if plugins:
operation_plugins.extend(plugins)
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));

}

private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
writer.pushState();
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
writer.addImports("smithy_event_stream.aio.interfaces", Set.of(
"EventStream", "InputEventStream", "OutputEventStream"));
var operationSymbol = context.symbolProvider().toSymbol(operation);
writer.putContext("operationName", operationSymbol.getName());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
writer.putContext("plugin", pluginSymbol);

var input = context.model().expectShape(operation.getInputShape());
var inputSymbol = context.symbolProvider().toSymbol(input);
writer.putContext("input", inputSymbol);

var eventStreamIndex = EventStreamIndex.of(context.model());
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
Expand All @@ -724,22 +776,107 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op

var output = context.model().expectShape(operation.getOutputShape());
var outputSymbol = context.symbolProvider().toSymbol(output);
writer.putContext("output", outputSymbol);

var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.orElse(null);
writer.putContext("outputStream", outputStreamSymbol);

writer.write("""
async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[
${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\
${^inputStream}None${/inputStream},
${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\
${^outputStream}None${/outputStream},
$T
]:
raise NotImplementedError()
""", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol);
writer.putContext("hasProtocol", context.protocolGenerator() != null);
if (context.protocolGenerator() != null) {
var serSymbol = context.protocolGenerator().getSerializationFunction(context, operation);
writer.putContext("serSymbol", serSymbol);
var deserSymbol = context.protocolGenerator().getDeserializationFunction(context, operation);
writer.putContext("deserSymbol", deserSymbol);
} else {
writer.putContext("serSymbol", null);
writer.putContext("deserSymbol", null);
}

if (inputStreamSymbol != null) {
if (outputStreamSymbol != null) {
writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> DuplexEventStream[${inputStream:T}, ${outputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
writer.consumer(w -> writeSharedOperationInit(w, operation, input)),
outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER));
} else {
writer.addImport("smithy_event_stream.aio.interfaces", "InputEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> InputEventStream[${inputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
) # type: ignore
${/hasProtocol}
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
}
} else {
writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream");
writer.write("""
async def ${operationName:L}(
self,
input: ${input:T},
plugins: list[${plugin:T}] | None = None
) -> OutputEventStream[${outputStream:T}, ${output:T}]:
${C|}
${^hasProtocol}
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
writer.consumer(w -> writeSharedOperationInit(w, operation, input)),
outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER));
}

writer.popState();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,26 @@ public final class SmithyPythonDependency {
false
);

/**
* Core interfaces for event streams.
*/
public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency(
"smithy_event_stream",
"==0.0.1",
Type.DEPENDENCY,
false
);

/**
* EventStream implementations for application/vnd.amazon.eventstream.
*/
public static final PythonDependency AWS_EVENT_STREAM = new PythonDependency(
"aws_event_stream",
"==0.0.1",
Type.DEPENDENCY,
false
);

/**
* testing framework used in generated functional tests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.smithy.model.shapes.ToShapeId;
import software.amazon.smithy.python.codegen.ApplicationProtocol;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.PythonWriter;
import software.amazon.smithy.utils.CaseUtils;
import software.amazon.smithy.utils.SmithyUnstableApi;

Expand Down Expand Up @@ -167,4 +168,25 @@ default void generateSharedDeserializerComponents(GenerationContext context) {
*/
default void generateProtocolTests(GenerationContext context) {
}

/**
* Generates the code to wrap an operation output into an event stream.
*
* <p>Important context variables are:
* <ul>
* <li>execution_context - Has the context, including the transport input and output.</li>
* <li>operation_output - The deserialized operation output.</li>
* <li>has_input_stream - Whether or not there is an input stream.</li>
* <li>event_deserializer - The deserialize method for output events, or None for no output stream.</li>
* <li>event_response_deserializer - A DeserializeableShape representing the operation's output shape,
* or None for no output stream. This is used when the operation sends the initial response over the
* event stream.
* </li>
* </ul>
*
* @param context Generation context.
* @param writer The writer to write to.
*/
default void wrapEventStream(GenerationContext context, PythonWriter writer) {
}
}
Loading

0 comments on commit a2c2e48

Please sign in to comment.