| /* |
| * Copyright 2000-2012 JetBrains s.r.o. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package com.intellij.compiler.notNullVerification; |
| |
| import com.sun.istack.internal.NotNull; |
| import com.sun.istack.internal.Nullable; |
| import org.jetbrains.org.objectweb.asm.*; |
| |
| import java.util.LinkedHashMap; |
| import java.util.Map; |
| |
| /** |
| * @author ven |
| */ |
| public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcodes { |
| private static final String NOT_NULL_CLASS_NAME = "org/jetbrains/annotations/NotNull"; |
| private static final String NOT_NULL_TYPE = "L"+ NOT_NULL_CLASS_NAME + ";"; |
| private static final String SYNTHETIC_CLASS_NAME = "java/lang/Synthetic"; |
| private static final String SYNTHETIC_TYPE = "L" + SYNTHETIC_CLASS_NAME + ";"; |
| private static final String IAE_CLASS_NAME = "java/lang/IllegalArgumentException"; |
| private static final String ISE_CLASS_NAME = "java/lang/IllegalStateException"; |
| private static final String STRING_CLASS_NAME = "java/lang/String"; |
| private static final String OBJECT_CLASS_NAME = "java/lang/Object"; |
| private static final String CONSTRUCTOR_NAME = "<init>"; |
| private static final String EXCEPTION_INIT_SIGNATURE = "(L" + STRING_CLASS_NAME + ";)V"; |
| |
| private static final String ANNOTATION_DEFAULT_METHOD = "value"; |
| |
| private static final String NULL_ARG_MESSAGE_INDEXED = "Argument %s for @NotNull parameter of %s.%s must not be null"; |
| private static final String NULL_ARG_MESSAGE_NAMED = "Argument for @NotNull parameter '%s' of %s.%s must not be null"; |
| private static final String NULL_RESULT_MESSAGE = "@NotNull method %s.%s must not return null"; |
| @SuppressWarnings("SSBasedInspection") private static final String[] EMPTY_STRING_ARRAY = new String[0]; |
| private final Map<String, Map<Integer, String>> myMethodParamNames; |
| |
| private String myClassName; |
| private boolean myIsModification = false; |
| private RuntimeException myPostponedError; |
| |
| private NotNullVerifyingInstrumenter(final ClassVisitor classVisitor, ClassReader reader) { |
| super(Opcodes.ASM5, classVisitor); |
| myMethodParamNames = getAllParameterNames(reader); |
| } |
| |
| public static boolean processClassFile(final ClassReader reader, final ClassVisitor writer) { |
| final NotNullVerifyingInstrumenter instrumenter = new NotNullVerifyingInstrumenter(writer, reader); |
| reader.accept(instrumenter, 0); |
| return instrumenter.isModification(); |
| } |
| |
| private static Map<String, Map<Integer, String>> getAllParameterNames(ClassReader reader) { |
| final Map<String, Map<Integer, String>> methodParamNames = new LinkedHashMap<String, Map<Integer, String>>(); |
| |
| reader.accept(new ClassVisitor(Opcodes.ASM5) { |
| private String myClassName = null; |
| |
| public void visit(final int version, final int access, final String name, final String signature, final String superName, final String[] interfaces) { |
| myClassName = name; |
| } |
| |
| public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature, final String[] exceptions) { |
| final String methodName = myClassName + '.' + name + desc; |
| final Map<Integer, String> names = new LinkedHashMap<Integer, String>(); |
| final Type[] args = Type.getArgumentTypes(desc); |
| methodParamNames.put(methodName, names); |
| |
| return new MethodVisitor(api) { |
| @Override |
| public void visitLocalVariable(String name2, String desc, String signature, Label start, Label end, int index) { |
| int parameterIndex = getParameterIndex(index, access, args); |
| if (parameterIndex >= 0) { |
| names.put(parameterIndex, name2); |
| } |
| } |
| }; |
| } |
| }, 0); |
| return methodParamNames; |
| } |
| |
| public boolean isModification() { |
| return myIsModification; |
| } |
| |
| @Override |
| public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { |
| super.visit(version, access, name, signature, superName, interfaces); |
| myClassName = name; |
| } |
| |
| private static class NotNullState { |
| @Nullable String message; |
| @NotNull String exceptionType; |
| |
| NotNullState(String exceptionType) { |
| this.exceptionType = exceptionType; |
| } |
| } |
| |
| @Override |
| public MethodVisitor visitMethod(final int access, final String name, String desc, String signature, String[] exceptions) { |
| final Type[] args = Type.getArgumentTypes(desc); |
| final Type returnType = Type.getReturnType(desc); |
| final MethodVisitor v = cv.visitMethod(access, name, desc, signature, exceptions); |
| final Map<Integer, String> paramNames = myMethodParamNames.get(myClassName + '.' + name + desc); |
| return new MethodVisitor(Opcodes.ASM5, v) { |
| private final Map<Integer, NotNullState> myNotNullParams = new LinkedHashMap<Integer, NotNullState>(); |
| private int mySyntheticCount = 0; |
| private NotNullState myMethodNotNull; |
| private Label myStartGeneratedCodeLabel; |
| |
| private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) { |
| return new AnnotationVisitor(Opcodes.ASM5, base) { |
| @Override |
| public void visit(String methodName, Object o) { |
| if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) { |
| state.message = (String) o; |
| } |
| else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) { |
| state.exceptionType = ((Type)o).getInternalName(); |
| } |
| super.visit(methodName, o); |
| } |
| }; |
| } |
| |
| public AnnotationVisitor visitParameterAnnotation(final int parameter, final String anno, final boolean visible) { |
| AnnotationVisitor av = mv.visitParameterAnnotation(parameter, anno, visible); |
| if (isReferenceType(args[parameter]) && anno.equals(NOT_NULL_TYPE)) { |
| NotNullState state = new NotNullState(IAE_CLASS_NAME); |
| myNotNullParams.put(new Integer(parameter), state); |
| av = collectNotNullArgs(av, state); |
| } |
| else if (anno.equals(SYNTHETIC_TYPE)) { |
| // see http://forge.ow2.org/tracker/?aid=307392&group_id=23&atid=100023&func=detail |
| mySyntheticCount++; |
| } |
| |
| return av; |
| } |
| |
| @Override |
| public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) { |
| AnnotationVisitor av = mv.visitAnnotation(anno, isRuntime); |
| if (isReferenceType(returnType) && anno.equals(NOT_NULL_TYPE)) { |
| myMethodNotNull = new NotNullState(ISE_CLASS_NAME); |
| av = collectNotNullArgs(av, myMethodNotNull); |
| } |
| |
| return av; |
| } |
| |
| @Override |
| public void visitCode() { |
| if (myNotNullParams.size() > 0) { |
| myStartGeneratedCodeLabel = new Label(); |
| mv.visitLabel(myStartGeneratedCodeLabel); |
| } |
| for (Map.Entry<Integer, NotNullState> entry : myNotNullParams.entrySet()) { |
| Integer param = entry.getKey(); |
| int var = ((access & ACC_STATIC) == 0) ? 1 : 0; |
| for (int i = 0; i < param; ++i) { |
| var += args[i].getSize(); |
| } |
| mv.visitVarInsn(ALOAD, var); |
| |
| Label end = new Label(); |
| mv.visitJumpInsn(IFNONNULL, end); |
| |
| NotNullState state = entry.getValue(); |
| String paramName = paramNames == null ? null : paramNames.get(param); |
| String descrPattern = state.message != null |
| ? state.message |
| : paramName != null ? NULL_ARG_MESSAGE_NAMED : NULL_ARG_MESSAGE_INDEXED; |
| String[] args = state.message != null |
| ? EMPTY_STRING_ARRAY |
| : new String[]{paramName != null ? paramName : String.valueOf(param - mySyntheticCount), myClassName, name}; |
| generateThrow(state.exceptionType, end, descrPattern, args); |
| } |
| } |
| |
| @Override |
| public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) { |
| final boolean isStatic = (access & ACC_STATIC) != 0; |
| final boolean isParameterOrThisRef = isStatic ? index < args.length : index <= args.length; |
| final Label label = (isParameterOrThisRef && myStartGeneratedCodeLabel != null) ? myStartGeneratedCodeLabel : start; |
| mv.visitLocalVariable(name, desc, signature, label, end, index); |
| } |
| |
| @Override |
| public void visitInsn(int opcode) { |
| if (opcode == ARETURN) { |
| if (myMethodNotNull != null) { |
| mv.visitInsn(DUP); |
| final Label skipLabel = new Label(); |
| mv.visitJumpInsn(IFNONNULL, skipLabel); |
| String descrPattern = myMethodNotNull.message != null ? myMethodNotNull.message : NULL_RESULT_MESSAGE; |
| String[] args = myMethodNotNull.message != null ? EMPTY_STRING_ARRAY : new String[]{myClassName, name}; |
| generateThrow(myMethodNotNull.exceptionType, skipLabel, descrPattern, args); |
| } |
| } |
| |
| mv.visitInsn(opcode); |
| } |
| |
| private void generateThrow(final String exceptionClass, final Label end, final String descrPattern, final String[] args) { |
| mv.visitTypeInsn(NEW, exceptionClass); |
| mv.visitInsn(DUP); |
| |
| mv.visitLdcInsn(descrPattern); |
| |
| mv.visitLdcInsn(args.length); |
| mv.visitTypeInsn(ANEWARRAY, OBJECT_CLASS_NAME); |
| |
| for (int i = 0; i < args.length; i++) { |
| mv.visitInsn(DUP); |
| mv.visitLdcInsn(i); |
| mv.visitLdcInsn(args[i]); |
| mv.visitInsn(AASTORE); |
| } |
| |
| //noinspection SpellCheckingInspection |
| mv.visitMethodInsn(INVOKESTATIC, STRING_CLASS_NAME, "format", "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", false); |
| |
| mv.visitMethodInsn(INVOKESPECIAL, exceptionClass, CONSTRUCTOR_NAME, EXCEPTION_INIT_SIGNATURE, false); |
| mv.visitInsn(ATHROW); |
| mv.visitLabel(end); |
| |
| myIsModification = true; |
| processPostponedErrors(); |
| } |
| |
| @Override |
| public void visitMaxs(final int maxStack, final int maxLocals) { |
| try { |
| super.visitMaxs(maxStack, maxLocals); |
| } |
| catch (Throwable e) { |
| //noinspection SpellCheckingInspection |
| registerError(name, "visitMaxs", e); |
| } |
| } |
| }; |
| } |
| |
| private static int getParameterIndex(int localVarIndex, int methodAccess, Type[] paramTypes) { |
| final boolean isStatic = (methodAccess & ACC_STATIC) != 0; |
| int parameterIndex = isStatic ? localVarIndex : localVarIndex - 1; |
| if (parameterIndex >= paramTypes.length) { |
| parameterIndex = -1; |
| } |
| return parameterIndex; |
| } |
| |
| private static boolean isReferenceType(final Type type) { |
| return type.getSort() == Type.OBJECT || type.getSort() == Type.ARRAY; |
| } |
| |
| private void registerError(String methodName, String operationName, Throwable e) { |
| if (myPostponedError == null) { |
| // throw the first error that occurred |
| Throwable err = e.getCause(); |
| if (err == null) { |
| err = e; |
| } |
| myPostponedError = new RuntimeException("Operation '" + operationName + "' failed for " + myClassName + "." + methodName + "(): " + err.getMessage(), err); |
| } |
| if (myIsModification) { |
| processPostponedErrors(); |
| } |
| } |
| |
| private void processPostponedErrors() { |
| final RuntimeException error = myPostponedError; |
| if (error != null) { |
| throw error; |
| } |
| } |
| } |
| |