Skip to content

Commit

Permalink
Add ability to remediate other XSS code shapes (#481)
Browse files Browse the repository at this point in the history
Took logic specific to Semgrep and generalized.
  • Loading branch information
nahsra authored Dec 6, 2024
1 parent 42914df commit 4eecd14
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import com.contrastsecurity.sarif.Result;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import io.codemodder.Codemod;
import io.codemodder.CodemodExecutionPriority;
import io.codemodder.CodemodFileScanningResult;
Expand All @@ -14,11 +12,9 @@
import io.codemodder.SarifFindingKeyUtil;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.providers.sarif.semgrep.ProvidedSemgrepScan;
import io.codemodder.remediation.FixCandidateSearcher;
import io.codemodder.remediation.GenericRemediationMetadata;
import io.codemodder.remediation.Remediator;
import io.codemodder.remediation.SearcherStrategyRemediator;
import io.codemodder.remediation.javadeserialization.JavaDeserializationFixStrategy;
import io.codemodder.remediation.javadeserialization.JavaDeserializationRemediator;
import java.util.Optional;
import javax.inject.Inject;

Expand All @@ -41,32 +37,7 @@ public SemgrepJavaDeserializationCodemod(
ruleId = "java.lang.security.audit.object-deserialization.object-deserialization")
final RuleSarif sarif) {
super(GenericRemediationMetadata.DESERIALIZATION.reporter(), sarif);
this.remediator =
new SearcherStrategyRemediator.Builder<Result>()
.withSearcherStrategyPair(
// matches declarations
new FixCandidateSearcher.Builder<Result>()
.withMatcher(
n ->
Optional.empty()
.or(
() ->
Optional.of(n)
.map(
m ->
m instanceof VariableDeclarationExpr vde
? vde
: null)
.filter(JavaDeserializationFixStrategy::match))
.or(
() ->
Optional.of(n)
.map(m -> m instanceof MethodCallExpr mce ? mce : null)
.filter(JavaDeserializationFixStrategy::match))
.isPresent())
.build(),
new JavaDeserializationFixStrategy())
.build();
this.remediator = new JavaDeserializationRemediator<>();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.codemodder.remediation.javadeserialization;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.remediation.*;
Expand All @@ -21,8 +23,26 @@ public JavaDeserializationRemediator() {
this.searchStrategyRemediator =
new SearcherStrategyRemediator.Builder<T>()
.withSearcherStrategyPair(
// matches declarations
new FixCandidateSearcher.Builder<T>()
.withMatcher(JavaDeserializationFixStrategy::match)
.withMatcher(
n ->
Optional.empty()
.or(
() ->
Optional.of(n)
.map(
m ->
m instanceof VariableDeclarationExpr vde
? vde
: null)
.filter(JavaDeserializationFixStrategy::match))
.or(
() ->
Optional.of(n)
.map(m -> m instanceof MethodCallExpr mce ? mce : null)
.filter(JavaDeserializationFixStrategy::match))
.isPresent())
.build(),
new JavaDeserializationFixStrategy())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import java.util.Optional;

/**
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity.
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity
* constructor.
*/
final class ResponseEntityFixStrategy implements RemediationStrategy {
final class ResponseEntityConstructorFixStrategy implements RemediationStrategy {

@Override
public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.codemodder.remediation.xss;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.resolution.types.ResolvedType;
import io.codemodder.remediation.RemediationStrategy;
import io.codemodder.remediation.SuccessOrReason;
import java.util.Optional;

/**
* Fix strategy for XSS vulnerabilities where a variable/expr is sent to a Spring ResponseEntity
* write method like ok().
*/
final class ResponseEntityWriteFixStrategy implements RemediationStrategy {

@Override
public SuccessOrReason fix(final CompilationUnit cu, final Node node) {
var maybeCall =
Optional.of(node).map(n -> n instanceof MethodCallExpr ? (MethodCallExpr) n : null);
if (maybeCall.isEmpty()) {
return SuccessOrReason.reason("Not a method call.");
}

MethodCallExpr call = maybeCall.get();
return EncoderWrapping.fix(call, 0);
}

static boolean match(final Node node) {
return Optional.of(node)
.map(n -> n instanceof MethodCallExpr ? (MethodCallExpr) n : null)
.filter(m -> "ok".equals(m.getNameAsString()))
.filter(m -> !m.getArguments().isEmpty())
.filter(
c -> {
Expression firstArg = c.getArguments().getFirst().get();
try {
ResolvedType resolvedType = firstArg.calculateResolvedType();
return "java.lang.String".equals(resolvedType.describe());
} catch (Exception e) {
// this is expected often, and indicates its a non-String type anyway
return false;
}
})
.isPresent();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ public XSSRemediator() {
new PrintingMethodFixStrategy())
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<T>()
.withMatcher(ResponseEntityFixStrategy::match)
.withMatcher(ResponseEntityConstructorFixStrategy::match)
.build(),
new ResponseEntityFixStrategy())
new ResponseEntityConstructorFixStrategy())
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<T>()
.withMatcher(ResponseEntityWriteFixStrategy::match)
.build(),
new ResponseEntityWriteFixStrategy())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

final class ResponseEntityFixStrategyTest {
final class ResponseEntityConstructorFixStrategyTest {

private ResponseEntityFixStrategy fixer;
private ResponseEntityConstructorFixStrategy fixer;
private DetectorRule rule;
private JavaParser parser;

@BeforeEach
void setup() throws IOException {
this.fixer = new ResponseEntityFixStrategy();
this.fixer = new ResponseEntityConstructorFixStrategy();
this.parser = JavaParserFactory.newFactory().create(List.of());
this.rule = new DetectorRule("xss", "XSS", null);
}
Expand Down Expand Up @@ -86,7 +86,7 @@ private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int
new SearcherStrategyRemediator.Builder<XSSFinding>()
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<XSSFinding>()
.withMatcher(ResponseEntityFixStrategy::match)
.withMatcher(ResponseEntityConstructorFixStrategy::match)
.build(),
fixer)
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package io.codemodder.remediation.xss;

import static org.assertj.core.api.Assertions.assertThat;

import com.github.javaparser.JavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.codetf.DetectorRule;
import io.codemodder.javaparser.JavaParserFactory;
import io.codemodder.remediation.FixCandidateSearcher;
import io.codemodder.remediation.SearcherStrategyRemediator;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

final class ResponseEntityWriteFixStrategyTest {

private ResponseEntityWriteFixStrategy fixer;
private DetectorRule rule;
private JavaParser parser;

@BeforeEach
void setup() throws IOException {
this.fixer = new ResponseEntityWriteFixStrategy();
this.parser = JavaParserFactory.newFactory().create(List.of());
this.rule = new DetectorRule("xss", "XSS", null);
}

private static Stream<Arguments> fixableSamples() {
return Stream.of(
Arguments.of(
"""
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.ok("Value: " + s);
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.ok("Value: " + Encode.forHtml(s));
}
}
"""),
Arguments.of(
"""
class Samples {
String should_be_fixed(Object s) {
return ResponseEntity.ok("Value: " + s.toString());
}
}
""",
"""
import org.owasp.encoder.Encode;
class Samples {
String should_be_fixed(Object s) {
return ResponseEntity.ok("Value: " + Encode.forHtml(s.toString()));
}
}
"""));
}

@ParameterizedTest
@MethodSource("fixableSamples")
void it_fixes_obvious_response_write_methods(final String beforeCode, final String afterCode) {
CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow();
LexicalPreservingPrinter.setup(cu);

var result = scanAndFix(cu, 3);
assertThat(result.changes()).isNotEmpty();
String actualCode = LexicalPreservingPrinter.print(cu);
assertThat(actualCode).isEqualToIgnoringWhitespace(afterCode);
}

private CodemodFileScanningResult scanAndFix(final CompilationUnit cu, final int line) {
XSSFinding finding = new XSSFinding("should_be_fixed", line, null);
var remediator =
new SearcherStrategyRemediator.Builder<XSSFinding>()
.withSearcherStrategyPair(
new FixCandidateSearcher.Builder<XSSFinding>()
.withMatcher(ResponseEntityWriteFixStrategy::match)
.build(),
fixer)
.build();
return remediator.remediateAll(
cu,
"path",
rule,
List.of(finding),
XSSFinding::key,
XSSFinding::line,
x -> Optional.empty(),
x -> Optional.ofNullable(x.column()));
}

@ParameterizedTest
@MethodSource("unfixableSamples")
void it_does_not_fix_unfixable_samples(final String beforeCode, final int line) {
CompilationUnit cu = parser.parse(beforeCode).getResult().orElseThrow();
LexicalPreservingPrinter.setup(cu);
var result = scanAndFix(cu, line);
assertThat(result.changes()).isEmpty();
}

private static Stream<Arguments> unfixableSamples() {
return Stream.of(
// this is not a ResponseEntity, shouldn't touch it
Arguments.of(
// this is not a ResponseEntity, shouldn't touch it
"""
class Samples {
String should_be_fixed(String s) {
return ResponseEntity.something_besides_ok("Value: " + s);
}
}
""",
3));
}
}

0 comments on commit 4eecd14

Please sign in to comment.